# -*- coding: utf-8 -*-

from sandbox import sdk2
from sandbox.common.errors import TaskFailure


def try_float(s):
    try:
        return float(s)
    except ValueError:
        return s


class YqlToSolomon(sdk2.Task):
    '''
        Updates solomon metrics with results of YQL query
    '''

    SOLOMON_API_URL = 'http://api.solomon.search.yandex.net'

    class Requirements(sdk2.Task.Requirements):
        cores = 1

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(sdk2.Task.Parameters):
        kill_timeout = 60 * 60
        do_not_restart = True

        with sdk2.parameters.Group('YQL parameters'):
            token_vault_name = sdk2.parameters.String(
                'Token secret name',
                default='YQL_TOKEN',
                required=True,
            )
            query = sdk2.parameters.String(
                'YQL Query',
                multiline=True,
                required=True,
            )
            yql_cluster = sdk2.parameters.String(
                'YQL cluster (specify - when to use USE clause)',
                default='hahn',
                required=True,
            )
            yql_provider = sdk2.parameters.String(
                'YQL provider (yt, clickhouse, ...)',
                default='yt',
                required=True,
            )
            use_clickhouse_syntax = sdk2.parameters.String(
                'Run query in Clickhouse syntax (True / False)',
                default='False',
            )
            retry_period = sdk2.parameters.Integer(
                'Time period to check request status (in seconds)',
                default=60 * 5
            )

        with sdk2.parameters.Group('Solomon parameters'):
            project = sdk2.parameters.String(
                'Project',
                required=True,
            )
            solomon_api_token = sdk2.parameters.String(
                'Solomon token secret name',
                default='solomon_token',
                required=True,
            )
            cluster = sdk2.parameters.String(
                'Cluster',
                required=True,
            )
            service = sdk2.parameters.String(
                'Service',
                required=True,
            )
            sensors = sdk2.parameters.JSON(
                'Sensors',
                required=True,
                default_value={
                    'sensor_name': r'sensor_value or {result_column_name}',
                },
            )
            labels = sdk2.parameters.JSON(
                'Labels',
                default_value={
                    'label_name': r'label_value or {result_column_name}',
                },
            )

    def _create_client(self):
        from yql.api.v1.client import YqlClient

        token = sdk2.Vault.data(self.author, self.Parameters.token_vault_name)
        yql_cluster = None if self.Parameters.yql_cluster == '-' else self.Parameters.yql_cluster

        return YqlClient(db=yql_cluster, token=token, provider=self.Parameters.yql_provider)

    def _run_yql(self):
        use_clickhouse_syntax = self.Parameters.use_clickhouse_syntax == 'True'
        request = self.yql_client.query(self.Parameters.query, clickhouse_syntax=use_clickhouse_syntax, syntax_version=1)
        request.run()
        self.Context.yql_operation_id = request.operation_id

    def _wait_yql(self):
        from yql.client.operation import YqlOperationStatusRequest

        status = YqlOperationStatusRequest(self.Context.yql_operation_id)
        status.run()
        if status.status in status.IN_PROGRESS_STATUSES:
            raise sdk2.WaitTime(self.Parameters.retry_period)
        if status.status != 'COMPLETED':
            raise TaskFailure('YQL query failed')

    def _iter_results(self):
        from yql.client.operation import YqlOperationResultsRequest

        results = YqlOperationResultsRequest(self.Context.yql_operation_id)
        results.run()
        for table in results.get_results():
            column_names = [column_name for column_name, _ in table.columns]
            for row in table.rows:
                yield dict(zip(column_names, row))

    def _push_results(self):
        from solomon import PushApiReporter, OAuthProvider

        solomon_api_token = sdk2.Vault.data(self.author, self.Parameters.solomon_api_token)
        pusher = PushApiReporter(
            project=self.Parameters.project,
            cluster=self.Parameters.cluster,
            service=self.Parameters.service,
            url=self.SOLOMON_API_URL,
            auth_provider=OAuthProvider(solomon_api_token)
        )

        for row in self._iter_results():
            row_sensors = {
                key: try_float(value.format(**row))
                for key, value in self.Parameters.sensors.items()
            }
            row_labels = {
                key: value.format(**row)
                for key, value in self.Parameters.labels.items()
            }
            pusher.set_value(row_sensors.keys(), row_sensors.values(), [row_labels] * len(row_sensors))

    def on_execute(self):
        self.yql_client = self._create_client()

        with self.memoize_stage.run_yql(commit_on_entrance=False):
            self._run_yql()

        self._wait_yql()

        with self.memoize_stage.push_results:
            self._push_results()
