import logging
import sandbox.common.types.task as ctt
import tempfile

from datetime import datetime, timedelta
from sandbox.sandboxsdk.process import run_process
from sandbox import sdk2
from sandbox.common.errors import TaskFailure
from sandbox.sandboxsdk import environments
from sandbox.projects.EntitySearch import resource_types
from sandbox.projects.websearch.upper.RequestSampler import RequestSampler
from sandbox.projects.websearch.upper import resources as upper_resources
from sandbox.projects.common import apihelpers
from time import time

APPHOST_LOGS_DIR = '//logs/apphost-event-log/1h'
SOURCES_TO_RESOURCE = {
    'ENTITYSEARCH': 'ENTITY_SEARCH_JSON_CONTEXTS',
    'YANSWER': 'YANSWER_GRPC_CLIENT_PLAN',
    'ANTIOBJECT': 'ANTI_OBJECT_GRPC_CLIENT_PLAN'
}


def get_request_resource(task, resource_name):
    return sdk2.Resource[resource_name].find(task=task).limit(1).first()


class EntitySearchApphostRequestSampler(sdk2.Task):
    """ A task to sample apphost request via search/tools/request_sample for ENTITYSEARCH, YANSWER, ANTIOBJECT"""

    class Parameters(sdk2.Parameters):
        kill_timeout = 21600  # 6 hours
        requests_count = sdk2.parameters.Integer('Requests count', default=300000)

    class Requirements(sdk2.Requirements):
        ram = 190 * 1024

        environments = [
            environments.PipEnvironment('yandex-yt')
        ]

    def start_request_sampler_task(self, request_count, request_sampler_resource_id, yt_table):
        params = {
            'sources': [source for source in SOURCES_TO_RESOURCE.keys() if source != 'ENTITYSEARCH'],
            'request_sampler': request_sampler_resource_id,
            'number_of_requests': request_count,
            'output_type': 'grpc-client-plan',
            'graphs': 'WEB',
            'check_enough_requests': False,
            'input_table': yt_table,
            'output_tables_prefix': '//tmp/serp_object_request_sampler_output_table_{}'.format(time()),
            'yt_token_vault_owner': 'robot-ontodb',
            'yt_token_vault_key': 'robot-ontodb-yt-token',
        }
        task_class = sdk2.Task['REQUEST_SAMPLER']

        logging.info('create task {} with params: {}'.format(RequestSampler.type, str(params)))
        serp_object_task = task_class(
            task_class.current,
            description='request sampler from apphost for serp object',
            owner=self.Parameters.owner,
            priority=self.Parameters.priority,
            **params
        ).enqueue()
        self.Context.tasks.append(serp_object_task.id)

        params['parallel_download'] = False
        params['output_tables_prefix'] = '//tmp/entity_search_request_sampler_output_table_{}'.format(time())
        params['output_type'] = 'app-host-json'
        params['sources'] = ['ENTITYSEARCH']
        logging.info('create task {} with params: {}'.format(RequestSampler.type, str(params)))
        entity_search_task = task_class(
            task_class.current,
            description='request sampler from apphost for ENTITYSEARCH',
            owner=self.Parameters.owner,
            priority=self.Parameters.priority,
            **params
        ).enqueue()
        self.Context.tasks.append(entity_search_task.id)

    def on_execute(self):
        with self.memoize_stage.on_start():
            from yt.wrapper import YtClient, ypath_join
            yt_token = sdk2.Vault.data('robot-ontodb', 'robot-ontodb-yt-token')
            yt = YtClient('hahn', yt_token)
            tables = sorted(yt.list(APPHOST_LOGS_DIR))
            prev_date = (datetime.strptime(tables[-1].split('T')[0], '%Y-%m-%d') - timedelta(days=1)).strftime('%Y-%m-%d')
            tables = [ypath_join(APPHOST_LOGS_DIR, table) for table in tables if table.startswith(prev_date)]

            request_sampler_resource_id = apihelpers.get_last_released_resource(upper_resources.REQUEST_SAMPLER_EXECUTABLE).id
            self.Context.tasks = []
            count = 4 * self.Parameters.requests_count / len(tables)  # more requests than need if some tasks failed
            for table in tables:
                self.start_request_sampler_task(count, request_sampler_resource_id, table)

            logging.info('Wait request sampler task')
            raise sdk2.WaitTask(self.Context.tasks, ctt.Status.Group.FINISH | ctt.Status.Group.BREAK, wait_all=True)

        logging.info('Concatenate result')
        sources_to_files = {src: [] for src in SOURCES_TO_RESOURCE}

        for task_id in self.Context.tasks:
            logging.info("Check task with id = {}".format(task_id))
            task = sdk2.Task[task_id]
            if task.status in ctt.Status.Group.SUCCEED:
                for src, resource_name in SOURCES_TO_RESOURCE.items():
                    resource = get_request_resource(task, resource_name)
                    if resource is not None:
                        sources_to_files[src].append(resource)
                        logging.info('Add file {} for source = {}'.format(resource, src))

        for src in SOURCES_TO_RESOURCE:
            with tempfile.NamedTemporaryFile() as tmp:
                count = 0

                for file_resource in sources_to_files[src]:
                    with open(str(sdk2.ResourceData(file_resource).path), 'r') as f:
                        for line in f:
                            tmp.write(line)
                            count += 1

                if count < self.Parameters.requests_count:
                    raise TaskFailure('Not enough requests for {}, need {}, in real {}'.format(src, self.Parameters.requests_count, count))

                name = 'requests_{}'.format(src)
                desc = 'Random apphost request for {}'.format(src)
                out_res = resource_types.ENTITY_SEARCH_APP_HOST_SHOOTER_PLAN(self, desc, name, source=src)
                with open(str(sdk2.ResourceData(out_res).path), 'w') as f:
                    with tempfile.NamedTemporaryFile() as tmp_2:
                        run_process(['shuf', tmp.name], stdout=tmp_2)
                        run_process(['head', '-n', str(self.Parameters.requests_count), tmp_2.name], stdout=f)

        logging.info('Task finished, everything is ok')
