# -*- coding: utf-8 -*-
import json
import logging

import sandbox.common.types.client as ctc
from sandbox.sandboxsdk.task import SandboxTask

from sandbox.projects.common import apihelpers

from sandbox.sandboxsdk.channel import channel
from sandbox.sandboxsdk.errors import SandboxTaskFailureError

from sandbox.projects.common.fusion.short_stats import get_all_diff_stats
from sandbox.projects.common.fusion.params import SAMOHOD_MODE
from sandbox.projects.common.search.components import DefaultMiddlesearchParams
import math


def get_task_out_resource(task_id):
    return channel.sandbox.get_task(task_id).ctx["out_resource_id"]


def get_resource_task(resource_id):
    return channel.sandbox.get_resource(resource_id).task_id


MiddleseearchBinary = DefaultMiddlesearchParams.Binary
MiddleseearchBinary.required = False


class CompareFusionResultsTask(SandboxTask):

    client_tags = ctc.Tag.Group.LINUX

    def initCtx(self):
        ctx = {}
        ctx["new_task_id"] = None
        ctx["prod_task_id"] = None
        ctx["prod_responses_id"] = None
        ctx["new_responses_id"] = None

        ctx["compare_task_id"] = None
        ctx["wait_tasks"] = []
        ctx["prod_doc_size"] = None
        ctx["new_doc_size"] = None
        ctx["doc_size_diff"] = None
        self.ctx["diffs"] = None
        self.ctx["zeroes"] = None
        return ctx

    def create_compare_task(self, new_task, prod_task):
        old_responses = get_task_out_resource(prod_task)
        new_responses = get_task_out_resource(new_task)

        sub_ctx = {
            "basesearch_responses1_resource_id": old_responses,
            "basesearch_responses2_resource_id": new_responses,
            "queries_per_file": 100,
            "fail_on_diff": False,
        }
        return self.create_subtask(
            task_type='COMPARE_BASESEARCH_RESPONSES',
            input_parameters=sub_ctx,
            description="Comparing fusion responses",
        ).id

    def get_compare_responses_task(self):
        return channel.sandbox.get_task(self.ctx["compare_task_id"])

    def get_compare_resources(self):
        return apihelpers.list_task_resources(self.ctx["compare_task_id"], 'BASESEARCH_RESPONSES_COMPARE_RESULT')

    def task_is_completed(self, ctx_name):
        task_id = self.ctx.get(ctx_name, None)
        if task_id:
            return channel.sandbox.get_task(task_id).is_done()

    def get_factor_names(self, task_id):
        resources = channel.sandbox.list_resources(task_id=task_id)
        return [r for r in resources if r.file_name == "factor_names.txt"][0]

    def get_latest_fstr(self):
        return channel.sandbox.list_resources("FSTR_INFO", status="READY")[0]

    def get_short_stats(self, fusion_type):
        old_responses = self.sync_resource(self.ctx["prod_responses_id"])
        new_responses = self.sync_resource(self.ctx["new_responses_id"])
        self.ctx["new_task_id"] = get_resource_task(self.ctx["new_responses_id"])
        self.ctx["prod_task_id"] = get_resource_task(self.ctx["prod_responses_id"])

        factor_names = self.sync_resource(self.get_factor_names(self.ctx["new_task_id"]))

        resource_id = channel.sandbox.get_task(self.ctx["new_task_id"]).ctx.get("fstr")
        if not resource_id:
            logging.error("No fstr file for %s task" % self.ctx["new_task_id"])
            resource_id = self.get_latest_fstr()

        if isinstance(resource_id, list):
            resource_id = resource_id[0]
        with open(self.sync_resource(resource_id)) as f:
            fstr = json.load(f)

        grouping = None
        if fusion_type:
            grouping = "d" if fusion_type == SAMOHOD_MODE else "d:fresh"

        result = get_all_diff_stats(old_responses, new_responses, factor_names, fstr, grouping)

        logging.info("Diffs: %s", result.AllFactorsDiff)
        logging.info("Zeroes: %s", result.AllFactorsZeroed)
        logging.info("Fresh diffs: %s", result.FreshDiff)
        logging.info("Fresh zeroed: %s", result.FreshZeroed)

        self.ctx["diffs"] = result.AllFactorsDiff
        self.ctx["zeroes"] = result.AllFactorsZeroed

        self.ctx["fresh_diffs"] = result.FreshDiff
        self.ctx["fresh_zeroes"] = result.FreshZeroed

        self.ctx["mango_diffs"] = result.MangoDiff
        self.ctx["mango_zeroes"] = result.MangoZeroed

        self.ctx["prod_doc_size"] = channel.sandbox.get_task(self.ctx["prod_task_id"]).ctx.get("doc_size", 0)
        self.ctx["new_doc_size"] = channel.sandbox.get_task(self.ctx["new_task_id"]).ctx.get("doc_size", 0)
        self.ctx["doc_size_diff"] = self.ctx["prod_doc_size"] - self.ctx["new_doc_size"]

    def get_zeroes_error(self):
        if self.ctx["zeroes"] and len(self.ctx["zeroes"]) and self.ctx["zeroes"][0][-1] > 99:
            return "Unacceptable pecentage of at least on zeroed factor (%s %%). " \
                   "See 'Zeroed' tab in the task view." % self.ctx["zeroes"][0][-1]

    def get_doc_size_error(self):
        if self.ctx["prod_doc_size"] and self.ctx["new_doc_size"] and self.ctx["doc_size_diff"]:
            if math.fabs(self.ctx["doc_size_diff"]) >= self.ctx["new_doc_size"] or math.fabs(self.ctx["doc_size_diff"]) >= self.ctx["prod_doc_size"]:
                return "Remarkable size of doc size diff: %s (new) vs %s (old)" % (self.ctx["new_doc_size"],  self.ctx["prod_doc_size"])

    def verify_results(self):
        errors = []
        check_functions = [
            self.get_zeroes_error,
            self.get_doc_size_error,
        ]
        for get_error in check_functions:
            error = get_error()
            if error:
                errors.append(error)

        if errors:
            errors_to_string = "\n".join(errors)
            raise SandboxTaskFailureError("The task has failed due to the following reason(s):\n %s" % errors_to_string)
