from __future__ import print_function, unicode_literals

import itertools as it
import logging
import typing
from six.moves import zip, zip_longest

from sandbox import sdk2
from sandbox.common import enum
from sandbox.common import fs
from sandbox.common import rest
from sandbox.projects.common import binary_task
from sandbox.projects.common import error_handlers as eh
from sandbox.projects.common import file_utils as fu
from sandbox.projects.common import task_env
from sandbox.projects.common import templates
from sandbox.projects.common.differ import printers
from sandbox.projects.common.search import bugbanner2 as bb2
from sandbox.projects.common.search import requester
from sandbox.projects.common.search.response.diff import protodiff
from sandbox.projects import resource_types


class DifferentQueriesBehaviour(enum.Enum):
    enum.Enum.preserve_order()

    FAIL_ON_DIFFERENT = None
    COMPARE_AND_SHOW_DIFFERENT = None  # compare as usual, but print both queries in title
    MARK_AS_NO_DIFF = None


class CheckQueriesResult(enum.Enum):
    enum.Enum.preserve_order()

    CHECK_OK = None
    MARK_AS_NO_DIFF = None


class CompareSearchResponses(binary_task.LastBinaryTaskRelease, bb2.BugBannerTask):
    class Requirements(task_env.TinyRequirements):
        cores = 4
        disk_space = 10 * 1024  # 10 Gb
        ram = 31 * 1024

    class Parameters(sdk2.Task.Parameters):
        _task_binary = binary_task.binary_release_parameters(stable=True)
        responses1 = sdk2.parameters.Resource(
            "Binary responses #1", resource_type=resource_types.SEARCH_PROTO_RESPONSES, required=True
        )
        responses2 = sdk2.parameters.Resource(
            "Binary responses #2", resource_type=resource_types.SEARCH_PROTO_RESPONSES, required=True
        )
        with sdk2.parameters.Group("Diff options") as diff_options:
            write_compact_diff = sdk2.parameters.Bool("Write diff in compact form", default_value=True)
            skip_fields = sdk2.parameters.Resource(
                "Custom SKIP_FIELDS for protodiff",
                resource_type=resource_types.CUSTOM_PROTODIFF_SKIP_FIELDS,
                required=False,
            )
            fail_if_no_queries = sdk2.parameters.Bool("Fail if no queries", default_value=True)
            with sdk2.parameters.String("Different queries behaviour") as different_queries_behaviour:
                for i, opt in enumerate(DifferentQueriesBehaviour):
                    different_queries_behaviour.values[opt] = different_queries_behaviour.Value(
                        value=opt,
                        default=(i == 0),
                    )
            compare_complete_responses_only = sdk2.parameters.Bool("Compare complete responses only")
            ignore_unanswered = sdk2.parameters.Bool("Compare fully answered only")
            ignore_diff_in_compressed_all_factors = sdk2.parameters.Bool("Do not compare _CompressedAllFactors field")
            ignore_diff_in_doc_ranking_factors = sdk2.parameters.Bool("Do not compare DocRankingFactors field")
            relev_factors_for_soft_check = sdk2.parameters.List(
                "Use soft difference check for given factors in _RelevFactors field",
                sdk2.parameters.Integer,
                default=[450, 451, 1861, 1862, 1920],
            )
            relev_factors_with_slices_for_soft_check = sdk2.parameters.JSON(
                "Use soft difference check for given factors with slices in _RelevFactors field",
                default={"web_production": [450, 451, 1861, 1862, 1920]},
            )

        with sdk2.parameters.Output:
            diff_resource = sdk2.parameters.Resource("Html diff")
            diff_count = sdk2.parameters.Integer("Diff count")
            compare_result = sdk2.parameters.Bool("Compare result")
            relevant_diff = sdk2.parameters.List("Relevant diff (first 300)")

    def on_enqueue(self):
        self.Parameters.diff_resource = resource_types.BASESEARCH_RESPONSES_COMPARE_RESULT(
            self, "{}, html diff".format(self.Parameters.description), "diff"
        ).id

    def on_execute(self):
        path_to_diffs = str(sdk2.ResourceData(self.Parameters.diff_resource).path)
        fs.make_folder(path_to_diffs)
        resp1, resp2 = sdk2.Resource[self.Parameters.responses1], sdk2.Resource[self.Parameters.responses2]
        if resp1.md5 == resp2.md5:
            self.Context.compare_result = True
            return
        logging.info("Compare responses of size %s vs %s", resp1.size, resp2.size)
        check_result, queries = self._get_queries(resp1, resp2)
        if check_result == CheckQueriesResult.MARK_AS_NO_DIFF:
            logging.info("found different queries, mark this task as no diff")
            self._mark_as_no_diff()
            return
        eh.ensure(check_result == CheckQueriesResult.CHECK_OK, "unexpected check_result={}".format(check_result))
        cmp_data_bundle = zip(
            requester.sequence_binary_data(str(sdk2.ResourceData(resp1).path)),
            requester.sequence_binary_data(str(sdk2.ResourceData(resp2).path)),
            queries,
        )
        printer = printers.PrinterToHtml(
            path_to_diffs,
            write_compact_diff=self.Parameters.write_compact_diff,
            pair_head_template="response {obj_index}",
            obj_head_template="query"
        )
        differ = protodiff.Protodiff(
            printer,
            only_complete=self.Parameters.compare_complete_responses_only,
            ignore_unanswered=self.Parameters.ignore_unanswered,
        )
        factor_names = get_and_check_factor_names(resp1, resp2)
        differ.set_factor_names(factor_names)
        if self.Parameters.skip_fields:
            skip_fields_file = str(sdk2.ResourceData(sdk2.Resource[self.Parameters.skip_fields]).path)
            differ.set_skip_fields(fu.yaml_load(skip_fields_file))
        differ.set_factors_check_params(
            self.Parameters.ignore_diff_in_compressed_all_factors,
            self.Parameters.ignore_diff_in_doc_ranking_factors,
            set(self.Parameters.relev_factors_for_soft_check),
            self.Parameters.relev_factors_with_slices_for_soft_check
        )
        differ.compare_pairs(cmp_data_bundle)

        diff_count = differ.get_diff_count()
        self.Parameters.diff_count = diff_count
        self.Parameters.compare_result = diff_count == 0
        self.Parameters.relevant_diff = find_relevant_diff(printer.get_compact_diff())

    def _get_queries(self, resp1, resp2):
        # type: (sdk2.Resource, sdk2.Resource) -> (CheckQueriesResult, typing.Iterable)
        sb_client = rest.Client()
        input_key1 = "input_parameters.requests"
        input_key2 = "context.queries_resource_id"
        q_data = sb_client.task[{
            "fields": [input_key1, input_key2],
            "hidden": True,
            "children": True,
            "limit": 2,
            "id": [resp1.task_id, resp2.task_id],
        }]["items"]
        queries_ids = list(filter(None, [i[input_key1] or i[input_key2] for i in q_data]))
        queries_resources = [sdk2.Resource[i] for i in queries_ids if i]
        logging.info("Got queries resources: %s", queries_resources)

        queries_iter = None
        if len(queries_resources) == 0:
            if self.Parameters.fail_if_no_queries:
                eh.check_failed("No queries found")
            logging.info("No queries found, use fake ones")
            queries_iter = it.repeat("")
        elif len(queries_resources) == 1:
            logging.info("Found only one query resource. Use it")
            queries_iter = fu.read_line_by_line(sdk2.ResourceData(queries_resources[0]).path)
        else:
            q1, q2 = queries_resources
            if q1.md5 == q2.md5:
                logging.info("Queries for responses are same")
                queries_iter = fu.read_line_by_line(sdk2.ResourceData(q1).path)
            elif self.Parameters.different_queries_behaviour == DifferentQueriesBehaviour.FAIL_ON_DIFFERENT:
                eh.check_failed("Queries for responses are different")
            elif self.Parameters.different_queries_behaviour == DifferentQueriesBehaviour.COMPARE_AND_SHOW_DIFFERENT:
                logging.info("Queries for responses are different. Will show them all")
                queries_iter = (
                    "query1: {}\nquery2: {}".format(i, j)
                    for i, j in zip_longest(
                        fu.read_line_by_line(sdk2.ResourceData(q1).path),
                        fu.read_line_by_line(sdk2.ResourceData(q2).path)
                    )
                )
            elif self.Parameters.different_queries_behaviour == DifferentQueriesBehaviour.MARK_AS_NO_DIFF:
                return CheckQueriesResult.MARK_AS_NO_DIFF, None
        return CheckQueriesResult.CHECK_OK, queries_iter

    def on_finish(self, prev_status, status):
        self._save_diff()
        super(CompareSearchResponses, self).on_finish(prev_status, status)

    def on_break(self, prev_status, status):
        self._save_diff()
        super(CompareSearchResponses, self).on_break(prev_status, status)

    def _save_diff(self):
        out_resource = sdk2.ResourceData(self.Parameters.diff_resource)
        if not out_resource.path.exists() or not next(out_resource.path.iterdir(), None):
            no_diff_filename = "no_diff_detected.html"
            out_resource.path.joinpath(no_diff_filename).write_bytes(templates.get_html_template(no_diff_filename))
        out_resource.ready()

    def _mark_as_no_diff(self):
        self.Parameters.diff_count = 0
        self.Parameters.compare_result = True
        self.Parameters.relevant_diff = ""


