import json

from sandbox import sdk2
from sandbox.common.types.task import Status
from sandbox.common.errors import TaskFailure
from sandbox.projects.common.wizard.printwizard import postprocess
from sandbox.projects.websearch.begemot.tasks.BegemotCreateResponsesDiff import jsondiff
from sandbox.projects.websearch.begemot.tasks.BegemotYT.BegemotReducer import BegemotReducer


class BegemotReducerMultiplier(sdk2.Task):
    class Context(sdk2.Context):
        offset = 0
        tasks = []
        error_reqid = {}

    class Parameters(BegemotReducer.Parameters):
        num_iter = sdk2.parameters.Integer('The number of iterations', default=3)

    def run_task(self, path):
        return BegemotReducer(
            self,
            description='Collect begemot answers for {}'.format(self.__class__.__name__),
            begemot_mapper=self.Parameters.begemot_mapper,
            shards=self.Parameters.shards,
            fresh=self.Parameters.fresh,
            answers_store_time=self.Parameters.answers_store_time,
            results_store_time=self.Parameters.results_store_time,
            job_count=self.Parameters.job_count,
            eventlog_table=self.Parameters.eventlog_table,
            output_path=path,
            ignore_exist=self.Parameters.ignore_existing,
            wait_time=self.Parameters.wait_time,
            yt_proxy=self.Parameters.yt_proxy,
            yt_token_vault_owner=self.Parameters.yt_token_vault_owner,
            yt_token_vault_name=self.Parameters.yt_token_vault_name,
            yt_pool=self.Parameters.yt_pool,
            kill_timeout=self.Parameters.kill_timeout,
            fail_on_any_error=self.Parameters.fail_on_any_error,
        ).enqueue().id

    def on_execute(self):
        import yt.wrapper as yt

        if self.Parameters.num_iter <= 0:
            raise TaskFailure('Error: incorrect number of iterations')
        if not self.Context.tasks:
            self.Context.tasks = [
                self.run_task(yt.ypath_join(self.Parameters.output_path, str(i)))
                for i in range(self.Parameters.num_iter)
            ]
        if self.Context.offset < len(self.Context.tasks):
            offset = self.Context.offset
            self.Context.offset += 1
            raise sdk2.WaitTask(self.Context.tasks[offset], Status.Group.FINISH)

        reqid_error = dict()
        dict_reqid = dict()
        tmp_reqid = set()

        for i in range(self.Parameters.num_iter):
            token = sdk2.Vault.data(self.Parameters.yt_token_vault_owner, self.Parameters.yt_token_vault_name)
            yt_client = yt.YtClient(self.Parameters.yt_proxy, token)
            child = self.find(BegemotReducer, id=self.Context.tasks[i]).first()
            ans_table = yt_client.read_table(child.Parameters.answers, format='json')

            for r in ans_table:
                answer = postprocess(json.loads(r['begemot_answer'])[0], None, None)
                str_req = r['reqid']

                if i == 0:
                    dict_reqid[str_req] = answer
                    continue
                if dict_reqid.get(str_req) is None:
                    reqid_error[str_req].add('not found in all iterations: appeared in {}'.format(i))
                else:
                    diff = jsondiff.diff(dict_reqid.get(str_req), answer)
                    if diff is not None:
                        reqid_error[str_req].add(diff)
                tmp_reqid.add(str_req)

            if i == 0:
                continue
            for r in dict_reqid.keys():
                if r not in tmp_reqid:
                    reqid_error[r].add('not found in all iterations: not in {}'.format(i))

        if reqid_error:
            self.Context.error_reqid = reqid_error
            raise TaskFailure('Results did not match')
