# -*- coding: utf-8 -*-
import six
from six.moves import range

import logging
import math
import sys
import traceback
import base64
import re

from sandbox.projects.common.search.response.diff import custom_proto_comparators as cpc
from sandbox.projects.common.search import factors_decompressor
from sandbox.projects.common.differ import printers
from sandbox.projects.common.differ import differ
from sandbox.projects.common import string


class Statement(object):
    MATCH = "match"
    STARTS = "starts"
    ENDS = "ends"
    CONTAINS = "contains"
    SIMPLE = "simple"


# Default skip fields
DEFAULT_SKIP_FIELDS = {
    "SearcherProp": {
        Statement.MATCH: [
            # Timings
            "WaitInfo.debug",
            "WaitInfo2.debug",
            "La",

            # Topology
            "BsTouched",
            "BinaryBuildTag",  # SEARCH-2056
            "ReqError.debug",

            # machine-dependent things
            "HostName",
            "Ncpu",
            "Nctx",

            "Acc",
            "Al.Prior",
            "Cached",  # SEARCH-4748

            # Unstable (based on group count, etc)
            "GroupRepresentative.DepletedGroups.debug",
            "CalcMetaRelevanceMcs",  # SEARCH-9579
            "DocumentsStats",  # SEARCH-11681
        ],
        Statement.CONTAINS: [
            # FRESHNESS-1913
            "DaterAge_7_100_",
        ],
        Statement.STARTS: [
            "DebugConfigInfo_",
            "ConfigInfo_",
            "ClientGroup_",
            "Lua",
            "Cached_",
        ],
        Statement.ENDS: [
            "_TimeDistr",
        ],
    },
    "GtaRelatedAttribute": {
        Statement.MATCH: [
            "_SearcherHostname",
            "_MetaSearcherHostname",
            "_CompressedFactors",
            # enable comparing huff compressed factors (TODO: disable if tests become flaky)
            # "_HuffCompressedFactors",
        ]
    },
    "FirstStageAttribute": {
        Statement.MATCH: [
            "_SearcherHostname",
        ],
    }
}


class ProtoDiffException(Exception):
    pass


class IProtoItem(object):
    __slots__ = "Protodiff"

    def __init__(self, protodiff):
        self.Protodiff = protodiff

    def cmp_values(self, other):
        raise NotImplementedError("Item should be comparable. Please, define 'cmp_values' method!")


class KeyVal(IProtoItem):
    __slots__ = ("Key", "Value")

    def __init__(self, protodiff, key, value=None):
        super(KeyVal, self).__init__(protodiff)
        self.Key = string.all_to_str(key)
        self.Value = value

    def cmp_values(self, other):
        custom_cmp = cpc.CMP_FUNCTIONS.get(self.Key)
        if custom_cmp:
            custom_cmp(self.Value, other.Value, self.Protodiff._output_printer)
        elif self.Value != other.Value:
            self.Protodiff._output_printer("{} -> {}".format(self.Value, other.Value), printers.DiffType.CHANGED)

    def __repr__(self):
        return str(self.Value)


class KeyValFloat(KeyVal):
    def cmp_values(self, other):
        if math.fabs(float(self.Value) - float(other.Value)) > 0.000001:
            self.Protodiff._output_printer("{} -> {}".format(self.Value, other.Value), printers.DiffType.CHANGED)
            return True
        return False


class Grouping(IProtoItem):
    __slots__ = ("Key", "Grouping", "__values_simple")

    def __init__(self, protodiff, grouping):
        super(Grouping, self).__init__(protodiff)
        self.Key = "Grouping: (Attr.Mode)=({}.{})".format(grouping.Attr, grouping.Mode)
        self.Grouping = grouping
        self.__values_simple = [
            "IsFlat",
            # Unstable:
            # "NumDocs",
            # "NumGroups",
        ]

    def cmp_values(self, other):
        for value in self.__values_simple:
            self.Protodiff.compare_simple_field("Grouping", value, self.Grouping, other.Grouping)
        group_list1 = self.get_groups_list()
        group_list2 = other.get_groups_list()
        docs1 = self.get_documents(group_list1)
        docs2 = self.get_documents(group_list2)
        self.Protodiff.cmp_items(group_list1, group_list2)
        self.Protodiff.cmp_items(docs1, docs2)

    def __repr__(self):
        # TODO: improve this (less size, more info!)
        return "{}".format(to_utf(self.Grouping))

    def get_groups_list(self):
        return [Group(self.Protodiff, _) for _ in self.Grouping.Group]

    def get_documents(self, group_list):
        """
            Consider that documents are the part of grouping, not group.
            Allows to eliminate diff when docs are moving to another group with same CategoryName
            SEARCH-1736
        """
        return [Document(self.Protodiff, d) for group in group_list for d in group.Group.Document]


