# coding: utf-8
import time
import ujson

import cachetools
import inject
from kazoo.exceptions import NoNodeError

from awacs.lib import zookeeper_client, zk_storage
from awacs.model.codecs import StaffCacheEntryCodec
from infra.awacs.proto import internals_pb2


class StaffCache(object):
    MEM_CACHE_SIZE = 1000
    MEM_CACHE_TTL = 60 * 60 * 24  # a day
    CACHE_TTL = 10 * 60  # 10 minutes

    coord = inject.attr(zookeeper_client.IZookeeperClient)

    def __init__(self, zk_path, mem_cache_size=MEM_CACHE_SIZE, mem_cache_ttl=MEM_CACHE_TTL, cache_ttl=CACHE_TTL):
        self._mem_cache_size = mem_cache_size
        self._mem_cache_ttl = mem_cache_ttl
        self._cache_ttl = cache_ttl
        self._zk_path = zk_path

        self.mem_caches = {}
        self.zk_storage_clients = {}

    def _get_mem_cache(self, bucket):
        if bucket not in self.mem_caches:
            self.mem_caches[bucket] = cachetools.TTLCache(maxsize=self._mem_cache_size, ttl=self._mem_cache_ttl)
        return self.mem_caches[bucket]

    def _get_zk_storage_client(self, bucket):
        if bucket not in self.zk_storage_clients:
            zk_bucket_path = zk_storage.path_with_prefix(self._zk_path, bucket)
            self.zk_storage_clients[bucket] = zk_storage.ZkStorageClient(self.coord,
                                                                         prefix=zk_bucket_path,
                                                                         codec=StaffCacheEntryCodec)
        return self.zk_storage_clients[bucket]

    def get(self, bucket, key):
        content, expired = None, True

        mem_cache = self._get_mem_cache(bucket)
        if key in mem_cache:
            content, updated_at = mem_cache[key]
            expired = updated_at + self._mem_cache_ttl < time.time()
        else:
            zk_cache = self._get_zk_storage_client(bucket)
            zk_cache_entry = zk_cache.get(key)
            if zk_cache_entry:
                content = ujson.loads(zk_cache_entry.content.decode('utf8'))
                updated_at = zk_cache_entry.mtime.ToSeconds()
                expired = updated_at + self._cache_ttl < time.time()
                if not expired:
                    mem_cache[key] = (content, updated_at)

        return content, expired

    def set(self, bucket, key, value):
        ts = int(time.time())
        mem_cache = self._get_mem_cache(bucket)
        mem_cache[key] = (value, ts)

        zk_cache = self._get_zk_storage_client(bucket)
        e = internals_pb2.CacheEntry(content=ujson.dumps(value).encode('utf8'))
        e.mtime.FromSeconds(ts)
        try:
            zk_cache.put(key, e)
        except NoNodeError:
            zk_cache.create(key, e)
