# -*- coding: utf-8 -*-

import logging
from typing import Tuple, Callable

from travel.hotels.lib.python3.yt import ytlib
from yt.wrapper import YtClient

from travel.hotels.tools.dataset_curator.data import DatasetVersion
from travel.hotels.tools.dataset_curator.yt_storage import StorageKey, BaseValidationInfoStorage

LOG = logging.getLogger(__name__)


class BaseValidationSchedule:
    def need_validation(self, cluster: str, dataset_version: DatasetVersion, validation_id: str, yt_client: YtClient) -> Tuple[bool, Callable[[], None]]:
        raise NotImplementedError()

    def force_mark_done(self, cluster: str, dataset_version: DatasetVersion, validation_id: str, yt_client: YtClient):
        raise NotImplementedError()


class OnChangeValidationSchedule(BaseValidationSchedule):
    def __init__(self, validation_info_storage: BaseValidationInfoStorage):
        self.validation_info_storage = validation_info_storage

    def need_validation(self, cluster: str, dataset_version: DatasetVersion, validation_id: str, yt_client: YtClient) -> Tuple[bool, Callable[[], None]]:
        key = StorageKey(cluster, dataset_version.get_dataset_id(), dataset_version.get_id(), validation_id)
        previous_data = self.validation_info_storage.load(key)
        curr_revision_tree = self._get_yt_directory_content_revision_tree(yt_client, dataset_version.path)
        if curr_revision_tree != previous_data:
            def on_validaton_done():
                self.validation_info_storage.save(key, curr_revision_tree)

            return True, on_validaton_done
        return False, lambda: None

    def force_mark_done(self, cluster: str, dataset_version: DatasetVersion, validation_id: str, yt_client: YtClient):
        key = StorageKey(cluster, dataset_version.get_dataset_id(), dataset_version.get_id(), validation_id)
        curr_revision_tree = self._get_yt_directory_content_revision_tree(yt_client, dataset_version.path)
        self.validation_info_storage.save(key, curr_revision_tree)

    def _get_yt_directory_content_revision_tree(self, yt_client: YtClient, path: str, content_revision=None):
        node_type = yt_client.get_attribute(path, "type")
        res = {
            'content_revision': yt_client.get_attribute(path, 'content_revision') if content_revision is None else content_revision
        }
        if node_type == 'table':
            return res
        elif node_type == 'map_node':
            children = yt_client.list(path, attributes=['content_revision'])
            res['children'] = {
                str(x): self._get_yt_directory_content_revision_tree(yt_client, str(ytlib.join(path, str(x))), x.attributes['content_revision'])
                for x in children
            }
            return res
        else:
            raise NotImplementedError(f'Unknown node type {node_type}')
