# SPDX-License-Identifier: GPL-2.0-or-later OR AGPL-3.0-or-later OR CERN-OHL-S-2.0+
from textwrap import dedent
from itertools import product
from xml.etree import ElementTree as ET
from typing import Any, Tuple, Dict, Set, Generator, Optional, Union, cast, overload

from ... import _util, dispatch as dsp
from ...typing import GDSLayerSpecDict
from ...technology import (
    property_ as prp, rule as rle, wafer_ as wfr, mask as msk, edge as edg,
    geometry as geo, primitive as prm, net as _net, technology_ as tch,
)
from ...design import layout as lay, library as lbry

pya: Any # Silence pylance
import pya


__all__= ["FileExporter", "export2db"]


class _MaskConverter(dsp.MaskDispatcher):
    def __init__(self, *, tech: tch.Technology) -> None:
        super().__init__()

        self.tech = tech

    def __call__(self, mask: msk._Mask) -> str:
        return super().__call__(mask)

    @staticmethod
    def __legalized_maskname(name: str) -> str:
        if name[0] in "0123456789":
            name = "_" + name
        return name.replace(".", "_").replace(":", "__")

    def DesignMask(self, mask: msk.DesignMask) -> str:
        return self.__legalized_maskname(mask.name)

    def _MaskAlias(self, mask: msk._MaskAlias) -> str:
        return self.__legalized_maskname(mask.name)

    # Handled on higher level
    # def _PartsWith(self, pw: msk._PartsWith):
    #     ...

    def Join(self, join: msk.Join) -> str:
        return f"({'+'.join(self(m) for m in join.masks)})"

    def Intersect(self, intersect: msk.Intersect) -> str:
        return f"({'&'.join(self(m) for m in intersect.masks)})"

    def _MaskRemove(self, mr: msk._MaskRemove):
        return f"({self(mr.from_)}-{self(mr.what)})"

    def _Wafer(self, wafer: wfr._Wafer, *args, **kwargs):
        # Size box to be sure it is big enough for substrate enclosure check
        bias = 0.0
        for act in self.tech.primitives.__iter_type__(prm.WaferWire):
            for enc in (act.min_substrate_enclosure, act.min_substrate_enclosure_same_type):
                if enc is not None:
                    bias = max(bias, enc.max())

        return "extent" if bias < self.tech.grid/100 else f"extent.sized({bias:.6})"
_mask_conv: Any = None


class _EdgeConverter(dsp.EdgeDispatcher):
    def __call__(self, edge: edg._Edge) -> str:
        return super().__call__(edge)

    def MaskEdge(self, edge: edg.MaskEdge) -> str:
        return f"{_mask_conv(edge.mask)}.edges"

    def _DualEdgeOperation(self, op: edg._DualEdgeOperation) -> str:
        s_edge1 = self(op.edge1)
        if isinstance(op.edge2, msk._Mask):
            s_edge2 = _mask_conv(op.edge2)
        elif isinstance(op.edge2, edg._Edge):
            s_edge2 = self(op.edge2)
        else: # pragma: no cover
            raise TypeError(f"[Internal error]Unexpected type for edge2 of {str(op)}")
        if op.operation == "interact_with":
            return f"{s_edge1}.interacting({s_edge2})"
        else: # pragma: no cover
            raise NotImplementedError(f"[Internal error]Operation: {op.operation}")

    def Join(self, join: edg.Join) -> str:
        s_join = "+".join(
            _mask_conv(e) if isinstance(e, msk._Mask) else self(e)
            for e in join.edges
        )
        return f"({s_join})"

    def Intersect(self, intersect: edg.Intersect) -> str:
        s_intersect = "&".join(
            _mask_conv(e) if isinstance(e, msk._Mask) else self(e)
            for e in intersect.edges
        )
        return f"({s_intersect})"
_edge_conv = _EdgeConverter()


def _str_designmask(mask: msk.DesignMask, *, gds_layers: GDSLayerSpecDict):
    gds_layer = gds_layers[mask.name]
    if not isinstance(gds_layer, tuple):
        gds_layer = (gds_layer, 0)
    return f"{_mask_conv(mask)} = input{gds_layer}\n"


def _str_alias(mask: msk._MaskAlias):
    return f"{_mask_conv(mask)} = {_mask_conv(mask.mask)}\n"


def _str_grid(mask: msk._Mask, grid: float):
    s_mask = _mask_conv(mask)
    return dedent(f"""
        {s_mask}.ongrid({grid}).output(
            "{s_mask} grid", "{s_mask} grid: {grid}µm"
        )
    """[1:])