def get_and_check_factor_names(resp1, resp2):
    # type: (sdk2.Resource, sdk2.Resource) -> typing.Optional[typing.Tuple[list, list]]
    factor_names_1 = _get_factor_names(resp1.id)
    factor_names_2 = _get_factor_names(resp2.id)
    difference = (set(factor_names_1) | set(factor_names_2)) - (set(factor_names_1) & set(factor_names_2))
    if difference:
        logging.info("Factors for responses have difference: %s", difference)
    else:
        logging.info("Factors for responses are equal")
    if len(factor_names_1) * len(factor_names_2) == 0:
        logging.warning("Empty factor names!")
    return factor_names_1, factor_names_2


def _get_factor_names(resp_id):
    resource = sdk2.Resource[resource_types.OTHER_RESOURCE].find(
        attrs={"responses_id": "resource:{}".format(resp_id)}
    ).first()
    if resource:
        resource_path = sdk2.ResourceData(resource).path
        return [i.split()[-1] for i in fu.read_line_by_line(resource_path)]
    else:
        logging.warning("Cannot find factor_names resource for %s", resp_id)
    return []


def find_relevant_diff(compact_diff_items):
    paths_to_exclude = ["debug"]  # if diff has this substring it does not influence serp
    result = []
    for head, _ in compact_diff_items:
        head_lower = head.lower()
        for field in paths_to_exclude:
            if field not in head_lower:
                result.append(head)
                break
    return sorted(result[:300])
