import asyncio
import dataclasses
import datetime
import os
import re
import socket
from collections import defaultdict, deque
from contextlib import contextmanager
from decimal import Decimal
from enum import Enum, unique
from functools import wraps
from inspect import getmodule
from typing import (
    Any, AsyncIterable, AsyncIterator, Callable, Collection, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union
)
from uuid import UUID

import pytz


class SentinelMissing:
    pass


MISSING = SentinelMissing()
_T = TypeVar('_T')

MaybeMissing = Union[SentinelMissing, _T]

__all__ = (
    'copy_context',
    'enum_values',
    'frommsktimestamp',
    'fromtimestamp',
    'get_dc',
    'get_hostname',
    'get_subclasses',
    'json_dict_factory',
    'json_value',
    'sort_hosts_by_geo',
    'str_to_underscore',
    'utcnow',
)

_camelcase_regex = re.compile('([A-Z])')
_timezone_msk = pytz.timezone('Europe/Moscow')

# PDDMAIL-271 https://gist.github.yandex-team.ru/mkznts/8f25c19950c0e2a959694e4e66eedec7
PREFERRED_DCS: Dict[str, Tuple[str]] = defaultdict(tuple, **{  # type: ignore
    'iva': ('iva', 'myt', 'vla', 'sas', 'man'),
    'myt': ('myt', 'iva', 'vla', 'sas', 'man'),
    'sas': ('sas', 'vla', 'iva', 'myt', 'man'),
    'vla': ('vla', 'iva', 'myt', 'sas', 'man'),
    'man': ('man', 'myt', 'iva', 'vla', 'sas'),
})


def utcnow() -> datetime.datetime:
    return datetime.datetime.now(tz=datetime.timezone.utc)


def fromtimestamp(ts: Union[float, int], tzname: datetime.tzinfo = datetime.timezone.utc) -> datetime.datetime:
    return datetime.datetime.fromtimestamp(ts, tz=tzname)


def frommsktimestamp(ts: Union[float, int]) -> datetime.datetime:
    return fromtimestamp(ts, _timezone_msk)


def get_hostname() -> str:
    return os.environ.get(
        'DEPLOY_POD_PERSISTENT_FQDN',
        os.environ.get('QLOUD_DISCOVERY_INSTANCE', ''),
    ) or socket.getfqdn()


def get_dc(default: str = 'local') -> str:
    dc = os.environ.get('DEPLOY_NODE_DC', os.environ.get('QLOUD_DATACENTER', default))
    return str(dc).lower()


def sort_hosts_by_geo(hosts: Iterable[str], dc: Optional[str] = None) -> List[str]:
    dc = dc or get_dc()
    priority: Tuple[str] = PREFERRED_DCS[dc]

    def _sort_key(item: str) -> int:
        for i, target_dc in enumerate(priority):
            if item.startswith(target_dc):
                return i
        return len(priority)

    return sorted(hosts, key=_sort_key)


def json_value(v: Any, custom_encoder: Optional[Callable] = None) -> Any:  # noqa: C901
    """
    Превращаем объекты в JSON-совместимое представление (рекурсивно).
    На выходе должен получиться объект, который можно безопасно скормить json.dumps.
    :param v: конвертируемое значение
    :param custom_encoder: функция, которая может переопределять поведение конвертации
    :return: JSON-совместимый объект

    N.B.: При обработке экземляров датаклассов json_value не включит в итоговый словарь
    скрытые поля (у которых repr == False), т.к. предполагается, что эти поля могут
    содержать секреты, не подлежащие логированию/сериализации. При необходимости
    сохранения объекта в базу или иное хранилище с последующим восстановлением из него,
    json_value может не дать нужного результата при наличии таких скрытых полей.

    N.B. НИКОГДА не используй json.JSONEncoder. Если КЛЮЧ словаря окажется Enum, то упадёт.
    JSONEncoder.default не вызывается напротив ключей. Только напротив значений
    """
    if custom_encoder:
        result, match = custom_encoder(v)
        if match:
            return result

    if isinstance(v, (str, int, float)):
        return v

    if v is None:
        return v

    if isinstance(v, dict):
        data = {}
        for key in v.keys():
            if isinstance(key, str):
                new_key = key
            else:
                new_key = json_value(key, custom_encoder=custom_encoder)
            data[new_key] = json_value(v[key], custom_encoder=custom_encoder)

        return data

    if isinstance(v, list):
        return [json_value(i, custom_encoder=custom_encoder) for i in v]

    if isinstance(v, tuple):
        return tuple(json_value(i, custom_encoder=custom_encoder) for i in v)

    if isinstance(v, set):
        return [json_value(i, custom_encoder=custom_encoder) for i in v]

    if isinstance(v, datetime.datetime):
        return v.isoformat(sep=' ')

    if isinstance(v, datetime.date):
        return v.isoformat()

    if isinstance(v, Enum):
        return json_value(v.value, custom_encoder=custom_encoder)

    if isinstance(v, (Decimal, UUID)):
        return str(v)

    if dataclasses.is_dataclass(type(v)):
        # only dataclass instances pass this check
        return {
            field.name: json_value(getattr(v, field.name), custom_encoder=custom_encoder)
            for field in dataclasses.fields(type(v))
            if field.repr  # skip the fields hidden from __repr__()
        }

    return repr(v)


