# -*- coding: utf-8 -*-
import datetime
import logging
from typing import Dict, List, Tuple, Callable, Generic, TypeVar

from travel.hotels.lib.python3.yt import ytlib
from yt.wrapper import YtClient
from yt.wrapper.errors import YtHttpResponseError

from travel.hotels.tools.dataset_curator.data import Dataset, DatasetType, DatasetVersion, TableConfiguration
from travel.hotels.tools.dataset_curator.schedules import BaseValidationSchedule, OnChangeValidationSchedule
from travel.hotels.tools.dataset_curator.tools import read_yt_table_with_progress
from travel.hotels.tools.dataset_curator.validation_results import ValidationResult, ValidationErrorWithCounts
from travel.hotels.tools.dataset_curator.validators import SchemaValidator, get_dataset_custom_validations, BasePerVersionValidator, BasePerTableValidator, BasePerRowValidator, \
    DatasetDirectoryStructureValidator
from travel.hotels.tools.dataset_curator.yt_storage import StorageKey, CurrentValidationResultsStorage, try_get_dataset_version_lock, BaseValidationInfoStorage

LOG = logging.getLogger(__name__)

T = TypeVar('T')


class ValidationConfiguration(Generic[T]):
    def __init__(self, id: str, validator: T, schedule: BaseValidationSchedule):
        self.id = id
        self.validator = validator
        self.schedule = schedule


class DatasetVersionValidator:
    def __init__(self, validations: List[ValidationConfiguration]):
        self.per_version_validations: List[ValidationConfiguration[BasePerVersionValidator]] = []
        self.per_table_validations: List[ValidationConfiguration[BasePerTableValidator]] = []
        self.per_row_validations: List[ValidationConfiguration[BasePerRowValidator]] = []
        for validation in validations:
            if isinstance(validation.validator, BasePerVersionValidator):
                self.per_version_validations.append(validation)
            elif isinstance(validation.validator, BasePerTableValidator):
                self.per_table_validations.append(validation)
            elif isinstance(validation.validator, BasePerRowValidator):
                self.per_row_validations.append(validation)
            else:
                raise Exception(f'Unknown validator type: {type(validation.validator)}')

    @staticmethod
    def _get_relevant_validations(yt_cluster: str,
                                  dataset_version: DatasetVersion,
                                  yt_client: YtClient,
                                  validations: List[ValidationConfiguration[T]]) -> List[Tuple[ValidationConfiguration[T], Callable[[], None]]]:
        res = []
        for validation in validations:
            need_validation, on_validation_done = validation.schedule.need_validation(yt_cluster, dataset_version, validation.id, yt_client)
            if need_validation:
                res.append((validation, on_validation_done))
            else:
                LOG.debug(f'{dataset_version.get_dataset_id()}|{yt_cluster}|{dataset_version.get_id()}|{validation.id}: Skipping validation')
        return res

    def run_validations_per_version(self, yt_cluster: str, dataset_version: DatasetVersion, yt_client: YtClient) -> Dict[StorageKey, ValidationResult]:
        validation_results: Dict[StorageKey, ValidationResult] = {}
        with yt_client.Transaction():
            try:
                yt_client.lock(dataset_version.path, mode='snapshot')
                self._do_run_validations_per_version(yt_cluster, dataset_version, yt_client, validation_results)
            except YtHttpResponseError as e:
                if not e.is_resolve_error():
                    raise
                for validation in self.per_version_validations + self.per_table_validations + self.per_row_validations:
                    res = ValidationResult()
                    res.add_error(ValidationErrorWithCounts('Dataset version not found', 1))
                    self._report_validation_result(yt_cluster, dataset_version, validation.id, res, validation_results)
        return validation_results

    def _do_run_validations_per_version(self, yt_cluster: str, dataset_version: DatasetVersion, yt_client: YtClient, validation_results: Dict[StorageKey, ValidationResult]):
        relevant_per_version_validations = self._get_relevant_validations(yt_cluster, dataset_version, yt_client, self.per_version_validations)
        relevant_per_table_validations = self._get_relevant_validations(yt_cluster, dataset_version, yt_client, self.per_table_validations)
        relevant_per_row_validations = self._get_relevant_validations(yt_cluster, dataset_version, yt_client, self.per_row_validations)

        for validation, on_validation_done in relevant_per_version_validations:
            LOG.debug(f'{dataset_version.get_dataset_id()}|{yt_cluster}|{dataset_version.get_id()}|{validation.id}: Running validation')
            res = ValidationResult()
            validation.validator.validate_version(yt_client, dataset_version, res)
            self._report_validation_result(yt_cluster, dataset_version, validation.id, res, validation_results)
            on_validation_done()

        for path, table in self._get_dataset_version_tables(dataset_version):
            for validation, _ in relevant_per_table_validations:
                if validation.validator.table_subpath is not None and table.subpath != validation.validator.table_subpath:
                    continue
                LOG.debug(f'{dataset_version.get_dataset_id()}|{yt_cluster}|{dataset_version.get_id()}|{validation.id}|{table.get_id()}: Running validation')
                res = ValidationResult()
                validation.validator.validate_table(yt_client, path, table, res)
                self._report_validation_result(yt_cluster, dataset_version, validation.id, res, validation_results)

            curr_relevant_per_row_validations = [
                (val, on_done)
                for val, on_done in relevant_per_row_validations
                if val.validator.table_subpath is None or table.subpath == val.validator.table_subpath
            ]

            if len(curr_relevant_per_row_validations) > 0:
                contexts = [(dict(), ValidationResult(), validation) for validation, on_validation_done in curr_relevant_per_row_validations]
                for ctx, res, validation in contexts:
                    LOG.debug(f'{dataset_version.get_dataset_id()}|{yt_cluster}|{dataset_version.get_id()}|{validation.id}|{table.get_id()}: Running validation')

                for row in read_yt_table_with_progress(yt_client, path):
                    all_stopped = True
                    for i, (ctx, res, validation) in enumerate(contexts):
                        if res.error_cnt < validation.validator.max_errors_cnt:
                            validation.validator.validate_table_row(ctx, row, table, res)
                            if res.error_cnt == validation.validator.max_errors_cnt:
                                LOG.info(f'Too many errors, stopping validation "{validation.id}"')
                            else:
                                all_stopped = False
                    if all_stopped:
                        break

                for ctx, res, validation in contexts:
                    validation.validator.validate_after_rows(ctx, table, res)
                    self._report_validation_result(yt_cluster, dataset_version, validation.id, res, validation_results)

        for validation, on_validation_done in relevant_per_table_validations:
            on_validation_done()

        for validation, on_validation_done in relevant_per_row_validations:
            on_validation_done()

    @staticmethod
    def _report_validation_result(yt_cluster: str, dataset_version: DatasetVersion, validation_id: str, result: ValidationResult, validation_results: Dict[StorageKey, ValidationResult]):
        if result.is_ok():
            LOG.debug(f'({dataset_version.get_dataset_id()}|{yt_cluster}|{dataset_version.get_id()}|{validation_id}): validation is ok')
        else:
            LOG.error(f'({dataset_version.get_dataset_id()}|{yt_cluster}|{dataset_version.get_id()}|{validation_id}): validation is failed: ' + str(result))
        validation_results[StorageKey(yt_cluster, dataset_version.get_dataset_id(), dataset_version.get_id(), validation_id)] = result

    @staticmethod
    def _get_dataset_version_tables(dataset_version: DatasetVersion) -> List[Tuple[str, TableConfiguration]]:
        if dataset_version.dataset.dataset_type == DatasetType.SINGLE_TABLE:
            return [(dataset_version.path, dataset_version.dataset.table)]
        elif dataset_version.dataset.dataset_type == DatasetType.DIRECTORY:
            return [
                (str(ytlib.join(dataset_version.path, table.subpath)), table)
                for table in dataset_version.dataset.tables
                if not table.ignored
            ]
        else:
            raise Exception(f'Unknown DatasetType: {dataset_version.dataset.dataset_type}')


