# coding=utf-8
import json
import hashlib
import logging
import os
from datetime import datetime, timedelta
from string import Template

import sandbox.projects.sandbox_ci.pulse.utils.yql as yql_utils
from sandbox import sdk2
from sandbox.common.errors import TaskFailure
from sandbox.common.types.task import Status
from sandbox.projects.sandbox_ci import WebMicroPackage
from sandbox.projects.sandbox_ci.pulse import parameters as pulse_params, const as pulse_const
from sandbox.projects.sandbox_ci.pulse.measure_wizard_weights import measurer
from sandbox.projects.sandbox_ci.pulse.pulse_shooter_custom import PulseShooterCustom
from sandbox.projects.sandbox_ci.pulse.pulse_shooting_basket import sql as common_sql
from sandbox.projects.sandbox_ci.pulse.resources import PulseShooterWizardWeights
from sandbox.projects.sandbox_ci.task.ManagersTaskMixin import ManagersTaskMixin
from sandbox.projects.sandbox_ci.task.binary_task import TasksResourceRequirement
from sandbox.projects.yql.RunYQL2 import RunYQL2

CUSTOM_TAG = 'WIZARD_WEIGHTS'
DEFAULT_DATE_FROM = (datetime.strptime(pulse_params.DEFAULT_DATE, '%Y-%m-%d') - timedelta(days=7)).strftime('%Y-%m-%d')
DEFAULT_DATE_TO = pulse_params.DEFAULT_DATE


