import yenv
import logging
from typing import Optional
from django.core.cache import caches

from redis_cache import RedisCache
from django.core.cache.backends.locmem import LocMemCache
from smarttv.droideka.protos.profile.profile_pb2 import TUserProfile, TSmotreshkaProfile
from smarttv.droideka.unistat.metrics import CacheSignal, CacheType, RedisCachePlace
from smarttv.droideka import unistat
from google.protobuf.message import DecodeError


logger = logging.getLogger(__name__)


def increment_counter(cache_signal_type: CacheSignal.Type, cache_place: str, cache_type: CacheType):
    counter = unistat.manager.get_counter(CacheSignal.get_cache_signal(
        cache_signal_type,
        cache_place,
        cache_type
    ))
    if counter:
        counter.increment()


class ProfileCache:
    def __init__(self, unversioned_cache: RedisCache, profile_class, tag: str):
        self.unversioned_cache = unversioned_cache
        self.profile_class = profile_class
        self.tag = tag

    def get_cache_key(self, raw_key):
        return f'{yenv.type}:{self.tag}_"profile":{raw_key}'

    def get(self, raw_key: str) -> Optional[TUserProfile]:
        key = self.get_cache_key(raw_key)

        raw_profile = self.unversioned_cache.get(key)
        if not raw_profile:
            logger.debug('No %s profile available', self.tag)
            increment_counter(CacheSignal.Type.MISS, RedisCachePlace.KP_PROFILE_CACHE.value, CacheType.REDIS)
            return None
        if not isinstance(raw_profile, str):
            logger.debug('Raw profile has wrong type')
            increment_counter(CacheSignal.Type.MISS, RedisCachePlace.KP_PROFILE_CACHE.value, CacheType.REDIS)
            return None
        increment_counter(CacheSignal.Type.HIT, RedisCachePlace.KP_PROFILE_CACHE.value, CacheType.REDIS)
        logger.info('Retrieved %s profile:%s', self.tag, raw_profile)
        result = self.profile_class()
        try:
            result.ParseFromString(raw_profile.encode())
        except (TypeError, DecodeError):
            logger.error('Error parsing raw profile for key: %s', key)
            return None
        return result

    def set(self, raw_key, profile):
        key = self.get_cache_key(raw_key)

        raw_profile = profile.SerializeToString().decode()
        logger.info('Set %s profile: %s', self.tag, raw_profile)
        try:
            self.unversioned_cache.set(key, raw_profile, timeout=None)
        except TypeError:
            logger.error('Error saving profile', exc_info=True)

    def set_many(self, raw_data: dict):
        """
        data - is a dict, which contains mapping raw raw_key -> profile proto
        """
        serializer_data = {self.get_cache_key(raw_key): raw_profile_proto.SerializeToString().decode() for raw_key, raw_profile_proto in raw_data.items()}
        logger.info('Save %s profile data: %s', self.tag, serializer_data)
        try:
            self.unversioned_cache.set_many(serializer_data, timeout=None)
        except TypeError:
            logger.error('Error saving profile', exc_info=True)


class CacheHitRatioEvaluator:
    def __init__(self, cache_place: str, cache_type: CacheType):
        self.cache_place = cache_place
        self.cache_type = cache_type

    def hit(self, *_, **__):
        increment_counter(CacheSignal.Type.HIT, self.cache_place, self.cache_type)

    def miss(self, *_, **__):
        increment_counter(CacheSignal.Type.MISS, self.cache_place, self.cache_type)

    def get_cache_memoize_callables(self) -> dict:
        return {
            'hit_callable': self.hit,
            'miss_callable': self.miss,
        }


def get_cache_memoize_callables(cache_place: str, cache_type: CacheType) -> dict:
    return CacheHitRatioEvaluator(cache_place, cache_type).get_cache_memoize_callables()


default: RedisCache = caches['default']
unversioned: RedisCache = caches['unversioned']
local: LocMemCache = caches['local']
user_profile = ProfileCache(unversioned, TUserProfile, 'user')
smotreshka_profile = ProfileCache(unversioned, TSmotreshkaProfile, 'smotreshka')


budapest_device_ids = []
