import base64
import logging
import os

from sandbox import sdk2
from sandbox.common.errors import TaskFailure
from sandbox.projects import resource_types
from sandbox.projects.common import binary_task
from sandbox.common.utils import Enum


class CompareMode(Enum):
    ANY = "any"
    SIGNIFICANT = "significant"


class CompareModelsServiceResponses(binary_task.LastBinaryTaskRelease, sdk2.Task):
    __logger = logging.getLogger("TASK_LOGGER")
    __logger.setLevel(logging.DEBUG)

    class Parameters(sdk2.Task.Parameters):
        first_responses = sdk2.parameters.Resource(
            "1st resource with responses",
            required=True,
            resource_type=resource_types.EXECUTOR_DUMP
        )
        second_responses = sdk2.parameters.Resource(
            "2nd resource with responses",
            required=True,
            resource_type=resource_types.EXECUTOR_DUMP
        )
        tasks_archive_resource = binary_task.binary_release_parameters(stable=True)

        with sdk2.parameters.String("Mode") as compare_mode:
            compare_mode.values[CompareMode.ANY] = compare_mode.Value(value=CompareMode.ANY)
            compare_mode.values[CompareMode.SIGNIFICANT] = compare_mode.Value(value=CompareMode.SIGNIFICANT, default=True)

        with compare_mode.value[CompareMode.ANY]:
            default_names_to_ignore = ["NBg.NProto.TBegemotResponse.SubSources"]

        with compare_mode.value[CompareMode.SIGNIFICANT]:
            embeddings_sum_error = sdk2.parameters.Float("Max embeddings compare error summary", default=1e-2)
            embeddings_diffs_count = sdk2.parameters.Integer("Max count of different embeddings values", default=1)
            default_names_to_ignore = ["NBg.NProto.TBegemotResponse.SubSources", "NBg.NProto.TWebResponse.AppliedModels"]

        names_to_ignore = sdk2.parameters.List(
            "Fields names",
            description="Names of proto fields, that will be ignored due compare",
            value_type=sdk2.parameters.String,
            default=default_names_to_ignore
        )
        precision = sdk2.parameters.Float("Float compare precision", default=1e-5)

    class Requirements(sdk2.Requirements):
        disk_space = 40 * 1024
        ram = 40 * 1024

    def _create_compare_context(self):
        from search.daemons.models_proxy.tests.lib.diff_responses import ComparePrecisionContext

        compare_ctx = ComparePrecisionContext()
        compare_ctx.fields_to_ignore = set(self.Parameters.names_to_ignore)
        compare_ctx.precision = self.Parameters.precision
        if self.Parameters.compare_mode == CompareMode.ANY:
            compare_ctx.embeddings_sum_error = 0
            compare_ctx.max_embeddings_diffs_count = 0
        elif self.Parameters.compare_mode == CompareMode.SIGNIFICANT:
            compare_ctx.embeddings_sum_error = self.Parameters.embeddings_sum_error
            compare_ctx.max_embeddings_diffs_count = self.Parameters.embeddings_diffs_count
        return compare_ctx

    def on_execute(self):
        from search.daemons.models_proxy.tests.lib.diff_responses import check_responses
        from search.daemons.models_proxy.tests.lib.diff_responses import check_responses_stripped
        from search.begemot.server.proto import begemot_pb2

        if (
            not getattr(self.Parameters.first_responses, "with_responses", False) or
            not getattr(self.Parameters.second_responses, "with_responses", False)
        ):
            raise TaskFailure(
                "Resource with requests must have `with_responses` attribute value == True"
            )
        if self.Parameters.first_responses.service_type != self.Parameters.second_responses.service_type:
            raise TaskFailure(
                "Resources with responses must have the same `service_type` attribute value"
            )
        first_responses_path = str(sdk2.ResourceData(self.Parameters.first_responses).path)
        second_responses_path = str(sdk2.ResourceData(self.Parameters.second_responses).path)
        check_responses_result = True
        check_responses_stripped_result = True

        diffs_dir = str(self.log_path("diffs"))
        os.mkdir(diffs_dir)
        compare_ctx = self._create_compare_context()
        with open(first_responses_path, "r") as first_responses, open(second_responses_path, "r") as second_responses:
            for i, (first_response_encoded, second_response_encoded) in enumerate(zip(first_responses, second_responses)):
                first_response = begemot_pb2.TBegemotResponse()
                second_response = begemot_pb2.TBegemotResponse()
                first_response.ParseFromString(base64.b64decode(first_response_encoded.strip()))
                second_response.ParseFromString(base64.b64decode(second_response_encoded.strip()))
                filename = os.path.join(diffs_dir, "%d-%d.diff" % (i // 50 * 50, i // 50 * 50 + 49))
                with open(filename, "a") as current_diff_file:
                    if not check_responses(first_response, second_response, current_diff_file, compare_ctx, i):
                        check_responses_result = False
                        stripped_diff_filename = os.path.join(diffs_dir, "%d-%d_stripped.diff" % (i // 50 * 50, i // 50 * 50 + 49))
                        with open(stripped_diff_filename, "a") as current_stripped_diff_file:
                            if not check_responses_stripped(first_response, second_response, current_stripped_diff_file, compare_ctx, i):
                                check_responses_stripped_result = False
        diff_detected = not check_responses_result if self.Parameters.compare_mode == CompareMode.ANY else not check_responses_stripped_result
        if diff_detected:
            raise TaskFailure("Diff detected, see logs for details")
