# -*- coding: utf-8 -*-

import logging
import traceback
from dataclasses import dataclass
from typing import Any, List, Dict, Tuple, Optional, Callable

from yt.wrapper import YtClient

from travel.hotels.tools.dataset_curator.data import DatasetType, DatasetVersion, TableConfiguration, Dataset
from travel.hotels.tools.dataset_curator.tools import list_recursive
from travel.hotels.tools.dataset_curator.validation_results import ValidationResult, ValidationErrorWithSamples, ValidationErrorWithCounts

LOG = logging.getLogger(__name__)

DEFAULT_MAX_ERRORS_CNT = 100


class BaseValidator:
    pass


class BasePerVersionValidator(BaseValidator):
    def validate_version(self, yt_client: YtClient, dataset_version: DatasetVersion, res: ValidationResult):
        raise NotImplementedError()


class BasePerTableValidator(BaseValidator):
    def __init__(self, table_subpath: Optional[str]):
        self.table_subpath = table_subpath

    def validate_table(self, yt_client: YtClient, path: str, table_configuration: TableConfiguration, res: ValidationResult):
        raise NotImplementedError()


class BasePerRowValidator(BaseValidator):
    def __init__(self, table_subpath: Optional[str], max_errors_cnt: int = DEFAULT_MAX_ERRORS_CNT):
        self.max_errors_cnt = max_errors_cnt
        self.table_subpath = table_subpath

    def validate_table_row(self, ctx: Dict[str, Any], row: Any, table_configuration: TableConfiguration, res: ValidationResult) -> None:
        raise NotImplementedError()

    def validate_after_rows(self, ctx: Dict[str, Any], table_configuration: TableConfiguration, res: ValidationResult) -> None:
        pass


@dataclass
class CustomUniquenessValidator(BasePerRowValidator):
    def __init__(self, table_subpath: str, key: List[str], max_errors_cnt: int = DEFAULT_MAX_ERRORS_CNT):
        super().__init__(table_subpath, max_errors_cnt)
        self.key = key

    def validate_table_row(self, ctx: Dict[str, Any], row: Any, table_configuration: TableConfiguration, res: ValidationResult) -> None:
        if 'seen_keys' not in ctx:
            ctx['seen_keys'] = set()
        seen_keys = ctx['seen_keys']
        curr_key = tuple([row[x] for x in self.key])
        if curr_key in seen_keys:
            res.add_error(ValidationErrorWithSamples(f'Key {self.key} is not unique', 1, [curr_key]))
        seen_keys.add(curr_key)


def _do_eval(code, gloval_vars, local_vars, res: ValidationResult) -> Any:
    try:
        return eval(code, gloval_vars, local_vars)
    except Exception:
        res.add_error(ValidationErrorWithCounts(f'Eval failed: {traceback.format_exc()}', 1))
        return None


def _build_check_function(res: ValidationResult, row: Any) -> Callable[[bool, Optional[str]], None]:
    def check(value: bool, error: str = None):
        if not isinstance(value, bool) or not value:
            curr_error = error or '<no error text>'
            text = f'Eval validation failed: {curr_error}'
            if row is not None:
                res.add_error(ValidationErrorWithSamples(text, 1, [row]))
            else:
                res.add_error(ValidationErrorWithCounts(text, 1))
    return check


class CustomPerRowValidator(BasePerRowValidator):
    def __init__(self, table_subpath: str, is_valid_expr: str, max_invalid_rows: int = DEFAULT_MAX_ERRORS_CNT):
        super().__init__(table_subpath, max_invalid_rows)
        self.is_valid_expr = is_valid_expr
        self.is_valid_expr_code = compile(self.is_valid_expr, 'main', 'exec')

    def validate_table_row(self, ctx: Dict[str, Any], row: Any, table_configuration: TableConfiguration, res: ValidationResult) -> None:
        if 'globals' not in ctx:
            ctx['globals'] = {}
            ctx['locals'] = {}
        ctx['globals'].update({'columns': row, 'check': _build_check_function(res, row)})
        _do_eval(self.is_valid_expr_code, ctx['globals'], ctx['locals'], res)


