# coding: utf8
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
import time

from sandbox import sdk2
from sandbox.sandboxsdk.environments import PipEnvironment
from sandbox.sandboxsdk.errors import SandboxTaskFailureError

from sandbox.projects.avia.base import AviaBaseTask
from sandbox.projects.avia.lib.yt_helpers import YtClientFactory
from sandbox.projects.avia.lib.datetime_helpers import _dt_to_unixtime_utc


class SendMetricsFromMetricsToYt(AviaBaseTask):
    """
    Send metrics from metrics to yt-tables.
    """

    _yt_client = None
    _yql_client = None

    cron_jobs_url = 'https://metrics-experiments.metrics.yandex-team.ru/api/cron/{cron_id}/jobs' \
                    '?status=COMPLETED&pageSize={cron_jobs_page_size}'

    mertics_url = 'https://metrics-calculation.qloud.yandex-team.ru/api/qex/metric-by-queries/' \
                  '?regional=RU&evaluation=WEB&metric={metric_name}' \
                  '&left-serp-set={serp_set_id}&right-serp-set={serp_set_id}&left-serp-set-filter={serp_set_filter}'

    basket_url = 'https://metrics-qgaas.metrics.yandex-team.ru/api/basket/{basket_id}/query'

    label_names = set()

    class Requirements(sdk2.Task.Requirements):
        cores = 1
        ram = 1024

        class Caches(sdk2.Requirements.Caches):
            pass  # We do not need caches

        environments = (
            PipEnvironment('requests'),
            PipEnvironment('python-dateutil'),
            PipEnvironment('yandex-yt', version='0.10.8'),
            PipEnvironment('yandex-yt-yson-bindings-skynet', version='0.3.32-0'),
        )

    class Parameters(sdk2.Task.Parameters):

        with sdk2.parameters.Group('YT settings') as yt_block:
            yt_cluster = sdk2.parameters.String('YT cluster', default='hahn', required=True)
            vaults_owner = sdk2.parameters.String('YT user', required=True)
            yt_token_vault_name = sdk2.parameters.String('YT Token vault name', required=True, default='YT_TOKEN')
            yt_dir = sdk2.parameters.String('Directory', required=True,
                                            default='//home/travel/analytics/regular/train_metrics_from_metrics')

        with sdk2.parameters.Group('Metrics settings') as metrics_block:
            metrics_token_vault_name = sdk2.parameters.String('Metrics Token vault name',
                                                              required=True, default='METRICS_TOKEN')
            basket_id = sdk2.parameters.Integer('Basket id', required=True, default=353344)
            cron_id = sdk2.parameters.Integer('Cron id', required=True, default=102463)
            metric_names_str = sdk2.parameters.String('Metric names', required=True,
                                                      default='has-wizard-transport has-wizard-raspisanie')
            serp_set_filter = sdk2.parameters.String('Serp set filter', required=True,
                                                      default='skipRightAlign')
            cron_jobs_page_size = sdk2.parameters.Integer('Number of last cron jobs for download', required=True,
                                                          default=3)
            request_attempts_number = sdk2.parameters.Integer('Request attempts number', required=True,
                                                              default=2)

        with sdk2.parameters.Group('Debug settings') as debug_settings:
            debug_run = sdk2.parameters.Bool('Debug run', default=False, required=True)

    @property
    def yt_client(self):
        if self._yt_client is None:
            self._yt_client = YtClientFactory.create(
                proxy=self.Parameters.yt_cluster,
                token=sdk2.Vault.data(self.Parameters.vaults_owner, self.Parameters.yt_token_vault_name),
            )

        return self._yt_client

    def str_to_unixtime_utc(self, sdt):
        from dateutil.parser import parse
        METRICS_DATE_FORMAT = '%Y-%m-%dT%H:%M:%S.%f%z'
        try:
            return _dt_to_unixtime_utc(parse(sdt))
        except ValueError as e:
            logging.info('time=%s not in format %s', sdt, METRICS_DATE_FORMAT)
            logging.exception(e)
            raise e

    def get_query_key(self, query):
        return query['text'], query['device'], query['regionId'], query['country']

    def add_metrics(self, token, serp_set_id, metric_name, metrics, serp_set_filter):
        url = self.mertics_url.format(serp_set_id=serp_set_id,
                                      metric_name=metric_name,
                                      serp_set_filter=serp_set_filter)
        print(url)
        response = self.make_request(url, token, self.Parameters.request_attempts_number)
        for q in response['calculatedQueries']:
            key = self.get_query_key(q['query'])
            if key in metrics:
                metrics[key][metric_name] = q['leftMetricValue']['metricValue']
            else:
                metrics[key] = {metric_name: q['leftMetricValue']['metricValue']}

    def get_serp_set_ids(self, response):
        return [(job['results'][0]['data'][0]['serpsetIds'][0], job['startedTime']) for job in response['content']]

    def make_request(self, url, token, attempts=2):
        import requests
        session = requests.Session()
        session.verify = False
        for i in range(attempts):
            try:
                r = session.get(url, headers={'Authorization': 'OAuth {}'.format(token)})
                r.raise_for_status()
                return r.json()
            except Exception as e:
                logging.info(e)
                # print(e)
                logging.exception(e)
                time.sleep(10)
        raise SandboxTaskFailureError('Exceeded number of attempts requesting url %s %s',
                                      url, 'OAuth {}'.format(token))

    def get_labels(self, labels):
        res = {}
        for i, label in enumerate(labels):
            label_name = 'label_{}'.format(i)
            res[label_name] = label
            self.label_names.add(label_name)
        return res

    def get_basket(self, token, basket_id):
        url = self.basket_url.format(basket_id=basket_id)
        print(url)
        response = self.make_request(url, token, self.Parameters.request_attempts_number)
        basket = []
        for q in response:
            key = self.get_query_key(q)
            if 'labels' in q:
                basket.append((key, self.get_labels(q['labels'])))
            else:
                basket.append((key, {}))

        return basket

    def generate_table(self, basket, metrics, serp_set_time, serp_set_id, metric_names):
        linked_queries = 0
        for key, labels in basket:
            (text, device, regionId, country) = key
            res = {'text': text,
                   'regionId': regionId,
                   'device': device,
                   'country': country,
                   'serp_set_time': serp_set_time,
                   'serp_set_timestamp': self.str_to_unixtime_utc(serp_set_time),
                   'serp_set_id': serp_set_id,
                   }
            if key in metrics:
                linked_queries += 1
            for m in metric_names:
                res[m] = metrics[key].get(m, 0) if key in metrics else 0.0
            res.update(labels)
            yield res
        logging.info('linked_queries = %d', linked_queries)

    def write_result(self, serp_set_id, serp_set_time, metrics, basket, metric_names, label_names):
        import yt.wrapper as yt

        logging.info('Writing result')
        output_table = yt.ypath_join(self.Parameters.yt_dir, '{}'.format(serp_set_id))
        logging.info('Output table: %s', output_table)

        if not self.yt_client.exists(self.Parameters.yt_dir):
            self.yt_client.create('map_node', self.Parameters.yt_dir, recursive=True)

        with self.yt_client.Transaction():
            if self.yt_client.exists(output_table):
                self.yt_client.remove(output_table)

            schema = [
                {'type': 'string', 'name': 'text'},
                {'type': 'int64', 'name': 'regionId'},
                {'type': 'string', 'name': 'country'},
                {'type': 'string', 'name': 'device'},
                {'type': 'string', 'name': 'serp_set_time'},
                {'type': 'uint64', 'name': 'serp_set_timestamp'},
                {'type': 'int64', 'name': 'serp_set_id'},
            ]
            schema += [{'type': 'double', 'name': metric_name} for metric_name in metric_names]
            schema += [{'type': 'string', 'name': label_name} for label_name in label_names]
            logging.info('Output table schema %s', schema)

            self.yt_client.create(
                'table',
                output_table,
                attributes={
                    'optimize_for': 'scan',
                    'schema': schema,
                },
                recursive=True,
            )

            self.yt_client.write_table(
                output_table,
                self.generate_table(basket, metrics, serp_set_time, serp_set_id, metric_names)
            )

    def on_execute(self):
        logging.info('Start')

        logging.info('token %s %s', self.Parameters.vaults_owner, self.Parameters.metrics_token_vault_name)
        token = sdk2.Vault.data(self.Parameters.vaults_owner, self.Parameters.metrics_token_vault_name)

        metric_names = self.Parameters.metric_names_str.split(' ')
        logging.info('metric_names = %s', metric_names)

        basket = self.get_basket(token, self.Parameters.basket_id)
        logging.info('Basket len = %s', len(basket))

        url = self.cron_jobs_url.format(cron_id=self.Parameters.cron_id,
                                        cron_jobs_page_size=self.Parameters.cron_jobs_page_size)
        logging.info('Cron jobs url = %s', url)
        response = self.make_request(url, token, self.Parameters.request_attempts_number)

        for serp_set_id, serp_set_time in self.get_serp_set_ids(response):
            logging.info('Process serp_set %s downloaded at %s', serp_set_id, serp_set_time)
            metrics = {}
            for metric_name in metric_names:
                self.add_metrics(token,
                                 serp_set_id,
                                 metric_name,
                                 metrics,
                                 self.Parameters.serp_set_filter)
            self.write_result(serp_set_id, serp_set_time, metrics, basket, metric_names, self.label_names)

        logging.info('End')
