import logging
import json
import random
import sandbox.common.types.task as ctt
from time import time
from sandbox import sdk2
from sandbox.sdk2.vcs.svn import Arcadia
from sandbox.common import errors as ce
from sandbox.projects.mediaanalyst.libs.common.base import MaSandboxBaseTask
from sandbox.common.types import misc as ctm
from sandbox.projects.common.juggler import jclient


class MaYtToChUpload(MaSandboxBaseTask):

    BINARY_TASK_ATTR_TARGET = "mediaanalyst/tasks/MaYtToChUpload"

    class Requirements(sdk2.Requirements):
        semaphores = ctt.Semaphores(
            acquires=[
                ctt.Semaphores.Acquire(
                    name='media_analyst_data_uploads')
            ],
            release=(ctt.Status.Group.BREAK, ctt.Status.Group.FINISH)
        )

    class Parameters(MaSandboxBaseTask.Parameters):

        with sdk2.parameters.Group("YT settings", collapse=True) as source_settings:
            yt_cluster = sdk2.parameters.String(
                "YT cluster",
                default="hahn",
                description="YT source cluster to copy from",
                required=True,
            )

            yt_pool = sdk2.parameters.String(
                "YT pool",
                default="",
                required=False
            )

            yt_table = sdk2.parameters.String(
                "YT source table",
                default="//home/media-analyst/example",
                required=True
            )

        with sdk2.parameters.Group("Clickhouse settings", collapse=True) as destination_settings:
            ch_cluster_id = sdk2.parameters.String(
                "CH MDB cluster id",
                default="0f4cbcdb-cfbf-486d-a89c-2b22439bf811",
                required=True
            )

            ch_database = sdk2.parameters.String(
                "CH database",
                required=True
            )

            ch_table = sdk2.parameters.String(
                "CH table name",
                required=True
            )

            ch_table_schema = sdk2.parameters.String(
                "CH schema",
                default="",
                required=False,
                description="Arc path to CH table schema"
            )

            ch_port = sdk2.parameters.Integer(
                "CH port",
                default=8443,
                required=True
            )

            ch_user = sdk2.parameters.String(
                "CH user",
                default="ottuser",
                required=True
            )

            ch_cluster_name = sdk2.parameters.String(
                "CH cluster name",
                default="ott_analytics",
                required=True
            )

            ch_primary_key = sdk2.parameters.String(
                "CH table primary key",
                required=False,
                default=""
            )

            ch_table_sharding_key = sdk2.parameters.String(
                "CH Distributed table sharding key",
                required=False,
                default=""
            )

        with sdk2.parameters.Group("Service account settings (public cloud)", collapse=True) as sa_settings:
            service_account_id = sdk2.parameters.String(
                "Service account id",
                required=False,
                default=""
            )

            service_account_key_id = sdk2.parameters.String(
                "Service account key id",
                required=False,
                default=""
            )

            service_account_private_key = sdk2.parameters.YavSecret(
                "Service account private key",
                required=False
            )

        with sdk2.parameters.Group("Other settings", collapse=True) as other_settings:
            yav_secret = sdk2.parameters.YavSecret("Yav secret", required=True)

            use_public_cloud = sdk2.parameters.Bool(
                "Use public cloud auth",
                required=False,
                default=False)

            tm_task_poll_interval = sdk2.parameters.Integer(
                "How often check tm job status (in seconds)",
                default=180
            )

            enable_monitoring = sdk2.parameters.Bool(
                "Enable juggler alerts in case task fails",
                required=False,
                default=False
            )

    @property
    def juggler_report_is_enabled(self):
        return self.Parameters.enable_monitoring

    @property
    def juggler_host_name(self):
        return self.Parameters.juggler_host_name

    @property
    def juggler_service_name(self):
        return self.Parameters.juggler_service_name

    @property
    def juggler_sandbox_url(self):
        return 'https://sandbox.yandex-team.ru/task/{}'.format(self.id)

    @property
    def juggler_author_staff_url(self):
        return 'https://staff.yandex-team.ru/{}'.format(self.author)

    @property
    def juggler_task_creation_date(self):
        return self.created.strftime('%Y-%m-%d %H:%M:%S')

    @property
    def juggler_task_description(self):
        return (
            '{sandbox_type} {sandbox_url} (by {staff_url}) '
            'created at {create_date} is in {sandbox_status} status'.format(
                sandbox_type=self.type,
                sandbox_url=self.juggler_sandbox_url,
                staff_url=self.juggler_author_staff_url,
                create_date=self.juggler_task_creation_date,
                sandbox_status=self.status))

    def on_success(self, prev_status):
        self.send_status_to_juggler(status='OK')

    def on_enqueue(self):
        self.Requirements.dns = ctm.DnsType.DNS64

    def on_terminate(self):
        if self.Context.tm_task_id:
            client = self.get_tm_client()
            client.abort_task(self.Context.tm_task_id)
        else:
            logging.error("task id is empty: can't terminate")

    def on_save(self):
        if self.Parameters.UseLastBinary:
            self.Requirements.tasks_resource = sdk2.service_resources.SandboxTasksBinary.find(
                attrs={"target": self.BINARY_TASK_ATTR_TARGET,
                       "release": self.Parameters.ReleaseType or "stable"}
            ).first().id
        else:
            self.Requirements.tasks_resource = self.Parameters.custom_tasks_archive_resource

    def on_break(self, prev_status, status):
        description = (
            '{task_description}: '
            'previous {previous_status} -> current {current_status}'.format(
                task_description=self.juggler_task_description,
                previous_status=prev_status, current_status=status))
        self.send_status_to_juggler(
            status='CRIT',
            description=description)
        self.cleanup_temp_table()

    def on_failure(self, prev_status):
        description = (
            '{task_description}: '
            'previous {previous_status} -> current FAILURE'.format(
                task_description=self.juggler_task_description,
                previous_status=prev_status))
        self.send_status_to_juggler(
            status='CRIT',
            description=description)
        self.cleanup_temp_table()

    def send_status_to_juggler(self, status, description=None):
        logging.info('sending event to juggler')
        if self.juggler_report_is_enabled:
            juggler_upload_id = self.Parameters.ch_database + '.' + self.Parameters.ch_table
            jclient.send_events_to_juggler(
                host='media-dwh.uploads.' + juggler_upload_id,
                service='status',
                status=status,
                description=(
                    description
                    if description is not None
                    else self.juggler_task_description))

    def cleanup_temp_table(self):
        logging.info("cleaning up temp tables..")
        if self.Context.ch_temp_table:
            mdb_client = self.get_mdb_client()
            hosts = mdb_client.cluster_hosts_list(self.Parameters.ch_cluster_id)
            for host_info in hosts:
                hostname = host_info["name"]

                if host_info["type"] == "CLICKHOUSE" and "health" in host_info and host_info["health"] == "ALIVE":
                    logging.info("checking table on %s" % hostname)
                    table_exists = self.check_table_exists(host=hostname,
                                                           db=self.Parameters.ch_database,
                                                           table=self.Context.ch_temp_table)
                    if table_exists:
                        logging.info("dropping table: %s.%s on host %s" % (self.Parameters.ch_database,
                                                                           self.Context.ch_temp_table,
                                                                           hostname))
                    else:
                        logging.info("table: %s.%s doesn't exist on host %s" % (self.Parameters.ch_database,
                                                                                self.Context.ch_temp_table,
                                                                                hostname))
        else:
            logging.info("temp table doesn't exist")
        logging.info('cleanup finished')

    def get_arcadia_table_schema(self, path):
        logging.info("retrieving schema from arcadia: ", path)
        schema = Arcadia.cat(':'.join([Arcadia.ARCADIA_SCHEME, path]))
        table_info = json.loads(schema)
        result = ""
        for field_name, field_type in table_info['schema'].items():
            result += " %s %s," % (field_name, field_type)
        return result[:-1].strip(" ")

    def rename_table(self, old_db, old_table, new_db, new_table, on_cluster=True, host=None):
        connect_to_host = host
        if not host:
            connect_to_host = random.choice(self.get_ch_hosts())
        client = self.get_ch_client(connect_to_host)
        q = "RENAME TABLE {old_db}.{old_table} TO {new_db}.{new_table}".format(
            old_db=old_db,
            new_db=new_db,
            old_table=old_table,
            new_table=new_table
        )
        if on_cluster:
            q += " ON CLUSTER {c}".format(
                c=self.Parameters.ch_cluster_name
            )
        logging.info("executing on %s\n\t%s" % (connect_to_host, q))
        client.execute(q)
        client.disconnect()

    def drop_table(self, db, table, no_delay=False):
        q = "DROP TABLE IF EXISTS {db}.{table}".format(
            db=db,
            table=table
        )
        if no_delay:
            q += " NO DELAY"
            for host in self.get_ch_hosts():
                client = self.get_ch_client(host)
                logging.info("executing query on %s\n\t%s" % (host, q))
                client.execute(q)
                client.disconnect()
        else:
            q += " ON CLUSTER {c}".format(c=self.Parameters.ch_cluster_name)
            host = random.choice(self.get_ch_hosts())
            client = self.get_ch_client(host)
            logging.info("executing query on %s\n\t%s" % (host, q))
            client.execute(q)
            client.disconnect()

    def create_replicated_merge_tree_table(self, db, table, schema, pk):
        logging.info("creating replicated merge tree table")
        engine_clause = "ReplicatedMergeTree('/clickhouse/tables/{{shard}}/{table}', '{{replica}}')".format(
            table=table
        )

        q = """
        CREATE TABLE {database}.{table} ON CLUSTER `{ch_cluster}` ( {schema} ) ENGINE = {engine} ORDER BY ({primary_key})
        """.format(
            database=db,
            table=table,
            ch_cluster=self.Parameters.ch_cluster_name,
            schema=schema,
            engine=engine_clause,
            primary_key=pk
        ).strip(" \n")

        host = random.choice(self.get_ch_hosts())
        client = self.get_ch_client(host)
        logging.info("executing query on %s:\n\t%s" % (host, q))
        client.execute(q)
        client.disconnect()

    def create_distributed_table(self, cluster, db, distr_table, rmt_table, sharding_key=None):
        logging.info("creating distributed table")
        engine_clause = ""
        if sharding_key:
            engine_clause = "ENGINE = Distributed(%s, %s, %s, %s)" % (cluster, db, rmt_table, sharding_key)
        else:
            engine_clause = "ENGINE = Distributed(%s, %s, %s)" % (cluster, db, rmt_table)

        q = "CREATE TABLE {db}.{distrib} AS {db}.{rmt_table} {engine}".format(
            db=self.Parameters.ch_database,
            distrib=distr_table,
            rmt_table=rmt_table,
            engine=engine_clause
        )

        for host in self.get_ch_hosts():
            client = self.get_ch_client(host)
            logging.info("executing query on %s\n\t%s" % (host, q))
            client.execute(q)
            client.disconnect()

    def start_checks(self):
        # TODO: add check_acl()
        logging.info("ACL: checking")
        logging.info("ACL: pass")

    def _use_service_account(self):
        if self.Parameters.service_account_id and self.Parameters.service_account_key_id \
           and self.Parameters.service_account_private_key:
            return True
        return False

    def get_default_tm_config(self):
        return {
            "clickhouse_copy_options": {
                "command": "append",
                "use_local_uploader": True
            },
            "clickhouse_credentials": {
                "password": None,
                "user": self.Parameters.ch_user,
            },
            "mdb_auth": {
                "use_public_cloud": self.Parameters.use_public_cloud
            },
            "mdb_cluster_address": {
                "cluster_id": self.Parameters.ch_cluster_id,
            },
            "clickhouse_copy_tool_settings_patch": {
                "clickhouse_client": {
                    "per_shard_quorum": "majority"
                },
            }
        }

    def get_tm_client(self):
        from yt.transfer_manager.client import TransferManager
        if self.Parameters.use_public_cloud:
            tm_backend_base_url = "http://tm-testing.yt.yandex.net"
            client = TransferManager(token=self.Parameters.yav_secret.data()["yt-token"],
                                     url=tm_backend_base_url)
        else:
            client = TransferManager(token=self.Parameters.yav_secret.data()["yt-token"])
        return client

    def get_mdb_client(self):
        from afisha.infra.libs.mdb import MdbClientClickhouse
        from afisha.infra.libs.mdb.iam_client import MdbServiceAccountCredentials
        if self._use_service_account():
            logging.info("using service account auth")
            creds = MdbServiceAccountCredentials(
                account_id=self.Parameters.service_account_id,
                key_id=self.Parameters.service_account_key_id,
                private_key=self.Parameters.service_account_private_key.data()["private_key"]
            )
            return MdbClientClickhouse(creds, use_public_cloud=self.Parameters.use_public_cloud)
        logging.info("using token auth")
        return MdbClientClickhouse(self.Parameters.yav_secret.data()["mdb-token"],
                                   use_public_cloud=self.Parameters.use_public_cloud)

    def start_transfer(self):
        client = self.get_tm_client()
        params = self.get_default_tm_config()
        if self.Parameters.ch_primary_key:
            if len(self.Parameters.ch_primary_key.split(',')) > 1:
                ch_pk = '(' + self.Parameters.ch_primary_key + ')'
            else:
                ch_pk = self.Parameters.ch_primary_key
            params["clickhouse_copy_options"]["primary_key"] = ch_pk
        if self.Parameters.ch_table_sharding_key:
            params["clickhouse_copy_options"]["sharding_key"] = self.Parameters.ch_table_sharding_key
        if self.Parameters.yt_pool:
            params["clickhouse_copy_tool_settings_patch"]["shard_uploader"] = {
                "job_executor": {"pool": self.Parameters.yt_pool}}
            params["clickhouse_copy_tool_settings_patch"]["preprocessing"] = {"yt_pool": self.Parameters.yt_pool}
        logging.info("TM parameters:")
        logging.info(json.dumps(params, indent=4))
        table_path = "%s.%s" % (self.Parameters.ch_database, self.Context.ch_temp_table)
        logging.info("appending data to table %s", table_path)
        params["clickhouse_credentials"]["password"] = self.Parameters.yav_secret.data()["ch-password"]
        if self._use_service_account():
            params["mdb_auth"]["sa_private_key"] = self.Parameters.service_account_private_key.data()["private_key"]
            params["mdb_auth"]["sa_key_id"] = self.Parameters.service_account_key_id
            params["mdb_auth"]["sa_id"] = self.Parameters.service_account_id
        else:
            params["mdb_auth"]["oauth_token"] = self.Parameters.yav_secret.data()["mdb-token"]
        logging.info("sending task to transfer manager")

        task_id = client.add_task(
            self.Parameters.yt_cluster,
            self.Parameters.yt_table,
            "mdb-clickhouse",
            table_path,
            params=params,
            sync=False
        )
        self.Context.tm_task_id = task_id

        if self.Parameters.use_public_cloud:
            self.Context.tm_task_url = "https://transfer-manager.yt.yandex-team.ru/task?id=%s&backend=testing" % task_id
        else:
            self.Context.tm_task_url = "https://transfer-manager.yt.yandex-team.ru/task?id=%s&backend=production" % task_id

        task_info = client.get_task_info(task_id)
        self.Context.transfer_state = task_info["state"]

        logging.info("added transfer_manager task: %s" % str(self.Context.tm_task_id))

    def _transform_column(self, yt_column_type, value):
        import re
        from yt.yson import yson_types

        type_conversion = {
            "int64": "Int64",
            "int32": "Int32",
            "int16": "Int16",
            "int8": "Int8",
            "uint64": "UInt64",
            "uint32": "UInt32",
            "uint16": "UInt16",
            "uint8": "UInt8",
            "boolean": "UInt8",
            "double": "Float64",
            "string": "String",
            "utf8": "String",
        }

        if yt_column_type == "string" and value is not None \
           and isinstance(value, str) and re.match(r"\d{4}-\d{2}-\d{2}", value):
            return "Date"

        if yt_column_type == "any" and isinstance(value, list):
            inner_ch_types = {
                {
                    yson_types.YsonInt64: "Int64",
                    yson_types.YsonUint64: "UInt64",
                    yson_types.YsonBoolean: "UInt8",
                    yson_types.YsonDouble: "Float64",
                    yson_types.YsonString: "String",
                    yson_types.YsonUnicode: "String",
                }.get(x.__class__) for x in value
            }

            if inner_ch_types:
                inner_ch_type = list(inner_ch_types)[0]
                if len(inner_ch_types) == 1 and inner_ch_type is not None:
                    return "Array({})".format(inner_ch_type)
        return type_conversion.get(yt_column_type)

    def get_yt_table_schema(self):
        from copy_yt_to_ch.yt import infer_schema
        import yt.wrapper as yt
        logging.info("retrieving table schema from yt")
        yt_client = yt.YtClient(
            self.Parameters.yt_cluster,
            token=self.Parameters.yav_secret.data()["yt-token"],
        )

        table_read_path = yt.TablePath(self.Parameters.yt_table, start_index=0, end_index=1, client=yt_client)
        yson_format = yt.YsonFormat(always_create_attributes=True)
        rows = yt_client.read_table(table_read_path, format=yson_format)
        row = next(rows)

        logging.info("attempt to infer yt table schema")
        yt_table_schema = infer_schema.infer_yt_table_schema(yt_client, self.Parameters.yt_table)
        logging.info(yt_table_schema)
        logging.info(yt_table_schema.columns)
        ch_schema_list = []

        for column in yt_table_schema.columns:
            column_name = str(column.name).strip()
            if " " in column_name:
                raise ValueError('there should be no ' ' in column name')
            yt_column_type = str(column.type)
            ch_column_type = self._transform_column(yt_column_type, row[column_name])
            if ch_column_type:
                ch_column_type = ch_column_type.strip()
                ch_schema_list.append("{} {}".format(column_name, ch_column_type))

        yt_schema = ", ".join(ch_schema_list)
        logging.info("schema:\n\t", yt_schema)
        return yt_schema

    def check_tm_task_status(self):
        logging.info("checking transfer manager status")

        if self.Context.tm_task_id:
            if not self.Context.transfer_finished:
                client = self.get_tm_client()
                task_info = client.get_task_info(self.Context.tm_task_id)
                self.Context.transfer_state = task_info["state"]
                if self.Context.transfer_state in "completed":
                    self.Context.transfer_finished = True
                    logging.info("task %s completed" % self.Context.tm_task_id)
                    logging.info("-" * 100)
                    logging.info(json.dumps(task_info, indent=4))
                    logging.info("-" * 100)
                elif self.Context.transfer_state in ("pending", "running"):
                    self.Context.transfer_finished = False
                    logging.info("task status: %s" % self.Context.transfer_state)
                    raise sdk2.WaitTime(self.Parameters.tm_task_poll_interval)
                else:
                    msg = "transfer manager: task %s link: %s" % (self.Context.transfer_state, self.Context.tm_task_url)
                    raise ce.TaskError(msg)
        else:
            raise ce.TaskError("can't check status: empty tm task_id")

    def get_ch_client(self, host):
        from clickhouse_driver import Client
        return Client(host=host,
                      database=self.Parameters.ch_database,
                      port=9440,
                      user=self.Parameters.ch_user,
                      password=self.Parameters.yav_secret.data()["ch-password"],
                      secure=True,
                      verify=False)

    def check_table_exists(self, host, db, table):
        client = self.get_ch_client(host)
        table_exists = client.execute("EXISTS TABLE {db}.{table}".format(db=db, table=table))
        client.disconnect()
        return table_exists[0][0] == 1

    def get_ch_hosts(self):
        mdb_client = self.get_mdb_client()
        hosts_info = mdb_client.cluster_hosts_list(self.Parameters.ch_cluster_id)
        hosts = [x["name"] for x in hosts_info if x["type"] == "CLICKHOUSE" and "health" in x and x["health"] == "ALIVE"]
        logging.info("hosts:\t", ",".join(hosts))
        return hosts

    def check_upload(self):
        logging.info("checking upload...")
        mdb_client = self.get_mdb_client()
        hosts = mdb_client.cluster_hosts_list(self.Parameters.ch_cluster_id)

        for host_info in hosts:
            hostname = host_info["name"]

            if host_info["type"] == "CLICKHOUSE" and host_info["health"] == "ALIVE":
                logging.info("checking table: %s" % hostname)
                table_exists = self.check_table_exists(host=hostname,
                                                       db=self.Parameters.ch_database,
                                                       table=self.Parameters.ch_table)
                if not table_exists:
                    raise Exception("upload checks: table doesn't exist on host ", hostname)

        logging.info("upload check: ok")

    def on_execute(self):
        with self.memoize_stage.start_acl_checks(commit_on_entrance=False):
            self.start_checks()

        with self.memoize_stage.create_ch_table(commit_on_entrance=False):
            self.Context.ch_temp_table = self.Parameters.ch_table + "__temp_%s" % int(time())
            schema = ""
            if self.Parameters.ch_table_schema:
                schema = self.get_arcadia_table_schema(self.Parameters.ch_table_schema)
            else:
                schema = self.get_yt_table_schema()

            self.create_replicated_merge_tree_table(
                db=self.Parameters.ch_database,
                table=self.Context.ch_temp_table,
                schema=schema,
                pk=self.Parameters.ch_primary_key
            )

        with self.memoize_stage.start_transfer(commit_on_entrance=False):
            self.start_transfer()

        self.check_tm_task_status()

        with self.memoize_stage.recreate_table(commit_on_entrance=False):
            postfixed_table_name = self.Parameters.ch_table + "_sharded"
            self.drop_table(db=self.Parameters.ch_database,
                            table=postfixed_table_name,
                            no_delay=True)
            self.rename_table(
                old_db=self.Parameters.ch_database,
                old_table=self.Context.ch_temp_table,
                new_db=self.Parameters.ch_database,
                new_table=postfixed_table_name
            )

        with self.memoize_stage.create_distributed_table(commit_on_entrance=False):
            postfixed_table_name = self.Parameters.ch_table + "_sharded"
            self.drop_table(db=self.Parameters.ch_database,
                            table=self.Parameters.ch_table)

            self.create_distributed_table(
                cluster=self.Parameters.ch_cluster_name,
                db=self.Parameters.ch_database,
                distr_table=self.Parameters.ch_table,
                rmt_table=postfixed_table_name
            )

        with self.memoize_stage.upload_checks(commit_on_entrance=False):
            self.check_upload()
