# -*- coding: utf-8 -*-
import copy
import datetime
import logging
import uuid
from typing import Dict, List, Any, Set

from travel.hotels.lib.python3.yt import ytlib
from yql.api.v1.client import YqlClient
from yt.wrapper import YtClient

from travel.hotels.tools.dataset_curator.data import Dataset, DatasetVersion, DatasetType
from travel.hotels.tools.dataset_curator.dataset_build_configuration import InnerBuildContext, BuildContext
from travel.hotels.tools.dataset_curator.validation_results import ValidationResult
from travel.hotels.tools.dataset_curator.validation_running import get_dataset_validations, DatasetVersionValidator, ValidationConfiguration
from travel.hotels.tools.dataset_curator.yt_storage import StorageKey, ValidationInfoYtStorage, CurrentValidationResultsStorage, try_get_dataset_version_lock_composite, \
    ValidationInfoInMemoryStorage

LOG = logging.getLogger(__name__)


class DatasetBuilder:
    def __init__(self, tmp_yt_dir: str, aux_data_yt_dir: str, aux_data_yt_proxy: str, yt_token: str, yql_token: str, allow_latest_update: bool):
        self.tmp_yt_dir = tmp_yt_dir
        self.aux_data_yt_dir = aux_data_yt_dir
        self.aux_data_yt_proxy = aux_data_yt_proxy
        self.yt_token = yt_token
        self.yql_token = yql_token
        self.allow_latest_update = allow_latest_update

    @staticmethod
    def _get_versioned_path_name(yt_client: YtClient, base_path: str):
        timestamp = datetime.datetime.utcnow()
        time_label = timestamp.isoformat().split(".")[0] + "Z"
        short_name = time_label[:10]
        if not yt_client.exists(ytlib.join(base_path, short_name)):
            return short_name
        else:
            return time_label

    def build_dataset(self, dataset: Dataset, inner_build_context: InnerBuildContext, args: Dict[str, Any]):
        yt_client = YtClient(inner_build_context.yt_cluster, self.yt_token)

        current_dataset_path = str(ytlib.join(self.tmp_yt_dir, f'{dataset.name}-{str(uuid.uuid4())}'))
        yt_client.create('map_node', current_dataset_path, recursive=True, ignore_existing=True)
        try:
            current_version_name = None
            if inner_build_context.versioned_path:
                current_version_name = self._get_versioned_path_name(yt_client, dataset.yt_path)
                if dataset.dataset_type == DatasetType.DIRECTORY:
                    yt_client.create('map_node', ytlib.join(current_dataset_path, current_version_name))
            yql_client = YqlClient(token=self.yql_token, db=inner_build_context.yt_cluster)
            build_context = BuildContext(yql_client, inner_build_context, YtClient(inner_build_context.yt_cluster, self.yt_token), dataset.yt_path,
                                         current_dataset_path, str(ytlib.join(current_dataset_path, current_version_name)), [])
            inner_build_context.func(build_context, args)
            build_context._finish()

            subnodes = yt_client.list(current_dataset_path)
            if len(subnodes) != 1:
                raise Exception(f'Expected exactly one node (=one new version) in dataset tmp directory dir after build, got {len(subnodes)}: {subnodes}')
            if current_version_name is None:
                current_version_name = subnodes[0]
            current_tmp_dataset = copy.copy(dataset)
            current_tmp_dataset.yt_path = current_dataset_path
            current_tmp_version = DatasetVersion(current_tmp_dataset, str(ytlib.join(current_dataset_path, current_version_name)))
            if inner_build_context.skip_empty:
                if dataset.dataset_type != DatasetType.SINGLE_TABLE:
                    raise Exception(f'Can use skip_empty setting only for SINGLE_TABLE datasets (got {dataset.dataset_type} for {dataset.name})')
                if yt_client.is_empty(current_tmp_version.path):
                    LOG.info(f'Built dataset version is empty, skipping it ({dataset.name})')
                    return
            if inner_build_context.skip_unexistent:
                if dataset.dataset_type != DatasetType.SINGLE_TABLE:
                    raise Exception(f'Can use skip_unexistent_table setting only for SINGLE_TABLE datasets (got {dataset.dataset_type} for {dataset.name})')
                if not yt_client.exists(current_tmp_version.path):
                    LOG.info(f'Built dataset version is not created, skipping it ({dataset.name})')
                    return

            validation_info_yt_storage = ValidationInfoYtStorage(self.aux_data_yt_dir, self.aux_data_yt_proxy, self.yt_token)
            validations_with_tmp_storage = get_dataset_validations(dataset, ValidationInfoInMemoryStorage())
            validations_with_yt_storage = get_dataset_validations(dataset, validation_info_yt_storage)
            validation_results = DatasetVersionValidator(validations_with_tmp_storage).run_validations_per_version(inner_build_context.yt_cluster, current_tmp_version, yt_client)
            validation_results = {k.validation_id: v for k, v in validation_results.items()}  # Dropping storage keys with tmp versions
            if not all([x.is_ok() for x in validation_results.values()]):
                raise Exception('Some validations failed, not committing new dataset version')

            built_version = DatasetVersion(dataset, str(ytlib.join(dataset.yt_path, current_version_name)))

            all_clusters_set: Set[str] = {inner_build_context.yt_cluster}
            if inner_build_context.transfer_results:
                all_clusters_set.update({x for x in dataset.yt_clusters})
            all_clusters: List[str] = sorted(list(all_clusters_set))

            with try_get_dataset_version_lock_composite(YtClient(self.aux_data_yt_proxy, self.yt_token), self.aux_data_yt_dir, all_clusters, built_version) as locked:
                ver_id = built_version.get_id()
                if not locked:
                    raise Exception(f'{dataset.name}: Version "{ver_id}" is locked on some of clusters: {all_clusters}')

                yt_client.create('map_node', dataset.yt_path, recursive=True, ignore_existing=True)
                with yt_client.Transaction():
                    yt_client.lock(dataset.yt_path, 'shared', child_key=current_version_name)
                    yt_client.remove(built_version.path, recursive=True, force=True)
                    yt_client.move(current_tmp_version.path, built_version.path)
                self._commit_validation_results(inner_build_context.yt_cluster, built_version, yt_client, validation_results, validations_with_yt_storage)
                validation_info_yt_storage.sync()

                latest_path = ytlib.join(dataset.yt_path, 'latest')
                need_latest = inner_build_context.create_latest and self.allow_latest_update
                if need_latest:
                    yt_client.link(built_version.path, latest_path, force=True)
                if inner_build_context.transfer_results:
                    for dst_cluster in dataset.yt_clusters:
                        if dst_cluster == inner_build_context.yt_cluster:
                            continue
                        dst_cluster_client = YtClient(dst_cluster, self.yt_token)
                        dst_cluster_client.remove(built_version.path, recursive=True, force=True)
                        ytlib.transfer_results(path=built_version.path,
                                               source_cluster=inner_build_context.yt_cluster,
                                               destination_cluster=dst_cluster,
                                               yt_token=self.yt_token,
                                               link_to_path=latest_path if need_latest else None,
                                               is_dir=dataset.dataset_type == DatasetType.DIRECTORY)

                        self._commit_validation_results(dst_cluster, built_version, dst_cluster_client, validation_results, validations_with_yt_storage)
                        validation_info_yt_storage.sync()
        finally:
            yt_client.remove(current_dataset_path, recursive=True, force=True)

    def _commit_validation_results(self,
                                   cluster: str,
                                   dataset_version: DatasetVersion,
                                   yt_client: YtClient,
                                   validation_results: Dict[str, ValidationResult],  # validation_id -> result
                                   validations: List[ValidationConfiguration]):
        if len(validation_results) > 0:
            validation_results_storage = CurrentValidationResultsStorage(self.aux_data_yt_dir, self.aux_data_yt_proxy, self.yt_token)
            validation_results_storage.update({
                StorageKey(cluster, dataset_version.get_dataset_id(), dataset_version.get_id(), val_id): res
                for val_id, res in validation_results.items()
            })
        for validation in validations:
            LOG.debug(f'Force mark: {cluster}, {dataset_version.dataset.yt_path}, {dataset_version.path}, {validation.id}, {yt_client}')
            validation.schedule.force_mark_done(cluster, dataset_version, validation.id, yt_client)
