"""Simple in-database cache."""

import hashlib
import inspect
import json
import logging
from collections import OrderedDict
from functools import wraps
from threading import Lock

import cachetools
import mongoengine
import six
from mongoengine import Q
from mongoengine import StringField, LongField, DynamicField

from sepelib.mongo.util import register_model
from walle.models import Document, timestamp

log = logging.getLogger(__name__)

MONGO_MAX_KEY_LEN = 512  # actually it's 1024, but there is also some structural BSON overhead, so let's make it a half
SHA_DIGEST_LEN = 64


class DbCacheTTLExpired(Exception):
    pass


class DbCacheNotFound(Exception):
    pass


@register_model
class _Cache(Document):
    """Caches an arbitrary value."""

    id = StringField(primary_key=True, required=True, help_text="Value ID")
    time = LongField(required=True, help_text="Value actualization time")
    expire = LongField(required=True, help_text="Value expiration time")
    value = DynamicField(help_text="An arbitrary value")

    meta = {"collection": "cache"}


def get_value(cache_id):
    """Returns a cached value."""

    try:
        cache = _Cache.objects(id=cache_id).get()
    except mongoengine.DoesNotExist:
        return None, None

    return cache.time, cache.value


def set_value(cache_id, value, ttl):
    """Sets a cached value."""
    current_time = timestamp()
    _Cache.objects(id=cache_id).update(
        set__time=current_time, set__value=value, set__expire=current_time + ttl, multi=False, upsert=True
    )


def get_cache_value(cache_id, ttl):
    value_time, value = get_value(cache_id)
    if value_time is None:
        raise DbCacheNotFound("cache key '{}' not found".format(cache_id))
    if timestamp() - value_time >= ttl:
        raise DbCacheTTLExpired("cache key '{}' ttl expired".format(cache_id))
    return value


def _cached(key_function, value_ttl, set_error, get_error):
    """A decorator that caches returned function value with the specified timeout.

    Attention: The cached value is returned on error regardless of the time it was cached.
    Note: The function mustn't be recursive.
    """

    def decorator(func):
        lock = Lock()

        @wraps(func)
        def decorated(*args, **kwargs):
            cache_key = key_function(func, args, kwargs)
            value_time, value = get_value(cache_key)

            if value_time is not None and timestamp() - value_time < value_ttl:
                return value

            with lock:
                value_time, value = get_value(cache_key)
                if value_time is not None and timestamp() - value_time < value_ttl:
                    return value

                last_error = get_error(cache_key)
                if last_error:
                    if value_time is None:
                        raise last_error
                    else:
                        return value

                try:
                    value = func(*args, **kwargs)
                except Exception as e:
                    set_error(cache_key, e)
                    log.exception("Failed to update '%s' DB cache:", cache_key)

                    if value_time is None:
                        raise
                else:
                    set_value(cache_key, value, value_ttl)

            return value

        return decorated

    return decorator


def gc_cache(max_ttl):
    _Cache.objects(Q(time__lt=timestamp() - max_ttl) | Q(expire__lt=timestamp())).delete()


def cached(cache_id, value_ttl, error_ttl=1):
    error_cache = cachetools.TTLCache(maxsize=1, ttl=error_ttl, timer=timestamp)
    set_error = error_cache.__setitem__
    get_error = error_cache.get

    return _cached(lambda *args: cache_id, value_ttl, set_error, get_error)


def _shorten_key_with_hash(key):
    """Shorten key to MAX_LEN by taking hash of it and inserting it instead of oversized part of key"""
    digest = hashlib.sha256(six.ensure_binary(key, "utf-8"))
    key = key[: MONGO_MAX_KEY_LEN - SHA_DIGEST_LEN - 3] + "..." + digest.hexdigest()
    return key


def cache_key_from_params(cache_id, params):
    ordered = OrderedDict(sorted(params.items()))
    key = cache_id + ":" + json.dumps(ordered)
    if len(key) >= MONGO_MAX_KEY_LEN:
        key = _shorten_key_with_hash(key)
    return key


def memoized(cache_id, value_ttl, error_ttl=1, error_cache_size=100):
    """Use function args to create cache key and store result for every unique set of args.

    Note: dicts and sets can render into different cache keys because of key/value order,
    so don't use this helper on functions with args like that.
    """

    def make_cache_key(func, args, kwargs):
        params = inspect.getcallargs(func, *args, **kwargs)
        key = cache_key_from_params(cache_id, params)
        return key

    error_cache = cachetools.TTLCache(maxsize=error_cache_size, ttl=error_ttl, timer=timestamp)
    set_error = error_cache.__setitem__
    get_error = error_cache.get

    return _cached(make_cache_key, value_ttl, set_error, get_error)