@dataclass
class CustomPerTableValidator(BasePerRowValidator):
    def __init__(self, table_subpath: str, init_expr: str, step_expr: str, result_expr: str, max_invalid_rows: int = DEFAULT_MAX_ERRORS_CNT):
        super().__init__(table_subpath, max_invalid_rows)
        self.init_expr = init_expr
        self.step_expr = step_expr
        self.result_expr = result_expr
        self.init_expr_code = compile(self.init_expr, 'main', 'exec')
        self.step_expr_code = compile(self.step_expr, 'main', 'exec')
        self.result_expr_code = compile(self.result_expr, 'main', 'exec')

    def validate_table_row(self, ctx: Dict[str, Any], row: Any, table_configuration: TableConfiguration, res: ValidationResult) -> None:
        if 'globals' not in ctx:
            ctx['globals'] = {}
            ctx['locals'] = {}
            curr_code = self.init_expr_code
        else:
            curr_code = self.step_expr_code
        ctx['globals'].update({'columns': row, 'check': _build_check_function(res, row)})
        _do_eval(curr_code, ctx['globals'], ctx['locals'], res)

    def validate_after_rows(self, ctx: Dict[str, Any], table_configuration: TableConfiguration, res: ValidationResult) -> None:
        ctx['globals'].update({'columns': None, 'check': _build_check_function(res, None)})
        _do_eval(self.result_expr_code, ctx['globals'], ctx['locals'], res)


class SchemaValidator(BasePerRowValidator):
    def __init__(self, max_invalid_rows: int = DEFAULT_MAX_ERRORS_CNT):
        super().__init__(None, max_invalid_rows)

    def validate_table_row(self, ctx: Dict[str, Any], row: Any, table_configuration: TableConfiguration, res: ValidationResult):
        if not table_configuration.allow_unknown_columns:
            for name in row:
                if name not in table_configuration.columns:
                    res.add_error(ValidationErrorWithSamples(f'Column {name} is unknown', 1, [row]))
        for name, column in table_configuration.columns.items():
            if name not in row:
                res.add_error(ValidationErrorWithSamples(f'Column {name} is not found', 1, [row]))
            else:
                for err in column.type.validate_value(row[name]):
                    res.add_error(ValidationErrorWithSamples(f'Column {name}: ' + err, 1, [row]))


class DatasetDirectoryStructureValidator(BasePerVersionValidator):
    def validate_version(self, yt_client: YtClient, dataset_version: DatasetVersion, res: ValidationResult):
        if dataset_version.dataset.dataset_type == DatasetType.SINGLE_TABLE:
            return
        elif dataset_version.dataset.dataset_type == DatasetType.DIRECTORY:
            known_subpaths = {table.subpath for table in dataset_version.dataset.tables}
            actual_subpath = set(list_recursive(yt_client, dataset_version.path))
            for not_found_subpath in known_subpaths - actual_subpath:
                res.add_error(ValidationErrorWithCounts(f'Not found table "{not_found_subpath}"', 1))
            for unexpected_subpath in actual_subpath - known_subpaths:
                res.add_error(ValidationErrorWithCounts(f'Unexpected table "{unexpected_subpath}"', 1))
        else:
            raise Exception(f'Unknown DatasetType: {dataset_version.dataset.dataset_type}')


def _get_table_custom_validations(table: TableConfiguration) -> List[Tuple[str, BaseValidator]]:
    res = []
    for validation in table.uniqueness_validations:
        res.append((f'{table.get_id()}.{validation.id}', CustomUniquenessValidator(table.subpath, validation.key)))
    for validation in table.per_row_validations:
        res.append((f'{table.get_id()}.{validation.id}', CustomPerRowValidator(table.subpath, validation.is_valid_expr)))
    for validation in table.per_table_validations:
        res.append((f'{table.get_id()}.{validation.id}', CustomPerTableValidator(table.subpath, validation.init_expr, validation.step_expr, validation.result_expr)))
    return res


def get_dataset_custom_validations(dataset: Dataset) -> Dict[str, BaseValidator]:
    if dataset.dataset_type == DatasetType.SINGLE_TABLE:
        validations = _get_table_custom_validations(dataset.table)
    elif dataset.dataset_type == DatasetType.DIRECTORY:
        validations = sum([_get_table_custom_validations(x) for x in dataset.tables], [])
    else:
        raise Exception(f'Unknown DatasetType: {dataset.dataset_type}')

    res = {}
    for id, validation in validations:
        if id in res:
            raise ValueError(f'Duplicate validation id "{id}" in dataset {dataset.name}')
        res[id] = validation
    return res