class Group(IProtoItem):
    __slots__ = ("Key", "Group", "doc_ids", "__values_simple")

    def __init__(self, protodiff, group):
        super(Group, self).__init__(protodiff)
        self.Key = "Group: CategoryName='{}'".format(group.CategoryName)
        self.Group = group
        self.doc_ids = [get_id_from_doc(doc) for doc in group.Document]
        self.__values_simple = [
            "Relevance",
            "Priority",
            "InternalPriority",
            # Unstable (See SEARCH-556)
            # "RelevStat",
        ]

    def cmp_values(self, other):
        for value in self.__values_simple:
            self.Protodiff.compare_simple_field("Group", value, self.Group, other.Group)
        self.Protodiff.cmp_items(self.get_document_ids(), other.get_document_ids())

    def __repr__(self):
        return "{}\nDoc ids: {}".format(
            "\n".join(["{}={}".format(v, getattr(self.Group, v, None)) for v in self.__values_simple]),
            ", ".join([str(d) for d in self.doc_ids]),
        )

    def get_document_ids(self):
        """
            Here we consider only adding and removing doc inside of current group by its DocId
        """
        return [KeyVal(self.Protodiff, "DocId: {}".format(doc_id)) for doc_id in self.doc_ids]


class Document(IProtoItem):
    __slots__ = ("Key", "Doc", "__values_simple")

    def __init__(self, protodiff, doc):
        super(Document, self).__init__(protodiff)
        self.Key = "Document: DocId='{}'".format(get_id_from_doc(doc))
        self.Doc = doc
        self.__values_simple = [
            "Relevance",
            "Priority",
            "PassageBreaks",
            "IndexGeneration",
            "SourceTimestamp",
            "DocRankingFactorsSliceBorders",
            "ServerDescr",
            "SearchInfo",
            "SRelevance",
            "SPriority",
            # SEARCH-1695
            # "InternalPriority",
            # "SInternalPriority",
        ]

    def cmp_values(self, other):
        for value in self.__values_simple:
            self.Protodiff.compare_simple_field("Document", value, self.Doc, other.Doc)
        self.Protodiff.compare_doc_ranking_factors(self.Doc, other.Doc)
        self.Protodiff.compare_field("BinFactor", self.get_bin_factors(), other.get_bin_factors())
        self.Protodiff.compare_field("FirstStageAttribute", self.get_factors(), other.get_factors())
        self.Protodiff.cmp_archive(self.Doc, other.Doc)

    def __repr__(self):
        return "{}".format(to_utf(self.Doc))

    def get_bin_factors(self):
        return [KeyValFloat(self.Protodiff, _.Key, _.Value) for _ in self.Doc.BinFactor]

    def get_factors(self):
        return self.Protodiff.get_keyval_list(self.Doc, "FirstStageAttribute")


def get_id_from_doc(doc):
    if doc.HasField("Route") and doc.HasField("DocHash"):
        return '{}/{}'.format(doc.Route, doc.DocHash)
    else:
        return doc.DocId


