# -*- coding: utf-8 -*-
import heapq
import logging
import math
import time as os_time
from datetime import datetime, timedelta
from functools import wraps
from typing import Any, Optional, Callable, SupportsFloat

logger = logging.getLogger(__name__)


def _fn_path(fn):
    return fn.__module__, fn.__name__


all_cache = []


def memoize(keyfun=None, cache=None):
    if keyfun is None:
        keyfun = lambda: True

    if cache is None:
        cache = CacheSimpleDict()
    all_cache.append(cache)

    def _memoize(fun):
        @wraps(fun)
        def memoized(*args, **kwargs):
            key = keyfun(*args, **kwargs)
            try:
                value = cache[key]
            except KeyError:
                value = fun(*args, **kwargs)
                cache[key] = value
            return value

        memoized.original_func = fun
        memoized.reset = cache.reset
        memoized._cache = cache
        return memoized

    return _memoize


def reset_all_caches():
    for c in all_cache:
        c.reset()


class CacheInMemcache(object):
    def __init__(self, mc, ttl, prefix='', suffix=None, no_reset=False, serializer=None, deserializer=None):
        # type: (Any, SupportsFloat, Optional[str], Optional[str], Optional[bool], Optional[Callable[[Any],str]], Optional[Callable[[str],Any]])->None
        self._mc = mc
        self._ttl = int(math.ceil(ttl))
        self._prefix = prefix
        self._salt = suffix or self._gen_salt()
        self._no_reset = no_reset
        self.serializer = serializer
        self.deserializer = deserializer
        if self.serializer is None and self.deserializer is not None or self.serializer is not None and self.deserializer is None:
            raise ValueError('You should provide serializer and deserializer at the same time')

    def _gen_salt(self):
        return str(int(os_time.time() * 10 ** 6))

    def reset(self):
        if self._no_reset:
            return
        self._salt = self._gen_salt()

    def prepare_key(self, key):
        return '{}/{}/{}'.format(self._prefix, self._salt, key)

    def __setitem__(self, key, value):
        if self.serializer:
            value = self.serializer(value)
        self._mc.set(self.prepare_key(key), value, self._ttl)

    def __getitem__(self, key):
        val = self._mc.get(self.prepare_key(key))
        if val is None:
            raise KeyError
        if self.deserializer:
            try:
                val = self.deserializer(val)
            except Exception:
                logger.exception('Deserialization failed')
                raise KeyError  # in case we failed deserialization, pretend like there is no data
        return val


class CacheSimpleDict(dict):
    def reset(self):
        self.clear()


class CacheWithKeyTTL(dict):
    def __init__(self, key_ttl, maxsize=None, *args, **kwargs):
        super(CacheWithKeyTTL, self).__init__(*args, **kwargs)
        self._ttl = key_ttl
        self._maxsize = maxsize
        self.reset()

    def reset(self):
        self.clear()
        self._priority_heapq = []
        self._keys_expirations = {}

    def __setitem__(self, key, value):
        self.actualize()
        expiration = datetime.now() + timedelta(seconds=self._ttl)
        self._keys_expirations[key] = expiration
        heapq.heappush(self._priority_heapq, (expiration, key))
        try:
            return super(CacheWithKeyTTL, self).__setitem__(key, value)
        finally:
            self.ensure_maxsize()

    def __getitem__(self, key):
        self.actualize()
        return super(CacheWithKeyTTL, self).__getitem__(key)

    def actualize(self):
        now = datetime.now()
        while self._priority_heapq:
            expiration, key = self._priority_heapq[0]
            if expiration > now:
                break
            heapq.heappop(self._priority_heapq)
            self._clean_expirations_to(now, key)

    def ensure_maxsize(self):
        while self._maxsize and len(self) > self._maxsize:
            assert self._priority_heapq
            omni_expiration, key = heapq.heappop(self._priority_heapq)
            self._clean_expirations_to(omni_expiration, key)

    def _clean_expirations_to(self, moment, key):
        if key in self._keys_expirations:
            expiration = self._keys_expirations[key]
            if expiration <= moment:
                del self[key]
                del self._keys_expirations[key]

    def get_key_ttl(self, key):
        expiration = self._keys_expirations.get(key)
        if expiration is None:
            return None
        return expiration - datetime.now()


class _AbstractWarmGroup(object):
    """
    Можно оборачивать только функции без аргументов
    """

    def __init__(self, name=None):
        self._name = name
        self._functions = []
        self._warm_groups.append(self)

    def __call__(self, fn):
        self._functions.append(fn)
        return fn

    @classmethod
    def get_groups(cls):
        return sorted(cls._warm_groups, key=lambda g: g._name)

    def warm_up(self, logger=None):
        if logger:
            logger.info('Warm group %r:', self._name)
        for fn in sorted(self._functions, key=_fn_path):
            if logger:
                logger.info('\t%s.%s', *_fn_path(fn))
            fn()


class SimpleWarmGroup(_AbstractWarmGroup):
    _warm_groups = []
