from yt.wrapper import ypath_join, TablePath
from datacloud.config.yt import PRODUCTION_ROOT, CRYPTA_ROOT, RELIABLE_TMP_FOLDER
from datacloud.dev_utils.time.patterns import FMT_DATE
from datacloud.dev_utils.logging.logger import get_basic_logger
from datetime import datetime, timedelta
from datacloud.dev_utils.yt import features, yt_utils


logger = get_basic_logger(__name__)


# YT_PATH2MODEL = 'yt:' + ypath_join(PRODUCTION_ROOT, 'datacloud/bins/model.dssm')
YT_PATH2MODEL = 'yt://hahn/home/x-products/production/datacloud/bins/model.dssm'
DEFAULT_RETRO_DAYS_TO_TAKE = 175
DEFAULT_PROD_DAYS_TO_TAKE = 7


def get_prod_log_tables(yt_client, config):
    date_time = datetime.strptime(config.date_str, FMT_DATE)
    tables = []
    for day_diff in range(config.days_to_take):
        date = (date_time - timedelta(days=day_diff)).strftime(FMT_DATE)
        for folder in config.log_folders:
            table = ypath_join(config.grep_root, folder, date)
            if yt_utils.check_table_exists(table, yt_client):
                tables.append(table)
            else:
                logger.warning('WARNING: table {} does not exists'.format(table))
    return tables


def get_retro_log_tables(yt_client, config):
    tables = []
    for log_folder in config.log_folders:
        for table in yt_client.list(ypath_join(config.grep_root, log_folder),
                                    absolute=True):
            tables.append(TablePath(
                table,
                columns=['external_id', 'timestamp', 'title', 'url', 'yuid'])
            )
    return tables


class DSSMConfig(object):
    def __init__(self, date_str, base_root=PRODUCTION_ROOT,
                 days_to_take=None, is_retro=False, model_url=YT_PATH2MODEL,
                 retro_tag='', garbage_collect_on=True, use_cloud_nodes=False,
                 yuid2cid_path='', weekly_dir='', ready_table_path=''):
        """
        Use retro_tag to differ your tmp folder from others
        """
        self.is_retro = is_retro
        self.tag = 'DSSM-' + date_str + retro_tag
        self.date_str = date_str
        self.base_root = base_root
        self.root = ypath_join(self.base_root, 'datacloud')
        if days_to_take is None:
            if is_retro:
                days_to_take = DEFAULT_RETRO_DAYS_TO_TAKE
            else:
                days_to_take = DEFAULT_PROD_DAYS_TO_TAKE
        self.days_to_take = days_to_take
        self.model_url = model_url
        self.garbage_collect_on = garbage_collect_on
        self.use_cloud_nodes = use_cloud_nodes
        self.grep_root = ypath_join(self.base_root, 'datacloud/grep')

        self.log_folders = ('watch_log_tskv', 'spy_log')

        if not yuid2cid_path:
            if is_retro:
                yuid2cid_path = ypath_join(self.base_root, 'input_yuid')
            else:
                yuid2cid_path = ypath_join(CRYPTA_ROOT, 'crypta_db_last/yuid_to_cid')

        if is_retro:
            self.get_grep_tables = get_retro_log_tables
            self.ext_id_key = 'external_id'
            self.yuid2cid_table = TablePath(
                yuid2cid_path, columns=['yuid', self.ext_id_key, 'timestamp']
            )
        else:
            self.get_grep_tables = get_prod_log_tables
            self.ext_id_key = 'cid'
            self.yuid2cid_table = TablePath(
                yuid2cid_path, columns=['yuid', self.ext_id_key]
            )
        self.weekly_dir = weekly_dir or ypath_join(self.root, '/aggregates/dssm/weekly')
        self.tmp_dir = ypath_join(RELIABLE_TMP_FOLDER, 'dssm-prod',
                                  date_str + retro_tag)
        self.ready_dir = ypath_join(self.root, '/aggregates/dssm/ready')

        ready_table_path = ready_table_path or ypath_join(self.ready_dir, 'features')

        compression_params = {
            'compression_codec': 'brotli_3',
            'optimize_for': 'scan',
        }

        self.yuid2title_url4_table_all = TablePath(
            ypath_join(self.tmp_dir, 'yuid2title_url_all'),
            schema=[
                {'name': 'title', 'type': 'string'},
                {'name': 'url', 'type': 'string'},
                {'name': 'key', 'type': 'string'},
                {'name': 'hash', 'type': 'string'},
                {'name': 'timestamp', 'type': 'int64'}
            ],
            attributes=compression_params
        )
        self.yuid2title_url4_table = TablePath(
            ypath_join(self.tmp_dir, 'yuid2title-unique'),
            schema=[
                {'name': 'title', 'type': 'string'},
                {'name': 'url', 'type': 'string'},
                {'name': 'yuids', 'type': 'string'},
                {'name': 'hash', 'type': 'string'}
            ],
            attributes=compression_params
        )
        self.vectors_table = ypath_join(self.tmp_dir, 'vector')
        self.vectors_table_binary = TablePath(
            ypath_join(self.tmp_dir, 'bin'),
            schema=[
                {'name': 'hash', 'type': 'string'},
                {'name': 'vector', 'type': 'string'}
            ],
            attributes=compression_params
        )
        self.dot_table = TablePath(
            ypath_join(self.tmp_dir, 'dot'),
            schema=[
                {'name': 'hash', 'type': 'string'},
                {'name': 'vector', 'type': 'string'}
            ],
            attributes=compression_params
        )
        self.id2_dot_tmp = TablePath(
            ypath_join(self.tmp_dir, 'dot_tmp'),
            schema=[
                {'name': 'key', 'type': 'string'},
                {'name': 'features', 'type': 'string'}
            ],
            attributes=compression_params
        )

        self.result_dssm_table = TablePath(
            ypath_join(self.weekly_dir, date_str),
            schema=[
                {'name': self.ext_id_key, 'type': 'string'},
                {'name': 'features', 'type': 'string'}
            ],
            attributes=compression_params
        )
        self.ready_table = TablePath(
            ready_table_path,
            schema=[
                {'name': self.ext_id_key, 'type': 'string'},
                {'name': 'features', 'type': 'string'}
            ],
            attributes=compression_params
        )

    @property
    def cloud_nodes_spec(self):
        return features.cloud_nodes_spec(self.use_cloud_nodes)
