# Copyright (C) 2021 Intel Corporation
# SPDX-License-Identifier: GPL-3.0-or-later

import csv
import json
import os
import time
from datetime import datetime
from logging import Logger
from typing import IO, Dict, List, Union

from ..cve_scanner import CVEData
from ..cvedb import CVEDB
from ..error_handler import ErrorHandler, ErrorMode
from ..log import LOGGER
from ..util import ProductInfo, Remarks
from ..version import VERSION
from .console import output_console
from .html import output_html
from .util import (
    add_extension_if_not,
    format_output,
    generate_filename,
    get_cve_summary,
    intermediate_output,
)


def output_json(all_cve_data: Dict[ProductInfo, CVEData], outfile: IO):
    """Output a JSON of CVEs"""
    formatted_output = format_output(all_cve_data)
    json.dump(formatted_output, outfile, indent="    ")


def save_intermediate(
    all_cve_data: Dict[ProductInfo, CVEData],
    filename: str,
    tag: str,
    scanned_dir: str,
    products_with_cve: int,
    products_without_cve: int,
    total_files: int,
):
    """Save the intermediate report"""

    inter_output = intermediate_output(
        all_cve_data,
        tag,
        scanned_dir,
        products_with_cve,
        products_without_cve,
        total_files,
    )
    with open(filename, "w") as f:
        json.dump(inter_output, f, indent="    ")


def output_csv(all_cve_data: Dict[ProductInfo, CVEData], outfile):
    """Output a CSV of CVEs"""

    formatted_output = format_output(all_cve_data)

    # Trim any leading -, =, + or @ to avoid excel macros
    for cve_entry in formatted_output:
        for key, value in cve_entry.items():
            cve_entry[key] = value.strip("-=+@")

    writer = csv.DictWriter(
        outfile,
        fieldnames=[
            "vendor",
            "product",
            "version",
            "cve_number",
            "severity",
            "score",
            "cvss_version",
            "cvss_vector",
            "paths",
            "remarks",
            "comments",
        ],
    )
    writer.writeheader()
    writer.writerows(formatted_output)


