import logging
import random
from typing import List, Callable, Coroutine

import dataclasses
from asyncpg import UniqueViolationError

from mail.callmeback.callmeback.detail.coro.stateful_event import StatefulEvent
from mail.callmeback.callmeback.detail.errors import EntityAlreadyExists
from mail.callmeback.callmeback.stages.api.libretto.worker_notifier import WorkerNotifier
from mail.callmeback.callmeback.stages.api.libretto.worker_monitor import ActiveWorkerMonitor
from mail.callmeback.callmeback.stages.api.settings import EventBufferSettings
from mail.python.theatre.app.roles.db_multihost_pool import DbMultihostPool
from .notify_item import NotifyItem

log = logging.getLogger(__name__)


class EventBuffer:
    """
    Packs several `add` notifies in one batch and store it in DB at once via `process` call. Clients are blocked on
        `add` call until notify event is actually stored.
    Passes stored events to `WorkerNotifier`.
    """

    def __init__(
            self,
            pool: DbMultihostPool,
            notifier: WorkerNotifier,
            worker_monitor: ActiveWorkerMonitor,
            settings: EventBufferSettings,
    ):
        self._pool = pool
        self._chunk_size = settings.chunk_size
        self._notifier = notifier
        self._worker_mon = worker_monitor

        self._data: List[NotifyItem] = []
        self._callers_event = StatefulEvent()

    async def add(self, item: NotifyItem):
        if item.bucket_id is None:
            item.bucket_id = random.choice(self._active_buckets)
        self._data.append(item)
        event = self._callers_event
        if len(self._data) >= self._chunk_size:
            await self.process()
        failed_events = await event.wait()
        for event, exc in [
            (event, exc) for (event, exc) in failed_events if event == item
        ]:
            raise exc

    async def process(self):
        if not self._data:
            return
        data = self._data
        callers_event = self._callers_event
        self._data = []
        self._callers_event = StatefulEvent()
        try:
            success_events, failed_events = await self._insert(data)
            callers_event.set(failed_events)
        except Exception as e:
            callers_event.set_exception(e)
            return
        self._notifier.put_nowait(success_events)

    @property
    def _active_buckets(self):
        return list(self._worker_mon.bucket2worker.keys())

    # TODO :: smart backoff
    async def _insert(self, data: List[NotifyItem]):
        async with self._pool().acquire(timeout=2) as conn:
            async def serialilzed_insert(items):
                async with conn.transaction():
                    await conn.execute('select pg_advisory_xact_lock(42)')
                    return await conn.copy_records_to_table(
                        'events',
                        schema_name='reminders',
                        records=map(dataclasses.astuple, items),
                        columns=[field.name for field in dataclasses.fields(NotifyItem)],
                    )
            successed, failed = await self.batch_op_split_on_fail(serialilzed_insert, items=data)
        failed = [
            (
                event,
                EntityAlreadyExists('Event already exists', fail_status='already_exists', fail_data={
                    'owner_client_id': event.owner_client_id,
                    'group_key': event.group_key,
                    'event_key': event.event_key,
                }) if isinstance(exc, UniqueViolationError) else exc
            ) for (event, exc) in failed
        ]
        return successed, failed

    @staticmethod
    async def batch_op_split_on_fail(
            op: Callable[[List], Coroutine],
            items: List,
            exc_type=UniqueViolationError
    ) -> (List, List):
        """Try to apply batch operation, on specified exception split batch and run each item independent."""
        try:
            await op(items)
            return items, []
        except exc_type:
            success_items, failed_items = [], []
            for item in items:
                try:
                    await op([item])
                    success_items.append(item)
                except Exception as e:
                    failed_items.append((item, e))
            return success_items, failed_items
