import logging
from datetime import datetime
from typing import Iterable

import dateutil.tz
import sqlalchemy as sa
from asyncpg import Connection
from asyncpgsa import compile_query

from mail.callmeback.callmeback.stages.worker.props.callback_item import CallbackItem
from mail.callmeback.callmeback.stages.worker.props.current_events import CurrentEvents
from mail.callmeback.callmeback.stages.worker.settings.event_poller import EventPollerSettings
from mail.python.theatre.roles import Timer
from mail.python.theatre.app.roles.db_multihost_pool import DbMultihostPool
from .bucket_holder import BucketHolder
from .herald import CallbackHerald

log = logging.getLogger(__name__)


class EventPoller(Timer):
    """
    Polls nearby events and sends them to `CallbackHerald`. The `wakeup_at` method allows to notify timer about
    needed early wakeup, forcing poll at given time.
    """
    FIND_SKIPPED_Q = '''
        -- find_skipped_events
        SELECT *
          FROM reminders.events
         WHERE bucket_id = ANY( :bucket_ids ::bigint[] )
           AND status = 'pending'
           AND run_at <= :max_run_at
           AND event_id <= :max_event_id
        ORDER BY run_at, event_id
        LIMIT :event_limit
    '''
    GET_PREV_Q = '''
        -- fetch_prev_events
        SELECT *
          FROM reminders.events
         WHERE bucket_id = ANY( :bucket_ids ::bigint[] )
           AND status = 'pending'
           AND run_at <= :max_run_at
           AND event_id > :max_event_id
        ORDER BY event_id
        LIMIT :event_limit
    '''
    GET_NEXT_Q = '''
        -- fetch_next_events
        SELECT *
          FROM reminders.events
         WHERE bucket_id = ANY( :bucket_ids ::bigint[] )
           AND status = 'pending'
           AND (run_at, event_id) > (:max_run_at, :max_event_id)
           AND run_at <= now() + make_interval(secs => :max_lookup_secs)
        ORDER BY run_at, event_id
        LIMIT :event_limit
    '''

    TZ = dateutil.tz.tzlocal()

    def __init__(
        self,
        pg_pool: DbMultihostPool,
        buckets_holder: BucketHolder,
        herald: CallbackHerald,
        current_events: CurrentEvents,
        settings: EventPollerSettings,
        **kwargs
    ):
        self._pg_pool = pg_pool
        self._buckets_holder = buckets_holder
        self._herald = herald
        self._current_events = current_events
        self._poll_max_delay = settings.poll_max_delay
        self._max_lookup_secs = int(settings.max_lookup_delta.total_seconds())
        self._skipped_items_limit = settings.skipped_items_limit

        # TODO :: faulty if bucket set changes. Need to separate flow for each bucket
        self._max_run_at = datetime.fromtimestamp(0, tz=self.TZ)
        self._max_event_id = 0

        super(EventPoller, self).__init__(job=self._event_poller_job, **kwargs)

    async def start(self):
        await super(EventPoller, self).start()
        self.run_at_unixtime(0)

    def state(self):
        return {
            'max_run_at': self._max_run_at,
            'max_event_id': self._max_event_id,
            'buckets': self._buckets_holder.bucket_ids,
            'events': {
                'in_progress': self._current_events.in_progress_ids,
                'failed': self._current_events.failed_ids,
                'completed': self._current_events.done_ids,
                'rejected': self._current_events.rejected_ids,
                'available': self._current_events.available_events_cnt,
            },
            **super().state()
        }

    def wakeup_at(self, at: datetime):
        self.run_at(at)

    async def find_skipped(self):
        now = datetime.now(tz=self.TZ)
        for event in await self._get_skipped_events(limit=self._skipped_items_limit):
            self._herald.put(event, now=now)

    async def _get_skipped_events(self, limit: int) -> Iterable[CallbackItem]:
        q, p = compile_query(
            sa.text(self.FIND_SKIPPED_Q).params(
                bucket_ids=self._buckets_holder.bucket_ids,
                max_event_id=self._max_event_id,
                max_run_at=self._max_run_at,
                event_limit=limit,
            )
        )
        async with self._pg_pool(ro=True).acquire(timeout=2) as conn:
            return (
                CallbackItem(**rec) for rec in await conn.fetch(q, *p)
                if rec['event_id'] not in set(self._current_events.total_ids)
            )

    async def _event_poller_job(self):
        now = datetime.now(tz=self.TZ)
        log.info(f'_event_poller_job: now= {now}')
        try:
            for event in await self._get_events(limit=self._current_events.available_events_cnt):
                self._herald.put(event, now=now)
                self._max_run_at = max(self._max_run_at, event.run_at)
                self._max_event_id = max(self._max_event_id, event.event_id)
        finally:
            if not self._stopped:
                # Can sleep for too long in edge cases
                log.info(f'_event_poller_job, run_at: now={now}, poll_max_delay={self._poll_max_delay}')
                self.run_at(now + self._poll_max_delay)

    async def _get_events(self, limit):
        log.info(f'_get_events: limit={limit}')
        async with self._pg_pool(ro=True).acquire(timeout=2) as conn:
            async with conn.transaction():
                result = await self._get_events_impl(conn, self.GET_PREV_Q, limit)
                limit -= len(result)
                result += await self._get_events_impl(conn, self.GET_NEXT_Q, limit,
                                                      max_lookup_secs=self._max_lookup_secs)
                log.info(f'_get_events, returns {len(result)} records')
                return result

    async def _get_events_impl(self, conn: Connection, q: str, limit: int, **kwargs):
        if limit <= 0:
            return []
        q, p = compile_query(
            sa.text(q).params(
                max_event_id=self._max_event_id,
                max_run_at=self._max_run_at,
                bucket_ids=self._buckets_holder.bucket_ids,
                event_limit=limit,
                **kwargs,
            )
        )
        return [CallbackItem(**rec) for rec in await conn.fetch(q, *p, timeout=1)]
