import logging
import os
import time
from datetime import datetime

import sandbox.common.types.task as ctt
from sandbox import sdk2
from sandbox.projects.yql.RunYQL2 import RunYQL2
from sandbox.sandboxsdk import environments
from sandbox.sandboxsdk import errors
from sandbox.sdk2.vcs.svn import Arcadia

from sandbox.projects.cloud.analytics.common.analytics_task import AnalyticsTask


class QuorumType:
    ALL = 'all'  # at least all hosts
    MAJORITY = 'majority'  # at least more than half of the hosts
    AT_LEAST_ONE = 'at_least_one'  # at least one host

    _ALL = [ALL, MAJORITY, AT_LEAST_ONE]


class CloudYtToCH(AnalyticsTask):
    """ Task to run YQL query from file in arcadia
    and then copy it's result to analytics CH"""

    class NotSyncedTableException(Exception):
        def __init__(self, ch_host, ch_count, yt_count):
            self.ch_host = ch_host
            self.ch_count = ch_count
            self.yt_count = yt_count

        def __str__(self):
            return ('Incorrect table rows count on host {}: '
                    'should be {}, but only {} found'
                    ''.format(self.ch_host, self.yt_count, self.ch_count))

    class Requirements(AnalyticsTask.Requirements):
        # TODO(syndicut): Use prebuilt wheels here
        environments = (
            environments.PipEnvironment('yandex-yt'),
            environments.PipEnvironment('yandex-yt-yson-bindings-skynet', version="0.3.32-0"),
            environments.PipEnvironment('yandex-yt-transfer-manager-client'),
            environments.PipEnvironment('python-clickhouse-client'),
        )

    class Parameters(AnalyticsTask.Parameters):
        retry_period = sdk2.parameters.Integer(
            'Time period to check request status (in seconds)',
            default=60
        )
        yql_file = sdk2.parameters.String(
            'Path to yql file',
            default='/arc/trunk/arcadia/cloud/analytics/yql/example.yql',
            required=True
        )
        yt_cluster = sdk2.parameters.String(
            'YT source cluster to copy from',
            default='hahn',
            required=True
        )
        source_table = sdk2.parameters.String(
            'YT source table to copy from',
            default='//home/cloud_analytics/example',
            required=True
        )
        sort_by = sdk2.parameters.String(
            'Comma-separated columns to sort by '
            '(table should be sorted to be imported to CH)',
            default='timestamp',
            required=True
        )
        destination_table = sdk2.parameters.String(
            'CH destination table',
            default='cloud_analytics.example',
            required=True
        )
        destination_table_schema = sdk2.parameters.String(
            'CH destination table schema',
            default='/arc/trunk/arcadia/cloud/analytics/clickhouse/example.sql',
            required=True
        )
        ch_cluster_id = sdk2.parameters.String(
            'CH MDB cluster id',
            default='07bc5e8c-c4a7-4c26-b668-5a1503d858b9',
            required=True
        )
        ch_hosts = sdk2.parameters.String(
            'CH hosts (comma separated)',
            default='sas-tt9078df91ipro7e.db.yandex.net,vla-2z4ktcci90kq2bu2.db.yandex.net',
            required=True
        )
        ch_user = sdk2.parameters.String(
            'CH user',
            default='admin',
            required=True
        )
        ch_password_key = sdk2.parameters.String(
            'Vault key for CH password',
            default='yt_to_ch-ch-password',
            required=True
        )
        ch_old_base_count = sdk2.parameters.Integer(
            'Count of old nonempty CH databases for saving',
            default=2,
            required=False
        )
        mdb_auth_token_key = sdk2.parameters.String(
            'Vault key for mdb auth token',
            default='yt_to_ch-mdb_auth_token',
            required=True
        )
        ch_read_timeout = sdk2.parameters.Integer(
            'ClickHouse readtimeout for queries',
            default=180
        )
        yt_token_name = sdk2.parameters.String(
            'Vault key for YT Token',
            default='yt_to_ch-yt-token',
            required=True
        )
        execute_yql = sdk2.parameters.Bool(
            'Execute YQL script before moving table from YT to CH',
            default=True
        )
        use_v1_syntax = sdk2.parameters.Bool(
            'Use SQLv1 syntax',
            default=False
        )
        enable_count_check = sdk2.parameters.Bool(
            'Check row counts in YT and CH are equal. Turn off if CH is multishard.',
            default=True,
        )

        per_shard_quorum = sdk2.parameters.String(
            'Type of quorum',
            description='''
                'all' - at least all hosts
                'majority' - at least more than half of the hosts
                'at_least_one' - at least one host
            ''',
            choices=[(key, key) for key in QuorumType._ALL],
            default=QuorumType.ALL,
        )

    @staticmethod
    def ch_execute(query, host, parameters=None):
        cur = host.cursor()
        cur.execute(query, parameters=parameters)

    @staticmethod
    def ch_host_healthcheck(host):
        cur = host.cursor()
        cur.execute('SELECT 1')

    @staticmethod
    def ch_execute_fetch(query, host, parameters=None):
        cur = host.cursor()
        cur.execute(query, parameters=parameters)
        return cur.fetchall()

    @staticmethod
    def ch_execute_each(query, hosts, parameters=None):
        for host in hosts.values():
            CloudYtToCH.ch_execute(query, host, parameters)

    @staticmethod
    def ch_execute_fetch_each(query, hosts, parameters=None):
        results = {}
        for name, host in hosts.items():
            results[name] = CloudYtToCH.ch_execute_fetch(query, host, parameters)
        return results

    def on_execute(self):
        import yt.transfer_manager.client as tm
        import yt.wrapper as yt
        from clickhouse.client import connect
        from clickhouse.errors import OperationalError

        logging.info('Check CH hosts health')

        ch_user = self.Parameters.ch_user
        ch_password = sdk2.Vault.data(
            self.owner,
            self.Parameters.ch_password_key
        )
        initial_ch_hosts_input = self.Parameters.ch_hosts.split(',')
        ch_hosts = {}

        for host in initial_ch_hosts_input:
            host = host.strip()

            ch_conn = connect(
                host=host,
                port='8443',
                username=ch_user,
                password=ch_password,
                ssl=True,
                read_timeout=self.Parameters.ch_read_timeout
            )

            try:
                self.ch_host_healthcheck(ch_conn)
            except OperationalError:
                logging.exception('Cannot connect to "%s"', host)
            else:
                ch_hosts[host] = ch_conn

        per_shard_quorum = self.Parameters.per_shard_quorum

        if not ch_hosts:
            raise ConnectionError("Available hosts are empty")

        if len(ch_hosts) != len(initial_ch_hosts_input):
            if self.Parameters.per_shard_quorum == QuorumType.ALL:
                raise ConnectionError("Not all hosts are available")
            if self.Parameters.per_shard_quorum == QuorumType.MAJORITY \
                    and len(initial_ch_hosts_input) // 2 >= len(ch_hosts):
                raise ConnectionError("Not majority hosts are available")
        else:
            per_shard_quorum = QuorumType.ALL

        logging.info("Available hosts: %s", ", ".join(ch_hosts))
        logging.info("Quorum: %s", per_shard_quorum)

        if self.Parameters.execute_yql:
            with self.memoize_stage.create_children:
                query = Arcadia.cat(':'.join([Arcadia.ARCADIA_SCHEME,
                                              self.Parameters.yql_file]))
                sub_task = RunYQL2(
                    self,
                    description='Child of task {}'.format(self.id),
                    create_sub_task=False,
                    query=query,
                    trace_query=True,
                    owner=self.Parameters.owner,
                    publish_query=True,
                    use_v1_syntax=self.Parameters.use_v1_syntax,
                    retry_period=self.Parameters.retry_period
                ).enqueue()

                raise sdk2.WaitTask(
                    sub_task,
                    (ctt.Status.Group.FINISH, ctt.Status.Group.BREAK)
                )

            logging.info('Checking child tasks status')
            child_tasks = self.find()
            for task in child_tasks:
                if task.status not in ctt.Status.Group.SUCCEED:
                    raise errors.SandboxTaskFailureError('Child task is failed.')

        logging.info('Sorting table')

        os.environ['YT_LOG_LEVEL'] = 'INFO'

        yt.config['token'] = sdk2.Vault.data(
            self.owner,
            self.Parameters.yt_token_name
        )
        yt.config['proxy']['url'] = self.Parameters.yt_cluster
        yt.run_sort(
            source_table=self.Parameters.source_table,
            sort_by=self.Parameters.sort_by.split(',')
        )

        logging.info('Creating table')

        dbaas_token = sdk2.Vault.data(
            self.owner,
            self.Parameters.mdb_auth_token_key
        )

        # TODO: Abstract everything to MDBClickHouseCluster client and use MDB API
        ch_cluster_id = self.Parameters.ch_cluster_id
        ch_table = '_'.join([self.Parameters.destination_table,
                             datetime.utcnow().strftime('%s')])
        ch_table_schema = Arcadia.cat(
            ':'.join([Arcadia.ARCADIA_SCHEME,
                      self.Parameters.destination_table_schema]))

        logging.info('Removing old tables')

        # get names of old tables
        ch_tables = CloudYtToCH.ch_execute_fetch_each(
            "SELECT database || '.' || name AS dbname FROM system.tables "
            "WHERE match(dbname, '^{}_[0-9]{{10,}}$') ORDER BY dbname DESC".format(
                self.Parameters.destination_table
            ), ch_hosts
        )

        for host, tables in ch_tables.items():
            table_counter = 0
            host = ch_hosts[host]

            for table in tables:
                table = table[0]
                table_size = int(CloudYtToCH.ch_execute_fetch(
                    'SELECT count(*) FROM {}'.format(table),
                    host
                )[0][0])
                if table_counter < self.Parameters.ch_old_base_count and table_size > 0:
                    table_counter += 1
                    continue
                CloudYtToCH.ch_execute(
                    'DROP TABLE IF EXISTS {}'.format(table),
                    host
                )

        # TODO(syndicut): Potenial security issue here, need to move it
        # eventually to something like https://github.com/cloudflare/sqlalchemy-clickhouse
        CloudYtToCH.ch_execute_each(
            ch_table_schema % {'table_name': ch_table},
            ch_hosts
        )

        logging.info('Sending task to transfer manager')
        params = {
            'clickhouse_copy_options': {
                'command': 'append',
            },
            'clickhouse_credentials': {
                'password': ch_password,
                'user': ch_user,
            },
            'mdb_auth': {
                'oauth_token': dbaas_token,
            },
            'mdb_cluster_address': {
                'cluster_id': ch_cluster_id,
            },
            'clickhouse_copy_tool_settings_patch': {
                'clickhouse_client': {
                    'per_shard_quorum': per_shard_quorum,
                },
            }
        }
        task = tm.add_task(
            self.Parameters.yt_cluster,
            self.Parameters.source_table,
            'mdb-clickhouse',
            ch_table,
            params=params,
            sync=False
        )
        task_info = tm.get_task_info(task)
        logging.info('Added transfer_manager '
                     'task: {}'.format(str(task_info)))

        while task_info['state'] in ('pending', 'running'):
            time.sleep(5)
            task_info = tm.get_task_info(task)

        logging.info('Task execution completed with '
                     'the following state: {}'.format(str(task_info)))

        if task_info['state'] != 'completed':
            raise Exception(
                'Transfer manager task failed with '
                'the following state: %s' % task_info['state']
            )

        if self.Parameters.enable_count_check:
            logging.info('Check count of records in DB')
            yt_count = int(yt.row_count(self.Parameters.source_table))
            ch_counts = CloudYtToCH.ch_execute_fetch_each(
                'SELECT count(*) FROM {}'.format(ch_table),
                ch_hosts
            )

            for ch_host, ch_results in ch_counts.items():
                ch_count = int(ch_results[0][0])
                if ch_count != yt_count:
                    raise self.NotSyncedTableException(
                        ch_host, ch_count, yt_count
                    )

        logging.info('Moving new table in-place')

        CloudYtToCH.ch_execute_each(
            'DROP TABLE IF EXISTS {}'.format(
                self.Parameters.destination_table),
            ch_hosts
        )
        CloudYtToCH.ch_execute_each(
            'RENAME TABLE {} TO {}'.format(
                ch_table, self.Parameters.destination_table),
            ch_hosts
        )