def json_dict_factory(dict_arg: Any) -> Dict[Any, Any]:
    return {
        key: json_value(value)
        for key, value in dict(dict_arg).items()
    }


def enum_name(enum_type: Optional[Enum]) -> Optional[str]:
    return enum_type.name if isinstance(enum_type, Enum) else None


def enum_value(enum_type: Optional[Enum], default: Any = None) -> Optional[str]:
    return enum_type.value if isinstance(enum_type, Enum) else default


def enum_values(enum_type: Type[Enum]) -> List[str]:
    return [v.value for v in enum_type]


def copy_context(async_func: Callable) -> Callable:
    @wraps(async_func)
    async def _inner(*args: Any, **kwargs: Any) -> Any:
        loop = asyncio.get_event_loop()
        task = loop.create_task(async_func(*args, **kwargs))
        return await task

    return _inner


def split_list(lst: List, size: Optional[int] = None) -> List[List]:
    """Split given list into list of lists with given maximum size."""
    if not size or len(lst) == 0:
        return [lst]
    return [lst[i: i + size] for i in range(0, len(lst), size)]


async def alist(iter_: AsyncIterable[_T]) -> List[_T]:
    return [_ async for _ in iter_]


async def anext(iterator: AsyncIterator[_T], default: Optional[Union[_T, object]] = MISSING) -> Optional[_T]:
    try:
        return await iterator.__anext__()
    except StopAsyncIteration as e:
        if default is not MISSING:
            return default  # type: ignore
        raise e


@contextmanager
def temp_set(obj, attr, value):
    before = getattr(obj, attr, MISSING)
    setattr(obj, attr, value)
    yield
    if before is MISSING:
        delattr(obj, attr)
    else:
        setattr(obj, attr, before)


@unique
class UserType(Enum):
    CONNECT = 'connect'
    STAFF = 'staff'
    PORTAL = 'portal'


def get_user_type(uid: int) -> UserType:
    if 1130000000000000 <= uid < 1140000000000000:  # connect range of uids
        return UserType.CONNECT
    if 1100000000000000 <= uid < 1130000000000000:  # staff range of uids
        return UserType.STAFF
    return UserType.PORTAL


def without_none(dict_: dict) -> dict:
    return {
        key: value
        for key, value in dict_.items()
        if value is not None
    }


def without_missing(dict_: dict) -> dict:
    return {
        key: value
        for key, value in dict_.items()
        if not isinstance(value, SentinelMissing)
    }


def str_to_underscore(s: str) -> str:
    return _camelcase_regex.sub(lambda match: '_' + match.group(1).lower(), s)


def get_subclasses(
    cls: type,
    exclude_module_pattern: Optional[Union[str, re.Pattern[str]]] = None
) -> Collection[type]:
    """
    Возвращает все дочерние классы.
    Не только непосредственные дочерние классы. Не только листовые дочерние классы.
    А вообще все дочерние классы - весь граф.

    Результат не включает сам класс.

    * exclude_module_pattern - если нужно исключить какие-то классы.
        Можно передать строку - тогда будет проверяться вхождение строки в название модуля.
        Можно передать re.Pattern - тогда будет проверяться вхождение паттерна в строку (в любой позиции)
    """
    all_subclasses = []
    q = deque(cls.__subclasses__())
    if exclude_module_pattern is not None:
        if not isinstance(exclude_module_pattern, re.Pattern):
            exclude_module_pattern = re.compile(re.escape(exclude_module_pattern))

    while len(q) > 0:
        cls = q.pop()
        if exclude_module_pattern is not None:
            module_name = str(getmodule(cls))
            if exclude_module_pattern.search(module_name):
                continue

        all_subclasses.append(cls)
        subclasses = cls.__subclasses__()
        q.extend(subclasses)

    return [cls for cls in all_subclasses]