class MeasureWizardWeights(TasksResourceRequirement, ManagersTaskMixin, sdk2.Task):
    """
    Вычислить вес колдунщиков Веба для определённой платформы
    с помощью Pulse Shooter
    """

    class Requirements(sdk2.Requirements):
        disk_space = 1024
        cores = 1

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(sdk2.Parameters):
        kill_timeout = 20 * 3600

        with sdk2.parameters.Group('Dataset Parameters') as dataset_params_block:
            date_from = sdk2.parameters.String('Date from', default=DEFAULT_DATE_FROM, required=False)
            date_to = sdk2.parameters.String('Date to', default=DEFAULT_DATE_TO, required=False)

            with sdk2.parameters.String('Project') as project:
                project.values['web4'] = project.Value('web4', default=True)

            with sdk2.parameters.String('Platform') as platform:
                platform.values['desktop'] = platform.Value('desktop')
                platform.values['touch'] = platform.Value('touch', default=True)
                platform.values['touch-pad'] = platform.Value('touch-pad')

            show_limit = sdk2.parameters.Integer('Wizard show limit', required=True, default=5000)

            use_cache = sdk2.parameters.Bool('Use cache', default=True)

        with sdk2.parameters.Group('Shooting Parameters') as shooting_params_block:
            wizard_limit = sdk2.parameters.Integer('Wizards for measure limit', required=True, default=25)
            request_number = sdk2.parameters.Integer('Request number', required=True, default=1000)
            apphost_mode = sdk2.parameters.Bool('Apphost mode', default=True)

        with sdk2.parameters.Group('Weight params') as weight_params_block:
            template_weight = sdk2.parameters.Integer('Template weight', required=True, default=70)
            size_weight = sdk2.parameters.Integer('Size weight', required=True, default=30)

    class Context(sdk2.Context):
        task_hash = None
        date_from = None
        date_to = None

        check_dataset_task_id = None
        collect_dataset_task_id = None
        get_wiz_ids_task_id = None
        psc_task_ids = []
        psc_wiz_map = {}

        wiz_ids = []
        ps_results = {}

        weights = {}
        avg_weight = 0

    def on_execute(self):
        with self.memoize_stage.check_dataset():
            self.setup_run()
            self.run_dataset_check()

        with self.memoize_stage.collect_dataset():
            self.run_dataset_collection()

        with self.memoize_stage.run_ids_extraction():
            self.check_collect_status()
            self.run_wiz_ids_extraction()

        with self.memoize_stage.run_shooting():
            self.download_wiz_ids()
            self.run_shooting()

        with self.memoize_stage.calculate_weights():
            self.get_shooting_results()
            self.measure_weights()

    def setup_run(self):
        if self.Parameters.date_from:
            self.Context.date_from = str(self.Parameters.date_from)
        else:
            self.set_info('Using default date from: %s' % DEFAULT_DATE_FROM)
            self.Context.date_from = DEFAULT_DATE_FROM

        if self.Parameters.date_to:
            self.Context.date_to = str(self.Parameters.date_to)
        else:
            self.set_info('Using default date to: %s' % DEFAULT_DATE_TO)
            self.Context.date_to = DEFAULT_DATE_TO

        hash_base = json.dumps({
            'project': str(self.Parameters.project),
            'platform': str(self.Parameters.platform),
            'date_from': self.Context.date_from,
            'date_to': self.Context.date_to,
            'show_limit': str(self.Parameters.show_limit),
        }, sort_keys=True)
        logging.debug('Hash base: %s', hash_base)
        self.Context.task_hash = hashlib.sha256(hash_base).hexdigest()

    def run_dataset_check(self):
        if not self.Parameters.use_cache:
            return

        task_path = os.path.dirname(__file__)
        tpl_path = '%s/sql/check-dataset-exists.sql' % task_path

        tpl_vars = {'task_hash': self.Context.task_hash}

        yql_query = Template(common_sql.read_file(tpl_path)).safe_substitute(**tpl_vars)
        logging.debug('=== YQL query begin\n%s\n=== YQL query end', yql_query)

        subtask = RunYQL2(
            self,
            owner=yql_utils.YQL_TOKEN_OWNER,
            kill_timeout=self.Parameters.kill_timeout,
            description=self.Parameters.description + ' – check dataset',
            query=yql_query,
            yql_token_vault_name=yql_utils.YQL_TOKEN_NAME,
            use_v1_syntax=True,
            publish_query=True,
            publish_download_link=True,
            public_download_link=True,
            download_format='JSON',
            trace_query=True,
            retry_period=5,
        ).enqueue()

        self.Context.check_dataset_task_id = subtask.id

        raise sdk2.WaitTask(
            subtask,
            Status.Group.FINISH | Status.Group.BREAK,
            timeout=self.Parameters.kill_timeout,
            wait_all=True,
        )

    def need_collect_dataset(self):
        if not self.Context.check_dataset_task_id:
            return True

        check_task = sdk2.Task[self.Context.check_dataset_task_id]
        if check_task.status not in Status.Group.SUCCEED:
            raise TaskFailure('Dataset check ends with status %s' % check_task.status)

        result_url = str(check_task.Parameters.results_sample_download_url)
        result_file = 'dataset_exists_text'
        yql_utils.download_from_yql(result_url, result_file)

        dataset_exists = False
        try:
            with open(result_file) as fp:
                dataset_exists = json.load(fp)['dataset_exists']
        except (ValueError, TypeError, KeyError) as e:
            logging.warning('Dataset existing read error: %s', e)

        return not dataset_exists

    def run_dataset_collection(self):
        if not self.need_collect_dataset():
            return self.set_info('Using existing dataset for current params')

        project = str(self.Parameters.project)
        platform = str(self.Parameters.platform)

        routes_id = pulse_const.ROUTES_CONFIG.get(project, {}).get(platform)
        if not routes_id:
            raise TaskFailure('There is no route for %s:%s' % (project, platform))

        request_pattern = pulse_const.HANDLERS_MAP.get(project, {}).get(platform)
        if not request_pattern:
            raise TaskFailure('There is no handler for %s:%s' % (project, platform))
        request_pattern += '?'

        task_path = os.path.dirname(__file__)
        tpl_path = '%s/sql/%s/prepare-dataset.sql' % (task_path, project)
        tpl_vars = {
            'task_hash': self.Context.task_hash,
            'date_from': self.Context.date_from,
            'date_to': self.Context.date_to,
            'routes_id': routes_id,
            'request_pattern': request_pattern,
            'show_limit': self.Parameters.show_limit
        }
        yql_query = Template(common_sql.read_file(tpl_path)).safe_substitute(**tpl_vars)

        subtask = RunYQL2(
            self,
            owner=yql_utils.YQL_TOKEN_OWNER,
            kill_timeout=self.Parameters.kill_timeout,
            description=self.Parameters.description + ' – collect dataset',
            query=yql_query,
            yql_token_vault_name=yql_utils.YQL_TOKEN_NAME,
            use_v1_syntax=True,
            publish_query=True,
            publish_download_link=False,
            public_download_link=False,
            trace_query=True,
            retry_period=120,
        ).enqueue()

        self.Context.collect_dataset_task_id = subtask.id

        raise sdk2.WaitTask(
            subtask,
            Status.Group.FINISH | Status.Group.BREAK,
            timeout=self.Parameters.kill_timeout,
            wait_all=True,
        )

    def check_collect_status(self):
        if not self.Context.collect_dataset_task_id:
            return

        task = sdk2.Task[self.Context.collect_dataset_task_id]
        if task.status not in Status.Group.SUCCEED:
            raise TaskFailure('Dataset collection ended with status %s' % task.status)

    def run_wiz_ids_extraction(self):
        task_path = os.path.dirname(__file__)
        tpl_path = '%s/sql/get-ids.sql' % task_path
        tpl_vars = {
            'task_hash': self.Context.task_hash,
            'limit': int(self.Parameters.wizard_limit),
        }
        yql_query = Template(common_sql.read_file(tpl_path)).safe_substitute(**tpl_vars)

        subtask = RunYQL2(
            self,
            owner=yql_utils.YQL_TOKEN_OWNER,
            kill_timeout=self.Parameters.kill_timeout,
            description=self.Parameters.description + ' – get ids',
            query=yql_query,
            yql_token_vault_name=yql_utils.YQL_TOKEN_NAME,
            use_v1_syntax=True,
            publish_query=True,
            publish_download_link=True,
            public_download_link=True,
            download_format='JSON',
            trace_query=True,
            retry_period=30,
        ).enqueue()

        self.Context.get_wiz_ids_task_id = subtask.id

        raise sdk2.WaitTask(
            subtask,
            Status.Group.FINISH | Status.Group.BREAK,
            timeout=self.Parameters.kill_timeout,
            wait_all=True,
        )

    def download_wiz_ids(self):
        task = sdk2.Task[self.Context.get_wiz_ids_task_id]
        if task.status not in Status.Group.SUCCEED:
            raise TaskFailure('IDS extraction ended with status %s' % task.status)

        dest_file = 'ids_data'
        res_params = task.Parameters
        download_url = res_params.full_results_download_url or res_params.results_sample_download_url
        yql_utils.download_from_yql(download_url, dest_file)

        wiz_ids = []
        with open(dest_file) as fp:
            for line in fp:
                data = json.loads(line) or {}

                wiz_id = data.get('id')
                if wiz_id:
                    wiz_ids.append(wiz_id)

        self.Context.wiz_ids = wiz_ids
        self.set_info('Measure wizard with IDs: %s' % ', '.join(wiz_ids))

    def run_shooting(self):
        task_path = os.path.dirname(__file__)
        tpl_path = '%s/sql/select-ammo-for-id.sql' % task_path
        tpl = Template(common_sql.read_file(tpl_path))

        project = str(self.Parameters.project)

        if project == 'web4':
            res_type = WebMicroPackage
        else:
            raise TaskFailure('Unknown project %s' % project)

        # noinspection PyTypeChecker
        templates = sdk2.Resource.find(
            res_type, attrs=dict(released='stable')
        ).order(-sdk2.Resource.id).first()

        if not templates:
            raise TaskFailure('Last templates release is not found')

        requests_number = int(self.Parameters.request_number)

        psc_task_ids = []
        psc_wiz_map = {}
        for wiz_id in self.Context.wiz_ids:
            tpl_vars = {
                'task_hash': self.Context.task_hash,
                'wiz_id': wiz_id,
                'limit': requests_number,
            }
            yql_query = tpl.safe_substitute(**tpl_vars)

            base_query_params = pulse_const.HAMSTER_QUERY_PARAM_LIST[:]
            base_query_params.append(
                ('rearr=scheme_blender/commands/after_blend/del_intent_for_velocity_check='
                 '{"Cmd":"del_intent_docs",'
                 '"DeleteIntents":{"%s":0},'
                 '"Places":["Wizplace", "Right", "Main"],'
                 '"spkey": "del_intent_for_velocity_check",'
                 '"on":0}') % wiz_id
            )

            actual_query_params = pulse_const.HAMSTER_QUERY_PARAM_LIST[:]
            actual_query_params.append(
                ('rearr=scheme_blender/commands/after_blend/del_intent_for_velocity_check='
                 '{"Cmd":"del_intent_docs",'
                 '"DeleteIntents":{"%s":1},'
                 '"Places":["Wizplace", "Right", "Main"],'
                 '"spkey": "del_intent_for_velocity_check",'
                 '"on":1}') % wiz_id
            )

            subtask = PulseShooterCustom(
                self,
                description=self.Parameters.description + ' – shoot %s' % wiz_id,
                ammo_type='custom',
                yql_query=yql_query,
                build_actual_ammo=True,
                base_query_params=base_query_params,
                actual_query_params=actual_query_params,
                requests_number=max(requests_number, 1000),
                access_log_threshold=max(round(requests_number * 0.6), 600),
                plan_threshold=max(round(requests_number * 0.5), 500),
                project=project,
                platform=self.Parameters.platform,
                base_templates_package=templates.id,
                actual_templates_package=templates.id,
                apphost_mode=self.Parameters.apphost_mode,
                tags=(CUSTOM_TAG,),
            ).enqueue()

            psc_task_ids.append(subtask.id)
            psc_wiz_map[subtask.id] = wiz_id

        self.Context.psc_task_ids = psc_task_ids
        self.Context.psc_wiz_map = psc_wiz_map

        raise sdk2.WaitTask(
            psc_task_ids,
            Status.Group.FINISH | Status.Group.BREAK,
            timeout=self.Parameters.kill_timeout,
            wait_all=True,
        )

    def get_shooting_results(self):
        ps_results = {}
        for task_id in self.Context.psc_task_ids:
            psc_task = sdk2.Task[task_id]
            if psc_task.status not in Status.Group.SUCCEED:
                raise TaskFailure('Pulse Shooter Custom %s failed with status %s' % (task_id, psc_task.status))

            wiz_id = self.Context.psc_wiz_map[str(task_id)]
            ps_local_wiz_result = {}
            for table in psc_task.Parameters.results:
                if table['subtitle'] != 'post-search':
                    continue

                for row in table['rows']:
                    raw_name = row['raw_name']
                    if raw_name not in ('total_template_time', 'gzipped_size'):
                        continue

                    val = None
                    for val_cell in row['values']:
                        if val_cell['aggr_name'] == 'avg':
                            val = val_cell['delta']
                            break

                    if val is not None:
                        ps_local_wiz_result[raw_name] = val

            if not ps_local_wiz_result:
                raise TaskFailure('No results about templating and size for PSC %s' % task_id)

            ps_results[wiz_id] = ps_local_wiz_result

        self.Context.ps_results = ps_results

    def measure_weights(self):
        weights = measurer.measure_weights(
            self.Context.ps_results,
            template_weight=int(self.Parameters.template_weight),
            size_weight=int(self.Parameters.size_weight),
        )
        avg_weight = sum(weights.itervalues()) / len(weights)

        diagram = measurer.draw_diagram(weights)

        self.Context.weights = weights
        self.Context.avg_weight = avg_weight

        weights_dir = 'weights'
        if not os.path.exists(weights_dir):
            os.mkdir(weights_dir)

        weights_file = os.path.join(weights_dir, 'weights.json')
        diagram_file = os.path.join(weights_dir, 'diagram.txt')

        with open(weights_file, 'w') as fw, open(diagram_file, 'w') as fd:
            json.dump({
                'weights': weights,
                'avg_weight': avg_weight,
            }, fw, indent=2, sort_keys=True)

            fd.write(diagram)

        res = PulseShooterWizardWeights(
            self, 'Pulse Shooter wizard weights',
            weights_dir,
            project=str(self.Parameters.project),
            platform=str(self.Parameters.platform),
        )
        sdk2.ResourceData(res).ready()
