import enum
import random
import asyncio
import logging
from functools import partial

import ylock


_STANDOFF_TIMEOUT_EXCEEDED = object()
_DISCONNECTED = object()


class RestartPolicy(enum.Enum):
    DO_NOT_RESTART = 1
    RESTART_ON_EXCEPTION = 2
    ALWAYS_RESTART = 3


class RetryFailedError(Exception):
    """
    Raised when retrying an operation ultimately failed, after
    retrying the maximum number of attempts.
    """
    pass


class RetrySleeper(object):
    """
    A retry sleeper that will track its jitter, backoff and
    sleep appropriately when asked.
    """
    def __init__(self, max_tries=-1, delay=0.5, backoff=2, max_jitter=0.8,
                 max_delay=600):
        """
        Create a :class:`RetrySleeper` instance

        :param max_tries: How many times to retry the command. -1 for unlimited.
        :param delay: Initial delay between retry attempts.
        :param backoff: Backoff multiplier between retry attempts.
                        Defaults to 2 for exponential backoff.
        :param max_jitter: Additional max jitter period to wait between
                           retry attempts to avoid slamming the server.
        :param max_delay: Maximum delay in seconds, regardless of other
                          backoff settings. Defaults to ten minutes.

        """
        self.max_tries = max_tries
        self.delay = delay
        self.backoff = backoff
        self.max_jitter = max_jitter
        self.max_delay = float(max_delay)
        self._attempts = 0
        self._cur_delay = delay

    @property
    def cur_delay(self):
        return self._cur_delay

    def reset(self):
        """
        Reset the attempt counter
        """
        self._attempts = 0
        self._cur_delay = self.delay

    async def increment(self, exception=True):
        """
        Increment the failed count, and sleep appropriately before
        continuing
        :return: False if max attempts reached, True otherwise
        :rtype: bool
        """
        try:
            time_to_sleep = self.get_next_time_to_sleep()
        except RetryFailedError:
            if exception:
                raise
            return False
        else:
            await asyncio.sleep(time_to_sleep)
            return True

    def get_next_time_to_sleep(self):
        """
        Increment the failed count and just return delay before the next retry
        and do not sleep
        :rtype: float
        """
        if self._attempts == self.max_tries:
            raise RetryFailedError("Too many retry attempts")
        self._attempts += 1
        jitter = random.random() * self.max_jitter
        result = min(self._cur_delay + jitter, self.max_delay)
        self._cur_delay = min(self._cur_delay * self.backoff, self.max_delay)
        return result


class ExclusiveService:
    DEFAULT_STANDOFF_STRATEGY = lambda: 12 * 3600
    DEFAULT_ACQUIRE_TIMEOUT_STRATEGY = lambda: None

    def __init__(
        self, cfg, name, runnable,
        acquire_timeout_strategy=DEFAULT_ACQUIRE_TIMEOUT_STRATEGY,
        standoff_strategy=DEFAULT_STANDOFF_STRATEGY,
        restart_policy=RestartPolicy.DO_NOT_RESTART
    ):
        self._service = None
        if callable(runnable):
            self._run = runnable
        else:
            self._run = runnable.run

        self.name = name
        self.manager = ylock.create_manager(**cfg)
        self.log = logging.getLogger(f'exclusive.{name}')

        self._standoff_strategy = standoff_strategy
        self._acquire_timeout_strategy = acquire_timeout_strategy
        self._restart_policy = restart_policy

        self._stopped = asyncio.Event()
        self._stopping_lock = asyncio.Lock()

    async def acquire_lock(self, lock, kwargs) -> bool:
        loop = asyncio.get_event_loop()
        acquire_coro = asyncio.ensure_future(loop.run_in_executor(None, partial(lock.acquire, **kwargs)))
        stop_coro = asyncio.ensure_future(self._stopped.wait())
        done, _ = await asyncio.wait((acquire_coro, stop_coro), return_when=asyncio.FIRST_COMPLETED)
        if acquire_coro in done:
            stop_coro.cancel()
            return await acquire_coro
        if stop_coro in done:
            acquire_coro.cancel()
        return False

    async def release_lock(self, lock) -> None:
        await asyncio.get_event_loop().run_in_executor(None, lock.release)

    async def wait_task_or_stop(self, lock, task, timeout) -> None:
        stop_coro = asyncio.ensure_future(self._stopped.wait())
        task_coro = asyncio.ensure_future(task)
        done, _ = await asyncio.wait((task_coro, stop_coro), timeout=timeout, return_when=asyncio.FIRST_COMPLETED)
        await self.release_lock(lock)
        if task_coro in done:
            stop_coro.cancel()
            await task_coro
            return
        if stop_coro in done:
            self.log.info(f"was leading for too long (more than {timeout} seconds), "
                          "stand off - stopping service")
            lock.release()
            task_coro.cancel()
            self.log.info("service stopped")
            await stop_coro

    async def run(self) -> None:
        self._stopped.clear()

        sleeper = RetrySleeper(max_delay=5)
        lock = self.manager.lock(name=self.name, block=True)
        timeout_supported = 'YT' in lock.acquire.__func__.__qualname__  # NOTE (torkve) zomg

        while not self._stopped.is_set():
            acquire_timeout = self._acquire_timeout_strategy()
            self.log.info(f"acquiring lock with timeout of {acquire_timeout} seconds...")

            args = dict(timeout=acquire_timeout) if timeout_supported else {}

            try:
                acquired = await self.acquire_lock(lock, args)
            except Exception as e:
                self.log.info(f'failed to acquire lock: {e}')
                await self.release_lock(lock)
                await sleeper.increment()
                continue

            if not acquired:
                self.log.info(f'failed to acquire lock within timeout of {acquire_timeout} seconds')
                continue

            try:
                # we locked, yey!
                # now start service and wait for session state change
                sleeper.reset()
                standoff_timeout = self._standoff_strategy()
                self.log.info("became singleton - "
                              f"starting service with standoff timeout of {standoff_timeout} seconds...")

                if self._stopped.is_set():
                    return
                self._service = self._run()
                await self.wait_task_or_stop(lock, self._service, standoff_timeout)
            except Exception:
                self.log.exception("service raised an exception")
                lock.release()
                if self._restart_policy not in (RestartPolicy.ALWAYS_RESTART, RestartPolicy.RESTART_ON_EXCEPTION):
                    break
                self.log.info("restarting according to the restart policy...")
            else:
                self.log.info("service finished")
                if self._restart_policy != RestartPolicy.ALWAYS_RESTART:
                    break
                self.log.info("restarting according to the restart policy...")

            # sleep for sometime to let someone take leadership
            await asyncio.sleep(2)

    async def stop(self) -> None:
        self.log.info("stopping exclusive service...")
        self._stopped.set()
        try:
            await self._service
        except Exception:
            pass
        self.log.info("stopped exclusive service")
