from __future__ import print_function
import datetime
import json
import logging
import time
import requests
import os

from dateutil.tz import tzlocal

from sandbox.common.errors import TaskFailure
from sandbox.common.types.resource import State
from sandbox.sdk2 import (
    Task,
    ResourceData,
    parameters,
)
from sandbox.sdk2.helpers import subprocess, ProcessLog

from sandbox.projects.yabs.base_bin_task import BaseBinTask
from sandbox.projects.yabs.cpm_multiplier.resources import (
    get_last_resource,
    YabsCpmMultiplierBinary,
    YabsCpmMultiplier2Binary,
    YabsCpmMultiplierConfig,
    YabsCpmMultiplierYtBackup,
)


class YabsCpmMultiplierExecutor(BaseBinTask):
    '''Updates CPMMultiplier values
    '''
    class Requirements(Task.Requirements):
        cores = 1
        ram = 4096
        disk_space = 4096

        class Caches(Task.Requirements.Caches):
            pass

    class Parameters(BaseBinTask.Parameters):
        description = 'Update CPMMultiplier values'

        with BaseBinTask.Parameters.version_and_task_resource() as version_and_task_resource:
            resource_attrs = parameters.Dict('Filter resource by', default={'name': 'YabsCpmMultiplier'})

        with parameters.Group('Resources') as resource_params:
            binary_resource = parameters.Resource(
                'Binary resource',
                resource_type=YabsCpmMultiplierBinary,
                state=State.READY,
            )
            binary_resource2 = parameters.Resource(
                'New cpm_multiplier2 binary resource',
                resource_type=YabsCpmMultiplier2Binary,
                state=State.READY,
            )
            config_resource = parameters.Resource(
                'Parameters resource',
                resource_type=YabsCpmMultiplierConfig,
                state=State.READY,
            )

        with parameters.Group('YT parameters') as yt_params:
            yt_token_secret_id = parameters.YavSecret(
                label="YT token secret id",
                required=True,
                description='secret should contain keys: YT_TOKEN',
                default="sec-01d6dyn0qa3xds1mp820ssgbez",
            )
            yt_pool = parameters.String('Yt pool')
            yt_log_level = parameters.String(
                'YT_LOG_LEVEL',
                default='INFO',
            )
            tm_queue = parameters.String(
                'TransferManager queue',
            )
            tm_pool = parameters.String(
                'TransferManager YT pool',
            )

        with parameters.Group('Execution parameters') as exec_params:
            clusters = parameters.List(
                'YT clusters',
                required=True,
                default=['hahn', 'arnold'],
            )
            stat_path = parameters.String(
                'Path to RtpPageStat table',
                required=True,
                default='//home/yabs/stat/RtbPageStat',
            )
            dict_path = parameters.String(
                'Dictionaries path',
                required=True,
                default='//home/yabs/dict',
            )
            yt_prefix = parameters.String(
                'Output prefix',
                required=True,
                default='//home/yabs/stat/cpm_multiplier',
            )
            link_path = parameters.String(
                'Result link path',
                required=True,
                default='//home/yabs/stat/CPMMultiplier',
            )
            use_new_binary = parameters.Bool(
                'Use new cpm_multiplier2 binary',
                default=False,
            )

        with parameters.Group('Solomon parameters') as statistics:
            send_to_solomon = parameters.Bool('Send metrics to solomon', default=False)
            solomon_token = parameters.YavSecret(
                'Solomon token secret',
                description='Required key: solomon_token',
            )
            solomon_api_url = parameters.String('Solomon API URL', default='https://solomon.yandex.net/api/v2/push')
            solomon_project = parameters.String('Solomon project', default='yabs')
            solomon_cluster = parameters.String('Solomon cluster', default='yabs')
            solomon_service = parameters.String('Solomon service', default='cpm_multiplier')

        with parameters.Group('Testing parameters') as test_params:
            current_time = parameters.Integer(
                'Current timestamp',
                required=False,
            )
            stat_time = parameters.Integer(
                'Statistics timestamp',
                required=False,
            )
            backup_resource = parameters.Resource(
                'Backup resource (overwrites execution parameters)',
                resource_type=YabsCpmMultiplierYtBackup,
                state=State.READY,
            )
            result_ttl = parameters.Integer(
                'Output ttl (days)',
                default=14,
            )

    def do_copy_results(self, yt_token, source_cluster, target_clusters, tables, link_dest, copy_attrs=()):
        from yabs.yt_wrapper import YtClient
        from yt.transfer_manager.client import TransferManager

        tmc = TransferManager(token=yt_token)
        tm_tasks = {}
        for cluster in target_clusters:
            tm_tasks[cluster] = []
            for table in tables:
                params = {
                    'additional_attributes': copy_attrs
                }
                if self.Parameters.tm_pool:
                    pool = self.Parameters.tm_pool
                    params.update({
                        'copy_spec': {'pool': pool},
                        'postprocess_spec': {'pool': pool},
                    })
                if self.Parameters.tm_queue:
                    params['queue_name'] = self.Parameters.tm_queue

                tm_task_id = tmc.add_task(source_cluster, table, cluster, table, params=params)
                logging.info('Created TM task #%s: %s.[%s] -> %s.[%s]', tm_task_id, cluster, table, cluster, table)
                tm_tasks[cluster].append(tm_task_id)

        failed_clusters = []
        for cluster, tasks in tm_tasks.iteritems():
            failed = False
            for task_id in tasks:
                state = 'pending'
                logging.info('Waiting for cluster %s.', cluster)
                while state in ('pending', 'running'):
                    time.sleep(10)
                    state = tmc.get_task_info(task_id)['state']
                    logging.info('Waiting for task %s. Current state: %s.', task_id, state)

                if state != 'completed':
                    failed = True
                    failed_clusters.append(cluster)
                    logging.error('Failed to copy to cluster %s because of task %s', cluster, task_id)
                    break

            if not failed:
                ytc = YtClient(token=yt_token, proxy=cluster)
                ytc.link(link_dest, self.Parameters.link_path, force=True)

        if failed_clusters:
            raise TaskFailure('Failed to copy results to clusters %s', failed_clusters)

    def set_product_type_mapping(self, ytc, conf_file, result_table):
        with open(conf_file, 'r') as f:
            config = json.load(f)

        mapping = config.get('ProductTypeToBlock', {})
        ytc.set_attribute(result_table, 'product_type_to_block', json.dumps(mapping))

    def do_run(self, exec_cluster, current_time, stat_time, stat_path, dict_path, prev_path, out_path, ytc):
        if self.Parameters.use_new_binary:
            bin_res = self.Parameters.binary_resource2
        else:
            bin_res = self.Parameters.binary_resource
        conf_res = self.Parameters.config_resource

        if not bin_res:
            if self.Parameters.use_new_binary:
                bin_res = get_last_resource(YabsCpmMultiplier2Binary)
            else:
                bin_res = get_last_resource(YabsCpmMultiplierBinary)
            logging.info('Found binary resource %s', bin_res.id)
        if not conf_res:
            conf_res = get_last_resource(YabsCpmMultiplierConfig)
            logging.info('Found config resource %s', conf_res.id)

        binary = ResourceData(bin_res)
        config = ResourceData(conf_res)

        command = [
            binary.path,
            '--out-prefix', out_path,
            '--constants', config.path,
            '--time', current_time,
            '--stat-time', stat_time,
            '--proxy', exec_cluster,
            '--pool', self.Parameters.yt_pool,
            '--dict-path', dict_path,
            '--stat-path', stat_path,
            '--regulator', '--decay', '--multiplier',
        ]
        if prev_path:
            command += ['--prev-prefix', prev_path]
        with ProcessLog(self, logger='cpm_multiplier') as pl:
            subprocess.check_call(map(str, command), stdout=pl.stdout, stderr=pl.stderr)

        self.set_product_type_mapping(ytc, str(config.path), '{}/{}'.format(out_path, 'CPMMultiplier'))

    def execute(self, yt_token):
        from yabs.yt_wrapper import YtClient, ClusterChoiceCriterion

        criterion = ClusterChoiceCriterion(clusters=self.Parameters.clusters, collector_name=['RtbPageChevent', 'RtbPageDspChecked', 'RtbPageSsp'])
        ytc = YtClient(yt_token, criterion=criterion)
        exec_cluster = ytc.cluster
        other_clusters = set(self.Parameters.clusters) - {exec_cluster}

        if self.Parameters.stat_time:
            stat_time = self.Parameters.stat_time
        else:
            stat_time = ytc.update_time
            stat_time -= 1800
            stat_time -= stat_time % 3600

        if self.Parameters.current_time:
            current_time = self.Parameters.current_time
        else:
            current_time = int(time.time())

        out_path = os.path.join(self.Parameters.yt_prefix, str(current_time))
        multiplier_path = os.path.join(out_path, 'CPMMultiplier')
        regulator_path = os.path.join(out_path, 'Regulator')

        prev_path = None
        if ytc.exists(self.Parameters.link_path):
            prev_path = ytc.get_attribute(self.Parameters.link_path + '&', 'target_path', None)
            prev_path = '/'.join(prev_path.split('/')[:-1])
        self.do_run(
            exec_cluster,
            current_time,
            stat_time,
            self.Parameters.stat_path,
            self.Parameters.dict_path,
            prev_path,
            out_path,
            ytc
        )
        ytc.link(multiplier_path, self.Parameters.link_path, force=True)

        if (len(other_clusters) > 0):
            tables = (multiplier_path, regulator_path)
            self.do_copy_results(yt_token, exec_cluster, other_clusters, tables, multiplier_path)

        if self.Parameters.send_to_solomon and not self.Parameters.current_time and not self.Parameters.stat_time:
            solomon_token = self.Parameters.solomon_token.data()['solomon_token']
            headers = {'Authorization': 'OAuth ' + solomon_token}
            json_data = {
                'commonLabels': {},
                'sensors': [{
                    'labels': {'sensor': 'statistics_delay'},
                    'ts': current_time,
                    'value': current_time - stat_time,
                }],
            }
            url = '{api_url}?project={project}&cluster={cluster}&service={service}'.format(
                api_url=self.Parameters.solomon_api_url,
                project=self.Parameters.solomon_project,
                cluster=self.Parameters.solomon_cluster,
                service=self.Parameters.solomon_service,
            )
            resp = requests.post(url, json=json_data, headers=headers)
            resp.raise_for_status()

    def execute_test(self, yt_token):
        import yt.wrapper as yt

        ytc = yt.YtClient(token=yt_token, proxy=self.Parameters.clusters[0])

        backup_file = ResourceData(self.Parameters.backup_resource)
        with open(str(backup_file.path), 'r') as f:
            backup_data = json.load(f)

        out_path = '{}/{}'.format(self.Parameters.yt_prefix, self.id)
        self.do_run(
            self.Parameters.clusters[0],
            backup_data['current_time'],
            backup_data['stat_time'],
            backup_data['stat_path'],
            backup_data['dict_path'],
            backup_data['previous_path'],
            out_path,
            ytc,
        )

        ts = datetime.datetime.now(tzlocal())
        ts += datetime.timedelta(days=self.Parameters.result_ttl)
        ytc.set_attribute(out_path, 'expiration_time', ts.isoformat())

    def on_execute(self):
        yt_token = self.Parameters.yt_token_secret_id.data()["YT_TOKEN"]
        os.environ['YT_TOKEN'] = yt_token
        os.environ['YT_LOG_LEVEL'] = self.Parameters.yt_log_level or ''

        if self.Parameters.backup_resource:
            self.execute_test(yt_token)
        else:
            self.execute(yt_token)
