import numpy as np
import scipy.sparse
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from pyjjasim.variable_row_array import VarRowArray


"""
Embedded Graph Module
"""

class NotSingleComponentError(Exception):
    pass

class NotPlanarEmbeddingError(Exception):
    pass

class NonSimpleError(Exception):
    pass

class SelfLoopError(Exception):
    pass

class NodeNotExistError(Exception):
    pass

class EdgeNotExistError(Exception):
    pass


class EmbeddedGraph:

    """
    Class for embedded 2D graphs. Can be used to construct faces and check planarity.

    Definitions:
     - l-cycle:
        cycle generated by traversing the embedded graph always taking the leftmost turn at
        every node.
     - face-cycle:
        An l-cycle with positive oriented area (using shoelace formula). If the graph is planar,
        face cycles are the boundaries of faces.

    Requirements:
     - No self-loops
     - Simple (no multigraph)

    """

    def __init__(self, x, y, node1, node2, require_single_component=False,
                 require_planar_embedding=False, _edges_are_sorted=None):
        """
        Parameters
        ----------
        x, y: (N,) array
            coordinates of nodes of embedded graph
        node1, node2: (E,) int array in range(N)
            endpoint nodes of edges in embedded graph. Nodes are referred to by their index in
            the coordinate arrays.
        require_single_component=False:
            If True, an error is raised if the graph is not single-component
        require_planar_embedding=False:
            If True, an error is raised if the graph is not a planar embedding

        Raises
        ------
        NonSimpleError
            If graph is not simple
        SelfLoopError
            if graph contains self-loops
        NotSingleComponentError
            if graph has disconnected components and require_single_component=True
        NotPlanarEmbeddingError
            If graph in not a planar embedding and require_planar_embedding=True
        """

        self.x = np.array(x, dtype=np.double).ravel()
        self.y = np.array(y, dtype=np.double).ravel()
        if len(self.x) != len(self.y):
            raise ValueError("x and y must be same size")
        self.N = len(self.x)
        self.node1, self.node2 = np.array(node1, dtype=int).ravel(), np.array(node2, dtype=int).ravel()
        self._assert_edges_correct_shape()
        self.E = len(self.node1)
        self._assert_edges_contain_existing_nodes()
        self._assert_no_self_loop()

        if _edges_are_sorted is None:
            self.edge_permute = np.arange(self.edge_count())
            self.edge_flip = np.zeros(self.edge_count(), dtype=bool)
            self._sort_edges()
        else:
            self.edge_permute, self.edge_flip = _edges_are_sorted

        self._assert_nonsimple()
        self.edge_v_array = self._assign_edge_v_array()

        # quantities are computed and stored only when a method needs them.
        self.F = None
        self.face_permutation = None
        self.faces_v_array = None
        self.face_edges = None
        self.face_nodes = None
        self.face_lengths = None
        self.determinant = None
        self.areas = None
        self.roll_ids = None
        self.boundary_face_indices = None

        if require_single_component:
            self._assert_single_component()
        if require_planar_embedding:
            self._assert_planar_embedding()

    def get_edges(self):
        """
        Return node endpoints of all edges.

        Returns
        -------
        node1, node2: (E,) arrays
            node endpoints of all edges.
        """
        return self._de_sort_edges()

    def get_edge_ids(self, node1, node2):
        """
        Return ids of edges with given endpoint nodes

        Parameters
        ----------
        node1, node2: arrays in range(N)
            endpoints of edges

        Returns
        -------
        edge_ids: array with same size as node1 in range(E)
            ids of edges

        Raises
        ------
        EdgeNotExistError
            If a queried node-pair does not exist.

        """
        node1, node2 = np.array(node1, dtype=int).ravel(), np.array(node2, dtype=int).ravel()
        mask = node1 < node2
        node1, node2 = np.where(mask, node1, node2), np.where(mask, node2, node1)
        if np.any(node1 > self.node1[-1]):
            raise EdgeNotExistError("queried edge that does not exist")
        starts, ends = self.edge_v_array.row_ranges()
        node1_idx = np.searchsorted(self.node1[starts], node1)
        starts, ends = starts[node1_idx], ends[node1_idx]
        while not np.all(node2 <= self.node2[starts]):
            starts[node2 > self.node2[starts]] += 1
            if not np.all(starts < ends):
                raise EdgeNotExistError("queried edge that does not exist")
        if not np.all((self.node1[starts] == node1) & (self.node2[starts] == node2)):
            raise EdgeNotExistError("queried edge that does not exist")
        return self.edge_permute[starts]

    def add_nodes(self, x, y):
        """
        Add nodes to graph.

        Parameters
        ----------
        x, y: arrays
            coordinates of the nodes to be added to the graph

        Returns
        -------
        new_graph: EmbeddedGraph
            new graph with nodes added to it.
        """
        return EmbeddedGraph(np.append(self.x, np.array(x, dtype=np.double).ravel()), np.append(self.y, np.array(y, dtype=np.double).ravel()), self.node1, self.node2, require_single_component=False)

    def add_edges(self, node1, node2, require_single_component=False, require_planar_embedding=False):
        """
        Add edges to graph.

        Parameters
        ----------
        node1, node2: arrays in range(N)
            endpoints of edges
        require_single_component=False:
            If True; raises error if the resulting graph is not single-component
        require_planar_embedding=False:
            If True; raises error if the resulting graph is not a planar embedding

        Returns
        -------
        new_graph: EmbeddedGraph
            new graph with edges added to it.

        """
        return EmbeddedGraph(self.x, self.y, np.append(self.node1, np.array(node1, dtype=int).ravel()),
                             np.append(self.node2, np.array(node2, dtype=int).ravel()),
                             require_single_component=require_single_component,
                             require_planar_embedding=require_planar_embedding)

    def add_nodes_and_edges(self, x, y, node1, node2, require_single_component=False,
                            require_planar_embedding=False):
        """
        Add nodes and edges to graph.

        Parameters
        ----------
        x, y: arrays
            coordinates of the nodes to be added to the graph
        node1, node2: arrays in range(N + x.size)
            endpoints of edges. The i-th new node must be referred to by index N + i.
        require_single_component=False:
            If True; raises error if the resulting graph is not single-component
        require_planar_embedding=False:
            If True; raises error if the resulting graph is not a planar embedding

        Returns
        -------
        new_graph: EmbeddedGraph
            new graph with nodes and edges added to it.

        """
        return EmbeddedGraph(np.append(self.x, np.array(x, dtype=np.double).ravel()),
                             np.append(self.y, np.array(y, dtype=np.double).ravel()),
                             np.append(self.node1, np.array(node1, dtype=int).ravel()),
                             np.append(self.node2, np.array(node2, dtype=int).ravel()),
                             require_single_component=require_single_component,
                             require_planar_embedding=require_planar_embedding)

    def remove_nodes(self, node_ids, require_single_component=False, require_planar_embedding=False):
        """
        Remove nodes from graph.

        Parameters
        ----------
        node_ids: int array in range(N)
            ids of nodes to be removed
        require_single_component=False:
            If True; raises error if the resulting graph is not single-component
        require_planar_embedding=False:
            If True; raises error if the resulting graph is not a planar embedding

        Returns
        -------
        new_graph: EmbeddedGraph
            new graph with nodes removed from it.
        """
        node_ids = np.array(node_ids, dtype=int).ravel()
        node_map = np.ones(self.node_count(), dtype=int)
        node_map[node_ids] = 0
        node_map = np.cumsum(node_map) - node_map
        node_map[node_ids] = -1
        n1, n2 = node_map[self.node1], node_map[self.node2]
        edge_ids = (n1 >= 0) & (n2 >= 0)
        return EmbeddedGraph(np.delete(self.x, node_ids), np.delete(self.y, node_ids),
                             n1[edge_ids], n2[edge_ids], require_single_component=require_single_component,
                             require_planar_embedding=require_planar_embedding)

    def remove_edges_by_ids(self, edge_ids, require_single_component=False, require_planar_embedding=False):
        """
        Remove edges from graph with edge ids as input.

        Parameters
        ----------
        edge_ids: int array in range(E)
            ids of nodes to be removed
        require_single_component=False:
            If True; raises error if the resulting graph is not single-component
        require_planar_embedding=False:
            If True; raises error if the resulting graph is not a planar embedding

        Returns
        -------
        new_graph: EmbeddedGraph
            new graph with edges removed from it.
        """
        ids_internal = self.edge_permute_inverse[edge_ids]
        return EmbeddedGraph(self.x, self.y, np.delete(self.node1, ids_internal),
                             np.delete(self.node2, ids_internal), _edges_are_sorted=True,
                             require_single_component=require_single_component,
                             require_planar_embedding=require_planar_embedding)

    def remove_edges(self, node1, node2, require_single_component=False, require_planar_embedding=False):
        """
        Remove edges from graph with node endpoints as input.

        Parameters
        ----------
        node1, node2: int arrays in range(N)
            ids of node endpoints of edges to be removed
        require_single_component=False:
            If True; raises error if the resulting graph is not single-component
        require_planar_embedding=False:
            If True; raises error if the resulting graph is not a planar embedding

        Returns
        -------
        new_graph: EmbeddedGraph
            new graph with edges removed from it.

        Raises
        ------
        EdgeNotExistError
            If a queried node-pair does not exist.
        """
        return self.remove_edges_by_ids(self.get_edge_ids(node1, node2),
                                        require_single_component=require_single_component,
                                        require_planar_embedding=require_planar_embedding)

    def coo(self):
        """
        Returns coordinates of nodes

        Returns
        -------
        x, y: (N,) arrays
            coordinates of nodes in graph
        """
        return self.x, self.y

    def node_count(self):
        """
        Returns number of nodes in graph (abbreviated by N)
        """
        return self.N

    def edge_count(self):
        """
        Returns number of edges in graph (abbreviated by E)
        """
        return self.E

    def get_l_cycles(self, to_list=True):
        """
        Returns for all l-cycles the nodes it traverses in order.

        Parameters
        ----------
        to_list=False:
            If true, output is in form of list-of-lists, otherwise concatenated array.

        Returns
        -------
        nodes:  list-of-lists or array
             for all l-cycles the nodes it traverses in order.
        lengths: (FR,) int array
            length of every l-cycle
        """
        self._assign_faces()
        self._assign_face_nodes()
        nodes, lengths = self.face_nodes, self.face_lengths
        if to_list:
            return VarRowArray(lengths).to_list(nodes), lengths
        return nodes, lengths

    def get_face_cycles(self, to_list=True):
        """
        Returns for all face-cycles the nodes it traverses in order. If the graph is planar,
        these enclose individual faces.

        Parameters
        ----------
        to_list=False:
            If true, output is in form of list-of-lists, otherwise concatenated array.

        Returns
        -------
        nodes:  list-of-lists or array
             for all face-cycles the nodes it traverses in order.
        lengths: (F,) int array
            length of every face-cycle
        """
        self._assign_faces()
        self._assign_face_nodes()
        self._assign_boundary_faces()
        nodes = self.faces_v_array.delete_rows(self.face_nodes, self.boundary_face_indices)
        lengths = np.delete(self.face_lengths, self.boundary_face_indices)
        if to_list:
            return VarRowArray(lengths).to_list(nodes), lengths
        return nodes, lengths

    def face_count(self):
        """
        Returns number of face cycles in graph (abbreviated by F). In a planar
        graph this equals the number of faces.
        """
        self._assign_faces()
        self._assign_boundary_faces()
        return self.F - len(self.get_boundary_faces())

    def l_cycle_count(self):
        """
        Returns number of l-cycles in graph. (abbreviated by FR)
        """
        self._assign_faces()
        return self.F

    def is_planar_embedding(self):
        """
        Returns if graph is planar embedding, which is true if edges only intersect at their endpoints.
        """
        self._assign_faces()
        return self.node_count() + self.face_count() == self.edge_count() + 1

    def get_face_areas(self):
        """
        Returns areas of face-cycles. These correspond to face-areas is the graph
        is planar.

        Returns
        -------
        areas:  (F,) array
            Areas of face-cycles
        """
        self._assign_areas()
        self._assign_boundary_faces()
        return np.delete(self.areas, self.boundary_face_indices)

    def get_l_cycle_areas(self):
        """
        Returns signed areas of l-cycles.

        Returns
        -------
        areas:  (FR,) array
            signed areas of l-cycles
        """
        self._assign_areas()
        return self.areas

    def get_l_cycle_centroids(self):
        # TODO
        """
        Returns centroids of faces in graph

        Two types of faces exist:
         * boundary-faces: surround whole components.
         * internal-faces: surround individual faces.

        Parameters
        ----------
        include_boundary_faces=True:
            If True, also includes boundary faces in output. Otherwise only internal faces.

        Returns
        -------
        x, y:  (F,) arrays
            Returns coordinates of centroids of faces
        """
        six_times_area = 6.0 * self.get_l_cycle_areas()
        mask = np.isclose(six_times_area, 0.0)
        six_times_area[mask] = 2.0 * self.face_lengths[mask]
        long_mask = self.faces_v_array.at_out_index(mask)
        X = self.x[self.face_nodes] + self.x[self.roll_ids]
        Y = self.y[self.face_nodes] + self.y[self.roll_ids]
        centroid_x = self.faces_v_array.sum(np.where(long_mask, X, self.determinant * X)) / six_times_area
        centroid_y = self.faces_v_array.sum(np.where(long_mask, Y, self.determinant * Y)) / six_times_area
        return centroid_x, centroid_y

    def get_face_centroids(self):
        # TODO
        """
        Returns centroids of faces in graph

        Two types of faces exist:
         * boundary-faces: surround whole components.
         * internal-faces: surround individual faces.

        Parameters
        ----------
        include_boundary_faces=True:
            If True, also includes boundary faces in output. Otherwise only internal faces.

        Returns
        -------
        x, y:  (F,) arrays
            Returns coordinates of centroids of faces
        """
        centroid_x, centroid_y = self.get_l_cycle_centroids()
        self._assign_boundary_faces()
        return np.delete(centroid_x, self.boundary_face_indices), \
               np.delete(centroid_y, self.boundary_face_indices)

    def get_num_components(self):
        return scipy.sparse.csgraph.connected_components(
            self.adjacency_matrix(), directed=False, return_labels=False)

    def get_components(self):
        _, components = scipy.sparse.csgraph.connected_components(
            self.adjacency_matrix(), directed=False, return_labels=True)
        return components

    def split_components(self):
        components = self.get_components()
        num_components = np.max(components) + 1
        return [self.remove_nodes(np.flatnonzero(components != c)) for c in range(num_components)]

    def get_boundary_faces(self):
        self._assign_boundary_faces()
        return self.boundary_face_indices

    def l_cycle_matrix(self, _permute=True):
        # TODO
        self._assign_faces()
        return self._cycle_matrix(self.face_edges, self.face_lengths,
                                  self.edge_permute if _permute else np.arange(self.edge_count()))

    def _cycle_matrix(self, face_edges, face_lengths, permute):
        E, F = self.edge_count(), len(face_lengths)
        indptr = np.append([0], np.cumsum(face_lengths))
        indices, data = permute[face_edges % E], 1 - 2 * (face_edges // E)
        return scipy.sparse.csr_matrix((data, indices, indptr), shape=(F, E)).tocsc()

    def face_cycle_matrix(self):
        # TODO
        self._assign_faces()
        self._assign_boundary_faces()
        nb_face_edges = self.face_edges[self.faces_v_array.get_item(rows=self._non_boundary_mask())]
        nb_face_lengths = self.face_lengths[self._non_boundary_mask()]
        return self._cycle_matrix(nb_face_edges, nb_face_lengths, self.edge_permute)

    def cut_space_matrix(self):
        E, N = self.edge_count(), self.node_count()
        row = np.concatenate((self.node1, self.node2))
        col = self.edge_permute[np.concatenate((np.arange(E), np.arange(E)))]
        data = np.concatenate((-np.ones(E), np.ones(E)))
        return scipy.sparse.coo_matrix((data, (row, col)), shape=(N, E)).tocsc()

    def adjacency_matrix(self):
        E, N = self.edge_count(), self.node_count()
        data = np.ones(2 * E)
        row = np.append(self.node1, self.node2)
        col = np.append(self.node2, self.node1)
        return scipy.sparse.coo_matrix((data, (row, col)), shape=(N, N)).tocsc()

    def get_common_edge_of_l_cycles(self, cycle1, cycle2, return_orientation=False):
        # returns index of an edge occurring in both faces if it exists; otherwise -1.
        # If multiple exist; returns the lowest index.
        # ({s},) ({s},) -> ({s},) in range(E) or -1
        # optionally return orientation; True if cycle1 passes edge counterclockwise
        # first encountering its node with lowest index
        self._assign_faces()
        f = self.faces_v_array.rows()[np.argsort(self.face_edges)].astype(np.int64)
        E = self.edge_count()
        f1, f2 = f[:E], f[E:2 * E]
        mask = f1 < f2
        f1, f2 = np.where(mask, f1, f2),  np.where(mask, f2, f1)

        sorter = np.lexsort((f2, f1))
        f1, f2 = f1[sorter],  f2[sorter]
        A = np.array(list(zip(f1, f2)), dtype=[('f1', 'int'), ('f2', 'int')])

        cycle1, cycle2 = np.array(cycle1, dtype=int), np.array(cycle2, dtype=int)
        input_shape = cycle1.shape
        cycle1, cycle2 = cycle1.ravel(), cycle2.ravel()
        mask2 = cycle1 < cycle2
        fmin, fmax = np.where(mask2, cycle1, cycle2).astype(np.int64), np.where(mask2, cycle2, cycle1).astype(np.int64)
        V = np.array(list(zip(fmin, fmax)), dtype=[('fmin', 'int'), ('fmax', 'int')])

        edge_ids = np.searchsorted(A, V)
        edge_ids[edge_ids >= E] = E - 1
        if return_orientation:
            orientation = mask[sorter][edge_ids] ^ ~mask2

        found_mask = (f1[edge_ids] != fmin) | (f2[edge_ids] != fmax)
        edge_ids = sorter[edge_ids]

        edge_ids[found_mask] = -1

        if return_orientation:
            orientation[found_mask] = False
            return self.edge_permute[edge_ids.reshape(input_shape)], \
                   orientation.reshape(input_shape)
        else:
            return self.edge_permute[edge_ids.reshape(input_shape)]

    def permute_nodes(self, permutation):
        """
        Permute node order. Because edges refer to nodes by their position,
        this also changes node1 and node2.
        """
        permutation = np.array(permutation, dtype=int)
        if not np.all(np.sort(permutation) == np.arange(self.node_count())):
            raise ValueError("invalid permutation")
        self.x = self.x[permutation]
        self.y = self.y[permutation]
        inv_perm = np.argsort(permutation)
        self.node1 = inv_perm[self.node1]
        self.node2 = inv_perm[self.node2]
        self.node1, self.node2 = self._de_sort_edges()
        self._sort_edges()
        self._recompute_faces()

    def permute_faces(self, permutation):
        """
        Permute face order.
        """
        permutation = np.array(permutation, dtype=int)
        if not np.all(np.sort(permutation) == np.arange(self.l_cycle_count())):
            raise ValueError("invalid permutation")
        self.face_permutation = self.face_permutation[permutation]
        self._recompute_faces()

    def face_dual_graph(self):
        A = self.face_cycle_matrix()
        adj = (A @ A.T).tocoo()
        mask = adj.row < adj.col
        return EmbeddedGraph(*self.get_face_centroids(), adj.row[mask], adj.col[mask])

    def locate_faces(self, x, y):
        """
        Get faces whose centroids are closest to queried coordinate.
        Graph must be planar embedding.

        Attributes
        ----------
        x, y: arrays:
            Coordinates at which one wants to locate faces

        Returns
        -------
        face_ids: int array with same size as x in range(Nf)
            ids of located faces
        """
        if not self.is_planar_embedding():
            raise ValueError("Only works for planar embedding")
        locator = scipy.spatial.KDTree(np.stack(self.get_face_centroids(), axis=-1))
        _, face_ids = locator.query(np.stack(np.broadcast_arrays(x, y), axis=-1), k=1)
        return face_ids, locator

    def plot(self, fig=None, ax=None, show_cycles=True, cycles="face_cycles", figsize=[5, 5],
             show_node_ids=False, show_edge_ids=False, show_face_ids=False,
             face_shrink_factor=0.9, markersize=5, linewidth=1):

        def _face_line(x, y, n, xcn, ycn, f_shrink):
            xp = f_shrink * x[np.append(n, n[0])] + (1 - f_shrink) * xcn
            yp = f_shrink * y[np.append(n, n[0])] + (1 - f_shrink) * ycn
            return np.stack((xp, yp), axis=1)

        if fig is None and ax is None:
            fig, ax = plt.subplots(figsize=figsize)
        if ax is None:
            ax = fig.add_axes()
        if fig is None:
            fig = ax.get_figure()
        self.fig, self.ax = fig, ax

        x, y = self.coo()
        n1, n2 = self.get_edges()
        lines = [((x[n1[i]], y[n1[i]]), (x[n2[i]], y[n2[i]])) for i in range(len(n1))]
        lc = LineCollection(lines, colors=[0.5,0.5,0.5], linewidths=linewidth)
        self.ax.add_collection(lc)
        self.ax.plot(x, y, color=[0,0,0], marker="o", markerfacecolor=[0,0,0], linestyle="None", markersize=markersize)

        if show_node_ids:
            for i, (xn, yn) in enumerate(zip(x, y)):
                self.ax.text(xn, yn, i.__str__())
        if show_edge_ids:
            x1, y1,  x2, y2 = x[n1], y[n1], x[n2], y[n2]
            for i, (xn, yn) in enumerate(zip(0.5 * (x1 + x2), 0.5 * (y1 + y2))):
                self.ax.annotate(i.__str__(), (xn, yn), color=[0.3, 0.5, 0.9], ha='center', va='center')
        if show_cycles:
            lines = []
            if cycles == "face_cycles":
                xc, yc = self.get_face_centroids()
                pn, _ = self.get_face_cycles(to_list=True)
                for i, (xcn, ycn, n) in enumerate(zip(xc, yc, pn)):
                    lines += [_face_line(x, y, n, xcn, ycn, face_shrink_factor)]
                    if show_face_ids:
                        self.ax.annotate(i.__str__(), (xcn, ycn), color=[1, 0.5, 0.2], ha='center', va='center')
            if cycles == "l_cycles":
                xc, yc = self.get_l_cycle_centroids()
                pn, _ = self.get_l_cycles(to_list=True)
                b_mask = ~self._non_boundary_mask()
                lines = []
                for i, (xcn, ycn, n) in enumerate(zip(xc, yc, pn)):
                    if b_mask[i]:
                        xp, yp = x[n], y[n]
                        self.ax.plot(np.append(xp, xp[0]), np.append(yp, yp[0]),
                                     color=[0.2, 0.5, 1], linewidth=linewidth)
                        if show_face_ids:
                            self.ax.annotate(i.__str__(), (xcn, ycn), color=[0.2, 0.5, 1], ha='center', va='center')
                    else:
                        lines += [_face_line(x, y, n, xcn, ycn, face_shrink_factor)]
                        if show_face_ids:
                            self.ax.annotate(i.__str__(), (xcn, ycn), color=[1, 0.5, 0.2], ha='center', va='center')

            lc1 = LineCollection(lines, colors=[1, 0.5, 0.2], linewidths=linewidth)
            self.ax.add_collection(lc1)
        return self.fig, self.ax

    def _assert_edges_correct_shape(self):
        if len(self.node1) != len(self.node2):
            raise ValueError("node1 and node2 must be same size")

    def _assert_edges_contain_existing_nodes(self):
        if np.any((self.node1 < 0) | (self.node1 >= self.N) | (self.node2 < 0) | (self.node2 >= self.N)):
            raise NodeNotExistError("node1,2 values must be in range(N)")

    def _assert_single_component(self):
        if self.get_num_components() != 1:
            raise NotSingleComponentError()

    def _assert_planar_embedding(self):
        if not self.is_planar_embedding():
            raise NotPlanarEmbeddingError()

    def _assert_no_self_loop(self):
        if np.any(self.node1 == self.node2):
            raise SelfLoopError("no edge is allowed to have identical end nodes.")

    def _assert_nonsimple(self):
        if np.any((self.node1[:-1] == self.node1[1:]) & (self.node2[:-1] == self.node2[1:])):
            raise NonSimpleError("no duplicate edges allowed.")

    def _assign_edge_v_array(self):
        self.edge_v_array = VarRowArray(np.diff(np.append(np.flatnonzero(np.roll(self.node1, 1) - self.node1), len(self.node1))))
        return self.edge_v_array

    def _sort_edges(self):
        self.edge_flip = self.node1 > self.node2
        nn1 = np.where(self.edge_flip, self.node2, self.node1)
        nn2 = np.where(self.edge_flip ,self.node1, self.node2)
        self.edge_permute = np.lexsort((nn2, nn1))
        self.edge_permute_inverse = np.argsort(self.edge_permute)
        self.node1, self.node2 = nn1[self.edge_permute], nn2[self.edge_permute]

    def _de_sort_edges(self):
        nn1, nn2 = self.node1[self.edge_permute_inverse], self.node2[self.edge_permute_inverse]
        return np.where(self.edge_flip, nn2, nn1), np.where(self.edge_flip, nn1, nn2)

    def _reset_precomputed_quantities(self):
        self.F = None
        self.faces_v_array = None
        self.face_edges = None
        self.face_nodes = None
        self.face_lengths = None
        self.determinant = None
        self.areas = None
        self.roll_ids = None
        self.boundary_face_indices = None

    def _assign_faces(self):
        if self.face_edges is None:
            self.face_edges, self.face_lengths = self._construct_l_cycles()
            self.faces_v_array = VarRowArray(self.face_lengths)
            self.F = len(self.face_lengths)
        if self.face_permutation is None:
            self.face_permutation = np.arange(self.F)

    def _assign_face_nodes(self):
        self._assign_faces()
        if self.face_nodes is None:
            e_idx = self.face_edges % self.edge_count()
            self.face_nodes = np.where(self.face_edges < self.edge_count(), self.node1[e_idx], self.node2[e_idx])

    def _assign_roll_ids(self):
        self._assign_face_nodes()
        if self.roll_ids is None:
            self.roll_ids = self.face_nodes[self.faces_v_array.roll(1)]

    def _assign_determinant(self):
        self._assign_roll_ids()
        if self.determinant is None:
            self.determinant = self.x[self.roll_ids] * self.y[self.face_nodes] - self.x[self.face_nodes] * self.y[self.roll_ids]

    def _assign_areas(self):
        self._assign_determinant()
        if self.areas is None:
            self.areas = 0.5 * self.faces_v_array.sum(self.determinant)

    def _assign_boundary_faces(self):
        self._assign_areas()
        if self.boundary_face_indices is None:
            self.boundary_face_indices = np.flatnonzero((self.areas < 0) | np.isclose(self.areas, 0.0))

    def _recompute_faces(self):
        self._reset_precomputed_quantities()
        permutation = self.face_permutation
        inv_permutation = np.argsort(permutation)
        self._assign_faces()
        self._assign_face_nodes()
        self._assign_areas()
        self._assign_determinant()
        self._assign_roll_ids()
        self._assign_boundary_faces()
        self.face_lengths = self.face_lengths[permutation]
        self.face_edges = self.faces_v_array.permute_rows(permutation, self.face_edges)
        self.areas = self.areas[permutation]
        self.determinant = self.faces_v_array.permute_rows(permutation, self.determinant)
        self.roll_ids = self.faces_v_array.permute_rows(permutation, self.roll_ids)
        self.boundary_face_indices = inv_permutation[self.boundary_face_indices]
        self.face_nodes = self.faces_v_array.permute_rows(permutation, self.face_nodes)
        self.faces_v_array = VarRowArray(self.faces_v_array.counts[permutation])

    def _get_edge_map(self):
        """
        Computes a one-to-one map over directed edges, where the image contains the edge
        which is encountered next when traversing the graph moving counter-clockwise.

        Used to generate faces by repeating indexing: e_(i+1) = map[e_i]

        map: np.arange(2*E) -> sorted_out_edge_directed

        Output:
        sorted_out_edge_directed    (2*E,) in range(2*E)
        """

        counter_clockwise = True
        edge_count = self.edge_count()
        ns = np.append(self.node1, self.node2)

        # construct neighbour structure
        nodes, count = np.unique(ns, return_counts=True)
        neighbours = VarRowArray(count)
        neighbour_edges = np.tile(np.arange(edge_count), 2)[np.argsort(ns)]
        neighbour_node_self = nodes[neighbours.rows()]
        neighbour_node_other = np.where(self.node1[neighbour_edges] == neighbour_node_self,
                                        self.node2[neighbour_edges], self.node1[neighbour_edges])

        # sort neighbour in-dimension by ascending angle
        angles = np.arctan2(self.y[neighbour_node_other] - self.y[neighbour_node_self],
                            self.x[neighbour_node_other] - self.x[neighbour_node_self])
        angle_arg_sort = np.argsort(3 * np.pi * neighbour_node_self.astype(np.double) + angles)
        neighbour_edges = neighbour_edges[angle_arg_sort]
        neighbour_node_other = neighbour_node_other[angle_arg_sort]

        def to_directed(edge_nr, edge_direction, edge_count):
            return edge_nr + edge_count * (1 - edge_direction.astype(int))

        # create combined-index for input edges of the map
        neighbour_edges_direction = neighbour_node_other < neighbour_node_self
        in_edge_combined_index = to_directed(neighbour_edges, neighbour_edges_direction, edge_count)

        # find the map from every (combined index) edge to the next edge in (counter-)-clockwise direction
        edges_map = neighbours.roll() if counter_clockwise else neighbours.roll(-1)

        # find the combined index of the output edges of the map
        out_edge_directed = to_directed(neighbour_edges, ~neighbour_edges_direction, edge_count)[edges_map]

        # sort the edge-map based on the (combined index of the) input edges
        sorted_out_edge_directed = out_edge_directed[np.argsort(in_edge_combined_index)]

        return sorted_out_edge_directed

    def _construct_l_cycles(self):
        """
        Computes l-cycles of the embedded graph. This is done using repeated iteration over the
        one-to-one map over directed edges computed with _get_edge_map(). It starts with all
        edges pointing from lowest to highest node to ensure all cycles are found.

        Used to generate cycles by repeating indexing: e_(i+1) = map[e_i]

        map: np.arange(2*E) -> sorted_out_edge_directed

        Output:
        sorted_out_edge_directed    (2*E,) in range(2*E)
        """

        map = self._get_edge_map()
        edge_ids = np.arange(self.edge_count())
        cycles = -np.ones((3, self.edge_count()), dtype=int)
        cycles[0, :] = edge_ids
        cycle_lengths = np.ones(self.edge_count(), dtype=int)
        current_cycle_length, out_cycles, out_cycle_lengths = 1, np.zeros(0, dtype=int), np.zeros(0, dtype=int)

        # iteration doing counter-clockwise walks through the graphs starting from each junction
        while len(edge_ids) > 0:
            edge_ids = map[edge_ids]
            is_terminated = edge_ids == cycles[0, :]
            if np.any(is_terminated):
                out_cycle_lengths, out_cycles = self._store_terminated_cycles(cycles[:current_cycle_length, is_terminated].T,
                                                                            out_cycle_lengths, out_cycles)
                cycles, cycle_lengths = cycles[:, ~is_terminated], cycle_lengths[~is_terminated]
                edge_ids = edge_ids[~is_terminated]
            cycles[current_cycle_length, :] = edge_ids
            current_cycle_length += 1
            if cycles.shape[0] == (current_cycle_length):
                cycles = np.append(cycles, -np.ones(cycles.shape, dtype=int), axis=0)
        return out_cycles, out_cycle_lengths

    def _store_terminated_cycles(self, terminated_cycles, out_cycle_lengths, out_cycles):
        # roll cycles until the lowest node index is in the first column (so its easier to remove duplicate cycles)
        cycle_cnt, cycle_len = terminated_cycles.shape
        cycles = terminated_cycles[np.arange(cycle_cnt)[:, None], (
                    np.arange(cycle_len) + np.argmin(terminated_cycles, axis=1)[:, None]) % cycle_len]

        _, idx = np.unique(cycles[:, 0], return_index=True)
        cycles = cycles[idx, :]
        #
        # cycles = np.unique(cycles, axis=0)

        out_cycle_lengths = np.append(out_cycle_lengths, cycle_len * np.ones(cycles.shape[0], dtype=int))
        out_cycles = np.append(out_cycles, cycles.flatten())
        return out_cycle_lengths, out_cycles

    def _non_boundary_mask(self):
        self._assign_boundary_faces()
        out = np.ones(self.l_cycle_count(), dtype=bool)
        out[self.boundary_face_indices] = False
        return out

    def _cycle_space_solve_for_integral_x(self, b):
        """
        Solves the equation: A @ x = b (where A = cycle_matrix, without boundary faces).
        If b is integral (contain only integers), the output array x will also be integral.

        input:  b (..., F)
        output: x (..., E)

        Notes:
            - The equation is underdetermined, so the solution x is not unique.
        """
        self._assign_boundary_faces()
        if (self.get_num_components() != 1) or (not self.is_planar_embedding()):
            raise ValueError("only implemented for single component planar embedding")

        E, F = self.edge_count(), self.face_count()
        b = np.array(b, dtype=np.double)
        b_shape = list(b.shape)
        b = b.reshape(-1, F)
        b_tally = b.copy()

        # insert boundary face in b_tally
        b_idx = self.boundary_face_indices[0]
        b_tally = np.concatenate((b_tally[:, :b_idx], np.zeros((b_tally.shape[0], 1)), b_tally[:, b_idx:]), axis=1)

        # do depth first search, resulting in cur (current node of tree) and prev (parent node of cur)
        A = self.l_cycle_matrix()
        cur, predecessor = scipy.sparse.csgraph.depth_first_order(A @ A.T, b_idx)
        prev = predecessor[cur]
        prev[0] = -1

        # find map of edge between pair of faces. signs returns +1 if face1 counterclockwise passes resulting edge in its own direction.
        juncs, orientation_in_face_1 = self.get_common_edge_of_l_cycles(cur, prev, return_orientation=True)
        sgns = (-1 + 2 * orientation_in_face_1.astype(np.double))

        # construct x at each edge by passing through the tree in reverse.
        x = np.zeros((b.shape[0], E), dtype=b.dtype)
        for i in reversed(range(F+1)):
            if prev[i] >= 0:
                b_tally[:, prev[i]] += b_tally[:, cur[i]]
                x[:, juncs[i]] += b_tally[:, cur[i]] * sgns[i]
                b_tally[:, cur[i]] = 0

        # check if resulting x solves A @ x == b
        b_shape[-1] = E
        mask = self._non_boundary_mask()
        if not np.allclose((A @ x.T)[mask, :], b.T):
            raise ValueError("failed integral solve of cycle space linear problem")

        # return x
        return x.reshape(tuple(b_shape))

class EmbeddedSquareGraph(EmbeddedGraph):

    def __init__(self, count_x, count_y,  x_scale=1.0, y_scale=1.0):
        y, x = np.mgrid[0:count_y, 0:count_x]
        idx = np.arange(count_x * count_y).reshape(count_y, count_x)
        n1 = np.concatenate((idx[:, 0:-1].ravel(), idx[0:-1, :].ravel()))
        n2 = np.concatenate((idx[:, 1:].ravel(), idx[1:, :].ravel()))
        super().__init__(x * x_scale, y * y_scale, n1, n2)


class EmbeddedHoneycombGraph(EmbeddedGraph):

    def __init__(self, count_x, count_y, x_scale=1.0, y_scale=1.0):
        y, x = np.mgrid[0:count_y, 0:count_x]
        x1, y1 = 3.0 * x, np.sqrt(3.0) * y
        nodes_x = np.concatenate((x1, x1 + 0.5, x1 + 1.5, x1 + 2), axis=0).ravel()
        nodes_y = np.concatenate((y1, y1 + np.sqrt(0.75), y1 + np.sqrt(0.75), y1), axis=0).ravel()
        idx = np.arange(count_x * count_y).reshape(count_y, count_x)
        s = count_x * count_y
        nodes1 = (idx,   idx[:-1, :]+s, idx+s,   idx[:-1, :]+2*s, idx+2*s, idx[:, :-1]+3*s)
        nodes2 = (idx+s, idx[1:, :],    idx+2*s, idx[1:, :]+3*s,  idx+3*s, idx[:, 1:])
        nodes1 = np.concatenate(tuple([n1.flatten() for n1 in nodes1])).ravel()
        nodes2 = np.concatenate(tuple([n2.flatten() for n2 in nodes2])).ravel()

        remove_node_ids = [0, idx[0, -1] + 3 * s]
        # nodes_x, nodes_y, nodes1, nodes2 = EmbeddedGraph._remove_nodes(
        #     nodes_x, nodes_y, nodes1, nodes2, remove_node_ids)

        super().__init__(nodes_x * x_scale, nodes_y * y_scale, nodes1, nodes2)
        self.remove_nodes(remove_node_ids)


class EmbeddedTriangularGraph(EmbeddedGraph):

    def __init__(self, count_x, count_y, x_scale=1.0, y_scale=1.0):
        y, x = np.mgrid[0:count_y, 0:count_x]
        x1, y1 = x, np.sqrt(3.0) * y
        nodes_x = np.concatenate((x1, x1 + 0.5), axis=0)
        nodes_y = np.concatenate((y1, y1 + np.sqrt(0.75)), axis=0)
        idx = np.arange(count_x * count_y).reshape(count_y, count_x)
        s = count_x * count_y
        nodes1 = (idx,   idx[:, :-1], idx[:-1, :]+s, idx[:-1, :-1]+s, idx[:, :-1]+s, idx[:, 1:])
        nodes2 = (idx+s, idx[:, 1:],  idx[1:, :],    idx[1:, 1:],     idx[:, 1:]+s,  idx[:, :-1]+s)
        nodes1 = np.concatenate(tuple([n1.flatten() for n1 in nodes1]))
        nodes2 = np.concatenate(tuple([n2.flatten() for n2 in nodes2]))
        super().__init__(nodes_x * x_scale, nodes_y * y_scale, nodes1, nodes2)
