import datetime
import logging
import itertools

from sandbox import sdk2

import sandbox.common.types.task as ctt
from sandbox.common import errors

from sandbox.projects.common import file_utils as fu

from sandbox.projects.MarketModelWizard.GetDocidsFromXmlSearch import GetDocidsFromXmlSearch

from sandbox.projects.MarketModelWizard import GetDocidsFromXmlSearchResultTsv, GetDocidsFromXmlSearchInputTsv


class CommonXmlSearchParameters(sdk2.Parameters):
    with sdk2.parameters.Group('Xml search parameters') as xml_search_block:
        with sdk2.parameters.RadioGroup('XML origin') as xml_origin:
            xml_origin.values['https://hamster.yandex.ru/search/xml'] = xml_origin.Value(value='Hamster', default=True)
            xml_origin.values['https://yandex.ru/search/xml'] = xml_origin.Value(value='Yandex')

        numdoc = sdk2.parameters.Integer('Top documents count', default=10)

        reqinfo = sdk2.parameters.String(
            'Your id from https://wiki.yandex-team.ru/jandeksxml/internal-users/',
            required=True,
        )

        task_count = sdk2.parameters.Integer('Task count', default=3)

        max_rps_per_task = sdk2.parameters.Integer('RPS per task', required=True, default=20)

        url_limit = sdk2.parameters.Integer('First n rows', default=None)

        mean_task_time = sdk2.parameters.Integer('One chunk task time in minutes', required=True, default=15)

        run_only_at_night = sdk2.parameters.Bool('Only night requests (from 22 to 10 MSK)', default=True)

        rps_per_thread = sdk2.parameters.Integer('RPS per thread', default=1)

        send_solomon_metrics = sdk2.parameters.Bool('Send rps to solomon', default=False)
        with send_solomon_metrics.value[True]:
            with sdk2.parameters.Group('Solomon push-api parameters') as saas_parameters:
                rps_solomon_project = sdk2.parameters.String(
                    'Project',
                    required=True,
                )

                rps_solomon_cluster = sdk2.parameters.String(
                    'Cluster',
                    required=True,
                )

                rps_solomon_service = sdk2.parameters.String(
                    'Service',
                    required=True,
                )

                rps_solomon_sensor = sdk2.parameters.String(
                    'Sensor',
                    required=True,
                )


