# -*- coding: utf-8 -*-
from __future__ import annotations

from argparse import Namespace
from collections import OrderedDict
from copy import copy
from datetime import datetime, timedelta, timezone
from enum import Enum
from inspect import isclass
from typing import Any, Callable, Dict, List, Optional, Set, Type, TypeVar, get_args, get_origin, get_type_hints
import dataclasses
import json
import os

from travel.library.python import yandex_vault


D = TypeVar('D', bound='dataclasses.dataclass')
T = TypeVar('T')


DC_YT_TYPE_MAP = {
    'str': 'string',
    'int': 'int64',
    'float': 'double',
    'uint': 'uint64',
    'bool': 'boolean',
    'dict': 'any',
    'list': 'any',
}

DC_YQL_TYPE_MAP = {
    'str': 'String',
    'int': 'Int64',
    'uint': 'UInt64',
    'float': 'Double',
    'bool': 'Bool',
    'dict': 'Any',
    'list': 'Any',
}


def get_field_type(field_type: Type) -> str:
    args = get_args(field_type)
    # check for optional
    if args is not None and len(args) == 2 and issubclass(args[-1], type(None)):
        field_type = args[0]
    origin = get_origin(field_type)
    if origin:
        field_type = origin
    if isclass(field_type) and issubclass(field_type, Enum):
        return 'str'
    return field_type.__name__


def get_dc_key_fields(dc: dataclasses.dataclass) -> List[str]:
    fields = list()
    for field in dataclasses.fields(dc):
        is_key = field.metadata.get('is_key')
        if is_key:
            fields.append(field.name)
    return fields


def get_dc_yt_schema(dc: dataclasses.dataclass) -> Dict[str, str]:
    fields = OrderedDict()
    for field_name, field_type in get_type_hints(dc).items():
        if dataclasses.is_dataclass(field_type):
            field_type = 'any'
        else:
            field_type = DC_YT_TYPE_MAP[get_field_type(field_type)]
        fields[field_name] = field_type
    return fields


def get_dc_yql_schema(dc: dataclasses.dataclass) -> Dict[str, str]:
    fields = OrderedDict()
    for field_name, field_type in get_type_hints(dc).items():
        if dataclasses.is_dataclass(field_type):
            field_type = 'Any'
        else:
            field_type = DC_YQL_TYPE_MAP[get_field_type(field_type)]
        fields[field_name] = field_type
    return fields


def get_dc_defaults(dc: dataclasses.dataclass) -> Dict[str, Any]:
    return dc_to_dict(dc(), dict_factory=OrderedDict)


def dc_from_dict(
        dc_cls: Type[D],
        d: Dict[str, Any],
        ignore_unknown: bool = False,
) -> D:
    if not dataclasses.is_dataclass(dc_cls):
        if isclass(dc_cls) and issubclass(dc_cls, Enum):
            # noinspection PyArgumentList
            return dc_cls(d)
        return d

    dc_dict = dict()

    for f in dataclasses.fields(dc_cls):
        if f.name not in d:
            continue

        value = d[f.name]
        origin = get_origin(f.type)
        converter = f.metadata.get('converter')

        if converter:
            value = converter(value)

        # check for optional
        is_optional = False
        field_type = f.type
        args = get_args(field_type)

        if args is not None and len(args) == 2 and issubclass(args[-1], type(None)):
            field_type = args[0]
            is_optional = True

        if is_optional:
            if value is not None:
                origin_sub = get_origin(field_type)
                if origin_sub and isclass(origin_sub):
                    if issubclass(origin_sub, (list, tuple, set)):
                        items = (dc_from_dict(field_type.__args__[0], item, ignore_unknown) for item in value)
                        value = origin_sub(items)
                    elif issubclass(origin_sub, dict):
                        value = {k: dc_from_dict(field_type.__args__[1], v, ignore_unknown) for k, v in value.items()}
            else:
                pass
        elif isclass(field_type) and issubclass(field_type, Enum):
            # noinspection PyArgumentList
            value = field_type(value)
        elif origin and isclass(origin):
            if issubclass(origin, (list, tuple, set)):
                items = (dc_from_dict(field_type.__args__[0], item, ignore_unknown) for item in value)
                value = origin(items)
            elif issubclass(origin, dict):
                value = {k: dc_from_dict(field_type.__args__[1], v, ignore_unknown) for k, v in value.items()}
        elif dataclasses.is_dataclass(field_type):
            value = dc_from_dict(field_type, value, ignore_unknown)

        dc_dict[f.name] = value

    if not ignore_unknown:
        unknown_keys = set(d.keys()) - set(dc_dict.keys())
        if unknown_keys:
            raise TypeError(f'Dict for {dc_cls} has unknown keys: {unknown_keys}')

    return dc_cls(**dc_dict)


