# -*- coding: utf-8 -*-

import logging
import re
import copy
import datetime
from typing import Dict, Any, List, Optional
from jinja2 import Environment, meta, Template, StrictUndefined

import yaml

from travel.hotels.tools.dataset_curator.type_schemas import BaseTypeSchema, AnyTypeSchema, PrimitiveTypeSchema, DictTypeSchema, StringTypeSchema, ListTypeSchema, StructTypeSchema
from travel.hotels.tools.dataset_curator.data import VersioningScheme, DatasetType, ColumnConfiguration, TableConfiguration, Dataset, UniquenessValidationConfiguration, \
    PerRowValidationConfiguration, PerTableValidationConfiguration, BuildConfiguration, BuilderConfiguration, BuilderType, YqlBuilderConfiguration, SbPlannerConfiguration, ENV_SUFFIXES, \
    CleanupConfiguration, CleanupAgeDetectionMode

LOG = logging.getLogger(__name__)


class ConfigParser:
    def __init__(self, env):
        self.env = env
        self.types: Dict[str, BaseTypeSchema] = {}

    def _replace_env_overrides(self, data, parent_key: str = None):
        if parent_key in ['sb_planner']:
            return data
        if type(data) is dict:
            if parent_key not in ['columns', 'fields']:
                all_suffixes = ENV_SUFFIXES
                suffix = f'_{self.env}'
                to_replace = set()
                to_del = set()
                for key, value in data.items():
                    if key.endswith(suffix):
                        to_replace.add(key)
                    for other_suffix in all_suffixes:
                        if key.endswith(other_suffix):
                            to_del.add(key)
                for key in to_replace:
                    data[key[:-len(suffix)]] = data[key]
                for key in to_del:
                    del data[key]
            return {k: self._replace_env_overrides(v, k) for k, v in data.items()}
        elif type(data) is list:
            return [self._replace_env_overrides(x) for x in data]
        else:
            return data

    def _init_default_types(self, template_vars):
        primitive_types = PrimitiveTypeSchema.get_supported_type_names() + ['string', 'any']
        self.types = {f'{x[0].upper()}{x[1:]}': self._parse_type_schema({'primitive_type': x}, '', template_vars) for x in primitive_types}

    def load_types(self, config_data):
        template_vars = {}
        self._init_default_types(template_vars)
        types_config = yaml.safe_load(config_data)
        for type_name, type_config in types_config.items():
            LOG.debug(f'Loading type "{type_name}"...')
            type_config = ConfigParser._expect_types(type_config, [str, dict], f'types.{type_name}')
            self.types[type_name] = self._parse_type_schema(type_config, f'types.{type_name}', template_vars)

    def parse_datasets(self, config_datas) -> Dict[str, Dataset]:
        def parse_time_interval(value: str) -> datetime.timedelta:
            match = re.fullmatch('(\\d+)([dmhs])', value)
            if not match:
                raise ValueError(f'Can\'t parse time interval: {value}')
            unit = {
                'd': 24 * 60 * 60,
                'h': 60 * 60,
                'm': 60,
                's': 1,
            }
            return datetime.timedelta(days=0, seconds=int(match.group(1)) * unit[match.group(2)])

        def parse_column(data, path_for_log, template_vars) -> ColumnConfiguration:
            type = ConfigParser._get_field(data, 'type', [dict, str], template_vars, path_for_log)
            description = ConfigParser._get_field(data, 'description', str, path_for_log, template_vars)
            return ColumnConfiguration(self._parse_type_schema(type, f'{path_for_log}.type', template_vars), description)

        def parse_uniqueness_validations(data, path_for_log, template_vars) -> UniquenessValidationConfiguration:
            id = ConfigParser._get_field(data, 'id', str, path_for_log, template_vars)
            key = ConfigParser._get_list_of_type(data, 'key', str, path_for_log, template_vars)
            return UniquenessValidationConfiguration(id, key)

        def parse_per_row_validations(data, path_for_log, template_vars) -> PerRowValidationConfiguration:
            id = ConfigParser._get_field(data, 'id', str, path_for_log, template_vars)
            is_valid_expr = ConfigParser._get_field(data, 'is_valid_expr', str, path_for_log, template_vars)
            return PerRowValidationConfiguration(id, is_valid_expr)

        def parse_per_table_validations(data, path_for_log, template_vars) -> PerTableValidationConfiguration:
            id = ConfigParser._get_field(data, 'id', str, path_for_log, template_vars)
            init_expr = ConfigParser._get_field(data, 'init_expr', str, path_for_log, template_vars)
            step_expr = ConfigParser._get_field(data, 'step_expr', str, path_for_log, template_vars)
            result_expr = ConfigParser._get_field(data, 'result_expr', str, path_for_log, template_vars)
            return PerTableValidationConfiguration(id, init_expr, step_expr, result_expr)

        def parse_table(data, single, path_for_log, template_vars) -> TableConfiguration:
            subpath = ConfigParser._get_field(data, 'subpath', str, path_for_log, template_vars) if not single else None
            ignored = ConfigParser._get_field(data, 'ignored', bool, path_for_log, template_vars, default=False) if not single else False
            allow_unknown_columns = ConfigParser._get_field(data, 'allow_unknown_columns', bool, path_for_log, template_vars, default=False)
            columns = None
            if not ignored:
                raw_columns = ConfigParser._get_field(data, 'columns', dict, path_for_log, template_vars)
                columns = {k: parse_column(v, f'{path_for_log}.columns["{k}"]', template_vars=template_vars) for k, v in raw_columns.items()}

            uniqueness_validations = [parse_uniqueness_validations(x, f'{path_for_log}.uniqueness_validations', template_vars=template_vars) for x in
                                      ConfigParser._get_list_of_type(data, 'uniqueness_validations', dict, path_for_log, template_vars, default=[])]
            per_row_validations = [parse_per_row_validations(x, f'{path_for_log}.per_row_validations', template_vars=template_vars) for x in
                                   ConfigParser._get_list_of_type(data, 'per_row_validations', dict, path_for_log, template_vars, default=[])]
            per_table_validations = [parse_per_table_validations(x, f'{path_for_log}.per_table_validations', template_vars=template_vars) for x in
                                     ConfigParser._get_list_of_type(data, 'per_table_validations', dict, path_for_log, template_vars, default=[])]

            return TableConfiguration(allow_unknown_columns, columns, subpath, ignored, uniqueness_validations, per_row_validations, per_table_validations)

        def parse_builder(data, path_for_log, template_vars) -> Optional[BuilderConfiguration]:
            if data is None:
                return None

            type = BuilderType(ConfigParser._get_field(data, 'type', str, path_for_log, template_vars))
            versioned_process = ConfigParser._get_field(data, 'versioned_process', bool, path_for_log, template_vars, default=False)
            create_latest = ConfigParser._get_field(data, 'create_latest', bool, path_for_log, template_vars, default=False)
            transfer_results = ConfigParser._get_field(data, 'transfer_results', bool, path_for_log, template_vars, default=False)
            skip_empty = ConfigParser._get_field(data, 'skip_empty', bool, path_for_log, template_vars, default=False)
            skip_unexistent = ConfigParser._get_field(data, 'skip_unexistent', bool, path_for_log, template_vars, default=False)

            specific_builder_configuration = None
            if type == BuilderType.YQL:
                query_path = ConfigParser._get_field(data, 'query_path', str, path_for_log, template_vars)
                yql_params = ConfigParser._get_field(data, 'yql_params', dict, path_for_log, template_vars)
                specific_builder_configuration = YqlBuilderConfiguration(query_path, yql_params)

            return BuilderConfiguration(type, versioned_process, create_latest, transfer_results, skip_empty, skip_unexistent, specific_builder_configuration)

        def fill_template_vars(data, template_vars):
            if isinstance(data, str):
                return Template(data, undefined=StrictUndefined).render(template_vars)
            if isinstance(data, list):
                return [fill_template_vars(x, template_vars) for x in data]
            if isinstance(data, dict):
                return {fill_template_vars(k, template_vars): fill_template_vars(v, template_vars) for k, v in data.items()}
            return data

        def parse_sb_planner(data, path_for_log, template_vars) -> Optional[SbPlannerConfiguration]:
            if data is None:
                return None

            sb_planner_args = ConfigParser._get_field(data, 'planner_args', dict, path_for_log, template_vars, required=False)
            builder_args_per_suffix = {'': ConfigParser._get_field(data, 'builder_args', dict, path_for_log, template_vars, required=False)}
            for suffix in ENV_SUFFIXES:
                builder_args_per_suffix[suffix] = ConfigParser._get_field(data, f'builder_args{suffix}', dict, path_for_log, template_vars, required=False)

            sb_planner_args = fill_template_vars(sb_planner_args, template_vars)
            for k, v in builder_args_per_suffix.items():
                builder_args_per_suffix[k] = fill_template_vars(v, template_vars)

            return SbPlannerConfiguration(sb_planner_args, {k: v for k, v in builder_args_per_suffix.items() if v is not None})

        def parse_build(data, path_for_log, template_vars) -> Optional[BuildConfiguration]:
            if data is None:
                return None

            builder = parse_builder(ConfigParser._get_field(data, 'builder', dict, path_for_log, template_vars, required=False), f'{path_for_log}.builder', template_vars=template_vars)
            sb_planner = parse_sb_planner(ConfigParser._get_field(data, 'sb_planner', dict, path_for_log, template_vars, required=False), f'{path_for_log}.sb_planner', template_vars=template_vars)

            return BuildConfiguration(builder, sb_planner)

        def parse_cleanup(data, path_for_log, template_vars) -> CleanupConfiguration:
            age_detection_mode = CleanupAgeDetectionMode(ConfigParser._get_field(data, 'age_detection_mode', str, path_for_log, template_vars, default=CleanupAgeDetectionMode.BY_NAME.value))
            yt_clusters = ConfigParser._get_list_of_type(data, 'yt_clusters', str, path_for_log, template_vars, required=False)
            subpaths = ConfigParser._get_list_of_type(data, 'subpaths', str, path_for_log, template_vars, required=False)
            raw_keep_newer_than = ConfigParser._get_field(data, 'keep_newer_than', str, path_for_log, template_vars, required=False)
            keep_newer_than = parse_time_interval(raw_keep_newer_than) if raw_keep_newer_than is not None else None
            keep_last = ConfigParser._get_field(data, 'keep_last', int, path_for_log, template_vars, default=1)
            return CleanupConfiguration(age_detection_mode, yt_clusters, subpaths, keep_newer_than, keep_last)

        def parse_cleanup_rules(data, path_for_log, template_vars) -> List[CleanupConfiguration]:
            if data is None:
                return []

            cleanup_rules = []
            for i, rule in enumerate(data):
                cleanup_rules.append(parse_cleanup(rule, f'{path_for_log}[{i}]', template_vars))
            return cleanup_rules

        def validate_cleanup_rules(cleanup_rules, tables, yt_clusters, path_for_log):
            if tables is None:
                for i, rule in enumerate(cleanup_rules):
                    if rule.subpaths is not None:
                        raise ValueError(f'{path_for_log}[{i}].subpaths: Can\'t use cleanup_rules with subpaths for single table dataset')
            else:
                allowed_subpaths = {x.subpath for x in tables}
                for i, rule in enumerate(cleanup_rules):
                    if rule.subpaths is None:
                        continue
                    unknown_subpaths = [x for x in rule.subpaths if x not in allowed_subpaths]
                    if len(unknown_subpaths) > 0:
                        raise ValueError(f'{path_for_log}[{i}].subpaths: Unknown subpaths: {unknown_subpaths}')

            for i, rule in enumerate(cleanup_rules):
                if rule.yt_clusters is None:
                    continue
                unknown_yt_clusters = [x for x in rule.yt_clusters if x not in yt_clusters]
                if len(unknown_yt_clusters) > 0:
                    raise ValueError(f'{path_for_log}[{i}].yt_clusters: Unknown yt_clusters: {unknown_yt_clusters}')

        def parse_dataset(dataset_name: str, dataset_config: Any, path_for_log: str, template_vars: dict) -> Dataset:
            owner = ConfigParser._get_field(dataset_config, 'owner', str, path_for_log, template_vars)
            yt_path = ConfigParser._get_field(dataset_config, 'yt_path', str, path_for_log, template_vars)
            yt_clusters = ConfigParser._get_list_of_type(dataset_config, 'yt_clusters', str, path_for_log, template_vars)
            dataset_type = DatasetType(ConfigParser._get_field(dataset_config, 'dataset_type', str, path_for_log, template_vars))
            versioning_scheme = VersioningScheme(ConfigParser._get_field(dataset_config, 'versioning_scheme', str, path_for_log, template_vars))
            raw_max_age = ConfigParser._get_field(dataset_config, 'max_age', str, path_for_log, template_vars, required=False)
            max_age = parse_time_interval(raw_max_age) if raw_max_age is not None else None
            table = None
            tables = None
            if dataset_type == DatasetType.SINGLE_TABLE:
                table = parse_table(ConfigParser._get_field(dataset_config, 'table', dict, path_for_log, template_vars), single=True, path_for_log=f'{path_for_log}.table', template_vars=template_vars)
            elif dataset_type == DatasetType.DIRECTORY:
                tables = [parse_table(x, single=False, path_for_log=f'{path_for_log}.tables[{i}]', template_vars=template_vars)
                          for i, x in enumerate(ConfigParser._get_field(dataset_config, 'tables', list, path_for_log, template_vars))]
            else:
                raise Exception(f'{path_for_log}: Unknown DatasetType: {dataset_type}')

            build = parse_build(ConfigParser._get_field(dataset_config, 'build', dict, path_for_log, template_vars, required=False), f'{path_for_log}.build', template_vars=template_vars)

            cleanup_rules = parse_cleanup_rules(ConfigParser._get_list_of_type(dataset_config, 'cleanup_rules', dict, path_for_log, template_vars, required=False), f'{path_for_log}.cleanup_rules',
                                                template_vars=template_vars)
            validate_cleanup_rules(cleanup_rules, tables, yt_clusters, f'{path_for_log}.cleanup_rules')
            return Dataset(dataset_name, owner, yt_path, yt_clusters, versioning_scheme, max_age, dataset_type, table, tables, build, template_vars, cleanup_rules)

        def iterate_vars_values(vars_dict_items):
            if len(vars_dict_items) == 0:
                yield {}
            else:
                key, values = vars_dict_items[0]
                for d in iterate_vars_values(vars_dict_items[1:]):
                    for value in values:
                        yield dict(d, **{key: value})

        def get_vars_in_string(value):
            env = Environment()
            vars_in_name = meta.find_undeclared_variables(env.parse(value))
            unknown_vars_in_name = [x for x in vars_in_name if x not in vars]
            if len(unknown_vars_in_name) > 0:
                raise Exception(f'Unknown vars in name {value} ({unknown_vars_in_name})')
            non_list_vars_in_name = [x for x in vars_in_name if not isinstance(vars[x], list)]
            if len(non_list_vars_in_name) > 0:
                raise Exception(f'Non-list vars in name {value} ({non_list_vars_in_name})')
            return vars_in_name

        datasets = {}
        for config_data in config_datas:
            datasets_config = yaml.safe_load(config_data)
            vars = self._replace_env_overrides(datasets_config.pop('_vars', {}))
            for dataset_name, dataset_config in datasets_config.items():
                if self.env is not None:
                    dataset_config = self._replace_env_overrides(dataset_config)
                vars_in_string = [(k, vars[k]) for k in get_vars_in_string(dataset_name)]
                for vars_values in iterate_vars_values(vars_in_string):
                    true_dataset_name = Template(dataset_name, undefined=StrictUndefined).render(vars_values)
                    LOG.debug(f'Loading config of "{true_dataset_name}"...')
                    curr_vars = dict(**vars)
                    for k, v in vars_values.items():
                        curr_vars[k] = v
                    datasets[true_dataset_name] = parse_dataset(true_dataset_name, dataset_config, f'datasets.{true_dataset_name}', curr_vars)
        return datasets

    def _parse_type_schema(self, data, path_for_log, template_vars) -> BaseTypeSchema:
        if type(data) == str:
            optional_match = re.fullmatch('Optional<(.*)>', data)
            if optional_match is not None:
                inner_type = copy.deepcopy(self._parse_type_schema(optional_match.group(1), path_for_log, template_vars))
                inner_type.nullable = True
                return inner_type

            list_match = re.fullmatch('List<(.*)>', data)
            if list_match is not None:
                return ListTypeSchema(False, None, self._parse_type_schema(list_match.group(1), path_for_log, template_vars))

            dict_match = re.fullmatch('Dict<([^,]+),\\s*([^,]+)>', data)
            if dict_match is not None:
                return DictTypeSchema(False, None, self._parse_type_schema(dict_match.group(1), path_for_log, template_vars), self._parse_type_schema(dict_match.group(2), path_for_log, template_vars))

            if data not in self.types:
                raise ValueError(f'{path_for_log}: Unknown type "{data}"')
            return self.types[data]

        if type(data) != dict:
            raise ValueError(f'Expected str or dict in _parse_type_schema, got {type(data)}')

        def parse_primitive_type_schema(c, template_vars):
            return PrimitiveTypeSchema(
                ConfigParser._get_field(c, 'nullable', bool, path_for_log, template_vars, default=False),
                ConfigParser._get_field(c, 'description', str, path_for_log, template_vars, required=False),
                ConfigParser._get_field(c, 'primitive_type', str, path_for_log, template_vars)
            )

        def parse_dict_type_schema(c, template_vars):
            keys = ConfigParser._get_field(c, 'keys', [dict, str], template_vars, path_for_log, required=False)
            values = ConfigParser._get_field(c, 'values', [dict, str], template_vars, path_for_log, required=False)
            return DictTypeSchema(
                ConfigParser._get_field(c, 'nullable', bool, path_for_log, template_vars, default=False),
                ConfigParser._get_field(c, 'description', str, path_for_log, template_vars, required=False),
                self._parse_type_schema(keys, f'{path_for_log}.keys', template_vars) if keys is not None else None,
                self._parse_type_schema(values, f'{path_for_log}.elements', template_vars) if values is not None else None
            )

        def parse_struct_type_schema(c, template_vars):
            fields = {}
            for k, v in ConfigParser._get_field(c, 'fields', dict, path_for_log, template_vars).items():
                subpath_for_log = f'{path_for_log}.fields["{k}"]'
                k = self._expect_types(k, [str], f'{subpath_for_log}{{key}}')
                v = self._expect_types(v, [dict, str], subpath_for_log)
                fields[k] = self._parse_type_schema(v, subpath_for_log, template_vars)
            return StructTypeSchema(
                ConfigParser._get_field(c, 'nullable', bool, path_for_log, template_vars, default=False),
                ConfigParser._get_field(c, 'description', str, path_for_log, template_vars, required=False),
                ConfigParser._get_field(c, 'allow_unknown_fields', bool, path_for_log, template_vars, default=False),
                fields
            )

        schema_type = {
            'int32': parse_primitive_type_schema,
            'uint32': parse_primitive_type_schema,
            'int64': parse_primitive_type_schema,
            'uint64': parse_primitive_type_schema,
            'string': lambda c, _: StringTypeSchema(
                ConfigParser._get_field(c, 'nullable', bool, path_for_log, template_vars, default=False),
                ConfigParser._get_field(c, 'description', str, path_for_log, template_vars, required=False),
                ConfigParser._get_list_of_type(c, 'values', str, path_for_log, template_vars, required=False),
            ),
            'boolean': parse_primitive_type_schema,
            'double': parse_primitive_type_schema,
            'list': lambda c, _: ListTypeSchema(
                ConfigParser._get_field(c, 'nullable', bool, path_for_log, template_vars, default=False),
                ConfigParser._get_field(c, 'description', str, path_for_log, template_vars, required=False),
                self._parse_type_schema(ConfigParser._get_field(c, 'elements', [dict, str], template_vars, path_for_log), f'{path_for_log}.elements', template_vars)
            ),
            'dict': parse_dict_type_schema,
            'struct': parse_struct_type_schema,
            'any': lambda c, _: AnyTypeSchema(
                ConfigParser._get_field(c, 'nullable', bool, path_for_log, template_vars, default=False),
                ConfigParser._get_field(c, 'description', str, path_for_log, template_vars, required=False),
            ),
        }[ConfigParser._get_field(data, 'primitive_type', str, path_for_log, template_vars)]
        return schema_type(data, template_vars)

    @staticmethod
    def _expect_types(value: Any, expected_types: List[type], path_for_log: str):
        if type(value) not in expected_types:
            raise ValueError(f'{path_for_log}: Expected one of {expected_types}, but got {type(value)}')
        return value

    @staticmethod
    def _get_list_of_type(data, list_name, element_type, path_for_log, template_vars, default=None, required=True):
        l = ConfigParser._get_field(data, list_name, list, path_for_log, template_vars, default=default, required=required)
        if l is None and not required:
            return None
        return [ConfigParser._expect_types(x, [element_type], f'{path_for_log}.{list_name}[{i}]') for i, x in enumerate(l)]

    @staticmethod
    def _get_field(data, field_name, field_type, path_for_log, template_vars, default=None, required=True):
        if type(field_type) != type and (type(field_type) != list or any([type(x) != type for x in field_type])):
            raise ValueError('Expected type or List[type] in field_type argument')
        if type(field_type) == type:
            field_type = [field_type]
        if field_name not in data:
            if default is not None:
                if type(default) not in field_type:
                    raise ValueError(f'{path_for_log}.{field_name}: expected types: {field_type}, but default value has type {type(default)}')
                return default
            if required:
                raise ValueError(f'{path_for_log}: Required field "{field_name}" not found')
            else:
                return None
        value = data[field_name]
        if field_type == [str]:
            value = Template(value, undefined=StrictUndefined).render(**template_vars)
        return ConfigParser._expect_types(value, field_type, f'{path_for_log}.{field_name}')