class GetDocidsFromXmlSearchByChunk(sdk2.Task):
    """Grep top docid for text in specific geoid by child tasks"""

    class Parameters(sdk2.Task.Parameters):
        input_tsv = sdk2.parameters.Resource(
            'Input tsv table (id, geoid, text)',
            resource_type=GetDocidsFromXmlSearchInputTsv,
        )

        with sdk2.parameters.RadioGroup('XML origin') as xml_origin:
            xml_origin.values['https://hamster.yandex.ru/search/xml'] = xml_origin.Value(value='Hamster', default=True)
            xml_origin.values['https://yandex.ru/search/xml'] = xml_origin.Value(value='Yandex')

        numdoc = sdk2.parameters.Integer('Top documents count', default=10)

        reqinfo = sdk2.parameters.String(
            'Your id from https://wiki.yandex-team.ru/jandeksxml/internal-users/',
            required=True,
        )

        task_count = sdk2.parameters.Integer('Task count', default=3)

        max_rps_per_task = sdk2.parameters.Integer('RPS per task', required=True, default=20)

        url_limit = sdk2.parameters.Integer('First n rows', default=None)

        mean_task_time = sdk2.parameters.Integer('One chunk task time in minutes', required=True, default=15)

        run_only_at_night = sdk2.parameters.Bool('Only night requests (from 22 to 10 MSK)', default=True)

        rps_per_thread = sdk2.parameters.Integer('RPS per thread', default=1)

        send_solomon_metrics = sdk2.parameters.Bool('Send rps to solomon', default=False)
        with send_solomon_metrics.value[True]:
            with sdk2.parameters.Group('Solomon push-api parameters') as saas_parameters:
                solomon_project = sdk2.parameters.String(
                    'Project',
                    required=True,
                )

                solomon_cluster = sdk2.parameters.String(
                    'Cluster',
                    required=True,
                )

                solomon_service = sdk2.parameters.String(
                    'Service',
                    required=True,
                )

                solomon_sensor = sdk2.parameters.String(
                    'Sensor',
                    required=True,
                )

        with sdk2.parameters.Output:
            result_resource = sdk2.parameters.Integer('Result resource id', required=True)

    class Requirements(sdk2.Task.Requirements):
        pass

    def create_chunk_task(self, shift, count, thread_count, rps_per_thread, solomon_parameters):
        if self.Parameters.url_limit is None:
            right_count = count
        else:
            right_count = count if shift + count <= self.Parameters.url_limit else self.Parameters.url_limit - shift

        sub_task = GetDocidsFromXmlSearch(
            self,
            description='from {shift} to {to}'.format(shift=shift, to=shift + right_count),
            kill_timeout=(2 * int(self.Parameters.mean_task_time) + 10) * 60,
            create_sub_task=False,
            input_tsv=self.Parameters.input_tsv.id,
            xml_origin=self.Parameters.xml_origin,
            numdoc=self.Parameters.numdoc,
            reqinfo=self.Parameters.reqinfo,
            thread_count=thread_count,
            rps_per_thread=rps_per_thread,
            url_shift=shift,
            url_limit=right_count,
            **solomon_parameters
        )
        logging.info('create task from {} to {}'.format(shift, shift + right_count))
        return sub_task

    def on_prepare_tasks(self):
        input_tsv_data = sdk2.ResourceData(self.Parameters.input_tsv)

        url_limit = int(self.Parameters.url_limit) if self.Parameters.url_limit is not None else None

        rps_per_thread = int(self.Parameters.rps_per_thread) if self.Parameters.rps_per_thread is not None else 1
        thread_count = int(self.Parameters.max_rps_per_task) / rps_per_thread

        count_per_task = 60 * self.Parameters.mean_task_time * self.Parameters.max_rps_per_task

        current_shift = 0

        tasks_count = 0

        if self.Parameters.send_solomon_metrics:
            solomon_parameters = dict(
                send_solomon_metrics=True,
                solomon_project=self.Parameters.solomon_project,
                solomon_cluster=self.Parameters.solomon_cluster,
                solomon_service=self.Parameters.solomon_service,
                solomon_sensor=self.Parameters.solomon_sensor,
                solomon_hostname=str(tasks_count),
            )
        else:
            solomon_parameters = dict(
                send_solomon_metrics=False,
            )

        tasks = [
            self.create_chunk_task(current_shift, count_per_task, thread_count, rps_per_thread, solomon_parameters)
        ]
        tasks_count += 1
        current_shift += count_per_task
        for i, line in enumerate(fu.read_line_by_line(str(input_tsv_data.path))):
            if url_limit is not None and i >= url_limit:
                break
            if i == current_shift:
                solomon_parameters['solomon_hostname'] = str(tasks_count)
                tasks.append(
                    self.create_chunk_task(
                        current_shift, count_per_task, thread_count, rps_per_thread, solomon_parameters
                    )
                )
                tasks_count += 1
                current_shift += count_per_task

    def get_wait_time(self):
        if not self.Parameters.run_only_at_night:
            return 0
        current_utc_time = datetime.datetime.utcnow().time()
        morning_time = datetime.time(hour=7)
        night_time = datetime.time(hour=19)
        if not (morning_time <= current_utc_time < night_time):  # from 10 to 22 MSK
            return 0
        else:
            return (night_time.hour - current_utc_time.hour) * 3600 + (night_time.minute - current_utc_time.minute) * 60

    def on_run_tasks(self):
        running = list(self.find(GetDocidsFromXmlSearch, status=(ctt.Status.Group.QUEUE, ctt.Status.Group.EXECUTE)))
        wait_task = []
        new_task_count = int(self.Parameters.task_count) - len(running)
        tasks = self.find(
            GetDocidsFromXmlSearch,
            status=(ctt.Status.Group.DRAFT, ctt.Status.TIMEOUT, ctt.Status.EXCEPTION)
        ).limit(new_task_count)
        for i, task in enumerate(tasks):
            if i >= new_task_count:
                break
            if task.status == ctt.Status.TIMEOUT or task.status == ctt.Status.EXCEPTION:
                task_id = str(task.id)
                if task_id not in self.Context.failed_tasks:
                    self.Context.failed_tasks[task_id] = 0
                self.Context.failed_tasks[task_id] += 1
                if self.Context.failed_tasks[task_id] >= 5:
                    raise errors.TaskError('to many fail on one child task')
            wait_task.append(task)

        wait_time = self.get_wait_time()
        if (wait_task or running) and wait_time > 0:
            logging.info('wait {seconds}s'.format(seconds=wait_time))
            raise sdk2.WaitTime(wait_time)
        else:
            if not (wait_task or running):
                return
            for task in wait_task:
                logging.info('start task {}'.format(task.id))
                task.enqueue()
            raise sdk2.WaitTask(
                [task.id for task in itertools.chain(running, wait_task)],
                (ctt.Status.SUCCESS, ctt.Status.FAILURE, ctt.Status.TIMEOUT),
                wait_all=False,
            )

    def on_join_results(self):
        result_resource = GetDocidsFromXmlSearchResultTsv(
            self,
            'result tsv from task {task_id} resource {resource_id}'.format(
                task_id=self.id, resource_id=self.Parameters.input_tsv.id,
            ),
            'result.tsv',
            input_resource_id=self.Parameters.input_tsv.id,
            line_shift=0,
        )
        result_resource_data = sdk2.ResourceData(result_resource)
        count = 0
        for task in self.find(GetDocidsFromXmlSearch):
            assert task.status == ctt.Status.SUCCESS
            resource = GetDocidsFromXmlSearchResultTsv.find(
                task=task,
            ).first()
            count += resource.line_count
            resource_data = sdk2.ResourceData(resource)
            fu.append_lines(str(result_resource_data.path), fu.read_lines(str(resource_data.path)))

        result_resource.line_count = count
        result_resource_data.ready()

        self.Parameters.result_resource = result_resource.id

    def on_execute(self):
        with self.memoize_stage.prepare_tasks:
            self.on_prepare_tasks()

        with self.memoize_stage.init_context:
            self.Context.failed_tasks = {}

        with self.memoize_stage.run_tasks(commit_on_entrance=False, commit_on_wait=False):
            self.on_run_tasks()

        with self.memoize_stage.join_results(commit_on_entrance=False):
            self.on_join_results()
