import logging
import os
import time

import ydb

from crypta.lib.python import (
    templater,
    time_utils,
)


logger = logging.getLogger(__name__)

UPLOAD_QUERY_TEMPLATE = """
PRAGMA File("libcrypta_sampler_udf.so", '{{crypta_sampler_udf_url}}');
PRAGMA Udf = "libcrypta_sampler_udf.so";

PRAGMA yt.QueryCacheMode = "disable";
PRAGMA yt.UserSlots = "{{yt_user_slots}}";
PRAGMA yt.DefaultMaxJobFails = "{{default_max_job_fails}}";

$ydb_endpoint = "{{ydb_endpoint}}";
$ydb_database = "{{ydb_database}}";
$ydb_table = "{{ydb_table}}";
$denominator = {{denominator}};
$rest = {{rest}};

PROCESS (
    SELECT
        *
    FROM `{{yt_table}}` VIEW raw
    WHERE CryptaRestSampler::PassesIfEqual({{sampling_field}}, $denominator, $rest)
) USING YDB::PushData(
    TableRows(),
    $ydb_endpoint,
    $ydb_database,
    $ydb_table,
    AsTuple("token", SecureParam("token:{{ydb_token_name}}"))
);
"""


class YtTableUploader(object):
    TMP_TABLE_PREFIX = "."

    def __init__(self, yt_client, yql_executer, ydb_client, ydb_token_name, crypta_sampler_udf_url):
        self.yt_client = yt_client
        self.yql_executer = yql_executer
        self.ydb_client = ydb_client
        self.ydb_token_name = ydb_token_name
        self.crypta_sampler_udf_url = crypta_sampler_udf_url

    def upload(self, yt_table, ydb_dir, uniform_partitions, yt_user_slots, ydb_table_schema, sampling_field, denominator, rest, tx_id, max_job_fails=5):
        dst_ydb_table_name = str(time_utils.get_current_time())
        dst_ydb_table_path = os.path.join(ydb_dir, dst_ydb_table_name)

        tmp_ydb_table_name = self.get_tmp_table_name(dst_ydb_table_name)
        tmp_ydb_table_path = os.path.join(ydb_dir, tmp_ydb_table_name)

        logger.info("YDB temporary table: %s, destination table: %s", tmp_ydb_table_path, dst_ydb_table_path)

        self.remove_tmp_ydb_tables(ydb_dir)
        self.create_ydb_tmp_table(tmp_ydb_table_path, ydb_table_schema, uniform_partitions)
        self.upload_data(tmp_ydb_table_path, yt_table, yt_user_slots, sampling_field, denominator, rest, tx_id, max_job_fails)
        self.move_ydb_table(tmp_ydb_table_path, dst_ydb_table_path)
        self.remove_old_tables(ydb_dir, dst_ydb_table_name)

    @staticmethod
    def get_tmp_table_name(table_name):
        return YtTableUploader.TMP_TABLE_PREFIX + table_name

    @staticmethod
    def is_table_temporary(table_name):
        return table_name.startswith(YtTableUploader.TMP_TABLE_PREFIX)

    def remove_tmp_ydb_tables(self, dst_dir):
        tables = self.ydb_client.list_directory(dst_dir)
        for table_name in [table.name for table in tables if self.is_table_temporary(table.name)]:
            path = os.path.join(dst_dir, table_name)
            logger.info("Drop YDB table %s", path)
            self.ydb_client.drop_table(path)

    def create_ydb_tmp_table(self, table_path, schema, uniform_partitions):
        logger.info("Create YDB temporary table %s", table_path)
        profile = ydb.TableProfile().with_partitioning_policy(ydb.PartitioningPolicy().with_uniform_partitions(uniform_partitions))
        self.ydb_client.create_table(table_path, schema.columns, schema.primary_key, profile)

    def upload_data(self, ydb_table, yt_table, yt_user_slots, sampling_field, denominator, rest, tx_id, max_job_fails):
        query = templater.render_template(UPLOAD_QUERY_TEMPLATE, strict=True, vars=dict(
            ydb_endpoint=self.ydb_client.endpoint,
            ydb_database=self.ydb_client.database,
            ydb_table=self.ydb_client.get_full_path(ydb_table),
            yt_table=yt_table,
            ydb_token_name=self.ydb_token_name,
            yt_user_slots=yt_user_slots,
            crypta_sampler_udf_url=self.crypta_sampler_udf_url,
            denominator=denominator,
            rest=rest,
            sampling_field=sampling_field,
            salt=int(time.time()),
            default_max_job_fails=max_job_fails,
        ))

        logger.info(query)
        self.yql_executer(query, transaction=tx_id, syntax_version=1)

    def remove_old_tables(self, ydb_dir, new_ydb_table_name):
        tables = self.ydb_client.list_directory(ydb_dir)
        for table in [table.name for table in tables if table.name != new_ydb_table_name]:
            table_path = os.path.join(ydb_dir, table)
            logger.info("Drop YDB table %s", table_path)
            self.ydb_client.drop_table(table_path)

    def move_ydb_table(self, tmp_table_path, dst_table_path):
        logger.info("Copy YDB temporary table %s to YDB table %s", tmp_table_path, dst_table_path)
        self.ydb_client.copy_table(tmp_table_path, dst_table_path)

        logger.info("Remove YDB temporary table %s", tmp_table_path)
        self.ydb_client.drop_table(tmp_table_path)