class Protodiff(differ.DifferBase):
    def __init__(
        self, output_printer,
        skip_fields=None,
        only_complete=False,
        ignore_unanswered=False,
        meta_pb_pack_path=None,
    ):
        """
            Constructor
            :param only_complete: compare only if DebugInfo.AnswerIsComplete == "true" for both responses
            :param ignore_unanswered: compare only if DebugInfo.BaseSearchNotRespondCount == 0 for both responses
        """
        super(Protodiff, self).__init__(output_printer)
        if skip_fields is None:
            self._skip_fields = DEFAULT_SKIP_FIELDS
        else:
            self._skip_fields = skip_fields
        self.only_complete = only_complete
        self.ignore_unanswered = ignore_unanswered
        if meta_pb_pack_path:
            sys.path.append(meta_pb_pack_path)
        factors_decompressor.FactorsDecompressor.init(["--output-mode", "json"], new_decompressor=True)
        factors_decompressor.PipedFactorsDecompressor.init()
        self._piped_factors_decompressor = None

    def get_skip_fields(self):
        return self._skip_fields

    def set_skip_fields(self, skip_fields):
        self._skip_fields = skip_fields
        logging.info("Set custom SKIP_FIELDS: {}".format(self._skip_fields))

    def set_factor_names(self, factor_names):
        """
             Set factor names
             :param factor_names: tuple of factor names for each response
        """
        if factor_names:
            for i in range(2):
                cpc.FACTOR_NAMES[i] = factor_names[i]

    def set_factors_check_params(
        self,
        ignore_diff_in_compressed_all_factors,
        ignore_diff_in_doc_ranking_factors,
        soft_check_factors_indexes,
        soft_check_factors_offsets_with_slices
    ):
        """
            Set factors check params
        """
        def add_skip_field(skip_fields, field_group, field_subgroup, field_name):
            if field_group not in skip_fields:
                skip_fields[field_group] = dict()
            if field_subgroup not in skip_fields[field_group]:
                skip_fields[field_group][field_subgroup] = []
            skip_fields[field_group][field_subgroup] += [field_name]

        if ignore_diff_in_compressed_all_factors:
            add_skip_field(self._skip_fields, "GtaRelatedAttribute", Statement.MATCH, "_CompressedAllFactors")
        if ignore_diff_in_doc_ranking_factors:
            add_skip_field(self._skip_fields, Statement.SIMPLE, "Document", "DocRankingFactors")
        cpc.SOFT_CHECK_FACTORS_INDEXES = soft_check_factors_indexes
        cpc.SOFT_CHECK_FACTORS_OFFSETS_WITH_SLICES = soft_check_factors_offsets_with_slices

    def _do_compare_single(self, report1, report2, query):
        """
        Main function, compares two single responses

        :param response_index: current response index
        :param report1: response in google.protobuf.message format
        :param report2: response in google.protobuf.message format
        :param query: responses are obtained using this query
        """
        diff_lines_before = self._output_printer.get_diff_lines_count()
        if not report1 or not report2:
            raise ProtoDiffException("There should be two reports to compare. ")

        meta_pb2 = get_meta_pb2()

        if isinstance(report1, (six.text_type, six.binary_type)):
            report_obj = meta_pb2.TReport()
            report_obj.ParseFromString(report1)
            report1 = report_obj

        if isinstance(report2, (six.text_type, six.binary_type)):
            report_obj = meta_pb2.TReport()
            report_obj.ParseFromString(report2)
            report2 = report_obj

        if report1.HasField("CompressedReport") or report2.HasField("CompressedReport"):
            raise ProtoDiffException("Please disable protobuf compression when comparing responses. ")

        # logging.debug("Comparing response #%s", response_index)
        self._piped_factors_decompressor = None

        # check debug info and decide compare or not:
        if self.cmp_debug_info(report1, report2):
            self.compare_head(report1, report2)

            # timings only, no chance to be equal
            # compare_balancing_info(report1, report2)

            self.compare_error_info(report1, report2)

            # unstable, SEARCH-556
            # compare_simple_field("Root", "TotalDocCount", report1, report2)

            self.compare_simple_field("Root", "FormField", report1, report2)

            self.compare_field(
                "SearcherProp",
                self.get_keyval_list(report1, "SearcherProp"),
                self.get_keyval_list(report2, "SearcherProp"),
            )
            self.cmp_items(self.get_groupings_list(report1), self.get_groupings_list(report2))

        if self._piped_factors_decompressor is not None:
            self._piped_factors_decompressor.close_pipe()
            self._piped_factors_decompressor = None

        return self._output_printer.get_diff_lines_count() > diff_lines_before

    def compare_simple_field(self, field_group, field_name, v1, v2):
        if (
            Statement.SIMPLE in self._skip_fields
            and field_group in self._skip_fields[Statement.SIMPLE]
            and field_name in self._skip_fields[Statement.SIMPLE][field_group]
        ):
            return

        field1 = getattr(v1, field_name, None)
        field2 = getattr(v2, field_name, None)
        if field1 != field2:
            self._output_printer.sched(field_name)
            self._output_printer("{} -> {}".format(field1, field2), printers.DiffType.CHANGED)
            self._output_printer.desched()

    _slice_regexp = re.compile(r'^([^[]*)\[(\d+);(\d+)\)$')

    def _parse_doc_ranking_factors(self, field, slice_borders):
        field = base64.b64encode(field)
        parsed = self._piped_factors_decompressor.extract(field)
        factors = parsed.split('\t')
        names = [("", i) for i in range(len(factors))]
        for single_slice in slice_borders.decode('ascii').split(' '):
            r = self._slice_regexp.match(single_slice)
            if r:
                name = r.group(1)
                start = int(r.group(2))
                end = int(r.group(3))
                for i in range(start, min(len(factors), end)):
                    names[i] = (name, i - start)
        return {names[i]: factors[i] for i in range(len(factors))}

    def compare_doc_ranking_factors(self, v1, v2):
        if 'DocRankingFactors' in self._skip_fields.get(Statement.SIMPLE, {}).get('Document', {}):
            return
        field1 = getattr(v1, 'DocRankingFactors', None)
        field2 = getattr(v2, 'DocRankingFactors', None)
        if field1 != field2:
            if self._piped_factors_decompressor is None:
                self._piped_factors_decompressor = factors_decompressor.PipedFactorsDecompressor(input_mode="raw")
                self._piped_factors_decompressor.open_pipe()
            factors1 = self._parse_doc_ranking_factors(field1, getattr(v1, 'DocRankingFactorsSliceBorders', ''))
            factors2 = self._parse_doc_ranking_factors(field2, getattr(v2, 'DocRankingFactorsSliceBorders', ''))
            names1 = set(factors1.keys())
            names2 = set(factors2.keys())
            self._output_printer.sched('DocRankingFactors')
            for key in sorted(names2 - names1):
                self._output_printer.print_item('{}[{}]'.format(*key), factors2[key], printers.DiffType.ADDED)
            for key in sorted(names1 - names2):
                self._output_printer.print_item('{}[{}]'.format(*key), factors1[key], printers.DiffType.REMOVED)
            for key in sorted(names1 & names2):
                if factors1[key] != factors2[key]:
                    self._output_printer.print_item('{}[{}]'.format(*key), '{} -> {}'.format(factors1[key], factors2[key]), printers.DiffType.CHANGED)
            self._output_printer.desched()

    def compare_field(self, field_name, list_field1, list_field2):
        self._output_printer.sched(field_name)
        self.cmp_items(list_field1, list_field2)
        self._output_printer.desched()

    def check_field(self, field_name, field1, field2):
        if field1.HasField(field_name) and field2.HasField(field_name):
            return True
        elif field1.HasField(field_name):
            self._output_printer(
                "{}:\n{}".format(field_name, to_utf(getattr(field1, field_name))),
                printers.DiffType.REMOVED,
            )
        elif field2.HasField(field_name):
            self._output_printer(
                "{}:\n{}".format(field_name, to_utf(getattr(field2, field_name))),
                printers.DiffType.ADDED,
            )
        return False

    def compare_head(self, report1, report2):
        self._output_printer.sched("Head")

        if self.check_field("Head", report1, report2):
            self.compare_simple_field("Head", "Version", report1.Head, report2.Head)
            # no need to compare
            # compare_simple_field("SegmentId", report1.Head, report2.Head)
            self.compare_simple_field("Head", "IndexGeneration", report1.Head, report2.Head)
            self.compare_simple_field("Head", "SearchInfo", report1.Head, report2.Head)
            self.compare_field(
                "FactorMapping",
                self.get_keyval_list(report1.Head, "FactorMapping"),
                self.get_keyval_list(report2.Head, "FactorMapping"),
            )
        self._output_printer.desched()

    def compare_balancing_info(self, report1, report2):
        """
            compares BalancingInfo
            as these fields changes every time,
            usually there is no need to compare them
        """
        self._output_printer.sched("BalancingInfo")

        if self.check_field("BalancingInfo", report1, report2):
            self.compare_simple_field("BalancingInfo", "Elapsed", report1.BalancingInfo, report2.BalancingInfo)
            self.compare_simple_field("BalancingInfo", "WaitInQueue", report1.BalancingInfo, report2.BalancingInfo)

        self._output_printer.desched()

    def cmp_debug_info(self, report1, report2):
        """
            return False - do not compare current responses pair
        """
        compare_this_response = True
        self._output_printer.sched("DebugInfo")

        if self.check_field("DebugInfo", report1, report2):
            if (
                self.only_complete and
                self.check_field("AnswerIsComplete", report1.DebugInfo, report2.DebugInfo) and
                (
                    report1.DebugInfo.AnswerIsComplete == "false" or
                    report2.DebugInfo.AnswerIsComplete == "false"
                )
            ):
                # self._output_printer("Answer is not complete")
                compare_this_response = False
            if (
                self.ignore_unanswered and
                self.check_field("BaseSearchNotRespondCount", report1.DebugInfo, report2.DebugInfo) and
                (
                    report1.DebugInfo.BaseSearchNotRespondCount != 0 or
                    report2.DebugInfo.BaseSearchNotRespondCount != 0
                )
            ):
                # self._output_printer("Response has unanswers")
                compare_this_response = False
            if compare_this_response:
                self.compare_simple_field(
                    "DebugInfo",
                    "AnswerIsComplete",
                    report1.DebugInfo,
                    report2.DebugInfo,
                )
                self.compare_simple_field(
                    "DebugInfo",
                    "BaseSearchNotRespondCount",
                    report1.DebugInfo,
                    report2.DebugInfo,
                )
            # flaky fields
            # compare_simple_field("DebugInfo", "BaseSearchCount", report1.DebugInfo, report2.DebugInfo)
            # compare_simple_field("DebugInfo", "NotRespondSourceName", report1.DebugInfo, report2.DebugInfo)
            # compare_simple_field("DebugInfo", "HostChain", report1.DebugInfo, report2.DebugInfo)
            # SEARCH-2056, unstable as revision often changes
            # compare_simple_field("DebugInfo", "BinaryBuildTag", report1.DebugInfo, report2.DebugInfo)
        self._output_printer.desched()
        return compare_this_response

    def compare_error_info(self, report1, report2):
        self._output_printer.sched("ErrorInfo")

        if self.check_field("ErrorInfo", report1, report2):
            self.compare_simple_field("ErrorInfo", "GotError", report1.ErrorInfo, report2.ErrorInfo)
            self.compare_simple_field("ErrorInfo", "Text", report1.ErrorInfo, report2.ErrorInfo)
            self.compare_simple_field("ErrorInfo", "Code", report1.ErrorInfo, report2.ErrorInfo)

        self._output_printer.desched()

    def cmp_archive_info(self, arch1, arch2):
        simple_fields = [
            "Title",
            "Headline",
            "IndexGeneration",
            "Url",
            "Size",
            "Charset",
            "Mtime",
            "Passage",
            "PassageAttr",
        ]

        for field in simple_fields:
            self.compare_simple_field("ArchiveInfo", field, arch1, arch2)

        self.compare_field(
            "GtaRelatedAttribute",
            self.get_keyval_list(arch1, "GtaRelatedAttribute"),
            self.get_keyval_list(arch2, "GtaRelatedAttribute"),
        )
        self.compare_field(
            "FloatRelatedAttribute",
            self.get_keyval_list(arch1, "FloatRelatedAttribute"),
            self.get_keyval_list(arch2, "FloatRelatedAttribute"),
        )

    def cmp_archive(self, doc1, doc2):
        self._output_printer.sched("ArchiveInfo")

        if self.check_field("ArchiveInfo", doc1, doc2):
            self.cmp_archive_info(doc1.ArchiveInfo, doc2.ArchiveInfo)

        self._output_printer.desched()

    def get_groupings_list(self, report):
        return [Grouping(self, _) for _ in report.Grouping]

    def get_keyval_list(self, field, prop_name):
        return [KeyVal(self, _.Key, _.Value) for _ in getattr(field, prop_name) if self._is_stable(_.Key, prop_name)]

    def _is_stable(self, key, prop_name):
        key = string.all_to_unicode(key)
        for mask_type, fields in six.iteritems(self._skip_fields.get(prop_name, {})):
            if mask_type == Statement.MATCH:
                for field in fields:
                    if key == field:
                        return False
            elif mask_type == Statement.STARTS:
                for field in fields:
                    if key.startswith(field):
                        return False
            elif mask_type == Statement.ENDS:
                for field in fields:
                    if key.endswith(field):
                        return False
            elif mask_type == Statement.CONTAINS:
                for field in fields:
                    if field in key:
                        return False
        return True

    def cmp_items(self, items1, items2):
        """
            Print diff to self._output_printer
            input:
                two lists of items
                each item must have attributes "cmp_values", and "__repr__"
        """
        def sort_by_key(x):
            return x.Key

        items1.sort(key=sort_by_key)
        items2.sort(key=sort_by_key)
        len_1 = len(items1)
        len_2 = len(items2)
        id1 = 0
        id2 = 0

        while True:
            if id1 < len_1 and id2 < len_2:
                item1 = items1[id1]
                item2 = items2[id2]
                if item1.Key == item2.Key:
                    self._output_printer.sched(item1.Key)
                    item1.cmp_values(item2)
                    self._output_printer.desched()
                    id1 += 1
                    id2 += 1
                elif item1.Key < item2.Key:
                    self._output_printer.print_item(item1.Key, item1, printers.DiffType.REMOVED)
                    id1 += 1
                else:
                    self._output_printer.print_item(item2.Key, item2, printers.DiffType.ADDED)
                    id2 += 1
            else:
                if id1 == len_1:
                    while id2 < len_2:
                        item = items2[id2]
                        self._output_printer.print_item(item.Key, item, printers.DiffType.ADDED)
                        id2 += 1
                if id2 == len_2:
                    while id1 < len_1:
                        item = items1[id1]
                        self._output_printer.print_item(item.Key, item, printers.DiffType.REMOVED)
                        id1 += 1
                break

    @staticmethod
    def write_html_diff(f_html, diff_list, query, response_index):
        """
            !!! DEPRECATED !!! try to use projects.common.search.response.diff.printers.PrinterToHtml
        """
        f_html.write(printers.HtmlBlock.start("response {}".format(response_index), is_open=False))

        if query is not None:
            f_html.write(printers.HtmlBlock.simple_block("query", query.strip()))

        depth = 0
        flag = False  # using flag to control indentations (depth in single block)
        for is_head, current_depth, line, color in diff_list:
            if is_head:
                flag = True

                for i in range(depth - current_depth):
                    f_html.write(printers.HtmlBlock.end())

                f_html.write(printers.HtmlBlock.start(line))
                depth = current_depth
            else:
                if flag:
                    depth += 1
                    flag = False

                line = line.strip().decode("string-escape")

                if color == printers.DiffType.CHANGED:
                    changed_lines = line.split("->")
                    first_line, second_line = printers.match_string_color_changed(
                        changed_lines[0].strip(), changed_lines[1].strip()
                    )
                    f_html.write(printers.HtmlBlock.colored_data(first_line))
                    f_html.write(printers.HtmlBlock.colored_data(second_line))
                else:
                    f_html.write(printers.HtmlBlock.simple_data(line, text_bg_class=color))

        for i in range(depth + 1):
            f_html.write(printers.HtmlBlock.end())


def get_meta_pb2():
    try:
        from search.idl import meta_pb2
        return meta_pb2
    except ImportError:
        logging.info("Unable to use meta_pb2 from ya.make. Will try to use downloaded one.")
    try:
        import meta_pb2
        return meta_pb2
    except ImportError:
        logging.info("Unable to use downloaded meta_pb2, use obsolete one: %s", traceback.format_exc())
    from sandbox.projects.common.base_search_quality.tree import meta_pb2
    return meta_pb2


def to_utf(message):
    return str(message).decode('string_escape')
