# -*- coding: utf-8 -*-
from __future__ import unicode_literals

import yt.wrapper as yt

from travel.cpa.lib.lib_logging import get_logger


LOG = get_logger(__name__)


def backup_table(yt_client, table_path, backup_dir):
    creation_time = yt_client.get(yt.ypath_join(table_path, '@creation_time')).split('.')[0]
    backup_path = yt.ypath_join(backup_dir, creation_time)
    yt_client.copy(table_path, backup_path, recursive=True)
    nightly_compression_settings = {
        'enabled': True,
        'compression_codec': 'zstd_9',
        'erasure_codec': 'lrc_12_2_2',
        'min_table_age': 0,
        'pool': yt_client.get_user_name()
    }
    yt_client.set(yt.ypath_join(backup_dir, '/@nightly_compression_settings'), nightly_compression_settings)
    yt_client.set(yt.ypath_join(backup_path, '/@optimize_for'), 'scan')
    yt_client.run_merge(backup_path, backup_path, spec={'force_transform': True})


def create_temp_table(yt_client, table, schema):
    if yt_client.exists(table):
        LOG.info('Temporary table %s already exists, dropping it', table)
        yt_client.remove(table)
    LOG.info('Creating temporary table: %s', table)
    yt_client.create('table', table, attributes={'schema': schema})


def get_table_schema(common_schema, sort_by=None, hide_fields=None):
    if sort_by is None:
        sort_by = []
    if hide_fields is None:
        hide_fields = []
    for field in hide_fields:
        common_schema.pop(field, None)
    unsorted_fields = [{'name': n, 'type': t} for n, t in common_schema.items() if n not in sort_by]
    sorted_fields = [{'name': n, 'type': common_schema[n], 'sort_order': 'ascending'} for n in sort_by]
    return sorted_fields + unsorted_fields


class ProcessedTables(object):
    def __init__(self, yt_client, processed_table, src_dir=None, src_tables=None):
        self.yt_client = yt_client
        if src_dir is None and src_tables is None:
            raise Exception('Neither src_dir nor src_tables are set')
        self.processed_table = processed_table
        if src_tables is None:
            src_tables = []
        existing_tables = src_tables
        if src_dir is not None:
            existing_tables.extend(self.get_existing_tables(src_dir))
        existing_table_keys = set()
        for t in existing_tables:
            path = LockableYtPath(yt_client, t, try_to_lock=True)
            if not path.is_locked():
                continue
            key = self.get_table_key(path)
            existing_table_keys.add(key)
        self.existing_table_keys = existing_table_keys
        self.new_table_keys = self.existing_table_keys - self.get_processed_table_keys()
        self.new_tables = set(key[0].node_id for key in self.new_table_keys)

    def get_content_revision(self, path):
        return self.yt_client.get(yt.ypath_join(path, '@content_revision'))

    def get_table_key(self, path):
        return path, self.get_content_revision(path.node_id)

    def get_existing_tables(self, src_dir):
        res = set()
        for table in self.yt_client.search(src_dir, node_type=['table']):
            res.add(table)
        return res

    def get_processed_table_keys(self):
        res = set()
        if not self.yt_client.exists(self.processed_table):
            return res
        for row in self.yt_client.read_table(self.processed_table):
            path = row['table']
            content_revision = row.get('content_revision')
            # this check is for processed tables migration only
            # delete after migration
            if content_revision is None and self.yt_client.exists(path):
                content_revision = self.get_content_revision(path)
            res.add((LockableYtPath(self.yt_client, path), content_revision))
        return res

    def write_processed_tables(self):
        processed_tables = self.existing_table_keys | self.new_table_keys
        columns = 'table', 'content_revision'
        processed_tables_path = [(t[0].node_path, t[1]) for t in processed_tables]
        rows = (dict(zip(columns, t)) for t in sorted(processed_tables_path))
        self.yt_client.write_table(self.processed_table, rows)


class LockableYtPath(object):

    def __init__(self, yt_client, path, try_to_lock=False):
        self.yt_client = yt_client
        self.node_path = path
        self.node_id = None
        if try_to_lock:
            self.try_lock_node()

    def __eq__(self, other):
        return self.node_path == other.node_path

    def __hash__(self):
        return hash(self.node_path)

    def get_node_id(self, path):
        node_id = self.yt_client.get(yt.ypath_join(path, '@id'))
        return '#' + node_id

    def try_lock_node(self):
        try:
            node_id = self.get_node_id(self.node_path)
            self.yt_client.lock(node_id, mode='snapshot', waitable=True)
            self.node_id = node_id
        except Exception as e:
            LOG.warn('Failed to lock node "{}": {}'.format(self.node_path, e))

    def is_locked(self):
        return self.node_id is not None