# load pdfs only if reportlab is found.  if not, make a stub that prints a
# logger message
try:
    from . import pdfbuilder

    def output_pdf(
        all_cve_data: Dict[ProductInfo, CVEData],
        is_report,
        products_with_cve,
        outfile,
        merge_report,
    ):
        """Output a PDF of CVEs"""
        cvedb_data = CVEDB()
        db_date = time.strftime(
            "%d %B %Y at %H:%M:%S", time.localtime(cvedb_data.get_db_update_date())
        )
        app_version = VERSION
        # Build document
        pdfdoc = pdfbuilder.PDFBuilder()
        cm = pdfdoc.cm
        severity_colour = {
            "UNKNOWN": pdfdoc.grey,
            "LOW": pdfdoc.blue,
            "MEDIUM": pdfdoc.green,
            "HIGH": pdfdoc.orange,
            "CRITICAL": pdfdoc.red,
        }
        pdfdoc.front_page("Vulnerability Report")
        pdfdoc.heading(1, "Introduction")
        pdfdoc.paragraph(
            "The identification of vulnerabilities has been performed using cve-bin-tool version "
            + app_version
        )
        if merge_report:
            pdfdoc.paragraph(
                "The report has been generated by merging multiple intermediate reports."
            )
        else:
            pdfdoc.paragraph(
                "The data used has been obtained from the NVD database which was retrieved on "
                + db_date
                + " and contained "
                + str(cvedb_data.get_cve_count())
                + " entries."
            )

        if is_report:
            pdfdoc.heading(1, "List of All Scanned binaries")
            pdfdoc.createtable(
                "Productlist",
                ["Vendor", "Product", "Version"],
                pdfdoc.tblStyle,
            )
            row = 1
            for product_info, cve_data in all_cve_data.items():
                star_warn = True if "*" in product_info.vendor else False
                for cve in cve_data["cves"]:
                    entry = [
                        product_info.vendor,
                        product_info.product,
                        product_info.version,
                    ]
                    pdfdoc.addrow(
                        "Productlist",
                        entry,
                    )
                    row += 1
            pdfdoc.showtable("Productlist", widths=[3 * cm, 2 * cm, 2 * cm])
            pdfdoc.paragraph("* vendors guessed by the tool") if star_warn else None
            pdfdoc.paragraph(
                f"There are {products_with_cve} products with vulnerabilities found."
            )

            pdfdoc.pagebreak()

        if merge_report:
            pdfdoc.heading(1, "Intermediate Reports")
            pdfdoc.paragraph(
                "The following table contains severity levels count of individual intermediate report sorted on the basis of timestamp."
            )
            pdfdoc.createtable(
                "SeverityLevels",
                [
                    "Timestamp",
                    "Tag",
                    "Total\nFiles",
                    "Products\nwith CVE",
                    "Products\nwithout CVE",
                    "UNKNOWN",
                    "LOW",
                    "MEDIUM",
                    "HIGH",
                    "CRITICAL",
                ],
                pdfdoc.intermediateStyle,
            )

            for inter_file in merge_report.intermediate_cve_data:

                entry = [
                    datetime.strptime(
                        inter_file["metadata"]["timestamp"], "%Y-%m-%d.%H-%M-%S"
                    ).strftime("%Y-%m-%d %H:%M"),
                    inter_file["metadata"]["tag"],
                    inter_file["metadata"]["total_files"],
                    inter_file["metadata"]["products_with_cve"],
                    inter_file["metadata"]["products_without_cve"],
                    inter_file["metadata"]["severity"]["UNKNOWN"],
                    inter_file["metadata"]["severity"]["LOW"],
                    inter_file["metadata"]["severity"]["MEDIUM"],
                    inter_file["metadata"]["severity"]["HIGH"],
                    inter_file["metadata"]["severity"]["CRITICAL"],
                ]
                pdfdoc.addrow(
                    "SeverityLevels",
                    entry,
                )
            pdfdoc.showtable(
                "SeverityLevels",
                widths=[
                    2.5 * cm,
                    3 * cm,
                    1.5 * cm,
                    2.5 * cm,
                    2.5 * cm,
                    None,
                    None,
                    None,
                    None,
                    None,
                ],
            )
            pdfdoc.pagebreak()

        if products_with_cve != 0:
            pdfdoc.heading(1, "Summary of Identified Vulnerabilities")

            pdfdoc.paragraph("A summary of the vulnerabilities found.")
            pdfdoc.createtable("CVESummary", ["Severity", "Count"], pdfdoc.tblStyle)
            summary = get_cve_summary(all_cve_data)
            row = 1
            for severity, count in summary.items():
                pdfdoc.addrow(
                    "CVESummary",
                    [severity, count],
                    [
                        (
                            "TEXTCOLOR",
                            (0, row),
                            (1, row),
                            severity_colour[severity.upper()],
                        ),
                        ("FONT", (0, row), (1, row), "Helvetica-Bold"),
                    ],
                )
                row += 1
            pdfdoc.showtable("CVESummary", widths=[3 * cm, 2 * cm])

            pdfdoc.heading(1, "List of Identified Vulnerabilities")
            pdfdoc.paragraph(
                "The following vulnerabilities are reported against the identified versions of the libraries."
            )
            pdfdoc.createtable(
                "Productlist",
                ["Vendor", "Product", "Version", "CVE Number", "Severity"],
                pdfdoc.tblStyle,
                [10, 10, None, None, None],
            )
            row = 1
            star_warn = False
            for product_info, cve_data in all_cve_data.items():
                for cve in cve_data["cves"]:
                    if cve.cve_number != "UNKNOWN":
                        if "*" in product_info.vendor:
                            star_warn = True
                        entry = [
                            product_info.vendor,
                            product_info.product,
                            product_info.version,
                            cve.cve_number,
                            cve.severity,
                        ]
                        pdfdoc.addrow(
                            "Productlist",
                            entry,
                            [
                                (
                                    "TEXTCOLOR",
                                    (0, row),
                                    (2, row),
                                    pdfdoc.black,
                                ),
                                ("FONT", (0, row), (2, row), "Helvetica"),
                                (
                                    "TEXTCOLOR",
                                    (3, row),
                                    (4, row),
                                    severity_colour[cve.severity.upper()],
                                ),
                                ("FONT", (3, row), (4, row), "Helvetica-Bold"),
                            ],
                        )
                        row += 1

            pdfdoc.showtable(
                "Productlist", widths=[3 * cm, 3 * cm, 2 * cm, 4 * cm, 3 * cm]
            )
        pdfdoc.paragraph("* vendors guessed by the tool") if star_warn else None

        pdfdoc.pagebreak()
        pdfdoc.paragraph("END OF DOCUMENT.")
        pdfdoc.publish(outfile)

