# Copyright 2022 Tier IV, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Function to create NetworkX object from CARET architecture.yaml
"""

from __future__ import annotations
import networkx as nx
import matplotlib.pyplot as plt
import yaml
from dear_ros_node_viewer.logger_factory import LoggerFactory

logger = LoggerFactory.create(__name__)


def quote_name(name: str) -> str:
    """
    Quote name, because pydot requires. https://github.com/pydot/pydot/issues/258

    Parameters
    ----------
    name : str
        original name

    Returns
    -------
    modified_name : str
        name with '"'
    """
    modified_name = '"' + name + '"'
    return modified_name


def parse_all_graph(yml, node_name_list, topic_pub_dict, topic_sub_dict):
    """Parse architecture file"""
    nodes = yml['nodes']
    for node in nodes:
        node_name = quote_name(node['node_name'])
        node_name_list.append(node_name)
        if 'publishes' in node:
            publishes = node['publishes']
            for publish in publishes:
                if publish['topic_name'] in topic_pub_dict:
                    topic_pub_dict[publish['topic_name']].append(node_name)
                else:
                    topic_pub_dict[publish['topic_name']] = [node_name]
        if 'subscribes' in node:
            subscribes = node['subscribes']
            for subscribe in subscribes:
                if subscribe['topic_name'] in topic_sub_dict:
                    topic_sub_dict[subscribe['topic_name']].append(node_name)
                else:
                    topic_sub_dict[subscribe['topic_name']] = [node_name]


def parse_target_path(yml, node_name_list, topic_pub_dict, topic_sub_dict):
    """Parse architecture file"""
    named_paths = yml['named_paths']
    if len(named_paths) > 0:
        for named_path in named_paths:
            node_chain = named_path['node_chain']
            for node in node_chain:
                node_name = quote_name(node['node_name'])
                node_name_list.append(node_name)
                if node['publish_topic_name'] != 'UNDEFINED':
                    topic_pub_dict[node['publish_topic_name']] = [node_name]
                if node['subscribe_topic_name'] != 'UNDEFINED':
                    topic_sub_dict[node['subscribe_topic_name']] = [node_name]
    else:
        logger.warning('named_paths not found')


def make_graph_from_topic_association(topic_pub_dict: dict[str, list[str]],
                                      topic_sub_dict: dict[str, list[str]]):
    """make graph from topic association"""
    graph = nx.DiGraph()
    for topic, node_pub_list in topic_pub_dict.items():
        if topic in topic_sub_dict:
            node_sub_list = topic_sub_dict[topic]
        else:
            # node_sub_list = ["none:" + topic]
            continue
        for node_pub in node_pub_list:
            for node_sub in node_sub_list:
                # logger.debug(topic, node_pub, node_sub)
                graph.add_edge(node_pub, node_sub, label=topic)

    return graph


def caret2networkx(filename: str, target_path: str = 'all_graph',
                   ignore_unconnected=True) -> nx.classes.digraph.DiGraph:
    """
    Create NetworkX Graph from architecture.yaml generated by CARET

    Parameters
    ----------
    filename : str
        path to architecture file (e.g. '/home/abc/architecture.yaml')
    target_path : str, default all_graph
        'all_graph': create a graph including all node and path
        'all_targets': create a graph including all paths in named_paths
        path name: create a graph including path name in named_paths

    Returns
    -------
    graph : nx.classes.digraph.DiGraph
        NetworkX Graph
    """

    node_name_list: list[str] = []

    # "/topic_0": ["/node_0", ], <- publishers of /topic_0 are ["/node_0", ] #
    topic_pub_dict: dict[str, list[str]] = {}

    # "/topic_0": ["/node_1", ], <- subscribers of /topic_0 are ["/node_1", ] #
    topic_sub_dict: dict[str, list[str]] = {}

    with open(filename, encoding='UTF-8') as file:
        yml = yaml.safe_load(file)
        if target_path == 'all_graph':
            parse_all_graph(yml, node_name_list, topic_pub_dict, topic_sub_dict)
        else:
            parse_target_path(yml, node_name_list, topic_pub_dict, topic_sub_dict)

    graph = make_graph_from_topic_association(topic_pub_dict, topic_sub_dict)

    if not ignore_unconnected:
        graph.add_nodes_from(node_name_list)

    logger.info('len(connected_nodes) = %d, len(nodes) = %d',
                len(graph.nodes), len(node_name_list))

    return graph


if __name__ == '__main__':
    def local_main():
        """main function for this file"""
        graph = caret2networkx('architecture.yaml')
        pos = nx.spring_layout(graph)
        # pos = nx.circular_layout(graph)
        nx.draw_networkx(graph, pos)
        plt.show()

    local_main()