class _RuleConverter(dsp.RuleDispatcher):
    def __call__(self, rule: rle._Rule, *, allow_unimplented: bool=False) -> str:
        if allow_unimplented:
            try:
                s = super().__call__(rule)
            except NotImplementedError:
                s = "# Not supported\n"
        else:
            s = super().__call__(rule)

        return f"# {rule}\n{s}"

    def GreaterEqual(self, ge: prp.Operators.GreaterEqual):
        left = ge.left
        right = ge.right
        if isinstance(left, msk._MaskProperty):
            s_mask = _mask_conv(left.mask)
            prop = left.prop_name
            if prop in {"width", "space"}:
                return dedent(f"""
                    {s_mask}.{prop}({right}).output(
                        "{s_mask} {prop}", "{s_mask} minimum {prop}: {right}µm"
                    )
                """[1:])
            elif left.prop_name == "area":
                return dedent(f"""
                    {s_mask}.with_area(nil, {right}).output(
                        "{s_mask} area", "{s_mask} minimum area: {right}µm"
                    )
                """[1:])
            elif left.prop_name == "density":
                assert isinstance(right, float)
                return dedent(f"""
                    {s_mask}_mindens = polygon_layer
                    dens_check({s_mask}_mindens, {s_mask}, {right}, 1)
                    {s_mask}_mindens.output(
                        "{s_mask} density", "{s_mask} minimum density: {round(100*right)}%"
                    )
                """[1:])
        elif isinstance(left, msk._DualMaskProperty):
            prop = left.prop_name
            if (
                (prop == "space")
                and (
                    isinstance(left.mask1, msk._PartsWith)
                    or isinstance(left.mask2, msk._PartsWith)
                )
            ):
                # Special code for handling width based spacing rules
                # Main objective if to support space tabels specified for primitives
                # Other application area are mainly untested
                if isinstance(left.mask1, msk._PartsWith):
                    assert len(left.mask1.condition) == 1
                    cond = left.mask1.condition[0]
                    assert left.mask1.mask == left.mask2
                    s_mask = _mask_conv(left.mask1.mask)
                else:
                    assert isinstance(left.mask2, msk._PartsWith)
                    assert len(left.mask2.condition) == 1
                    cond = left.mask2.condition[0]
                    assert left.mask2.mask == left.mask1
                    s_mask = _mask_conv(left.mask2.mask)
                assert isinstance(cond, prp.Operators.GreaterEqual)
                assert isinstance(cond.left, msk._MaskProperty)
                assert cond.left.prop_name == "width"
                return dedent(f"""
                    space4width_check({s_mask}, {cond.right}, {right}).output(
                        "{s_mask} table spacing",
                        "Minimum {s_mask} spacing for {cond.right}µm width: {right}µm"
                    )
                """[1:])
            s_mask1 = _mask_conv(left.mask1)
            s_mask2 = _mask_conv(left.mask2)
            if prop == "space":
                return dedent(f"""
                    {s_mask1}.separation({s_mask2}, {right}, square).output(
                        "{s_mask1}:{s_mask2} spacing",
                        "Minimum spacing between {s_mask1} and {s_mask2}: {right}µm"
                    )
                """[1:])
            elif prop == "overlapwidth":
                return dedent(f"""
                    ({s_mask1}&{s_mask2}).width({right}).output(
                        "{s_mask1}:{s_mask2} overlap width",
                        "Minimum overlap widht of {s_mask1} and {s_mask2}: {right}µm"
                    )
                """[1:])
            elif prop == "extend_over":
                return dedent(f"""
                    extend_check({s_mask2}, {s_mask1}, {right}).output(
                        "{s_mask1}:{s_mask2} extension",
                        "Minimum extension of {s_mask1} of {s_mask2}: {right}µm"
                    )
                """[1:])
        elif isinstance(left, msk._DualMaskEnclosureProperty):
            s_mask1 = _mask_conv(left.mask1)
            s_mask2 = _mask_conv(left.mask2)
            prop = left.prop_name
            if prop == "enclosed_by":
                # TODO: Proper typing for Property
                assert isinstance(right, prp.Enclosure)
                if not right.is_assymetric:
                    return dedent(f"""
                        {s_mask2}.enclosing({s_mask1}, {right.first}).output(
                            "{s_mask2}:{s_mask1} enclosure",
                            "Minimum enclosure of {s_mask2} around {s_mask1}: {right.first}µm"
                        )
                    """[1:])
                else:
                    s_desc = (
                        f"Minimum enclosure of {s_mask2} around {s_mask1}: "
                        f"{right.min()}µm minimum, {right.max()}µm opposite"
                    )
                    return dedent(f"""
                        oppenc_check({s_mask1}, {s_mask2}, {right.min()}, {right.max()}).output(
                            "{s_mask2}:{s_mask1} asymmetric enclosure",
                            "{s_desc}"
                        )
                    """[1:])
        elif isinstance(left, edg._EdgeProperty):
            s_edge = _edge_conv(left.edge)
            prop = left.prop_name
            if prop == "length":
                return dedent(f"""
                    {s_edge}.with_length(nil, {right}).output(
                        "{s_edge} length",
                        "Minimum length of {s_edge}: {right}µm"
                    )
                """[1:])
        elif isinstance(left, edg._DualEdgeProperty):
            s_edge1 = _edge_conv(left.edge1)
            s_edge2 = (
                _mask_conv(left.edge2)+".edges" if isinstance(left.edge2, msk._Mask)
                else _edge_conv(left.edge2)
            )
            prop = left.prop_name
            if prop == "enclosed_by":
                return dedent(f"""
                    {s_edge2}.enclosing({s_edge1}, {right}).output(
                        "{s_edge2}:{s_edge1} enclosure",
                        "Minimum enclosure of {s_edge2} around {s_edge1}: {right}µm"
                    )
                """[1:])

        raise NotImplementedError(f"GreateEqual rule '{ge}'") # pragma: no cover

    def Equal(self, eq: prp.Operators.Equal) -> str:
        left = eq.left
        right = eq.right
        # TODO: proper typing of Property
        assert isinstance(right, float)
        if isinstance(left, msk._MaskProperty):
            s_mask = _mask_conv(left.mask)
            prop = left.prop_name
            if prop == "width":
                return dedent(f"""
                    width_check({s_mask}, {right}).output(
                        "{s_mask} width", "{s_mask} width: {right}µm"
                    )
                """[1:])
            elif prop == "area":
                if round(right, 6) != 0.0:
                    raise ValueError("For area equal check value can only be 0.0")
                return f'{s_mask}.output("{s_mask} empty")\n'
        elif isinstance(left, edg._EdgeProperty):
            s_edge = _edge_conv(left.edge)
            prop = left.prop_name
            if prop == "length":
                if round(right, 6) != 0.0:
                    raise ValueError("For length equal check value can only be 0.0")
                return f'{s_edge}.output("{s_edge} empty")\n'

        raise NotImplementedError(f"Equal rule '{eq}'") # pragma: no cover

    def _MaskAlias(self, alias: msk._MaskAlias) -> str:
        return f"{_mask_conv(alias)} = {_mask_conv(alias.mask)}\n"

    def Connect(self, conn: msk.Connect) -> str:
        return "".join(
            f"connect({_mask_conv(mask1)}, {_mask_conv(mask2)})\n"
            for mask1, mask2 in product(conn.mask1, conn.mask2)
        )