class DatasetValidationRunner:
    def __init__(self, validations: List[ValidationConfiguration], yt_token: str, aux_data_yt_dir: str, aux_data_yt_proxy: str, deadline: datetime.datetime):
        self.dataset_version_validator = DatasetVersionValidator(validations)
        self.yt_clients = {}
        self.yt_token = yt_token
        self.aux_data_yt_dir = aux_data_yt_dir
        self.aux_data_yt_proxy = aux_data_yt_proxy
        self.deadline = deadline
        self._all_validation_results: Dict[StorageKey, ValidationResult] = {}

    def _commit_results(self):
        if len(self._all_validation_results) > 0:
            validation_results_storage = CurrentValidationResultsStorage(self.aux_data_yt_dir, self.aux_data_yt_proxy, self.yt_token)
            validation_results_storage.update(self._all_validation_results)
            self._all_validation_results = {}

    def _get_yt_client(self, yt_cluster: str) -> YtClient:
        if yt_cluster not in self.yt_clients:
            self.yt_clients[yt_cluster] = YtClient(proxy=yt_cluster, token=self.yt_token)
        return self.yt_clients[yt_cluster]

    def run_validations(self, dataset: Dataset, sync_validation_infos: Callable[[], None]) -> None:
        aux_data_yt_client = YtClient(self.aux_data_yt_proxy, self.yt_token)
        for yt_cluster in dataset.yt_clusters:
            LOG.info(f'{dataset.name}: Processing cluster "{yt_cluster}"')
            yt_client = self._get_yt_client(yt_cluster)
            for dataset_version in dataset.get_dataset_versions(yt_client):
                if self.deadline is not None and datetime.datetime.utcnow() >= self.deadline:
                    LOG.info(f'Stopping dataset {dataset.name} because of time limit')
                    return
                with try_get_dataset_version_lock(aux_data_yt_client, self.aux_data_yt_dir, yt_cluster, dataset_version) as locked:
                    ver_id = dataset_version.get_id()
                    if not locked:
                        LOG.debug(f'{dataset.name}: Skipping version "{ver_id}" because it is locked')
                        continue
                    LOG.debug(f'{dataset.name}|{yt_cluster}: Processing version "{ver_id}"')
                    validation_results = self.dataset_version_validator.run_validations_per_version(yt_cluster, dataset_version, yt_client)
                    self._all_validation_results.update(validation_results)
                    self._commit_results()
                    sync_validation_infos()


def get_dataset_validations(dataset: Dataset, validation_info_storage: BaseValidationInfoStorage) -> List[ValidationConfiguration]:
    common_validations = [
        ValidationConfiguration('schema', SchemaValidator(), OnChangeValidationSchedule(validation_info_storage)),
        ValidationConfiguration('directory_structure', DatasetDirectoryStructureValidator(), OnChangeValidationSchedule(validation_info_storage)),
    ]
    custom_validations = [
        ValidationConfiguration(id, validation, OnChangeValidationSchedule(validation_info_storage))
        for id, validation in get_dataset_custom_validations(dataset).items()
    ]

    return common_validations + custom_validations
