import time
from contextlib import contextmanager

from gevent.lock import Semaphore

from sepelib.core.exceptions import LogicalError
from walle.errors import TooManyRequestsError

DEFAULT_CONCURRENT_TIMEOUT = 30


class LimitManager:
    def __init__(self):
        # api_func_name: RateLimiterInstance
        self.limiters = {}

    def initialize(self, func, rps=None, max_concurrent=None, concurrent_timeout=DEFAULT_CONCURRENT_TIMEOUT):
        if func not in self.limiters:
            self.limiters[func] = self.decide_limiter(func, rps, max_concurrent, concurrent_timeout)

    @staticmethod
    def decide_limiter(func, rps, max_concurrent, concurrent_timeout=DEFAULT_CONCURRENT_TIMEOUT):
        if not rps and not max_concurrent:
            return RateLimiter(func)

        if rps and not max_concurrent:
            return RPSRateLimiter(func, rps)

        if not rps and max_concurrent:
            return ConcurrentRateLimiter(func, max_concurrent, concurrent_timeout)

        raise LogicalError  # Proposed by n-malakhov

    @contextmanager
    def check_limit(self, func, rps, max_concurrent, concurrent_timeout=DEFAULT_CONCURRENT_TIMEOUT):
        if func not in self.limiters:
            self.initialize(func, rps, max_concurrent, concurrent_timeout)
        with self.limiters[func].check():
            yield


class RateLimiter:
    def __init__(self, func):
        self.func = func

    def enter(self):
        pass

    def exit(self):
        pass

    @contextmanager
    def check(self):
        self.enter()
        try:
            yield
        finally:
            self.exit()


class ConcurrentRateLimiter(RateLimiter):
    """Semaphore Based 'Queue'"""

    def __init__(self, func, max_concurrent, concurrent_timeout=DEFAULT_CONCURRENT_TIMEOUT):
        super().__init__(func)
        self.concurrent_timeout = concurrent_timeout
        self.concurrent_lock = Semaphore(value=max_concurrent)

    def enter(self):
        if not self.concurrent_lock.acquire(timeout=self.concurrent_timeout):
            raise TooManyRequestsError(func=self.func)

    def exit(self):
        self.concurrent_lock.release()


class RPSRateLimiter(RateLimiter):
    """Token Bucket Algorithm"""

    def __init__(self, func, rps):
        super().__init__(func)
        self.max_size = rps
        self.cur_size = 0
        self.last_check = time.time()

    def enter(self):
        cur_time = time.time()
        if cur_time - self.last_check >= 1:
            self.cur_size = 0
            self.last_check = cur_time

        if self.cur_size + 1 > self.max_size:
            raise TooManyRequestsError(func=self.func)

        self.cur_size += 1