_rule_conv = _RuleConverter()


def _str_lvsresistor(res: prm.Resistor):
    s = f"# {res.name}\n"

    s_res = _mask_conv(res.mask)
    s_conn = _mask_conv(res.wire.conn_mask)

    s += dedent(f"""
        extract_devices(resistor("{res.name}", {res.sheetres}), {{
            "R" => {s_res}, "C" => {s_conn},
        }})
        same_device_classes("{res.name}", "RES")
    """[1:])

    return s


def _str_lvsdiode(tech: tch.Technology, diode: prm.Diode):
    s = f"# {diode.name}\n"

    is_n = diode.implant.type_ == "n"

    s_diode = _mask_conv(diode.mask)
    s_conn = _mask_conv(diode.wire.conn_mask)
    s_well = _mask_conv(
        diode.well.mask if diode.well is not None
        else tech.substrate
    )

    if is_n:
        s_p = s_well
        s_n = s_diode
        s_conn_port = "tC"
    else:
        s_n = s_well
        s_p = s_diode
        s_conn_port = "tA"

    s += dedent(f"""
        extract_devices(diode("{diode.model}"), {{
            "P" => {s_p}, "N" => {s_n}, "{s_conn_port}" => {s_conn}
        }})
    """[1:])

    return s


def _str_lvsmosfet(tech: tch.Technology, mosfet: prm.MOSFET):
    s = f"# {mosfet.name}\n"

    s_sd = _mask_conv(mosfet.gate.active.conn_mask)
    s_gate = _mask_conv(mosfet.gate_mask)
    s_bulk = _mask_conv(
        mosfet.well.mask if mosfet.well is not None
        else tech.substrate
    )
    s_poly = _mask_conv(mosfet.gate.poly.conn_mask)

    s += dedent(f"""
        extract_devices(mos4("{mosfet.model}"), {{
            "SD" => {s_sd}, "G" => {s_gate}, "tG" => {s_poly}, "W" => {s_bulk},
        }})
    """[1:])

    return s


