import logging

from collections import defaultdict

import networkx as nx
import yt.wrapper as yt
from crypta.graph.export.proto.graph_pb2 import TGraph
from crypta.graph.export.lib.python.edge_weight_helper import get_edge_weight, cleanup_edges_weight

logger = logging.getLogger(__name__)
print_enabled = False


def log(msg, value=None):
    if print_enabled:
        if value:
            msg += str(value)
        logger.info(msg)


def is_indevice_edge(edge):
    return any(a.Name == "indevice" for a in edge.Attributes)


class GraphSpanningTreeTransformer(object):
    filter_by_edges = True
    only_mark_edges = False

    def transform(self, proto_graph, active_vertices):
        nx_graph = self._to_nx_graph(proto_graph)
        spanning_edges = nx.algorithms.tree.mst.minimum_spanning_edges(nx_graph)

        spanning_edges = [(v1, v2) for v1, v2, _ in spanning_edges]
        log("spanning edges: ", spanning_edges)

        active_indexes = []
        for idx, vertex in enumerate(proto_graph.Nodes):
            email_or_phone = vertex.Type in {"email", "phone", "email_md5", "phone_md5"}
            if email_or_phone or (vertex.Id, vertex.Type) in active_vertices:
                active_indexes.append(idx)

        log("active indexes: ", active_indexes)

        keep_spanning_edges = self._filter_corner_edges(spanning_edges, active_indexes)
        log("keep spanning edges ", keep_spanning_edges)
        keep_important_edges = list(self._find_important_edges(proto_graph))
        log("keep important edges ", keep_important_edges)

        keep_edges = set(keep_spanning_edges + keep_important_edges)
        log("keep edges ", keep_edges)

        if self.filter_by_edges:
            if self.only_mark_edges:
                return self._mark_graph_edges(proto_graph, keep_edges, "spanning", "y")
            else:
                return self._filter_graph_by_edges(proto_graph, keep_edges)
        else:
            keep_vertices = set()
            for v1, v2 in keep_edges:
                keep_vertices.add(v1)
                keep_vertices.add(v2)
            if self.only_mark_edges:
                raise Exception("only_mark_edges is not supported while filter_by_edges=False")
            return self._filter_graph_by_vertices(proto_graph, keep_vertices)

    @staticmethod
    def _mark_graph_edges(proto_graph, mark_edges, attribute_key, attribute_value):
        mark_edges = set(mark_edges)

        for edge in proto_graph.Edges:
            if (edge.Node1, edge.Node2) in mark_edges or (edge.Node2, edge.Node1) in mark_edges:
                log("adding attribute to ", edge)
                edge.Attributes.add(Name=attribute_key, Value=attribute_value)

        return proto_graph

    @staticmethod
    def _filter_corner_edges(edges, keep_vertices):
        adjacent_vertices = defaultdict(set)
        vertices_degrees = defaultdict(int)
        for v1, v2 in edges:
            adjacent_vertices[v1].add(v2)
            adjacent_vertices[v2].add(v1)
            vertices_degrees[v1] += 1
            vertices_degrees[v2] += 1

        remove_queue = []

        for vertex, degree in vertices_degrees.items():
            if degree <= 1 and vertex not in keep_vertices:
                remove_queue.append(vertex)

        removable_vertices = set()
        while remove_queue:
            log("remove_queue ", remove_queue)

            vertex = remove_queue.pop()
            removable_vertices.add(vertex)

            for target_vertex in adjacent_vertices[vertex]:
                vertices_degrees[target_vertex] -= 1

                target_vertex_degree = vertices_degrees[target_vertex]
                if (
                    target_vertex_degree <= 1
                    and target_vertex not in keep_vertices
                    and target_vertex not in removable_vertices
                ):
                    remove_queue.append(target_vertex)

        return [(v1, v2) for v1, v2 in edges if v1 not in removable_vertices and v2 not in removable_vertices]

    @staticmethod
    def _find_important_edges(proto_graph):
        for edge in proto_graph.Edges:
            if is_indevice_edge(edge):
                yield (edge.Node1, edge.Node2)

    @staticmethod
    def _filter_graph_by_edges(proto_graph, keep_edges):
        keep_edges = set(keep_edges)

        keep_vertices = set()
        for v1, v2 in keep_edges:
            keep_vertices.add(v1)
            keep_vertices.add(v2)

        new_nodes = []
        new_indexes = dict()
        for old_index, node in enumerate(proto_graph.Nodes):
            if old_index in keep_vertices:
                new_indexes[old_index] = len(new_nodes)
                new_nodes.append(node)

        new_edges = []
        seen_edges = set()

        def edge_priority_key(edge):
            weight = get_edge_weight(proto_graph, edge)
            indevice = is_indevice_edge(edge)
            return (indevice, weight)

        for idx, edge in enumerate(sorted(proto_graph.Edges, key=edge_priority_key, reverse=True)):
            is_keeped_edge = (edge.Node1, edge.Node2) in keep_edges or (edge.Node2, edge.Node1) in keep_edges
            is_seen_edge = (edge.Node1, edge.Node2) in seen_edges or (edge.Node2, edge.Node1) in seen_edges
            if is_keeped_edge and not is_seen_edge:
                seen_edges.add((edge.Node1, edge.Node2))
                edge.Node1 = new_indexes[edge.Node1]
                edge.Node2 = new_indexes[edge.Node2]
                new_edges.append(edge)

        return TGraph(
            CryptaId=proto_graph.CryptaId, Nodes=new_nodes, Edges=new_edges, Attributes=proto_graph.Attributes
        )

    @staticmethod
    def _filter_graph_by_vertices(proto_graph, keep_vertices):
        new_nodes = []
        new_indexes = dict()
        for old_index, node in enumerate(proto_graph.Nodes):
            if old_index in keep_vertices:
                new_indexes[old_index] = len(new_nodes)
                new_nodes.append(node)

        new_edges = []
        for edge in proto_graph.Edges:
            if edge.Node1 in keep_vertices and edge.Node2 in keep_vertices:
                edge.Node1 = new_indexes[edge.Node1]
                edge.Node2 = new_indexes[edge.Node2]
                new_edges.append(edge)

        return TGraph(
            CryptaId=proto_graph.CryptaId, Nodes=new_nodes, Edges=new_edges, Attributes=proto_graph.Attributes
        )

    @staticmethod
    def _to_nx_graph(proto_graph):
        g = nx.Graph()

        # sort -- for keep the strongest multi edge
        for weight, edge in sorted([(get_edge_weight(proto_graph, edge), edge) for edge in proto_graph.Edges]):
            g.add_edge(edge.Node1, edge.Node2, weight=-weight)

        return g


def proto_reducer(max_component_size=0):
    @yt.with_context
    def _proto_reducer(crypta_id_key, recs, context):
        graph_bin = None
        active_vertices = set()

        for rec in recs:
            if context.table_index == 0:
                graph_bin = rec.get("graph")
            else:
                active_vertices.add((rec["id"], rec["id_type"]))

        if graph_bin:
            graph = TGraph()
            graph.ParseFromString(graph_bin)

            transformer = GraphSpanningTreeTransformer()
            new_graph = transformer.transform(graph, active_vertices)
            new_graph = cleanup_edges_weight(new_graph)

            if len(new_graph.Edges) > 0:
                new_rec = {
                    "cryptaId": crypta_id_key["cryptaId"],
                    "graph": new_graph.SerializeToString(),
                }

                overlimit = len(new_graph.Nodes) > max_component_size
                output_table_index = 1 if overlimit else 0

                yield yt.create_table_switch(output_table_index)
                yield new_rec

    return _proto_reducer