except ModuleNotFoundError:

    def output_pdf(
        all_cve_data: Dict[ProductInfo, CVEData],
        is_report,
        products_with_cve,
        outfile,
        merge_report,
    ):
        LOGGER.warn("PDF output requires install of reportlab")


class OutputEngine:
    def __init__(
        self,
        all_cve_data: Dict[ProductInfo, CVEData],
        scanned_dir: str,
        filename: str,
        themes_dir: str,
        time_of_last_update,
        tag: str,
        logger: Logger = None,
        products_with_cve: int = 0,
        products_without_cve: int = 0,
        total_files: int = 0,
        is_report: bool = False,
        append: Union[str, bool] = False,
        merge_report: Union[None, List[str]] = None,
        affected_versions: int = 0,
        all_cve_version_info=None,
        vex_filename: str = "",
    ):
        self.logger = logger or LOGGER.getChild(self.__class__.__name__)
        self.all_cve_version_info = all_cve_version_info
        self.scanned_dir = scanned_dir
        self.filename = os.path.abspath(filename) if filename else ""
        self.products_with_cve = products_with_cve
        self.products_without_cve = products_without_cve
        self.total_files = total_files
        self.themes_dir = themes_dir
        self.is_report = is_report
        self.time_of_last_update = time_of_last_update
        self.append = append
        self.tag = tag
        self.merge_report = merge_report
        self.affected_versions = affected_versions
        self.all_cve_data = all_cve_data
        self.vex_filename = vex_filename

    def output_cves(self, outfile, output_type="console"):
        """Output a list of CVEs
        format self.checkers[checker_name][version] = dict{id: severity}
        to other formats like CSV or JSON
        """
        if output_type == "json":
            output_json(self.all_cve_data, outfile)
        elif output_type == "csv":
            output_csv(self.all_cve_data, outfile)
        elif output_type == "pdf":
            output_pdf(
                self.all_cve_data,
                self.is_report,
                self.products_with_cve,
                outfile,
                self.merge_report,
            )
        elif output_type == "html":
            output_html(
                self.all_cve_data,
                self.scanned_dir,
                self.filename,
                self.themes_dir,
                self.total_files,
                self.products_with_cve,
                self.products_without_cve,
                self.merge_report,
                self.logger,
                outfile,
            )
        else:  # console, or anything else that is unrecognised
            output_console(
                self.all_cve_data,
                self.all_cve_version_info,
                self.time_of_last_update,
                self.affected_versions,
            )

        if isinstance(self.append, str):
            save_intermediate(
                self.all_cve_data,
                self.append,
                self.tag,
                self.scanned_dir,
                self.products_with_cve,
                self.products_without_cve,
                self.total_files,
            )
            self.logger.info(f"Output stored at {self.append}")

        if self.vex_filename != "":
            self.generate_vex(self.all_cve_data, self.vex_filename)

    def generate_vex(self, all_cve_data: Dict[ProductInfo, CVEData], filename: str):
        analysis_state = {
            Remarks.NewFound: "under_review",
            Remarks.Unexplored: "under_review",
            Remarks.Confirmed: "exploitable",
            Remarks.Mitigated: "not_affected",
            Remarks.Ignored: "not_affected",
        }
        response_state = {
            Remarks.NewFound: "Outstanding",
            Remarks.Unexplored: "Not defined",
            Remarks.Confirmed: "Upgrade required",
            Remarks.Mitigated: "Resolved",
            Remarks.Ignored: "No impact",
        }
        # Generate VEX file
        vex_output = {"bomFormat": "CycloneDX", "specVersion": "1.4", "version": 1}
        # Extra info considered useful
        #     "creationInfo": {
        #         "created": datetime.now().strftime("%Y-%m-%dT%H-%M-%SZ"),
        #         "creators": ["Tool: cve_bin_tool", "Version:" + VERSION],
        #     },
        #     "documentDescribes": ["VEX_File"],
        #     "externalDocumentRefs": [{
        #         "sbomDocument": "<FILENAME>"
        #     }],
        # }
        vuln_entry = []
        for product_info, cve_data in all_cve_data.items():
            for cve in cve_data["cves"]:
                # Create vulnerability entry. Contains id, scoring, analysis and affected component
                vulnerability = dict()
                vulnerability["id"] = cve.cve_number
                vulnerability["source"] = {
                    "name": "NVD",
                    "url": "https://nvd.nist.gov/vuln/detail/" + cve.cve_number,
                }
                if cve.cvss_version == 3:
                    url = f"v3-calculator?name={cve.cve_number}&vector={cve.cvss_vector}&version=3.1"
                else:
                    url = f"v2-calculator?name={cve.cve_number}&vector={cve.cvss_vector}&version=2.0"
                ratings = [
                    {
                        "source": {
                            "name": "NVD",
                            "url": "https://nvd.nist.gov/vuln-metrics/cvss/" + url,
                        },
                        "score": str(cve.score),
                        "severity": cve.severity,
                        "method": "CVSSv" + str(cve.cvss_version),
                        "vector": cve.cvss_vector,
                    }
                ]
                vulnerability["ratings"] = ratings
                vulnerability["cwes"] = []
                vulnerability["description"] = cve.description
                vulnerability["recommendation"] = ""
                vulnerability["advisories"] = []
                vulnerability["created"] = "NOT_KNOWN"
                vulnerability["published"] = "NOT_KNOWN"
                vulnerability["updated"] = "NOT_KNOWN"
                analysis = {
                    "state": analysis_state[cve.remarks],
                    "response": response_state[cve.remarks],
                    "justification": "",
                    "detail": cve.comments,
                }
                vulnerability["analysis"] = analysis
                bom_urn = "NOTKNOWN"
                bom_version = 1
                vulnerability["affects"] = [
                    {
                        "ref": f"urn:cdx:{bom_urn}/{bom_version}#{product_info.product}-{product_info.version}",
                    }
                ]
                vuln_entry.append(vulnerability)

        vex_output["vulnerabilities"] = vuln_entry

        # Generate file
        with open(filename, "w") as outfile:
            json.dump(vex_output, outfile, indent="   ")

    def output_file(self, output_type="console"):

        """Generate a file for list of CVE"""

        if self.append:
            if isinstance(self.append, str):
                self.append = self.check_dir_path(
                    self.append, output_type="json", prefix="intermediate"
                )
                self.append = add_extension_if_not(self.append, "json")
                self.append = self.check_file_path(
                    self.append, output_type="json", prefix="intermediate"
                )
            else:
                # file path for intermediate report not given
                self.append = generate_filename("json", "intermediate")

        if output_type == "console":
            # short circuit file opening logic if we are actually
            # just writing to stdout
            self.output_cves(self.filename, output_type)
            return

        # Check if we need to generate a filename
        if not self.filename:
            self.filename = generate_filename(output_type)
        else:
            # check and add if the filename doesn't contain extension
            self.filename = add_extension_if_not(self.filename, output_type)

            self.filename = self.check_file_path(self.filename, output_type)

            # try opening that file
            with ErrorHandler(mode=ErrorMode.Ignore) as e:
                with open(self.filename, "w") as f:
                    f.write("testing")
                os.remove(self.filename)
            if e.exit_code:
                self.logger.info(
                    f"Exception {e.exc_val} occurred while writing to the file {self.filename} "
                    "Switching Back to Default Naming Convention"
                )
                self.filename = generate_filename(output_type)

        # Log the filename generated
        if output_type == "html" or output_type == "pdf":
            self.logger.info(f"{output_type.upper()} report stored at {self.filename}")
        else:
            self.logger.info(f"Output stored at {self.filename}")

        # call to output_cves
        mode = "w"
        if output_type == "pdf":
            mode = "wb"
        with open(self.filename, mode) as f:
            self.output_cves(f, output_type)

    def check_file_path(self, filepath: str, output_type: str, prefix: str = "output"):
        # check if the file already exists
        if os.path.isfile(filepath):
            self.logger.warning(f"Failed to write at '{filepath}'. File already exists")
            self.logger.info("Generating a new filename with Default Naming Convention")
            filepath = generate_filename(output_type, prefix)

        return filepath

    def check_dir_path(
        self, filepath: str, output_type: str, prefix: str = "intermediate"
    ):

        if os.path.isdir(filepath):
            self.logger.info(
                f"Generating a new filename with Default Naming Convention in directory path {filepath}"
            )
            filename = os.path.basename(generate_filename(output_type, prefix))
            filepath = os.path.join(filepath, filename)

        return filepath