class FileExporter:
    def __init__(self, *,
        tech: tch.Technology, export_name: Optional[str]=None,
        gds_layers: GDSLayerSpecDict,
    ):
        self.tech = tech
        self.export_name = tech.name if export_name is None else export_name
        self.gds_layers = gds_layers

        global _mask_conv
        _mask_conv = _MaskConverter(tech=tech)

    def __call__(self):
        return {
            "drc": self._s_drc(),
            "ly_drc": self._ly_drc(),
            "extract": self._s_extract(),
            "ly_extract": self._ly_extract(),
            "lvs": self._s_lvs(),
            "ly_tech": self._ly_tech(),
        }

    def _s_drc(self):
        s = dedent(f"""
            # Autogenerated file. Changes will be overwritten.

            source(ENV["SOURCE_FILE"])
            report("{self.export_name} DRC", ENV["REPORT_FILE"])

        """[1:])

        return s + self._s_drcrules()

    def _ly_drc(self):
        ly_drc = ET.Element("klayout-macro")
        ET.SubElement(ly_drc, "description")
        ET.SubElement(ly_drc, "version")
        ET.SubElement(ly_drc, "category").text = "drc"
        ET.SubElement(ly_drc, "prolog")
        ET.SubElement(ly_drc, "epilog")
        ET.SubElement(ly_drc, "doc")
        ET.SubElement(ly_drc, "autorun").text = "false"
        ET.SubElement(ly_drc, "autorun-early").text = "false"
        ET.SubElement(ly_drc, "shortcut")
        ET.SubElement(ly_drc, "show-in-menu").text = "true"
        ET.SubElement(ly_drc, "group-name").text = "drc_scripts"
        ET.SubElement(ly_drc, "menu-path").text = "tools_menu.drc.end"
        ET.SubElement(ly_drc, "interpreter").text = "dsl"
        ET.SubElement(ly_drc, "dsl-interpreter-name").text = "drc-dsl-xml"
        s = dedent(f"""
            # Autogenerated file. Changes will be overwritten.
            
            report("{self.export_name} DRC")

        """[1:]) + self._s_drcrules()
        ET.SubElement(ly_drc, "text").text = s
        
        return ly_drc

    def _s_drcrules(self):
        s = dedent(f"""
            def width_check(layer, w)
                small = layer.width(w).polygons
                big = layer.sized(-0.5*w).size(0.5*w)

                small | big
            end

            def space4width_check(layer, w, s)
                big = layer.sized(-0.5*w).size(0.5*w)
                big.edges.separation(layer.edges, s)
            end

            def oppenc_check(inner, outer, min, max)
                toosmall = outer.enclosing(inner, min).second_edges

                smallenc = outer.enclosing(inner, max - 1.dbu, projection).second_edges
                # These edges may not touch each other
                touching = smallenc.width(1.dbu, angle_limit(100)).edges

                inner.interacting(toosmall + touching)
            end

            def extend_check(base, extend, e)
                extend.enclosing(base, e).first_edges.not_interacting(base)
            end

            def dens_check(output, input, min, max)
                tp = RBA::TilingProcessor::new

                tp.output("res", output.data)
                tp.input("input", input.data)
                tp.dbu = 1.dbu  # establish the real database unit
                tp.var("vmin", min)
                tp.var("vmax", max)

                tp.queue("_tile && (var d = to_f(input.area(_tile.bbox)) / to_f(_tile.bbox.area); (d < vmin || d > vmax) && _output(res, _tile.bbox))")
                tp.execute("Density check")
            end
        """[1:])

        s += "\n# Define layers\n"
        dms = tuple(self.tech.rules.__iter_type__(msk.DesignMask))
        s += "".join(_str_designmask(dm, gds_layers=self.gds_layers) for dm in dms)

        s += "\n# Grid check\n"
        gridrules = cast(Tuple[prp.Operators.Equal, ...], tuple(filter(
                lambda rule: (
                    isinstance(rule, prp.Operators.Equal)
                    and isinstance(rule.left, msk._MaskProperty)
                    and (rule.left.prop_name == "grid")
                ),
                self.tech.rules,
            )
        ))
        gridspecs = {
            cast(msk._MaskProperty, gridrule.left).mask: gridrule.right
            for gridrule in gridrules
        }
        globalgrid = gridspecs[wfr.wafer]
        s += "".join(
            _str_grid(dm, cast(float, gridspecs.get(dm, globalgrid)))
            for dm in dms
        )

        s += "\n# Derived layers\n"
        aliases = tuple(self.tech.rules.__iter_type__(msk._MaskAlias))
        s += "".join(_rule_conv(alias) for alias in aliases)

        s += "\n# Connectivity\n"
        conns = tuple(self.tech.rules.__iter_type__(msk.Connect))
        s += "".join(_rule_conv(conn) for conn in conns)

        s += "\n# DRC rules\n" + "".join(
            _rule_conv(rule) for rule in filter(
                lambda rule: rule not in dms + gridrules + conns + aliases,
                self.tech.rules
            )
        )

        return s

    def _s_extract(self):
        s = dedent(f"""
            # Autogenerated file. Changes will be overwritten

            source(ENV["SOURCE_FILE"])
            target_netlist(ENV["SPICE_FILE"], write_spice(true, true))

        """[1:])
        s += self._s_extractrules()

        return s

    def _ly_extract(self):
        ly_extract = ET.Element("klayout-macro")
        ET.SubElement(ly_extract, "description")
        ET.SubElement(ly_extract, "version")
        ET.SubElement(ly_extract, "category").text = "lvs"
        ET.SubElement(ly_extract, "prolog")
        ET.SubElement(ly_extract, "epilog")
        ET.SubElement(ly_extract, "doc")
        ET.SubElement(ly_extract, "autorun").text = "false"
        ET.SubElement(ly_extract, "autorun-early").text = "false"
        ET.SubElement(ly_extract, "shortcut")
        ET.SubElement(ly_extract, "show-in-menu").text = "true"
        ET.SubElement(ly_extract, "group-name").text = "lvs_scripts"
        ET.SubElement(ly_extract, "menu-path").text = "tools_menu.lvs.end"
        ET.SubElement(ly_extract, "interpreter").text = "dsl"
        ET.SubElement(ly_extract, "dsl-interpreter-name").text = "lvs-dsl-xml"
        s = dedent(f"""
            # Autogenerated file. Changes will be overwritten

            report_netlist

        """[1:])
        s += self._s_extractrules()
        ET.SubElement(ly_extract, "text").text = s
        
        return ly_extract

    def _s_lvs(self):
        s = dedent(f"""
            # Autogenerated file. Changes will be overwritten

            source(ENV["SOURCE_FILE"])
            schematic(ENV["SPICE_FILE"])
            report_lvs(ENV["REPORT_FILE"])

        """[1:])
        s += self._s_extractrules() + dedent(f"""
            align
            ok = compare
            if ok then
                print("LVS OK\\n")
            else
                abort "LVS Failed!"
            end
        """)

        return s

    def _s_extractrules(self):
        # TODO: bug report for failing LVS on hierarchical LVS and diodes
        s = "flat\n\n# Define layers\n"
        dms = tuple(self.tech.rules.__iter_type__(msk.DesignMask))
        s += "".join(_str_designmask(dm, gds_layers=self.gds_layers) for dm in dms)
        aliases = tuple(self.tech.rules.__iter_type__(msk._MaskAlias))
        s += "".join(_str_alias(alias) for alias in aliases)

        s += "\n# Connectivity\n"
        conns = tuple(self.tech.rules.__iter_type__(msk.Connect))
        s += "".join(_rule_conv(conn) for conn in conns)

        s += "\n# Resistors\n"
        resistors = tuple(self.tech.primitives.__iter_type__(prm.Resistor))
        s += "".join(_str_lvsresistor(res) for res in resistors)

        s += "\n# Diodes\n"
        diodes = tuple(self.tech.primitives.__iter_type__(prm.Diode))
        s += "".join(_str_lvsdiode(self.tech, diode) for diode in diodes)

        s += "\n# Transistors\n"
        mosfets = tuple(self.tech.primitives.__iter_type__(prm.MOSFET))
        s += "".join(_str_lvsmosfet(self.tech, mosfet) for mosfet in mosfets)

        s += "\nnetlist\n"

        return s

    def _ly_tech(self):
        lyt = ET.Element("technology")
        ET.SubElement(lyt, "name").text = self.export_name
        ET.SubElement(lyt, "description").text = (
            f"KLayout generated from {self.tech.name} PDKMaster technology"
        )
        ET.SubElement(lyt, "group")
        ET.SubElement(lyt, "dbu").text = f"{self.tech.dbu}"
        ET.SubElement(lyt, "layer-properties_file")
        ET.SubElement(lyt, "add-other-layers").text = "true"
        ropts = ET.SubElement(lyt, "reader-options")
        roptscom = ET.SubElement(ropts, "common")
        ET.SubElement(roptscom, "create-other-layers").text = "true"
        def s_gds_layer(m: msk.DesignMask):
            try:
                l = self.gds_layers[m.name]
            except KeyError: # pragma: no cover
                raise ValueError(
                    f"No gds_layer provided for mask '{m.name}'"
                )
            if isinstance(l, int):
                return f"{l}/0"
            else:
                assert isinstance(l, tuple)
                return f"{l[0]}/{l[1]}"
        s_map = ";".join(
            f"'{s_gds_layer(mask)} : {mask.name}'"
            for mask in self.tech.designmasks
        )
        ET.SubElement(roptscom, "layer-map").text = f"layer_map({s_map})"
        ET.SubElement(lyt, "writer-options")
        ET.SubElement(lyt, "connectivity")

        return lyt


