import os
import ydb
from concurrent.futures import TimeoutError
from datacloud.dev_utils.logging.logger import get_basic_logger

logger = get_basic_logger(__name__)


__all__ = [
    'YdbPathConfig',
    'get_ydb_driver',
    'ensure_path_exists',
    'is_directory_exists',
    'create_folder',
    'describe_table'
]


def get_ydb_driver(endpoint, database, auth_token):
    connection_params = ydb.ConnectionParams(endpoint, database=database, auth_token=auth_token)
    try:
        driver = ydb.Driver(connection_params)
        driver.wait(timeout=5)
    except TimeoutError:
        raise RuntimeError('Connect to YDB failed')
    return driver


def is_directory_exists(driver, path):
    try:
        return driver.scheme_client.describe_path(path).is_directory()
    except ydb.SchemeError:
        return False


def is_table_exists(driver, path):
    try:
        return driver.scheme_client.describe_path(path).is_table()
    except ydb.SchemeError:
        return False


# TODO: Remove after test update
def ensure_path_exists(driver, database, path):
    paths_to_create = list()
    path = path.rstrip('/')
    while path != '':
        full_path = os.path.join(database, path)
        if is_directory_exists(driver, full_path):
            break
        paths_to_create.append(full_path)
        path = os.path.dirname(path).rstrip('/')

    while len(paths_to_create) > 0:
        full_path = paths_to_create.pop(-1)
        driver.scheme_client.make_directory(full_path)


def create_folder(driver, database_path, path):
    logger.info('Create folder {} {}'.format(database_path, path))
    paths_to_create = []
    path = path.rstrip('/')
    while path:
        full_path = os.path.join(database_path, path)
        if is_directory_exists(driver, full_path):
            break
        paths_to_create.append(full_path)
        path = os.path.dirname(path).rstrip("/")
    while len(paths_to_create) > 0:
        full_path = paths_to_create.pop(-1)
        driver.scheme_client.make_directory(full_path)


def describe_table(session, path, name):
    result = session.describe_table(os.path.join(path, name))
    print('\n> describe table: series')
    for column in result.columns:
        print('column, name:', column.name, ',', str(column.type.optional_type.item).strip())


class YdbPathConfig:
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.config_dir = os.path.join(self.root_dir, 'config')

        self.partner_scores_table_path = os.path.join(self.config_dir, 'partner_scores')
        self.score_path_table_path = os.path.join(self.config_dir, 'score_path')
        self.partner_tokens_table_path = os.path.join(self.config_dir, 'partner_tokens')
        self.scores_root_path = os.path.join(self.root_dir, 'scores')

        self.geo_linear_parts_table_path = os.path.join(self.config_dir, 'partner_linears')
        self.geo_path_table_path = os.path.join(self.config_dir, 'geo_path')
        self.geo_root_path = os.path.join(self.root_dir, 'geo')
        self.rejected_requests_path = os.path.join(self.root_dir, 'rejected/rejected_requests')
