import datetime
import logging

from sandbox import sdk2
from sandbox.common.errors import TaskFailure
from sandbox.projects.yabs.base_bin_task import BaseBinTask

QUERY_PARTITIONS = """
SELECT DISTINCT
    database || '.' || table,
    partition
FROM
    system.parts
WHERE
    {where}
"""

PARTITIONS_ONE_TABLE = """(
    database='{database}'
    AND
    table = '{table}'
    AND
    partition < '{limit}'
    AND
    match(partition, '{regexp}')
    AND
    active = 1
)"""

DATES_FORMAT = {
    "YYYY-MM-DD": {
        "format": "%Y-%m-%d",
        "regexp": "[0-9]{4}-[0-9]{2}-[0-9]{2}",
        "drop_partition_query": "ALTER TABLE {table} DROP PARTITION '{partition}'"
    },
    "YYYYMM": {
        "format": "%Y%m",
        "regexp": "[0-9]{6}",
        "drop_partition_query": "ALTER TABLE {table} DROP PARTITION {partition}"
    },
}
DEFAULT_DATABASE = "system"


def get_secure_by_port(port):
    if port == 9440:
        return True
    if port == 9000:
        return False
    logging.error("Unknown port %d", port)
    return False


class YabsClickhouseDropOldPartitions(BaseBinTask):
    """
    Drop old partitions in clickhouse
    """

    class Parameters(BaseBinTask.Parameters):

        with BaseBinTask.Parameters.version_and_task_resource() as version_and_task_resource:
            resource_attrs = sdk2.parameters.Dict(
                "Filter resource by", default={"name": "YabsClickhouseDropOldPartitions"}
            )

        with sdk2.parameters.Group("ClickHouse") as ch_params:
            master_host = sdk2.parameters.String("ClickHouse master host", required=True)
            master_port = sdk2.parameters.Integer("ClickHouse master port", required=True, default=9000)
            master_user = sdk2.parameters.String("ClickHouse username for master host", required=True)
            master_password = sdk2.parameters.YavSecret("Yav secret with master password", required=True)
            cluster = sdk2.parameters.String("ClickHouse cluster", required=True)
            cluster_user = sdk2.parameters.String("ClickHouse username for cluster", required=True)
            cluster_password = sdk2.parameters.YavSecret("Yav secret with cluster password", required=True)
            with sdk2.parameters.RadioGroup("Partitions date format") as date_format:
                date_format.values["YYYY-MM-DD"] = date_format.Value(value="YYYY-MM-DD", default=True)
                date_format.values["YYYYMM"] = date_format.Value(value="YYYYMM")
            tables = sdk2.parameters.Dict("Tables to cleanup, format: db.table -> days limit", required=True)

    def get_shards(self):
        from clickhouse_driver import Client as ClickHouseClient
        from yabs.stat.infra.clickhouse.lib import get_shards_for_cluster

        master_client = ClickHouseClient(
            host=self.Parameters.master_host,
            port=self.Parameters.master_port,
            user=self.Parameters.master_user,
            password=self.master_password,
            secure=get_secure_by_port(self.Parameters.master_port),
            database=DEFAULT_DATABASE,
            verify=False,
        )

        return get_shards_for_cluster(master_client, self.Parameters.cluster)

    def on_execute(self):
        from clickhouse_driver import Client as ClickHouseClient

        self.master_password = self.Parameters.master_password.data()[self.Parameters.master_user]
        self.cluster_password = self.Parameters.cluster_password.data()[self.Parameters.cluster_user]

        shards = self.get_shards()

        now = datetime.datetime.now()
        date_format = DATES_FORMAT[self.Parameters.date_format]
        where_clauses = []
        tables = {table: int(period) for table, period in self.Parameters.tables.iteritems()}
        for db_table, period in tables.iteritems():
            db, table = db_table.split(".")
            limit = now - datetime.timedelta(days=period)
            where_clauses.append(
                PARTITIONS_ONE_TABLE.format(
                    database=db, table=table, limit=limit.strftime(date_format['format']), regexp=date_format['regexp']
                )
            )

        where_parts = " OR ".join(where_clauses)
        query_parts = QUERY_PARTITIONS.format(where=where_parts)

        success_all = True
        for shard, hosts in shards.iteritems():
            logging.info("Processing shard %s with hosts %s", shard, hosts)
            success_shard = False
            for host, port in hosts:
                client = ClickHouseClient(
                    host=host,
                    port=port,
                    user=self.Parameters.cluster_user,
                    password=self.cluster_password,
                    secure=get_secure_by_port(port),
                    database=DEFAULT_DATABASE,
                    verify=False,
                )
                try:
                    partitions = client.execute(query_parts)
                except Exception as e:
                    logging.warning("Failed get partitions from host %s port %s: %s", host, port, e)
                    continue

                logging.info('Dropping partitions %s', partitions)
                success_drop = True
                for table, partition in partitions:
                    try:
                        client.execute(date_format['drop_partition_query'].format(table=table, partition=partition))
                        logging.info("Successfully dropped partition %s on table %s", partition, table)
                    except Exception as e:
                        logging.warning("%s: Failed to drop partition %s on table %s: %s", host, partition, table, e)
                        success_drop = False
                        break
                if success_drop:
                    success_shard = True
                    break

            if not success_shard:
                success_all = False
                logging.error("Failed to process shard %s", shard)

        if not success_all:
            raise TaskFailure("Failed to process some shards, see logs")