class _ShapeExporter(dsp.ShapeDispatcher):
    """Converts a _geo,_Shape object to KLayout database object"""
    def __init__(self, *, export_fullshape: bool):
        self._mps_exported: Optional[Set[geo.MultiPartShape]]
        self._mps_exported = set() if export_fullshape else None

    def _pointsshapes(self, shape: geo._Shape) -> Generator[Any, None, None]:
        # Helper to convert the individual pointsshapes
        for pointshape in shape.pointsshapes:
            conv = self(pointshape)
            assert not _util.is_iterable(conv), "Internal error: unsupported"
            yield conv

    def _Shape(self, shape: geo._Shape): # pragma: no cover
        raise ValueError(f"Unsupported object of type {shape.__class__.__name__}")

    def Point(self, point: geo.Point) -> "pya.DPoint": # type: ignore
        return pya.DPoint(point.x, point.y)

    def Line(self, line: geo.Line) -> "pya.DPath": # type: ignore
        # We represent a Line by a path with zero width
        points = (self.Point(line.point1), self.Point(line.point2))
        return pya.DPath(points, 0.0)

    def Polygon(self, polygon: geo.Polygon, **_) -> "pya.DSimplePolygon": # type: ignore
        # In PDKMaster polygon needs last point to be same as first point;
        # in klayout this is not the case.
        points = tuple(self.Point(point) for point in polygon.points[:-1])
        return pya.DSimplePolygon(points)

    def Rect(self, rect: geo.Rect) -> "pya.DBox": # type: ignore
        return pya.DBox(rect.left, rect.bottom, rect.right, rect.top)

    def RectRing(self, rs: geo.RectRing) -> Generator[Any, None, None]:
        # TODO: Can repetition information be retained in KLayout ?
        return self._pointsshapes(rs)

    def MultiPartShape(self, mps: geo.MultiPartShape):
        if self._mps_exported is None:
            return self(mps.fullshape)
        else:
            if mps in self._mps_exported:
                return None
            else:
                self._mps_exported.add(mps)
                return self(mps.fullshape)

    def MultiPartShape__Part(self, part: geo.MultiPartShape._Part):
        if self._mps_exported is None:
            return self(part.partshape)
        else:
            return self(part.multipartshape)

    def MultiShape(self, ms: geo.MultiShape) -> Generator[Any, None, None]:
        for shape in ms.shapes:
            conv = self(shape)
            if _util.is_iterable(conv):
                yield from conv
            else:
                yield conv

    def RepeatedShape(self, rs: geo.RepeatedShape) -> Generator[Any, None, None]:
        # TODO: Can repetition information be retained in KLayout ?
        return self._pointsshapes(rs)

    # TODO: Does KLayout allow more efficient array representation
    # ArrayShape -> RepeatedShape


