import time
import threading

import msgpack

from sandbox import common
import sandbox.common.types.misc as ctm
import sandbox.common.types.task as ctt

from sandbox.serviceapi.mules import aggregator
from sandbox.serviceapi import constants as sa_consts

import sandbox.serviceq.config as qconfig

from sandbox.yasandbox import context


class AutoUpdatable(object):
    def __init__(self, ttl, update_func=None, update_args=(), default=None):
        self._ttl = ttl
        self._update_func = update_func or (lambda: None)
        self._update_args = update_args
        self._last_updated = 0
        self.value = default

    @common.utils.singleton_property
    def _lock(self):
        return threading.Lock()

    def set_func(self, update_func, update_args=()):
        self._update_func = update_func
        self._update_args = update_args

    def update(self, force=False):
        if force or int(time.time() - self._last_updated) > self._ttl:
            with self._lock:
                now = time.time()
                if force or int(now - self._last_updated) > self._ttl:
                    if self._update_func(*self._update_args):
                        self._last_updated = now


class Aggregator(aggregator.BaseAggregator):
    _banned_list = None
    _web_banned_list = None
    _current_user = (None, None)
    _request_type = None
    _user_to_api_quota = AutoUpdatable(5, default={})

    def __init__(self):
        super(Aggregator, self).__init__()
        self._serialized_banned_list = self.EMPTY_SERIALIZED_BANNED_LIST
        self._user_to_api_quota.set_func(self.update_api_consumption)
        self._api_quotas_config = qconfig.Registry().serviceq.server.api_quotas
        context.set_current(None)

    def check_consumption(self, user, request_type=ctt.RequestSource.API):
        with self._lock:
            if self._banned_list is None:
                if self._serialized_banned_list is not None:
                    try:
                        self._banned_list, self._web_banned_list = map(set, msgpack.loads(self._serialized_banned_list))
                    except Exception:
                        self.logger.warning("Can't unpack banned list. Run next interval without api quotas.")
                        self._banned_list = set()
                        self._web_banned_list = set()
                else:
                    self._banned_list = set()
                    self._web_banned_list = set()
            banned_list = self._web_banned_list if request_type == ctt.RequestSource.WEB else self._banned_list
            if user in banned_list and common.config.Registry().server.api.quotas.check:
                return False
            self._current_user = (user, int(time.time() * 1000))
            self._request_type = request_type
        return True

    def __delta_part(self, now):
        consumption = now - self._current_user[1]
        web_consumption = 0
        if self._request_type == ctt.RequestSource.WEB:
            consumption, web_consumption = web_consumption, consumption
        return self._current_user[0], self._current_user[1] / 1000, consumption, web_consumption

    def add_delta(self):
        with self._lock:
            if self._current_user[0] is not None:
                self._delta.append((
                    self.__delta_part(int(time.time() * 1000))
                ))
                if len(self._delta) > common.config.Registry().server.api.quotas.max_delta_size:
                    self.logger.warning("Delta array is too big, clear it.")
                    self._delta = []
            self._current_user = (None, None)

    def update_api_consumption(self):
        import uwsgi
        try:
            serialized_consumption_table = uwsgi.cache_get(
                sa_consts.ApiConsumption.CONSUMPTION_KEY_NAME, sa_consts.ApiConsumption.UWSGI_CACHE_NAME
            )
            self._user_to_api_quota.value = (
                msgpack.loads(serialized_consumption_table) if serialized_consumption_table else {}
            )
        except Exception as ex:
            if not self._user_to_api_quota.value:
                self._user_to_api_quota.value = {}
            self.logger.error(
                "Error in communication with mule. Run next interval without updating api quotas",
                exc_info=ex
            )

    def proc(self):
        import uwsgi
        with self._lock:
            delta = self._delta
            self._delta = []
            if self._current_user[0] is not None:
                now = time.time()
                delta.append(self.__delta_part(int(now * 1000)))
                self._current_user = (self._current_user[0], int(now * 1000))
        try:
            self.send_msg(delta)
            uwsgi.sharedarea_rlock(sa_consts.UWSGIKey.API_QUOTAS)
            self.revision = self._length_struct.unpack(
                uwsgi.sharedarea_memoryview(sa_consts.UWSGIKey.API_QUOTAS)[:4]
            )[0]
            read_length = self._length_struct.unpack(uwsgi.sharedarea_memoryview(sa_consts.UWSGIKey.API_QUOTAS)[4:8])[0]
            if read_length > 0:
                serialized_banned_list = "".join(
                    uwsgi.sharedarea_memoryview(sa_consts.UWSGIKey.API_QUOTAS)[8:8 + read_length]
                )
            else:
                serialized_banned_list = self.EMPTY_SERIALIZED_BANNED_LIST
            uwsgi.sharedarea_unlock(sa_consts.UWSGIKey.API_QUOTAS)
        except Exception as ex:
            self.logger.error(
                "Error in communication with mule. Run next interval without api quotas.", exc_info=ex
            )
            serialized_banned_list = self.EMPTY_SERIALIZED_BANNED_LIST

        self._user_to_api_quota.update()

        with self._lock:
            self._banned_list = None
            self._serialized_banned_list = serialized_banned_list

    def update_api_consumption_headers(self, headers, user=None):
        user = user or self._current_user[0]
        if not user or not self._user_to_api_quota or not self._user_to_api_quota.value:
            return

        consumption, quota = self._user_to_api_quota.value.get(user, (0, None))

        headers[ctm.HTTPHeader.API_QUOTA_CONSUMPTION] = consumption

        if quota is not None:
            headers[ctm.HTTPHeader.API_QUOTA] = quota

    def main(self):
        while True:
            self._wakeup.wait(self.UPDATE_INTERVAL)
            self._wakeup.clear()
            self.proc()

    def start(self):
        self._thread = threading.Thread(target=self.main)
        self._thread.daemon = True
        self._thread.start()
