# -*- coding: utf-8 -*-

#  ./CompareNewsAnnotatorResponses upload --verbose --owner NEWS --attr ttl=inf --attr task_type=COMPARE_NEWS_ANNOTATOR_RESPONSES --enable-taskbox

import collections
import itertools
import os
import re
import shutil
import difflib
import json
import logging

from sandbox.sdk2.helpers import ProgressMeter
from sandbox import sdk2
import sandbox.common.types.client as ctc
import sandbox.common.types.misc as ctm

from sandbox.projects.common import binary_task
from sandbox.projects.common import utils2
from sandbox.projects.news import resources

from .templates import Templates


class NEWS_ANNOTATOR_RESPONSES_COMPARE_RESULT(sdk2.Resource):
    """
        result of comparing two instances of NEWS_NEWS_DOC_INFO_DUMP
    """
    executable = False
    releasable = False
    any_arch = True
    auto_backup = True


TResponsePair = collections.namedtuple("TResponsePair", "Info Raw")


class CompareNewsAnnotatorResponses(binary_task.LastBinaryTaskRelease, sdk2.Task):
    '''
        Сравнивает результаты работы двух аннотаторов
    '''

    class Parameters(sdk2.Task.Parameters):
        responses1 = sdk2.parameters.Resource('annotator responses 1', resource_type=resources.NEWS_NEWS_DOC_INFO_DUMP, required=True)
        responses2 = sdk2.parameters.Resource('annotator responses 2', resource_type=resources.NEWS_NEWS_DOC_INFO_DUMP, required=True)
        ext_params = binary_task.binary_release_parameters(stable=True)
        output_unchanged = sdk2.parameters.Bool('add unchanged documents to report', required=False, default=False)
        output_debug_data = sdk2.parameters.Bool('add debug data from various stages', required=False, default=False)
        ignore_fetch_time = sdk2.parameters.Bool('compare documents base on url only', required=False, default=False)

    class Requirements(sdk2.Task.Requirements):
        client_tags = ctc.Tag.Group.LINUX
        disk_space = 5 * 1024
        cores = 1
        ram = 8 * 1024

        class Caches(sdk2.Requirements.Caches):
            pass

    def _register_out_resource(self):
        if self.Context.out_resource_id is ctm.NotExists:
            self.Context.out_resource_id = NEWS_ANNOTATOR_RESPONSES_COMPARE_RESULT(self, "compare annotator responses", "compare").id

    def on_enqueue(self):
        self._register_out_resource()

    def on_execute(self):
        binary_task.LastBinaryTaskRelease.on_execute(self)

        os.environ["DJ_USE_LEGACY_COUNTERS_SERIALIZATION"] = "1"

        import yweb.news.annotator.helpers.compare_news_doc_info.cython.compare as compare

        self._register_out_resource()
        out_resource = sdk2.ResourceData(sdk2.Resource[self.Context.out_resource_id])

        responses1 = self._load_responses(str(sdk2.ResourceData(self.Parameters.responses1).path), 1)
        responses2 = self._load_responses(str(sdk2.ResourceData(self.Parameters.responses2).path), 2)

        responses1_group = self._group_responses(responses1, self.Parameters.ignore_fetch_time)
        responses2_group = self._group_responses(responses2, self.Parameters.ignore_fetch_time)

        responses_keys = frozenset(responses1_group) | frozenset(responses2_group)

        diff_cnt, add_cnt, del_cnt, total_cnt = 0, 0, 0, 0

        out_resource_path = str(out_resource.path)
        shutil.rmtree(out_resource_path, ignore_errors=True)
        os.makedirs(out_resource_path)

        if self.Parameters.output_debug_data:
            for additional_dir in ("responses1", "responses2", "diff_log", "source_info1", "source_info2"):
                os.makedirs(os.path.join(out_resource_path, additional_dir))

        index_items = []
        global_changed_paths = collections.Counter()

        with ProgressMeter("Comparing responses", maxval=len(responses_keys)) as comparisonProgress:
            for query_id, key in enumerate(sorted(responses_keys)):
                comparisonProgress.add(1)
                doc_url = key[0]
                g1 = responses1_group.get(key, [])
                g2 = responses2_group.get(key, [])
                for sub_id, (response1, response2) in enumerate(itertools.izip_longest(g1, g2)):
                    total_cnt += 1

                    record_class = None

                    cmpResult = None
                    if response1 is None or response2 is None:
                        if response1 is None:
                            text_response1 = ""
                            add_cnt += 1
                        if response2 is None:
                            text_response2 = ""
                            del_cnt += 1
                        equal = False
                        record_class = "add_del"
                    else:
                        defaultFloatThreshold = compare.FloatComparisonThreshold(Fraction=1.0e-5, Margin=1.0e-5)
                        erfFloatThreshold = compare.FloatComparisonThreshold(Fraction=5.0e-3, Margin=1.0e-5)
                        cmpResult = compare.CompareNewsDocInfo(response1.Raw, response2.Raw, defaultFloatThreshold=defaultFloatThreshold, erfFloatThreshold=erfFloatThreshold)
                        text_response1 = cmpResult.Message1Text
                        text_response2 = cmpResult.Message2Text
                        equal = cmpResult.Equal
                        if equal:
                            pass
                        else:
                            diff_cnt += 1
                            record_class = "diff"

                    id_path = "{:04d}{}".format(query_id, ('' if sub_id == 0 else '_{:02d}'.format(sub_id)))

                    if self.Parameters.output_debug_data:
                        if response1 is not None:
                            with open(os.path.join(out_resource_path, 'source_info1', '{}.info'.format(id_path)), "wb") as f:
                                f.write(response1.Raw)
                        if response2 is not None:
                            with open(os.path.join(out_resource_path, 'source_info2', '{}.info'.format(id_path)), "wb") as f:
                                f.write(response2.Raw)
                        with open(os.path.join(out_resource_path, 'responses1', '{}.txt'.format(id_path)), "wb") as f:
                            f.write(text_response1.encode('utf-8'))
                        with open(os.path.join(out_resource_path, 'responses2', '{}.txt'.format(id_path)), "wb") as f:
                            f.write(text_response2.encode('utf-8'))
                        with open(os.path.join(out_resource_path, 'diff_log', '{}.txt'.format(id_path)), "wb") as f:
                            f.write(cmpResult.Diff.encode('utf-8'))

                    target_fn_short = None

                    lines_response1 = text_response1.split('\n')
                    lines_response2 = text_response2.split('\n')

                    lines_response1 = self._cut_volatile_data(lines_response1)
                    lines_response2 = self._cut_volatile_data(lines_response2)

                    identical = (lines_response1 == lines_response2)
                    if equal:
                        assert record_class is None
                        record_class = "equal" if identical else "approximate_equal"

                    html, classes_stat = self._make_html_diff(lines_response1, lines_response2)

                    if not identical or self.Parameters.output_unchanged:
                        if identical:
                            target_fn_prefix = '_identical_'
                        elif equal:
                            target_fn_prefix = '_approximate-equal_'
                        else:
                            target_fn_prefix = ''
                        target_fn_short = 'z{}{}.html'.format(target_fn_prefix, id_path)
                        target_fn = os.path.join(out_resource_path, target_fn_short)
                        with open(target_fn, 'wb') as f:
                            f.write(html.encode('utf-8'))
                    assert record_class is not None

                    if cmpResult is not None:
                        changed_paths, joined_changed_paths, path_stat = self._extract_changed_paths(cmpResult.Diff, ignore_indices=True)
                    else:
                        changed_paths, joined_changed_paths, path_stat = [], [], {}
                    global_changed_paths.update(joined_changed_paths)

                    short_joined_changed_paths = sorted(joined_changed_paths)
                    if len(short_joined_changed_paths) > 8:
                        short_joined_changed_paths = sorted(short_joined_changed_paths, key=lambda path: (path.count('.'), len(path), path))
                        short_joined_changed_paths = sorted(short_joined_changed_paths[:8]) + ["..."]
                    short_joined_changed_paths = [{"path": path, "stat": path_stat.get(path, TPathStat()).format(html=True, skip_zero=True)} for path in short_joined_changed_paths]

                    index_items.append({
                        "path": target_fn_short,
                        "changed_paths": changed_paths,
                        "joined_changed_paths": short_joined_changed_paths,
                        "equal": equal,
                        "id_path": id_path,
                        "doc_url": doc_url,
                        "identical": identical,
                        "record_class": record_class,
                    })

        summary = {
            "total_cnt": total_cnt,
            "add_cnt": add_cnt,
            "del_cnt": del_cnt,
            "diff_cnt": diff_cnt,
        }
        with open(os.path.join(out_resource_path, 'index.html'), 'w') as f:
            stat_html = Templates.index.render(summary=summary, index_items=index_items, most_common_paths=global_changed_paths.most_common(10))
            if isinstance(stat_html, unicode):
                stat_html = stat_html.encode('utf-8')
            f.write(stat_html)

        out_resource.ready()

        has_diff = any((add_cnt, del_cnt, diff_cnt))
        top_level_changed_paths = list(frozenset(key.split('.', 1)[0] for key in global_changed_paths.keys()))

        self.Context.has_diff = has_diff
        self.Context.changed_features = (top_level_changed_paths or True) if has_diff else []

        self.set_info("add: {0[add_cnt]}, del: {0[del_cnt]}, diff: {0[diff_cnt]}, total: {0[total_cnt]}".format(summary))
        self.set_info(utils2.resource_redirect_link(self.Context.out_resource_id, title="Comparison results"), do_escape=False)
        self.set_info(utils2.resource_redirect_link(self.Context.out_resource_id, title="Comparison index", path="index.html"), do_escape=False)

    @staticmethod
    def _load_responses(path, tag):
        from yweb.news.proto.storage.info_db_pb2 import TNewsDocInfo
        from yt.wrapper import yson
        result = []
        with ProgressMeter("Extracting responses #{}".format(tag)):
            with open(path, 'rb') as f:
                for record in yson.load(f, yson_type="list_fragment"):
                    doc = bytes(record['doc'])
                    info = TNewsDocInfo.FromString(doc)
                    result.append(TResponsePair(info, doc))
        return result

    @staticmethod
    def _group_responses(responses, ignore_fetch_time):
        group = collections.defaultdict(list)
        for r in responses:
            key = (r.Info.Url, r.Info.FetchTime) if not ignore_fetch_time else (r.Info.Url,)
            group[key].append(r)
        group.default_factory = None
        return group

    # dirty
    @staticmethod
    def _cut_volatile_data(lines):
        result = []
        for line in lines:
            m = re.match('^(\s*CalcTime: )([0-9]+)$', line)
            if m:
                line = m.expand(r'\1<skipped>')
            m = re.match('^(\s*LastUpdateTime: )([0-9]+)$', line)
            if m:
                line = m.expand(r'\1<skipped>')
            result.append(line)
        return result

    @staticmethod
    def _calc_distance(e1, e2):
        def dot(u, v):
            return sum(x * y for (x, y) in zip(u, v))
        d = [(x - y) for (x, y) in zip(e1, e2)]
        l1 = sum(abs(z) for z in d)
        l2 = sum(z**2 for z in d)**0.5
        li = max(abs(z) for z in d)
        cosphi = dot(e1, e2) / (dot(e1, e1) * dot(e2, e2)) ** 0.5
        return l1, l2, li, cosphi

    @classmethod
    def _enrich_embeddings_diff(cls, diff_lines):
        diff_lines = diff_lines[:] + [""]
        res = []
        i = 0
        while i + 1 < len(diff_lines):
            upd_lines = []
            consume_lines = 1
            m1 = re.match('^\-.*UnpackedValue.*:\s+(\[.*\])\s*$', diff_lines[i + 0])
            m2 = re.match('^\+.*UnpackedValue.*:\s+(\[.*\])\s*$', diff_lines[i + 1])
            if m1 and m2:
                try:
                    e1 = json.loads(m1.group(1))
                    e2 = json.loads(m2.group(1))
                    if len(e1) != len(e2):
                        upd_lines.append("? Length changed: -{} +{}".format(len(e1), len(e2)))
                    else:
                        m = cls._calc_distance(e1, e2)
                        codes = ["l_1", "l_2", "l_inf", "cos"]
                        msg = " ".join("{}={:.15f}".format(code, value) for (code, value) in zip(codes, m))
                        upd_lines.append("? " + "      " + msg)
                except:
                    upd_lines = []
                else:
                    consume_lines = 2
            res.extend(diff_lines[i:i + consume_lines])
            res.extend(upd_lines)
            i += consume_lines
        return res

    @classmethod
    def _make_html_diff(cls, lines1, lines2):
        gen = list(difflib.ndiff(lines1, lines2))

        try:
            gen = cls._enrich_embeddings_diff(gen)
        except:
            logging.exception("Failed to enrich diff")

        class_dict = {
            "  ": "common",
            "? ": "comment",
            "- ": "del",
            "+ ": "add",
        }

        classes_stat = collections.defaultdict(int)

        idx = 0
        res = []
        for l in gen:
            action = l[:2]
            text = l[2:].rstrip('\n\r')  # remove trailing newline from comments
            ln_class = class_dict.get(action, "unknown")
            comment = (ln_class == "comment")
            if not comment:
                idx += 1
            ln = {
                "number": "" if comment else idx,
                "action": action,
                "text": text,
                "class": ln_class,
            }
            res.append(ln)
            classes_stat[ln_class] += 1

        context_size = 10
        min_hidden = 10

        blocks = []

        def append_block(lines, hidden=False):
            bid = "block{}".format(len(blocks))
            b = {
                "hidden": hidden,
                "lines": lines,
                "block_id": bid,
                "length": len(lines),
                "classes": ("hidden" if hidden else ""),
            }
            blocks.append(b)

        for block_id, (is_common, lines) in enumerate(itertools.groupby(res, key=lambda ln: ln["class"] == "common")):
            lines = list(lines)
            if is_common and len(lines) >= 2 * context_size + min_hidden:
                append_block(lines[:context_size])
                append_block(lines[context_size:-context_size], hidden=True)
                append_block(lines[-context_size:])
            else:
                append_block(lines)

        classes_stat.default_factory = None

        return Templates.diff.render(blocks=blocks), classes_stat

    @staticmethod
    def _extract_changed_paths(diff_text, ignore_indices=False):
        res = collections.defaultdict(set)
        path_stat = collections.defaultdict(TPathStat)
        for line in diff_text.split('\n'):
            m = re.match("(modified|added|deleted|moved):\s+([^:]+):\s", line)
            if m is None:
                continue
            action, path = m.groups()
            if ignore_indices:
                path = re.sub("\[[0-9]+\]", "", path)
            res[action].add(path)
            path_stat[path].update(action)
        joined = frozenset(itertools.chain.from_iterable(res.values()))
        return res, joined, path_stat


class TPathStat(object):
    symbols = ["+", "-", "⇋", "⇱", "?"]
    actions = ["added", "deleted", "modified", "moved", "other"]

    def __init__(self):
        self.modified = 0
        self.added = 0
        self.deleted = 0
        self.moved = 0
        self.other = 0

    def update(self, action, weight=1):
        assert weight > 0
        if action not in self.actions:
            action = 'other'
        prev = getattr(self, action)
        new = prev + weight
        setattr(self, action, new)

    def __str__(self):
        return self.to_text(html=False)

    def format(self, html=True, skip_zero=False):
        res = []
        for sym, action in zip(self.symbols, self.actions):
            n = getattr(self, action)
            if skip_zero and not n:
                continue
            text = "{}{}".format(sym, n)
            if html:
                text = '<span class="path_stat {0}" title="{0}">{1}</span>'.format(action, text)
            res.append(text)
        return ' '.join(res).decode('utf-8')