class _MaskLayerDict(Dict[msk.DesignMask, int]):
    def __init__(self, *, layout: "pya.Layout", gds_layers: GDSLayerSpecDict): # type: ignore
        self._layout = layout
        self._gds_layers = gds_layers

    def __getitem__(self, mask: msk.DesignMask) -> int:
        if mask not in self:
            layer = self._gds_layers[mask.name]
            if isinstance(layer, tuple):
                layer, datatype = layer
            else:
                datatype = 0
            self[mask] = self._layout.layer(layer, datatype, mask.name)
        return super().__getitem__(mask)


_rotation_to_rot_mirr: Dict[geo.Rotation, Tuple[int, bool]] = {
    geo.Rotation.No: (0, False),
    geo.Rotation.R90: (1, False),
    geo.Rotation.R180: (2, False),
    geo.Rotation.R270: (3, False),
    geo.Rotation.MX: (0, True),
    geo.Rotation.MX90:(1, True),
    geo.Rotation.MY: (2, True),
    geo.Rotation.MY90: (3, True),
}


class _LayoutExporter:
    def __init__(self):
        self._clear()

    def _clear(self):
        self.layout = None
        self.layerdict = None
        self.cell_lookup: Dict[str, Tuple[Optional[lbry._Cell], "pya.Layout"]] = {} # type: ignore
        self.cells_todo: Set[lbry._Cell] = set()
        self.cells_done: Set[lbry._Cell] = set()
        self.cell = None
        self.shapeexporter = None
        self.pinmasks: Optional[Tuple[msk._Mask, ...]] = None

    def __call__(self, *,
        obj: Union[geo.MaskShape, geo.MaskShapes, lay._Layout, lbry.Library],
        gds_layers: GDSLayerSpecDict, cell_name: Optional[str],
        merge: bool, add_pin_label: bool, dbu: float=0.001,
    ) -> "pya.Layout": # type: ignore
        self.layout = layout = pya.Layout()
        layout.dbu = dbu
        self.layerdict = _MaskLayerDict(layout=layout, gds_layers=gds_layers)
        self.shapeexporter = _ShapeExporter(export_fullshape=True)

        if isinstance(obj, lbry.Library):
            assert cell_name is None
            # A cell will be created in add() function below for each cell of the library
        else:
            if cell_name is None:
                cell_name = "anon"
            self.cell = self._create_layout(cell_name)

        # Define local function to allow recursive calling
        self._add(obj, add_pin_label=add_pin_label, net=None)
        while len(self.cells_todo) > 0:
            cell = self.cells_todo.pop()
            self.cells_done.add(cell)
            self.cell = self.cell_lookup[cell.name][1]
            # Don't reuse ShapeExporter between cells, otherwise MultiPartShapes
            # can be wrongly marked as written when they are not.
            self.shapeexporter = _ShapeExporter(export_fullshape=True)
            self._add(cell.layout, add_pin_label=add_pin_label, net=None)

        if merge:
            for cell in layout.each_cell():
                for layer_idx in self.layerdict.values():
                    # https://www.klayout.de/forum/discussion/697/merge-all-shapes-of-a-certain-layer-of-a-cell
                    old_shapes = cell.shapes(layer_idx)
                    region = pya.Region(old_shapes)
                    region.merge()

                    new_shapes = pya.Shapes()
                    new_shapes.insert(region)
                    # Copy texts over to avoid they are lost
                    for shape in old_shapes.each():
                        if shape.is_text():
                            new_shapes.insert(shape)

                    cell.shapes(layer_idx).assign(new_shapes)

        self._clear()
        return layout

    def _create_layout(self, name: str):
        assert self.layout is not None
        assert len(self.cell_lookup) == 0
        cell = self.layout.create_cell(name)
        self.cell_lookup[name] = (None, cell)

        return cell

    def _register_cell(self, cell: lbry._Cell):
        assert self.layout is not None
        assert self.cell_lookup is not None

        name = cell.name
        if name in self.cell_lookup:
            # Check if no two cells from different libraries with the same name are used
            if self.cell_lookup[name][0] != cell: # pragma: no cover
                raise ValueError(
                    "Export of hierarchy with two cells from different libraries with same name not supported"
                )
        else:
            self.cell_lookup[name] = (cell, self.layout.create_cell(name))
            if cell not in self.cells_done:
                self.cells_todo.add(cell)
        return self.cell_lookup[name][1]

    def _add(self,
        o: Union[geo.MaskShape, geo.MaskShapes, lay._Layout, lbry._Cell, lbry.Library],
        add_pin_label: bool, net: Optional[_net.Net],
    ) -> None:
        assert self.shapeexporter is not None
        if isinstance(o, geo.MaskShape):
            assert (self.cell is not None) and (self.layerdict is not None)
            layer = self.layerdict[o.mask]
            shapes = self.cell.shapes(layer)
            exp = self.shapeexporter(o.shape)
            if _util.is_iterable(exp):
                for s in exp:
                    # s is None if MultiPartShape is already added
                    if s is not None:
                        shapes.insert(s)
            else:
                # exp is None if MultiPartShape has already been converted
                if exp is not None:
                    shapes.insert(exp)
            if add_pin_label:
                assert self.pinmasks is not None
                if o.mask in self.pinmasks:
                    assert net is not None
                    for ps in o.shape.pointsshapes:
                        if isinstance(ps, geo.Rect):
                            point = ps.center
                        else:
                            point = _util.first(ps.points)
                        shapes.insert(pya.DText(net.name, point.x, point.y))
        elif isinstance(o, geo.MaskShapes):
            for ms in o:
                self._add(ms, add_pin_label=add_pin_label, net=net)
        elif isinstance(o, lay._Layout):
            prims = o.fab.tech.primitives
            pinmasks = []
            for prim in prims.__iter_type__(prm._PinAttribute):
                if prim.pin is not None:
                    for pin in prim.pin:
                        pinmasks.append(pin.mask)
            self.pinmasks = tuple(pinmasks)
            for sl in o.sublayouts:
                if isinstance(sl, lay.MaskShapesSubLayout):
                    self._add(sl.shapes, add_pin_label=add_pin_label, net=sl.net)
                elif isinstance(sl, lay._InstanceSubLayout):
                    assert self.cell is not None
                    instcell = self._register_cell(sl.inst.cell)
                    rot, mirr = _rotation_to_rot_mirr[sl.rotation]
                    dtrans = pya.DTrans(rot, mirr, sl.origin.x, sl.origin.y)
                    self.cell.insert(pya.DCellInstArray(instcell.cell_index(), dtrans))
                else: # pragma: no cover
                    raise RuntimeError(
                        "Internal error: unsupported",
                    )
        elif isinstance(o, lbry._Cell):
            assert net is None
            # Register cell to be exported
            self._register_cell(o)
            # Make use of self.cell after this an error
            self.cell = None
        elif isinstance(o, lbry.Library):
            # Initialize list of cells to tape-out, each of the cells will be
            # added in a loop below this function.
            assert net is None
            self.cells_todo.update(o.cells)
            # Register all cells to be exported
            for lbrycell in o.cells:
                self._register_cell(lbrycell)
            # Make use of self.cell after this an error
            self.cell = None
        else:
            raise TypeError(f"No support for exporting a '{type(o)}' object")


