import datetime
import logging

from sandbox import sdk2
from sandbox.common.errors import TaskFailure
from sandbox.projects.yabs.base_bin_task import BaseBinTask


OPTIMIZE_QUERY = "OPTIMIZE TABLE {table} PARTITION '{partition}' FINAL"


DEFAULT_DATABASE = "system"


def get_secure_by_port(port):
    if port == 9440:
        return True
    if port == 9000:
        return False
    logging.error("Unknown port %d", port)
    return False


class YabsClickhouseOptimizeDaily(BaseBinTask):
    """
    Optimize previous daily partitions in clickhouse
    """

    class Requirements(sdk2.Requirements):
        cores = 1
        ram = 4096
        disk_space = 4096

        class Caches(sdk2.Requirements.Caches):
            pass

    class Parameters(BaseBinTask.Parameters):

        with BaseBinTask.Parameters.version_and_task_resource() as version_and_task_resource:
            resource_attrs = sdk2.parameters.Dict(
                "Filter resource by", default={"name": "YabsClickhouseOptimizeDaily"}
            )

        with sdk2.parameters.Group("ClickHouse") as ch_params:
            master_host = sdk2.parameters.String("ClickHouse master host", required=True)
            master_port = sdk2.parameters.Integer("ClickHouse master port", required=True, default=9000)
            master_user = sdk2.parameters.String("ClickHouse username for master host", required=True)
            master_password = sdk2.parameters.YavSecret("Yav secret with master password", required=True)
            cluster = sdk2.parameters.String("ClickHouse cluster", required=True)
            cluster_user = sdk2.parameters.String("ClickHouse username for cluster", required=True)
            cluster_password = sdk2.parameters.YavSecret("Yav secret with cluster password", required=True)
            table = sdk2.parameters.String("Clickhouse table", required=True)
            days = sdk2.parameters.Integer("Amount of days to optimize from yesterday", default=3)

    def get_shards(self):
        from clickhouse_driver import Client as ClickHouseClient
        from yabs.stat.infra.clickhouse.lib import get_shards_for_cluster

        master_client = ClickHouseClient(
            host=self.Parameters.master_host,
            port=self.Parameters.master_port,
            user=self.Parameters.master_user,
            password=self.master_password,
            secure=get_secure_by_port(self.Parameters.master_port),
            database=DEFAULT_DATABASE,
            verify=False,
        )

        return get_shards_for_cluster(master_client, self.Parameters.cluster)

    def on_execute(self):
        from clickhouse_driver import Client as ClickHouseClient

        self.master_password = self.Parameters.master_password.data()[self.Parameters.master_user]
        self.cluster_password = self.Parameters.cluster_password.data()[self.Parameters.cluster_user]

        shards = self.get_shards()

        now = datetime.datetime.now()
        for day in range(self.Parameters.days):
            date = now - datetime.timedelta(days=1 + day)
            partition = date.strftime("%Y-%m-%d")

            query_optimize = OPTIMIZE_QUERY.format(table=self.Parameters.table, partition=partition)
            logging.debug('Query to perform: %s', query_optimize)

            success_all = True
            for shard, hosts in shards.iteritems():
                logging.info("Processing shard %s with hosts %s", shard, hosts)
                success_shard = False
                for host, port in hosts:
                    client = ClickHouseClient(
                        host=host,
                        port=port,
                        user=self.Parameters.cluster_user,
                        password=self.cluster_password,
                        secure=get_secure_by_port(port),
                        database=DEFAULT_DATABASE,
                        verify=False,
                    )

                    success_optimize = True
                    try:
                        client.execute(query_optimize)
                    except Exception as e:
                        logging.warning("Failed to optimize host %s port %s: %s", host, port, e)
                        success_optimize = False

                    if success_optimize:
                        success_shard = True
                        break

                if not success_shard:
                    success_all = False
                    logging.error("Failed to process shard %s", shard)

        if not success_all:
            raise TaskFailure("Failed to process some shards, see logs")
