import os
from concurrent.futures import TimeoutError
import yt.wrapper as yt_wrapper
import ydb

from datacloud.dev_utils.yql import yql_helpers
from datacloud.dev_utils.id_value.id_value_lib import encode_hexhash_as_uint64, encode_as_uint64
from datacloud.dev_utils.logging.logger import get_basic_logger
from datacloud.dev_utils.transfer.yt_to_ydb import table_description, transfer_table
from datacloud.dev_utils.transfer.yt_to_ydb.ydb_constants import (
    YDB_ENDPOINT, YDB_DATABASE, YDB_ROOT
)
from datacloud.dev_utils.transfer.yt_to_ydb.score_path_table import ScorePathTableDescription
from datacloud.dev_utils.transfer.yt_to_ydb.score_table import ScoreTableDescription
from datacloud.dev_utils.ydb.lib.core import utils as ydb_utils


logger = get_basic_logger(__name__)

YDB_SCORES_ROOT = os.path.join(YDB_ROOT, 'scores')
YDB_SCORE_PATH = os.path.join(YDB_ROOT, 'config', 'score_path')


class BaseTransferException(Exception):
    def __init__(self, *args, **kwargs):
        super(BaseTransferException, self).__init__(*args, **kwargs)


class TargetTableNotFoundException(Exception):
    def __init__(self, *args, **kwargs):
        super(TargetTableNotFoundException, self).__init__(*args, **kwargs)


@yt_wrapper.with_context
class HashedIdReducer:
    def __init__(self, hash_cid=False):
        self.hash_cid = hash_cid

    def __call__(self, key, recs, context):
        score = None
        for rec in recs:
            if context.table_index == 0:
                score = rec['score']
            elif score is None:
                break
            elif rec['id_type'] in ('phone_md5', 'email_md5') and rec['id_value']:
                yield {
                    'hashed_id': encode_hexhash_as_uint64(rec['id_value']),
                    'score': score
                }

        if self.hash_cid and score is not None and key['cid']:
            yield {
                'hashed_id': encode_as_uint64(key['cid']),
                'score': score
            }


def _get_ydb_driver(token=None):
    token = token or os.environ['YDB_TOKEN']
    assert token, 'Please provide YDB token'
    connection_params = ydb.ConnectionParams(
        YDB_ENDPOINT, database=YDB_DATABASE, auth_token=token)
    try:
        driver = ydb.Driver(connection_params)
        driver.wait(timeout=5)
        return driver
    except TimeoutError:
        raise RuntimeError('Connect failed to YDB')


def remove_prev_tables(yql_client, cur_score_fname, prev_score_fname, partner_id, score_name):
    driver = _get_ydb_driver()
    s_client = driver.scheme_client
    session = driver.table_client.session().create()

    score_folder = os.path.join(YDB_SCORES_ROOT, partner_id, score_name)
    names = [ent.name for ent in s_client.list_directory(score_folder).children]

    def _names_filter(name):
        return name != prev_score_fname and name != cur_score_fname

    ydb_connection_params = table_description.YdbConnectionParams(
        endpoint=YDB_ENDPOINT,
        database=YDB_DATABASE
    )
    score_path_table_description = ScorePathTableDescription(YDB_SCORE_PATH, ydb_connection_params)
    cur_score_folder_in_prod = score_path_table_description.select(yql_client, score_name)[0].score_path
    assert cur_score_folder_in_prod == os.path.join(partner_id, score_name, cur_score_fname)

    for name in filter(_names_filter, names):
        folder2delete = os.path.join(score_folder, name)
        logger.info('Removing %s...', folder2delete)
        session.drop_table(os.path.join(folder2delete, 'score'))
        s_client.remove_directory(folder2delete)
        logger.info('Removed %s!', folder2delete)


def transfer_score_to_ydb(yt_client, model_by_hashed_id, partner_id, score_name, date_str):
    yql_token = os.environ.get('YQL_TOKEN') or yt_wrapper.config['token']
    assert yql_token, 'Please provide YQL token'
    yql_client = yql_helpers.create_yql_client(yt_client, token=yql_token)

    # Check required later token before very long operation
    ydb_token = os.environ['YDB_TOKEN']
    assert ydb_token, 'Please provide YDB token'

    ydb_connection_params = table_description.YdbConnectionParams(
        endpoint=YDB_ENDPOINT,
        database=YDB_DATABASE)

    source_table = table_description.YtTableDescription(
        table_path=model_by_hashed_id.to_yson_type().capitalize(),
        yt_cluster=yt_wrapper.config['proxy']['url'])

    cur_score_fname = 'date-' + date_str
    cur_score_folder_rel = os.path.join(partner_id, score_name, cur_score_fname)
    cur_score_table = os.path.join(YDB_SCORES_ROOT, cur_score_folder_rel, 'score')
    logger.info('Transfering score to %s', cur_score_table)
    target_table = ScoreTableDescription(
        table_path=cur_score_table, ydb_connection_params=ydb_connection_params)

    transfer_table.create_table(yql_client, target_table)
    transfer_table.transfer(yql_client, source_table, target_table, yql_token)

    driver = _get_ydb_driver(ydb_token)
    if not ydb_utils.is_table_exists(driver, cur_score_table):
        raise TargetTableNotFoundException(
            'Prevent score switch, table not found: {}'.format(cur_score_table))

    score_path_table_description = ScorePathTableDescription(
        YDB_SCORE_PATH, ydb_connection_params)
    prev_score_folder_rel = score_path_table_description.replace_path(
        yql_client, score_name, cur_score_folder_rel)
    remove_prev_tables(
        yql_client=yql_client,
        cur_score_fname=cur_score_fname,
        prev_score_fname=prev_score_folder_rel.rsplit('/')[-1],
        partner_id=partner_id,
        score_name=score_name)