@overload
def export2db(
    obj: geo._Shape, *,
    export_fullshape: Optional[bool]=None,
    add_pin_label: bool=False,
    gds_layers: None=None,
    cell_name: None=None,
    merge: bool=False,
) -> Any:
    ... # pragma: no cover
@overload
def export2db(
    obj: Union[geo.MaskShape, geo.MaskShapes, lay._Layout, lbry.Library], *,
    export_fullshape: None=None,
    add_pin_label: bool=False,
    gds_layers: GDSLayerSpecDict,
    cell_name: Optional[str]=None,
    merge: bool=False,
) -> "pya.Layout": # type: ignore
    ... # pragma: no cover
def export2db(
    obj: Union[geo._Shape, geo.MaskShape, geo.MaskShapes, lay._Layout, lbry.Library], *,
    export_fullshape: Optional[bool]=None, add_pin_label: bool=False,
    gds_layers: Optional[GDSLayerSpecDict]=None,
    cell_name: Optional[str]=None,
    merge: bool=False,
):
    """This function allows to export PDKMaster geometry/layout objects
    to a klayout Layout object

    Arguments:
        obj: This is the object to export. There are two different call types.
            1) a _Shape geometry object without mask provided
            2) a MaskShape geometry or an object that contains maskshapes.
        export_fullshape: only to be specified when obj is a _Shape object.
            If `True` the full MultiPartShape shape will be exported when a
            MultiPartShape._Part object is met. It defaults to `False`
        add_pin_label: If `True` a label will exported on top of layers that are pin
            layers of one of the technology's MetalWire primitives.
        gds_layers: Has to be specified when `obj` is a MaskShape or a collection of them
            and should not be provided otherwise.
            It contains the lookup table to get the corresponding KLayout layer information
            for PDKMaster _DesignMask objects.
        cell_name: Only to be provided when `obj` is a MaskShape or a collection of them but
            not a Library.
            If specified the name of the cell in which the MaskShape(s) will be exported.
            By default 'anon' will be used.
        merge: Wether to merge the exported shapes or not.

    Returns:
        An equivalent KLayout database object if obj is a _Shape geometry or a KLayout
        Layout object when obj is a MaskShape or a collection of MaskShapes. If obj is
        a Library a Cell will be added for each _Cell in the Library. If there are
        cell instances in the PDKMaster _Layout object these cells will also be exported
        to the output KLayout Layout object as a cell even if it is from another PDKMaster
        Library. An exception will be generated when two instances need to be exported to
        the same cell name but in a different Library.
    """
    # TODO: Provide facility to also export netlist information

    if isinstance(obj, geo._Shape):
        assert (gds_layers is None) and (cell_name is None) and (not add_pin_label)
        if export_fullshape is None:
            export_fullshape = False
        _exporter = _ShapeExporter(export_fullshape=export_fullshape)
        return _exporter(obj)
    else:
        assert (export_fullshape is None) and (gds_layers is not None)
        return _LayoutExporter()(
            obj=obj, gds_layers=gds_layers, cell_name=cell_name, merge=merge,
            add_pin_label=add_pin_label,
        )
