from contextlib import ExitStack
from dataclasses import dataclass
import difflib
import os

import library.python.json as lpj

BENCHMARK_UNMATCHED_NODES_FILE_NAME = 'benchmark-unmatched-nodes.txt'
TEST_UNMATCHED_NODES_FILE_NAME = 'test-unmatched-nodes.txt'
UID_ONLY_CHANGES_FILE_NAME = 'uid-only-changes.diff.txt'
SIGNIFICANT_BUT_NO_UID_CHANGES_FILE_NAME = 'significant-but-no-uid-changes.diff.txt'
SIGNIFICANT_CHANGES_FILE_NAME = 'significant-changes.diff.txt'
UID_AND_DEPS_CHANGES_ONLY_FILE_NAME = 'uid-and-deps-changes-only.diff.txt'
INSIGNIFICANT_CHANGES_FILE_NAME = 'insignificant-changes.diff.txt'

SIGNIFICANT_KEYS = {'uid', 'cmds', 'deps', 'outputs', 'env', 'platform', 'requirements', 'tags', 'target_properties'}


@dataclass
class CompareGraphStat:
    benchmark_graph_node_count: int = 0
    test_graph_node_count: int = 0

    benchmark_unmatched_node_count: int = 0
    test_unmatched_node_count: int = 0

    uid_only_changes_count: int = 0
    significant_but_no_uid_changes_count: int = 0
    significant_changes_count: int = 0
    uid_and_deps_only_changes_count: int = 0
    insignificant_changes_count: int = 0

    @property
    def fatal_error_count(self):
        return self.benchmark_unmatched_node_count + \
            self.test_unmatched_node_count + \
            self.uid_only_changes_count + \
            self.significant_but_no_uid_changes_count + \
            self.significant_changes_count

    @property
    def total_error_count(self):
        return self.fatal_error_count + \
            self.uid_and_deps_only_changes_count + \
            self.insignificant_changes_count


def compare_graphs(bm_graph_path, tst_graph_path, dest_dir):
    stat = CompareGraphStat()
    with ExitStack() as stack:
        bm_graph = _load_and_preprocess_graph(bm_graph_path)
        tst_graph = _load_and_preprocess_graph(tst_graph_path)

        bm_unmatched_nodes_path = os.path.join(dest_dir, BENCHMARK_UNMATCHED_NODES_FILE_NAME)
        tst_unmatched_nodes_path = os.path.join(dest_dir, TEST_UNMATCHED_NODES_FILE_NAME)
        uid_only_changes_path = os.path.join(dest_dir, UID_ONLY_CHANGES_FILE_NAME)
        significant_but_no_uid_changes_path = os.path.join(dest_dir, SIGNIFICANT_BUT_NO_UID_CHANGES_FILE_NAME)
        significant_changes_path = os.path.join(dest_dir, SIGNIFICANT_CHANGES_FILE_NAME)
        uid_and_deps_changes_only_path = os.path.join(dest_dir, UID_AND_DEPS_CHANGES_ONLY_FILE_NAME)
        insignificant_changes_path = os.path.join(dest_dir, INSIGNIFICANT_CHANGES_FILE_NAME)

        bm_unmatched_nodes_file = stack.enter_context(open(bm_unmatched_nodes_path, 'w'))
        tst_unmatched_nodes_file = stack.enter_context(open(tst_unmatched_nodes_path, 'w'))
        uid_only_changes_file = stack.enter_context(open(uid_only_changes_path, 'w'))
        significant_but_no_uid_changes_file = stack.enter_context(open(significant_but_no_uid_changes_path, 'w'))
        significant_changes_file = stack.enter_context(open(significant_changes_path, 'w'))
        uid_and_deps_changes_only_file = stack.enter_context(open(uid_and_deps_changes_only_path, 'w'))
        insignificant_changes_file = stack.enter_context(open(insignificant_changes_path, 'w'))

        stat.benchmark_graph_node_count = len(bm_graph.node_by_stats_uid)
        stat.test_graph_node_count = len(tst_graph.node_by_stats_uid)

        differ = difflib.Differ()
        for tst_stats_uid, tst_node in tst_graph.node_by_stats_uid.items():
            if tst_stats_uid in bm_graph.node_by_stats_uid:
                bm_node = bm_graph.node_by_stats_uid[tst_stats_uid]
                diff_keys = _get_diff_keys(bm_node, tst_node)
                if diff_keys:
                    significant_diff_keys = diff_keys & SIGNIFICANT_KEYS
                    if significant_diff_keys:
                        if significant_diff_keys == {'uid'}:
                            diff_file = uid_only_changes_file
                            stat.uid_only_changes_count += 1
                        elif significant_diff_keys == {'uid', 'deps'}:
                            diff_file = uid_and_deps_changes_only_file
                            stat.uid_and_deps_only_changes_count += 1
                        elif 'uid' in significant_diff_keys:
                            diff_file = significant_changes_file
                            stat.significant_changes_count += 1
                        else:
                            diff_file = significant_but_no_uid_changes_file
                            stat.significant_but_no_uid_changes_count += 1
                    else:
                        diff_file = insignificant_changes_file
                        stat.insignificant_changes_count += 1

                    bm_lines = _get_node_as_lines(bm_node)
                    tst_lines = _get_node_as_lines(tst_node)
                    diff = differ.compare(bm_lines, tst_lines)
                    diff_file.writelines(diff)
            else:
                _dump_node(tst_node, tst_unmatched_nodes_file)
                stat.test_unmatched_node_count += 1

        # Dump benchmark unmatched nodes
        if stat.benchmark_graph_node_count > (stat.test_graph_node_count - stat.test_unmatched_node_count):
            for bm_stats_uid, bm_node in bm_graph.node_by_stats_uid.items():
                if bm_stats_uid not in tst_graph.node_by_stats_uid:
                    _dump_node(bm_node, bm_unmatched_nodes_file)
                    stat.benchmark_unmatched_node_count += 1

    return stat


@dataclass
class Graph:
    full_graph: dict  # For future usage (compare context?)
    node_by_stats_uid: dict


def _load_and_preprocess_graph(graph_path):
    with open(graph_path, 'rb') as f:
        graph = lpj.loads(f.read(), intern_keys=True, intern_vals=True)
        node_by_stats_uid = {}
        for node in graph['graph']:
            if node.keys() < {'stats_uid', 'uid', 'deps'}:
                raise Exception(f'Graph {graph_path} has a wrong node with missing one of required keys. The node: {node}')
            # Don't care about inputs
            node.pop('inputs', None)
            # Sort arrays because order is unimportant
            for k in 'outputs', 'deps', 'tags':
                v = node.get(k)
                if v is not None and len(v) > 1:
                    v = sorted(v)
            # Remove insignificant env keys
            env = node.get('env', {})
            for k in list(env.keys()):
                if k.startswith('YA_'):
                    del env[k]
            node_by_stats_uid[node['stats_uid']] = node
    return Graph(full_graph=graph, node_by_stats_uid=node_by_stats_uid)


def _dump_node(node, f):
    lpj.dump(node, f, indent=4)
    f.write('\n')


def _get_node_as_lines(node):
    return (lpj.dumps(node, indent=4, sort_keys=True) + '\n').splitlines(keepends=True)


def _get_diff_keys(left, right):
    result = set()
    if left != right:
        for k in set(left.keys()) | set(right.keys()):
            if left.get(k) != right.get(k):
                result.add(k)
    return result
