# This file is part of the faebryk project
# SPDX-License-Identifier: MIT

"""
"""

import logging
import re

import black

from faebryk.libs.pycodegen import sanitize_name
from faebryk.libs.util import NotNone

logger = logging.getLogger("export_faebryk")

template = """
# This file was generated by faebryk netlist exporter
{header}

\"""
{docstring}
\"""
from pathlib import Path
import logging

logger = logging.getLogger("main")


# function imports
from faebryk.exporters.netlist.kicad.netlist_kicad import from_faebryk_t2_netlist
from faebryk.exporters.netlist.netlist import make_t2_netlist_from_t1
from faebryk.exporters.netlist.graph import (
    make_graph_from_components,
    make_t1_netlist_from_graph,
)
{function_imports}

# library imports
from faebryk.library.core import Component
from faebryk.library.library.parameters import Constant
from faebryk.library.library.interfaces import Electrical
from faebryk.library.trait_impl.component import (
    has_defined_footprint,
    has_defined_footprint_pinmap,
    has_defined_type_description,
)
from faebryk.library.kicad import (
    has_defined_kicad_ref,
    KicadFootprint,
)
{library_imports}

from faebryk.libs.experiments.buildutil import export_graph, export_netlist

def run_experiment():

    {definitions}

    {connections}

    comps = [
        {comps}
    ]

    t1_ = make_t1_netlist_from_graph(make_graph_from_components(comps))

    netlist = from_faebryk_t2_netlist(make_t2_netlist_from_t1(t1_))
    assert netlist is not None

    export_netlist(netlist)
    export_graph(t1_, show=True)

# Boilerplate -----------------------------------------------------------------
import sys


def main(argc, argv, argi):
    logging.basicConfig(level=logging.INFO)

    logger.info("Running experiment")
    run_experiment()


if __name__ == "__main__":
    import os
    import sys

    root = os.path.join(os.path.dirname(__file__), "..")
    sys.path.append(root)
    main(len(sys.argv), sys.argv, iter(sys.argv))

"""

comp_template = """
    class _{name}CLS(Component):
        def __init__(self):
            super().__init__()

            class _IFs(Component.InterfacesCls()):
                unnamed = {unnamed_ifs}
                {named_if_expr}

            self.IFs = _IFs(self)

            {trait_expr}


    {ifs}
    {name} = _{name}CLS()

"""


def dict_to_str(obj):
    return "{" + "".join(f"'{k}':{v}," for k, v in obj.items()) + "}"


def str_to_str(obj):
    return f'"{obj}"'


def from_t1_netlist(t1_netlist):
    # t1_netlist = [comps{
    #   name,
    #   real,
    #   properties,
    #   neighbors,
    #   value,
    components = t1_netlist
    project = template

    comp_names = {}
    if_names = {}

    def comp_to_faebryk(component):
        def get_comp_name(component):
            if component["name"].startswith("COMP["):
                class_name = NotNone(
                    re.search(r"\[(.*):.*\]", component["name"])
                ).group(1)
            else:
                class_name = component["name"]

            class_name = sanitize_name(class_name)
            assert type(class_name) is str
            if re.match(pattern="^[a-zA-Z_]+[a-zA-Z_0-9]*$", string=class_name) is None:
                assert False, class_name

            ctr = comp_names.get(class_name, 0)
            name = "{type}_{ctr}".format(type=class_name, ctr=ctr)
            comp_names[class_name] = ctr + 1

            return name

        name = get_comp_name(component)

        pinmap = {
            pin: (f"{{name}}.IFs.P{pin}", f"P{pin}")
            for pin in component["neighbors"].keys()
        }

        named_if_expr = ("\n" + "    " * 4).join(
            [f"{pin_name} = Electrical()" for _, pin_name in pinmap.values()]
        )

        # Traits --------------------------------------------------
        trait_template = "".join(
            [
                "self.add_trait({trait_name}(",
                "    {trait_args}",
                "))",
            ]
        )

        traits = []

        def add_trait(name, args):
            traits.append(trait_template.format(trait_name=name, trait_args=args))

        add_trait(
            "has_defined_footprint_pinmap",
            dict_to_str({k: v[0].format(name="self") for k, v in pinmap.items()}),
        )

        if component["real"]:
            add_trait("has_defined_type_description", str_to_str(component["value"]))

            add_trait(
                "has_defined_footprint",
                "KicadFootprint({})".format(
                    str_to_str(component["properties"]["footprint"])
                ),
            )

        trait_expr = ("\n" + "    " * 3).join(traits)

        # ---------------------------------------------------------

        comp = comp_template.format(
            ifs="",
            unnamed_ifs="[]",
            name=name,
            named_if_expr=named_if_expr,
            trait_expr=trait_expr,
        )

        return name, (comp, pinmap, component)

    named_comps = dict(map(comp_to_faebryk, components))
    name_map = {c["name"]: cname for cname, (_, __, c) in named_comps.items()}

    connections = {}
    for cname, (ccode, pinmap, comp) in sorted(
        named_comps.items(), key=lambda x: not x[1][2]["real"]
    ):
        for pin, neighbors in comp["neighbors"].items():
            for neighbor in neighbors:
                neighborcname = name_map[neighbor["vertex"]["name"]]
                npin = neighbor["pin"]
                if ((neighborcname, npin), (cname, pin)) in connections:
                    # don't add connection if symmetric connection already exists
                    continue
                connections[
                    ((cname, pin), (neighborcname, npin))
                ] = f"{pinmap[pin][0].format(name=cname)}.connect({named_comps[neighborcname][1][npin][0].format(name=neighborcname)})"

    project = project.format(
        header="",
        docstring="",
        function_imports="",
        library_imports="",
        definitions="\n        ".join([c[0] for c in named_comps.values()]),
        connections="\n    ".join(connections.values()),
        comps=",\n        ".join(named_comps.keys()),
    )

    project = black.format_file_contents(project, fast=False, mode=black.FileMode())
    # project = black.format_str(project, mode=black.Mode())

    return project
