# -*- coding: utf-8 -*-

import logging
import contextlib
import hashlib
from dataclasses import dataclass
from typing import Dict, Any, List

from travel.hotels.lib.python3.yt import ytlib
from yt.wrapper import YtClient

import yt

from travel.hotels.tools.dataset_curator.data import DatasetVersion
from travel.hotels.tools.dataset_curator.validation_results import ValidationResult

LOG = logging.getLogger(__name__)

DEFAULT_LOCK_WAIT_TIME_MS = 5 * 60 * 1000  # 5 min


@dataclass(eq=True, frozen=True)
class StorageKey:
    cluster: str
    dataset_id: str
    dataset_version_id: str
    validation_id: str


class BaseValidationInfoStorage:
    def load(self, storage_key: StorageKey) -> Any:
        raise NotImplementedError()

    def save(self, storage_key: StorageKey, data: Any) -> None:
        raise NotImplementedError()

    def sync(self) -> None:
        raise NotImplementedError()


class ValidationInfoInMemoryStorage(BaseValidationInfoStorage):
    def __init__(self):
        self.data: Dict[StorageKey, Any] = {}

    def load(self, storage_key: StorageKey) -> Any:
        res = self.data.get(storage_key)
        return res['data'] if res is not None else None

    def save(self, storage_key: StorageKey, data: Any) -> None:
        self.data[storage_key] = data

    def sync(self) -> None:
        pass


class ValidationInfoYtStorage(BaseValidationInfoStorage):
    def __init__(self, data_path: str, yt_proxy: str, yt_token: str):
        self.path = data_path
        self.yt_client: YtClient = YtClient(yt_proxy, yt_token)
        self.data = None
        self.updates = {}

    def _get_table_path(self):
        return ytlib.join(self.path, 'validation_infos')

    def _load_table_content(self) -> Dict[StorageKey, Any]:
        table_content: Dict[StorageKey, Any] = dict()
        for row in self.yt_client.read_table(self._get_table_path()):
            table_content[StorageKey(row['cluster'], row['dataset_id'], row['dataset_version_id'], row['validation_id'])] = row
        return table_content

    def _create_table(self):
        schema = ytlib.schema_from_dict({
            'cluster': 'string',
            'dataset_id': 'string',
            'dataset_version_id': 'string',
            'validation_id': 'string',
            'data': 'any',
        })
        self.yt_client.create("table", self._get_table_path(), attributes={"schema": schema}, ignore_existing=True)

    def _prepare(self, ):
        if self.data is None:
            self.sync()

    def load(self, storage_key: StorageKey) -> Any:
        self._prepare()
        res = self.data.get(storage_key)
        return res['data'] if res is not None else None

    def save(self, storage_key: StorageKey, data: Any) -> None:
        self.updates[storage_key] = {
            'cluster': storage_key.cluster,
            'dataset_id': storage_key.dataset_id,
            'dataset_version_id': storage_key.dataset_version_id,
            'validation_id': storage_key.validation_id,
            'data': data,
        }

    def sync(self) -> None:
        with self.yt_client.Transaction():
            self._create_table()
            self.yt_client.lock(self._get_table_path(), waitable=True, wait_for=DEFAULT_LOCK_WAIT_TIME_MS)
            self.data = self._load_table_content()
            if len(self.updates) > 0:
                for key, value in self.updates.items():
                    self.data[key] = value
                self.yt_client.write_table(self._get_table_path(), self.data.values())


class CurrentValidationResultsStorage:
    def __init__(self, data_path: str, yt_proxy: str, yt_token: str):
        self.path: str = data_path
        self.yt_client: YtClient = YtClient(yt_proxy, yt_token)

    def _get_table_path(self):
        return ytlib.join(self.path, 'validation_results')

    def update(self, updates: Dict[StorageKey, ValidationResult]):
        with self.yt_client.Transaction():
            self._prepare()
            self.yt_client.lock(self._get_table_path(), waitable=True, wait_for=DEFAULT_LOCK_WAIT_TIME_MS)
            table_content: Dict[StorageKey, Any] = dict()
            for row in self.yt_client.read_table(self._get_table_path()):
                table_content[StorageKey(row['cluster'], row['dataset_id'], row['dataset_version_id'], row['validation_id'])] = row
            for key, validation_result in updates.items():
                if key not in table_content:
                    table_content[key] = {
                        'cluster': key.cluster,
                        'dataset_id': key.dataset_id,
                        'dataset_version_id': key.dataset_version_id,
                        'validation_id': key.validation_id,
                    }
                table_content[key]['is_valid'] = validation_result.is_ok()
                table_content[key]['errors'] = [x.dump_as_dict() for x in validation_result.errors.values()]
            self.yt_client.write_table(self._get_table_path(), table_content.values())

    def _prepare(self):
        if not self.yt_client.exists(self._get_table_path()):
            schema = ytlib.schema_from_dict({
                'cluster': 'string',
                'dataset_id': 'string',
                'dataset_version_id': 'string',
                'validation_id': 'string',
                'is_valid': 'boolean',
                'errors': 'any',
            })
            self.yt_client.create("table", self._get_table_path(), attributes={"schema": schema})


@contextlib.contextmanager
def try_get_dataset_version_lock(yt_client: YtClient, aux_data_yt_dir: str, yt_cluster: str, dataset_version: DatasetVersion):
    locks_path = ytlib.join(aux_data_yt_dir, 'locks')
    yt_client.create('map_node', locks_path, ignore_existing=True)

    key = f'{yt_cluster}|{dataset_version.get_dataset_id()}|{dataset_version.get_id()}'
    key_hash = hashlib.sha256(key.encode()).hexdigest()
    lock_path = ytlib.join(aux_data_yt_dir, 'locks', key_hash)

    with yt_client.Transaction():
        try:
            yt_client.create('table', lock_path)  # create and take shared lock
            yield True
            yt_client.remove(lock_path)
        except yt.wrapper.errors.YtCypressTransactionLockConflict:
            yield False


@contextlib.contextmanager
def try_get_dataset_version_lock_composite(yt_client: YtClient, aux_data_yt_dir: str, yt_clusters: List[str], dataset_version: DatasetVersion):
    if len(yt_clusters) == 0:
        yield True
        return

    with try_get_dataset_version_lock(yt_client, aux_data_yt_dir, yt_clusters[0], dataset_version) as locked_first:
        if not locked_first:
            yield False
            return
        with try_get_dataset_version_lock_composite(yt_client, aux_data_yt_dir, yt_clusters[1:], dataset_version) as locked_other:
            yield locked_other
