# coding=utf-8
from __future__ import unicode_literals

import itertools
import logging

from sandbox import sdk2
from sandbox.projects.avia.base import AviaBaseTask
from sandbox.projects.avia.lib import yql_helpers as yqlh
from sandbox.projects.avia.lib.datetime_helpers import (
    get_utc_now, _dt_to_unixtime_utc, _dt_to_string
)
from sandbox.projects.common import binary_task, solomon
from sandbox.sandboxsdk.errors import SandboxTaskFailureError

logger = logging.getLogger(__name__)


class AviaSendYdbDeletionLagToSolomon(binary_task.LastBinaryTaskRelease, AviaBaseTask):
    """
        Send difference between oldest row from TTL-tables and current unixtime to solomon
        results_expiration_queue[0-9]
        wizard_results_expiration_queue[0-9]

    """
    _yql_client = None

    class Requirements(sdk2.Requirements):
        # configure this for your task, the more accurate - the better
        cores = 1  # exactly 1 core
        disk_space = 128  # 128 Megs or less
        ram = 128  # 128 Megs or less

        class Caches(sdk2.Requirements.Caches):
            pass  # means that task do not use any shared caches

    class Parameters(sdk2.Parameters):

        # binary task release parameters
        ext_params = binary_task.binary_release_parameters(stable=True)

        with sdk2.parameters.Group('YQL settings') as yql_settings:
            yav_yql_token = sdk2.parameters.YavSecret(
                "Yav-secret с YQL-токеном робота",
                default='sec-01e4q88q2jcz84c8dtwrm29z3r'
            )

        with sdk2.parameters.Group('Solomon settings') as solomon_settings:
            solomon_project = sdk2.parameters.String('Solomon project', required=True, default='avia')
            solomon_cluster = sdk2.parameters.String('Solomon cluster', required=True, default='ydb')
            solomon_service = sdk2.parameters.String('Solomon service', required=True, default='metrics')

        with sdk2.parameters.Group('YDB settings') as ydb_ttl_tables_settings:
            table_path = sdk2.parameters.String(
                'YDB TablePathPrefix',
                required=True,
                default='/ru/ticket/production/search_results',
            )
            search_result_ttl_table_name = sdk2.parameters.String(
                'SearchResults ttl table name',
                required=True,
                default='results_expiration_queue',
            )
            search_result_ttl_tables_count = sdk2.parameters.Integer(
                'SearchResults ttl tables count',
                required=True,
                default=10,
            )
            wizard_result_ttl_table_name = sdk2.parameters.String(
                'WizardResults ttl table name',
                required=True,
                default='wizard_results_expiration_queue',
            )
            wizard_result_ttl_tables_count = sdk2.parameters.Integer(
                'WizardResults ttl tables count',
                required=True,
                default=10,
            )
            wizard_result_experimental_ttl_table_name = sdk2.parameters.String(
                'WizardResults experimental ttl table name',
                required=True,
                default='wizard_results_experimental_expiration_queue',
            )
            wizard_result_experimental_ttl_tables_count = sdk2.parameters.Integer(
                'WizardResults experimental ttl tables count',
                required=True,
                default=10,
            )

    def on_execute(self):
        super(AviaSendYdbDeletionLagToSolomon, self).on_execute()
        start_dt = get_utc_now()

        expired_by_table = self.get_last_expired_rows_unixtime(start_dt)
        ttl_tables_lag = self.calculate_lag(expired_by_table, start_dt)
        self.send_data_to_solomon(ttl_tables_lag, start_dt)

    def get_last_expired_rows_unixtime(self, start_dt):
        results = []
        yql_client = self._get_yql_client()

        tables = itertools.chain(
            self._gen_ttl_tables(self.Parameters.search_result_ttl_table_name,
                                 self.Parameters.search_result_ttl_tables_count),
            self._gen_ttl_tables(self.Parameters.wizard_result_ttl_table_name,
                                 self.Parameters.wizard_result_ttl_tables_count),
            self._gen_ttl_tables(self.Parameters.wizard_result_experimental_ttl_table_name,
                                 self.Parameters.wizard_result_experimental_ttl_tables_count),
        )
        unixtime = _dt_to_unixtime_utc(start_dt)
        for ttl_table in tables:
            query = yql_client.query(
                self.get_ttl_table_lag(ttl_table, unixtime),
                syntax_version=1,
                title='[YQL] Get YDB TTL tables lag'
            )

            query_result = query.run()
            logging.info('YQL Operation: %s', yqlh.get_yql_operation_url(query_result))
            query_result.wait_progress()

            if not query_result.is_success:
                yqlh.log_errors(query_result, logging)
            else:
                for result in query_result.get_results():
                    if result.fetch_full_data():
                        try:
                            expires_at = result.rows[0][0]
                        except IndexError:
                            expires_at = None

                        results.append((ttl_table, expires_at))
                        logging.info(
                            'YQL query done for %s table. Last expired unixtime %r',
                            ttl_table, expires_at
                        )
        if not results:
            raise SandboxTaskFailureError('No results from YQL')

        return results

    @staticmethod
    def calculate_lag(tables_expires_at, start_dt):
        ttl_tables_lag = []
        start_unixtime = _dt_to_unixtime_utc(start_dt)
        for t_name, expires_at in tables_expires_at:
            lag = start_unixtime - expires_at if expires_at is not None else 0
            ttl_tables_lag.append((t_name, lag))
        return ttl_tables_lag

    def send_data_to_solomon(self, ttl_tables_lag, start_dt):
        shard_labels = {
            'project': self.Parameters.solomon_project,
            'cluster': self.Parameters.solomon_cluster,
            'service': self.Parameters.solomon_service,
        }
        logger.info(shard_labels)
        sensors = []

        for ttl_table, lag in ttl_tables_lag:
            sensors.append(
                {
                    'ts': _dt_to_string(start_dt) + 'Z',
                    'labels': {'sensor': 'ttl_table_lag', 'ttl_table': ttl_table},
                    'value': lag,
                }
            )

        logger.info('Sending sensors to solomon. %r', sensors)
        solomon.push_to_solomon_v2(self.solomon_token, shard_labels, sensors, common_labels=())

    def get_ttl_table_lag(self, ttl_table, unixtime):
        query = """
            PRAGMA TablePathPrefix("{table_path}");

            SELECT expires_at from {ttl_table} WHERE expires_at < {unixtime} LIMIT 1;
        """.format(
            table_path=self.Parameters.table_path,
            unixtime=unixtime,
            ttl_table=ttl_table,
        )

        return query

    def _get_yql_client(self):
        from yql.api.v1.client import YqlClient, config

        if self._yql_client is None:
            token = self.Parameters.yav_yql_token.data()['token']
            self._yql_client = YqlClient(token=token)
            config.single_cluster = True
            config.db = self.Parameters.table_path

        return self._yql_client

    def _gen_ttl_tables(self, table_prefix, count):
        return map(table_prefix.__add__, map(str, range(count)))
