import logging
from asyncio import QueueFull
from bisect import insort_right
from datetime import datetime, timedelta
from itertools import dropwhile
from time import time
from typing import List, Callable, Any

from mail.python.theatre.profiling.typing import Metrics
from mail.python.theatre.utils.nested_job import nested_job
from .timer import Timer

log = logging.getLogger(__name__)


class DelayedItem:
    """Object wrapper provides `a < b` semantics for `bisect` module"""
    def __init__(self, item, unix_at: float):
        self.item = item
        self.at = unix_at

    def __lt__(self, other):
        return self.at < other.at


class Delayed(Timer):
    """
    Role that plans delayed processing of events. Maintains event queue sorted by planned time.
        When an event is ready to fire, `consumer` is called.
        `consumer` is a client function (not a coroutine) that accepts item and returns boolean
    """
    def __init__(self, consumer: Callable[[Any], bool], max_delayed: int, retry_delayed: timedelta, **kwargs):
        self._waiting: List[DelayedItem] = []
        self._max_delayed: int = max_delayed
        self._retry_delayed_secs: float = retry_delayed.total_seconds()
        super().__init__(job=nested_job(self._process_ready_items, consumer), **kwargs)

    def put_delayed(self, item, at: datetime):
        if len(self._waiting) >= self._max_delayed:
            raise QueueFull
        insort_right(self._waiting, DelayedItem(item, at.timestamp()))
        self.run_at(at)

    @staticmethod
    def _split_items(events: List[DelayedItem], unix_timepoint: float):
        ready, waiting = [], []
        for ev in events:
            ready.append(ev) if ev.at <= unix_timepoint else waiting.append(ev)
        return ready, waiting

    async def _process_ready_items(self, consumer):
        def try_put(delayed: DelayedItem):
            try:
                return consumer(delayed.item)
            except Exception as e:
                log.exception(e)
                return False

        now = time()
        ready, waiting = self._split_items(self._waiting, unix_timepoint=now)
        self._waiting = list(dropwhile(try_put, ready)) + waiting
        if self._waiting:
            self.run_at_unixtime(max(self._waiting[0].at, now + self._retry_delayed_secs))

    def metrics(self) -> Metrics:
        return super().metrics() + [
            (f'{self._name}_qsize_ammx', len(self._waiting))
        ]
