#!/usr/bin/python
# -*- coding: utf-8 -*-

import os
import logging
import random

from sandbox import sdk2
from sandbox import sandboxsdk
from sandbox.common.errors import TaskFailure
from sandbox.sdk2.helpers import subprocess
from sandbox.common.types.misc import NotExists
from sandbox.common.types.task import Status
from sandbox.projects.websearch.begemot import resources
from sandbox.projects.websearch.begemot.tasks.BegemotYT.common import CommonYtParameters, utc_from_now
from sandbox.projects.websearch.begemot.tasks.BegemotYT.BegemotMapper import BegemotMapper
from sandbox.projects.websearch.begemot.tasks.BegemotYT.paths import BegemotYtPaths


REQUESTS_COUNT = 3000000


class BEGEMOT_YT_MERGER(sdk2.Resource):
    """
        Begemot workers answers merger
    """
    arcadia_build_path = 'tools/wizard_yt/merger'
    executable = True


class BegemotReducer(sdk2.Task):
    __logger = logging.getLogger('TASK_LOGGER')
    __logger.setLevel(logging.DEBUG)
    yt_client, yt, token = None, None, None
    shards, fresh = dict(), dict()

    class Parameters(CommonYtParameters):
        begemot_mapper = sdk2.parameters.Resource(
            'Begemot mapper binary',
            resource_type=resources.BEGEMOT_YT_MAPPER,
            required=True,
        )
        eventlog_mapper = sdk2.parameters.Resource(
            'Begemot eventlog mapper binary',
            resource_type=resources.BEGEMOT_YT_EVENTLOG_MAPPER,
            required=False,
        )
        shards = sdk2.parameters.Resource(
            'Begemot shards paths files',
            resource_type=resources.BEGEMOT_CYPRESS_SHARD,
            multiple=True, required=True,
        )
        fresh = sdk2.parameters.Resource(
            'Begemot fresh paths files',
            resource_type=resources.BEGEMOT_CYPRESS_SHARD,
            multiple=True,
        )
        answers_store_time = sdk2.parameters.Integer('Days to store Merger answers', default=3)
        results_store_time = sdk2.parameters.Integer('Days to store intermediate results', default=1)
        job_count = sdk2.parameters.Integer('Yt job count', default=1)
        columns = sdk2.parameters.String('Columns for direct mode. Query column, than region column', default='')
        with sdk2.parameters.Output:
            answers = sdk2.parameters.String('Begemot answers')

    class Requirements(sdk2.Requirements):
        disk_space = 100
        ram = 100
        environments = [sandboxsdk.environments.PipEnvironment('yandex-yt', version='0.10.8')]

    def create_output(self):
        self.yt_client.create(
            'map_node', self.Parameters.output_path, recursive=True, ignore_existing=self.Parameters.ignore_existing,
            attributes={'expiration_time': utc_from_now(self.Parameters.answers_store_time)}
        )
        self.yt_client.create(
            'map_node', self.yt.ypath_join(self.Parameters.output_path, 'merged_answers'), ignore_existing=self.Parameters.ignore_existing,
            attributes={'expiration_time': utc_from_now(self.Parameters.results_store_time)}
        )
    
    def on_enqueue(self):
        if self.Parameters.eventlog_table == 'action':
            self.Parameters.eventlog_table = BegemotYtPaths.get_last_eventlog_table() + '[#0:#%d]' % REQUESTS_COUNT
        if self.Parameters.output_path == 'action':
            import yt.wrapper as yt
            self.Parameters.output_path = yt.ypath_join(
                BegemotYtPaths.get_testenv_path(),
                str(self.type),
                'postcommit_' + str(random.randint(1, 1000000))
            )


    def start_workers(self):
        self.Context.workers_tasks = dict()
        self.create_output()
        self.__logger.info('Shards dict: ', self.shards)
        self.__logger.info('Fresh dict: ', self.fresh)
        self.__logger.info('Running workers')
        for shard_name, shard in self.shards.items():
            if shard_name == 'Merger':
                continue
            if shard_name == 'Bert':
                self.set_info('Bert is not supported in this task, skipped')
                continue
            threads = 15 if shard_name in ['Wizard', 'Ethos'] else 5
            worker_task = BegemotMapper(
                self,
                description='Begemot mapper %s' % shard_name,
                service='begemot',
                begemot_mapper=self.Parameters.begemot_mapper,
                columns=self.Parameters.columns,
                eventlog_mapper=self.Context.eventlog_mapper,
                shard=shard,
                fresh=self.fresh.get(shard_name),
                input_table=self.Parameters.input_table,
                eventlog_table=self.Parameters.eventlog_table,
                output_path=self.yt.ypath_join(self.Parameters.output_path, shard_name),
                yt_proxy=self.Parameters.yt_proxy,
                yt_token_vault_owner=self.Parameters.yt_token_vault_owner,
                yt_token_vault_name=self.Parameters.yt_token_vault_name,
                yt_pool=self.Parameters.yt_pool,
                wait_time=self.Parameters.wait_time,
                results_store_time=self.Parameters.results_store_time,
                job_count=self.Parameters.job_count,
                threads=threads,
            ).enqueue()
            self.Context.workers_tasks[shard_name] = worker_task.id
            self.__logger.info('Started worker {}, task id = {}'.format(shard.shard_name, worker_task.id))
        self.__logger.info('Started all workers, waiting tasks to finish')
        raise sdk2.WaitTask(self.Context.workers_tasks.values(), Status.Group.FINISH)

    def get_workers_responses(self):
        self.Context.workers_responses = []
        rules_errors, stderrs = {}, {}
        mapper_failed = False
        for worker_name, task_id in self.Context.workers_tasks.items():
            worker_task = self.find(BegemotMapper, id=task_id).first()
            if worker_task is None:
                raise TaskFailure('%s task not found' % worker_name)
            elif worker_task.Parameters.eventlog_contains_errors:
                self.Context.workers_with_errors = (self.Context.workers_with_errors or []) + [worker_name]
                if worker_task.Context.rules_errors is not NotExists:
                    rules_errors[worker_name] = worker_task.Context.rules_errors
            elif worker_task.status != Status.SUCCESS:
                if worker_task.Context.rules_errors is not NotExists:
                    rules_errors[worker_name] = worker_task.Context.rules_errors
                if worker_task.Context.failed_begemot_stderr is not NotExists:
                    stderrs[worker_name] = worker_task.Context.failed_begemot_stderr
                mapper_failed = True
                failure_message = '%s task failed' % worker_name
            self.Context.workers_responses.append(worker_task.Parameters.answers)
        self.Context.rules_errors = rules_errors
        self.Context.stderrs = stderrs
        if mapper_failed:
            raise TaskFailure(failure_message)

    def start_merger(self):
        env = os.environ.copy()
        env['YT_TOKEN'] = self.token
        env['YT_PROXY'] = self.Parameters.yt_proxy
        env['YT_POOL'] = self.Parameters.yt_pool
        reducer = BEGEMOT_YT_MERGER.find().order(-sdk2.Resource.id).first()
        if reducer is None:
            raise TaskFailure('Begemot reducer resource not found')
        binary = str(sdk2.ResourceData(reducer).path)
        merged_answers_table = self.yt.ypath_join(self.Parameters.output_path, 'merged_answers', 'merged')
        args = [binary, '--output', merged_answers_table]
        for responses_table in self.Context.workers_responses:
            args.extend(['--input', responses_table])
        with sdk2.helpers.ProcessLog(self, logger=self.__logger.getChild('binary')):
            subprocess.check_call(args, env=env, close_fds=False)
        self.Context.merger_task_id = BegemotMapper(
            self,
            description='Begemot mapper Merger',
            service='begemot',
            begemot_mapper=self.Parameters.begemot_mapper,
            eventlog_mapper=self.Context.eventlog_mapper,
            shard=self.shards['Merger'],
            fresh=self.fresh.get('Merger'),
            input_table=merged_answers_table,
            output_path=self.yt.ypath_join(self.Parameters.output_path, 'Merger'),
            yt_proxy=self.Parameters.yt_proxy,
            yt_token_vault_owner=self.Parameters.yt_token_vault_owner,
            yt_token_vault_name=self.Parameters.yt_token_vault_name,
            yt_pool=self.Parameters.yt_pool,
            wait_time=self.Parameters.wait_time,
            results_store_time=self.Parameters.answers_store_time,
            job_count=self.Parameters.job_count,
            threads=5,
        ).enqueue().id
        self.__logger.info('Started Merger task, id = {}'.format(self.Context.merger_task_id))
        raise sdk2.WaitTask(self.Context.merger_task_id, Status.Group.FINISH)

    def on_execute(self):
        self.shards = dict([(s.shard_name, s) for s in self.Parameters.shards])
        self.fresh = dict([(s.shard_name, s) for s in self.Parameters.fresh or []])
        if 'Merger' not in self.shards:
            raise TaskFailure('Merger shard not found')
        import yt.wrapper as yt
        self.yt = yt
        self.token = sdk2.Vault.data(self.Parameters.yt_token_vault_owner, self.Parameters.yt_token_vault_name)
        self.yt_client = yt.YtClient(self.Parameters.yt_proxy, self.token)

        if not self.Parameters.eventlog_mapper:
            self.Context.eventlog_mapper = sdk2.Resource["BEGEMOT_YT_EVENTLOG_MAPPER"].find(state='READY', attrs={'released': 'stable'}).first().id
        else:
            self.Context.eventlog_mapper = self.Parameters.begemot_mapper.id

        if self.Context.workers_tasks == NotExists:
            self.start_workers()
        if self.Context.workers_responses == NotExists:
            self.get_workers_responses()
        if self.Context.merger_task_id == NotExists:
            self.start_merger()

        merger_task = self.find(BegemotMapper, id=self.Context.merger_task_id).first()
        if merger_task is None:
            raise TaskFailure('Merger task not found')
        if merger_task.status != Status.SUCCESS:
            raise TaskFailure('Merger task failed')
        self.Context.answers = self.Parameters.answers = merger_task.Parameters.answers
        if self.Context.workers_with_errors != NotExists:
            raise TaskFailure('Workers with errors in eventlog: %s' % ','.join(self.Context.workers_with_errors))

    def get_table_url(self, table_path):
        return "https://yt.yandex-team.ru/{}/#page=navigation&path={}".format(self.Parameters.yt_proxy, table_path)

    @sdk2.header()
    def header(self):
        if self.Context.merger_finished:
            return '<a target="_blank" href="{}">Answers</a><br/>'.format(self.get_table_url(self.Parameters.answers))