def dc_to_dict(dc: dataclasses.dataclass, dict_factory: Optional[Type[dict]] = None) -> Dict[str, Any]:
    if dict_factory is None:
        dict_factory = dict
    result = dict_factory()
    for field in dataclasses.fields(dc):
        field_name = field.name
        field_value = _get_field_value(getattr(dc, field_name))
        result[field_name] = field_value
    return result


def _get_field_value(field_value: Any) -> Any:
    if dataclasses.is_dataclass(field_value):
        return dc_to_dict(field_value)
    elif isinstance(field_value, list):
        field_value = [_get_field_value(item) for item in field_value]
    elif isinstance(field_value, Enum):
        field_value = field_value.value
    return field_value


def get_modified_dc_node(node_name: Optional[str], node_value: Any, modifier: Callable[[str], str]) -> Any:
    if dataclasses.is_dataclass(node_value):
        result = copy(node_value)
        for field in dataclasses.fields(node_value):
            field_name = field.name
            field_value = getattr(node_value, field_name)
            setattr(result, field_name, get_modified_dc_node(field_name, field_value, modifier))
        return result
    if isinstance(node_value, (list, set, tuple)):
        result = (get_modified_dc_node(None, item, modifier) for item in node_value)
        return type(node_value)(result)
    if isinstance(node_value, dict):
        return {key: get_modified_dc_node(key, value, modifier) for key, value in node_value.items()}
    if isinstance(node_value, str) and is_secret(node_name, node_value):
        return modifier(node_value)
    return node_value


def hide_secrets_dc(options: dataclasses.dataclass) -> dataclasses.dataclass:
    return get_modified_dc_node(None, options, lambda x: f'******{x[-4:]}')


def hide_secrets_ns(options: Namespace) -> Namespace:
    result = dict()
    for key in options.__dict__:
        value = options.__dict__[key]
        if value is not None and is_secret(key, value):
            value = f'******{value[-4:]}'
        result[key] = value
    return Namespace(**result)


def is_secret(k: str, v: Any) -> bool:
    stopwords = ['auth', 'password', 'secret', 'token']

    if isinstance(v, str) and v.startswith('AQAD-'):
        return True
    for sw in stopwords:
        if sw in k:
            return True
    return False


def resolve_secrets_dc(options: dataclasses.dataclass, vault_token: str) -> dataclasses.dataclass:
    if vault_token:
        client = yandex_vault.instances.Production(authorization=vault_token)
        return get_modified_dc_node(None, options, lambda x: yandex_vault.get_filled_arg(x, client))
    env_args = os.environ.get('ARG_REPLACES')
    if env_args is None:
        return options
    replaces = json.loads(env_args)
    return get_modified_dc_node(None, options, lambda x: replaces.get(x, x))


def str_from_set(items: Set[str], converter: Optional[Callable[[Any], str]] = None) -> str:
    converter = converter or (lambda x: x)
    return ','.join(sorted(converter(item) for item in items))


def str_to_set(s: str, converter: Optional[Callable[[str], T]] = None) -> Set[T]:
    if not s:
        return set()
    converter = converter or (lambda x: x)
    return set((converter(item) for item in s.split(',')))


def ts_to_str_msk_tz(ts: int) -> str:
    msk_tz = timezone(timedelta(hours=3))
    return datetime.fromtimestamp(ts, tz=msk_tz).strftime('%Y-%m-%d %H:%M:%S')
