# -*- coding: utf-8 -*-

import json
import datetime
import pytz
import logging

from sandbox import sdk2
from sandbox.common.types.task import Status
from sandbox.projects.common import link_builder
from sandbox.projects.websearch.clickdaemon import resources
from sandbox.sandboxsdk import environments

from sandbox.projects.websearch.clickdaemon.tasks.redir_log_loss.CalculateRedirLogLoss import CalculateRedirLogLoss
from sandbox.projects.websearch.clickdaemon.tasks.redir_log_loss import parameters as params


class UpdateRedirLogLoss(sdk2.Task):
    """
        Задача обновляет статистику по потерям redir-лога
    """

    class Requirements(sdk2.Task.Requirements):
        cores = 1
        disk_space = 50  # 50 Mb
        environments = [
            environments.PipEnvironment('yandex-yt'),
            environments.PipEnvironment('python-statface-client', version="0.154.0", custom_parameters=["requests==2.18.4"]),
        ]

        class Caches(sdk2.Requirements.Caches):
            pass  # do not use any shared caches

    class Parameters(sdk2.Task.Parameters):
        log_types = params.log_types()

        with sdk2.parameters.RadioGroup("Hours range type") as hours_range_type:
            hours_range_type.values['last_hours'] = hours_range_type.Value('Last N hours', default=True)
            hours_range_type.values['range'] = hours_range_type.Value('Select range')

            with hours_range_type.value['last_hours']:
                hours = sdk2.parameters.Integer('Number of last hours to calculate redir-log loss', required=True)
            with hours_range_type.value['range']:
                start_hour = sdk2.parameters.String('Start hour to process in "YYYY-mm-dd HH:00:00" format', required=True)
                end_hour = sdk2.parameters.String('End hour to process in "YYYY-mm-dd HH:00:00" format', required=True)

        check_tables_existence = sdk2.parameters.Bool('Check that required tables exist', default=True)
        force_recalculate = sdk2.parameters.Bool('Recalculate already processed hours', default=False)
        redir_log_loss_tool = sdk2.parameters.Resource('redir-logs loss tool', resource_type=resources.RedirLogLossTool, required=False)

        statface_report_path = sdk2.parameters.String('Statface report path', required=True)

        yt_cluster = sdk2.parameters.String('YT cluster', default='hahn')
        yt_pool = sdk2.parameters.String('YT pool')
        yt_token_owner = sdk2.parameters.String('YT_TOKEN owner', required=True)
        statface_token_owner = sdk2.parameters.String('STATFACE_TOKEN owner', required=True)

        process_unclassified = sdk2.parameters.Bool('Use unclassified req-ids table if possible', default=True)
        store_unclassified = sdk2.parameters.Bool('Store unclassified req-ids table in YT if possible', default=True)

        redir_logs_dict = sdk2.parameters.Dict(
            "Redir logs list (path: scale)", default={},
            description="scale types: 1d, 1h, 30min. path will be prefixed with //logs/ in code"
        )

    def get_hours(self):
        if self.Parameters.hours_range_type == 'last_hours':
            current_date = datetime.datetime.now(pytz.timezone('Europe/Moscow'))
            return [current_date - datetime.timedelta(hours=t) for t in range(self.Parameters.hours)]
        elif self.Parameters.hours_range_type == 'range':
            start_hour = datetime.datetime.strptime(self.Parameters.start_hour, '%Y-%m-%d %H:%M:%S')
            end_hour = datetime.datetime.strptime(self.Parameters.end_hour, '%Y-%m-%d %H:%M:%S')
            return [start_hour + datetime.timedelta(hours=t) for t in range(int((end_hour - start_hour).total_seconds() / 60 / 60) + 1)]

    def get_statface_report(self, report_path):
        import statface_client
        statface_token = sdk2.Vault.data(self.Parameters.statface_token_owner, name='STATFACE_TOKEN')
        sf_client = statface_client.StatfaceClient(host=statface_client.STATFACE_PRODUCTION, oauth_token=statface_token)
        return sf_client.get_report(report_path)

    def get_data(self, report, date_str):
        import statface_client
        date_str = date_str.replace('T', ' ')
        data = report.download_data(scale=statface_client.constants.HOURLY_SCALE, date_min=date_str, date_max=date_str)

        result = []
        for section in data:
            result.append(dict(filter(lambda kv: kv[0] in params.REPORT_KEYS + ['fielddate', 'method'] and kv[1] is not None, section.items())))
        return result

    def merge_data(self, datas):
        result = {}
        for data in datas:
            result.setdefault('{}#{}'.format(data['fielddate'], data['method']), {}).update(data)
        return result.values()

    def get_log_paths(self, log_type, date):
        if log_type == 'redir_log' and self.Parameters.redir_logs_dict:
            cfg = {
                'logs' : [
                    {
                        'cypress_root' : '//logs/{}'.format(path),
                        'scale' : scale,
                    } for path, scale in self.Parameters.redir_logs_dict.items()
                ]
            }
        else:
            cfg = params.LOG_TYPES[log_type]

        result = []
        for log in cfg['logs']:
            date_strs = []
            if log['scale'] == params.DAILY:
                date_strs.append(date.strftime('%Y-%m-%d'))
            elif log['scale'] == params.HOURLY:
                date_strs.append(date.strftime('%Y-%m-%dT%H:00:00'))
            elif log['scale'] == params.HALF_HOURLY:
                date_strs.append(date.strftime('%Y-%m-%dT%H:00:00'))
                date_strs.append(date.strftime('%Y-%m-%dT%H:30:00'))

            result.extend(map(lambda date_str: '{}/{}/{}'.format(log['cypress_root'], log['scale'], date_str), date_strs))

        return result

    def get_yt_client(self):
        from yt.wrapper import YtClient
        yt_token = sdk2.Vault.data(self.Parameters.yt_token_owner, name='YT_TOKEN')
        return YtClient(self.Parameters.yt_cluster, yt_token)

    def delete_temporary_reqid_tables(self):
        yt_client = self.get_yt_client()
        temporary_tables = yt_client.list(params.TEMPORARY_REQID_TABLES_PATH)
        current_date = datetime.datetime.now()
        for table_name in temporary_tables:
            if table_name.endswith("GET"):
                table_date_str = table_name[:-len("-GET")]
            else:
                table_date_str = table_name[:-len("-POST")]
            table_date = datetime.datetime.strptime(table_date_str, '%Y-%m-%dT%H:%M:%S')
            if current_date - table_date > datetime.timedelta(days=14):
                yt_client.remove('{}/{}'.format(params.TEMPORARY_REQID_TABLES_PATH, table_name), force=True)

    def on_execute(self):
        import statface_client

        logging.info("\nINFO: REDIR_LOG_TABLES:\n{}\n".format("\n".join(["{}={}".format(path, scale) for path, scale in sorted(self.Parameters.redir_logs_dict.items())])))

        with self.memoize_stage.launch_caclulation:
            hours = self.get_hours()
            yt_client = self.get_yt_client()

            def required_tables_exist(date):
                table_paths = []
                for log_type in params.LOG_TYPES:
                    if log_type not in self.Parameters.log_types:
                        continue
                    table_paths.extend(self.get_log_paths(log_type, date))

                logging.info("\nINFO: Table paths existence check: \n{}\n".format("\n".join(sorted(table_paths))))
                return all(map(lambda table_path: yt_client.exists(table_path), table_paths))

            self.Context.unprocessed_hours = sorted(list(set(map(
                lambda hour: hour.strftime('%Y-%m-%dT%H:00:00'),
                filter(lambda hour: not self.Parameters.check_tables_existence or required_tables_exist(hour), hours)
            ))))

            logging.info("\nINFO: unprocessed hours after required tables exist check: \n{}\n".format("\n".join(self.Context.unprocessed_hours)))

            report = self.get_statface_report(self.Parameters.statface_report_path)

            def data_is_calculated(date_str):
                data = self.get_data(report, date_str)

                if len(data) != 3:
                    return False

                return all(map(
                    lambda log_type: all(item in section for section in data for item in params.LOG_TYPES[log_type]['report_keys']),
                    self.Parameters.log_types
                ))

            self.Context.unprocessed_hours = list(filter(
                lambda hour: self.Parameters.force_recalculate or not data_is_calculated(hour), self.Context.unprocessed_hours
            ))

            logging.info("\nINFO: unprocessed hours after data is calculated check: \n{}\n".format("\n".join(self.Context.unprocessed_hours)))

            self.Context.calc_loss_subtasks = dict()
            for hour in self.Context.unprocessed_hours:
                self.Context.calc_loss_subtasks[hour] = CalculateRedirLogLoss(
                    self,
                    description='Calculate loss for {}'.format(hour.replace('T', ' ')),
                    kill_timeout=12 * 3600,
                    redir_log_loss_tool=self.Parameters.redir_log_loss_tool,
                    log_types=self.Parameters.log_types,
                    date=hour,
                    yt_cluster=self.Parameters.yt_cluster,
                    yt_pool=self.Parameters.yt_pool,
                    yt_token_owner=self.Parameters.yt_token_owner,
                    process_unclassified=self.Parameters.process_unclassified,
                    store_unclassified=self.Parameters.store_unclassified,
                    redir_logs_dict=self.Parameters.redir_logs_dict,
                ).enqueue().id

        with self.memoize_stage.wait_tasks(len(self.Context.calc_loss_subtasks) + 1):
            processed_hours = []
            working_subtask_ids = []

            for hour in self.Context.unprocessed_hours:
                task = sdk2.Task[self.Context.calc_loss_subtasks[hour]]
                if task.status in Status.Group.FINISH:
                    processed_hours.append(hour)
                elif task.status not in Status.Group.BREAK:
                    working_subtask_ids.append(task.id)

            for hour in processed_hours:
                self.Context.unprocessed_hours.remove(hour)

            if processed_hours:
                self.set_info('Done: {}'.format(', '.join(map(
                    lambda hour: link_builder.task_link(self.Context.calc_loss_subtasks[hour], hour),
                    processed_hours
                ))), do_escape=False)

                report = self.get_statface_report(self.Parameters.statface_report_path)
                stats = []
                for hour in processed_hours:
                    calc_task_id = self.Context.calc_loss_subtasks[hour]
                    calc_task = sdk2.Task[calc_task_id]
                    stat_file_path = str(sdk2.ResourceData(calc_task.Parameters.stats).path)
                    stats += self.merge_data(self.get_data(report, hour) + json.load(open(stat_file_path, 'r')))
                report.upload_data(scale=statface_client.constants.HOURLY_SCALE, data=stats)

            if working_subtask_ids:
                raise sdk2.WaitTask(working_subtask_ids, Status.Group.FINISH | Status.Group.BREAK, wait_all=False)

        if self.Parameters.process_unclassified and self.Parameters.store_unclassified:
            self.delete_temporary_reqid_tables()
