import asyncio
import concurrent.futures
import logging
import time
from collections import defaultdict
from itertools import chain
from random import shuffle
from typing import Optional, Iterator

from infra.deploy_notifications_controller.lib.async_queue_with_statistic import AsyncQueueWithStatistic
from infra.deploy_notifications_controller.lib.http import StatisticsServer
from infra.deploy_notifications_controller.lib.infra_client import InfraClient
from infra.deploy_notifications_controller.lib.jns_client import JNSClient
from infra.deploy_notifications_controller.lib.models.action import QnotifierMessage, DummyQnotifierMessage, \
    InfraChange, Notification
from infra.deploy_notifications_controller.lib.models.notification_policy import JNSMessageAction
from infra.deploy_notifications_controller.lib.models.stage import Stage
from infra.deploy_notifications_controller.lib.models.url_formatter import UrlFormatter
from infra.deploy_notifications_controller.lib.paste_client import PasteClient
from infra.deploy_notifications_controller.lib.qnotifier_client import QnotifierClient
from infra.deploy_notifications_controller.lib.yp_client import YpClient


class Looper:
    def __init__(
        self,
        yp_client: YpClient,
        infra_client: InfraClient,
        infra_attempts_delay: float,
        qnotifier_client: QnotifierClient,
        paste_client: PasteClient,
        batch_size: int,
        stage_read_period: int,
        timestamp_generate_period: int,
        qnotifier_send_attempts: int,
        qnotifier_aggregation_period: int,
        history_workers: int,
        history_read_delay: int,
        statistics: StatisticsServer,
        url_formatter: UrlFormatter,
        jns_client: JNSClient,
        jns_attempts_delay: float,
    ) -> None:
        self.batch_size = batch_size
        self.stage_read_period = stage_read_period
        self.timestamp_generate_period = timestamp_generate_period
        self.qnotifier_send_attempts = qnotifier_send_attempts
        self.history_workers = history_workers
        self.heavy_task_pool = concurrent.futures.ThreadPoolExecutor(
            max_workers=history_workers,
            thread_name_prefix='history-worker',
        )
        self.yp_client = yp_client
        self.infra_client = infra_client
        self.infra_attempts_delay = infra_attempts_delay
        self.qnotifier_client = qnotifier_client
        self.qnotifier_aggregation_period = qnotifier_aggregation_period
        self.paste_client = paste_client
        self.url_formatter = url_formatter
        self.jns_client = jns_client
        self.jns_attempts_delay = jns_attempts_delay
        # we intentionally limit last to "now() - some time",
        # because YP may have some time drift, or leap second,
        # or disk write lag and there would be events appering
        # after we've read more recent ones
        self.history_read_delay = history_read_delay
        self.statistics = statistics

        self.log = logging.getLogger('looper')
        self.qnotifier_queue: AsyncQueueWithStatistic[QnotifierMessage] = AsyncQueueWithStatistic('qnotifier', statistics)
        self.infra_queue: AsyncQueueWithStatistic[InfraChange] = AsyncQueueWithStatistic('infra', statistics)
        self.stages_queue: AsyncQueueWithStatistic[Stage] = AsyncQueueWithStatistic('stages', statistics)
        self.notify_queue: AsyncQueueWithStatistic[Notification] = AsyncQueueWithStatistic('notification', statistics)
        self.stages = {}
        self.stages_by_id = {}
        self.notification_policies_by_id = {}
        self.last_timestamp = None

    async def _get_next_timestamp(self, log: logging.Logger, old_ts: Optional[int], old_is_ok=False) -> int:
        # TODO since watches virtually ain't working, should we really have this
        # function and _timestamp_loop?
        while True:
            ts = self.last_timestamp
            if ts is None:
                sleep_time = 1.
                log.debug('timestamp is not ready, will sleep for %fs', sleep_time)
                await asyncio.sleep(sleep_time)
                continue

            if ts == old_ts and not old_is_ok:
                sleep_time = 1.
                # log.debug('timestamp has not changed, will sleep for %fs', sleep_time)
                await asyncio.sleep(sleep_time)
                continue

            return ts

    async def _timestamp_loop(self) -> None:
        ts = None
        log = self.log.getChild('ts')
        while True:
            try:
                log.debug('generating new timestamp')
                ts = await self.yp_client.generate_timestamp()
            except asyncio.CancelledError:
                break
            except Exception as e:
                log.exception('timestamp generation failed: %s', e)
                self.statistics.push('failed_timestamp', 1)
                await asyncio.sleep(3)
            else:
                self.last_timestamp = ts
                self.statistics.push('successful_timestamp', 1)
                log.info('generated new timestamp: %d', ts)
                await asyncio.sleep(self.timestamp_generate_period)

    async def _poll_stages_from_yp(
        self,
        log: logging.Logger,
        ts: int,
    ):
        log.info('polling for new stages (current queue size = %d)', self.stages_queue.qsize())
        batch = self.yp_client.select_stages(
            timestamp=ts,
            batch_size=self.batch_size,
        )

        stages = {}
        stages_by_id = {}
        async for stage in batch:
            stages[stage.uuid] = stage
            stages_by_id[stage.id] = stage

        # We do not load state from annotations by default, since they are heavy and
        # we do not need them at most cases.
        # So we load it only if the state was stored earlier in some other process
        # and this one is reading this stage for the first time.
        stages_to_restore = [
            stage.id
            for stage in stages.values()
            if stage.last_timestamp
            and stage.uuid not in self.stages
        ]

        log.debug("will restore %d stages from saved state", len(stages_to_restore))
        async for stage in self.yp_client.select_stages(
            timestamp=ts,
            batch_size=self.batch_size,
            state_from_annotation=True,
            stage_ids=stages_to_restore,
        ):
            stages[stage.uuid].update_by(stage)

        return stages, stages_by_id

    async def update_stages_last_timestamp_label(self, log, stages, ts):
        stages_without_last_timestamp_label = [
            stage
            for stage in stages.values()
            if stage.last_timestamp is None
        ]

        current_timestamp = int((ts >> 30) * 1e9)

        log.info("polled %r stages without last_timestamp label", len(stages_without_last_timestamp_label))
        self.statistics.set_absolute('polled_stages_without_last_timestamp', len(stages_without_last_timestamp_label))

        for stage in stages_without_last_timestamp_label:
            stage.last_timestamp = current_timestamp
            log.info(
                "[%r] requested last_timestamp label update %r",
                stage.id,
                current_timestamp,
            )

        await self.yp_client.save_stages_last_timestamp_label(
            stages=stages_without_last_timestamp_label,
            batch_size=self.batch_size,
        )

    async def shedule_removed_stages(self, log, stages):
        removed_stages = [
            stage
            for stage in self.stages.values()
            if stage.uuid not in stages
        ]

        for stage in removed_stages:
            log.info("scheduled removed %r for history check", stage.id)
            del self.stages[stage.uuid]

        self.stages_queue.put_all_nowait(removed_stages)

    async def update_self_stages(self, stages):
        for stage in stages.values():
            # it's either new stage, or stage was removed and created
            if stage.uuid not in self.stages:
                self.stages[stage.uuid] = stage
                self.stages_by_id[stage.id] = stage
            else:
                # we'll update infra as we are not tracking it in history
                dest_stage = self.stages[stage.uuid]
                dest_stage.infra_service = stage.infra_service
                dest_stage.infra_environment = stage.infra_environment

    async def shedule_new_stages(self, stages):
        new_stages = [
            self.stages[stage.uuid]
            for stage in stages.values()
            if stage not in self.stages_queue
        ]

        shuffle(new_stages)
        self.stages_queue.put_all_nowait(new_stages)

    async def update_stages_lag_statistics(self):
        self.statistics.set_absolute(
            'max_update_lag',
            0 if not self.stages
            else max(stage.update_lag for stage in self.stages.values())
        )

        self.statistics.set_absolute(
            'sum_update_lag',
            sum(stage.update_lag for stage in self.stages.values())
        )

    async def _poll_stages(
        self,
        log: logging.Logger,
        ts: int,
    ) -> int:
        stages, stages_by_id = await self._poll_stages_from_yp(
            log=log,
            ts=ts,
        )

        await self.update_stages_last_timestamp_label(log, stages, ts)

        await self.shedule_removed_stages(log, stages)

        await self.update_self_stages(stages)

        notification_policies = self.yp_client.select_notification_policies(
            timestamp=ts,
            batch_size=self.batch_size,
        )

        self.notification_policies_by_id = {}
        async for np in notification_policies:
            self.notification_policies_by_id[np.stage_id] = np

        await self.shedule_new_stages(stages)

        await self.update_stages_lag_statistics()

        self.stages_by_id = {
            stage.id: stage
            for stage in self.stages.values()
        }

        return len(stages)

    async def _poll_stages_loop(self):
        log = self.log.getChild('stages')
        ts = None
        while True:
            try:
                ts = await self._get_next_timestamp(log, ts)
                polled = await self._poll_stages(log, ts)
            except asyncio.CancelledError:
                break
            except Exception as e:
                log.exception("poll failed: %s", e)
                self.statistics.push('failed_poll_stages', 1)
                self.statistics.set('polled_stages', 0)
            else:
                log.info("polled %d stages", polled)
                self.statistics.set('polled_stages', polled)
                self.statistics.set('total_stages', len(self.stages))
                self.statistics.set('total_stages_by_id', len(self.stages_by_id))
                self.statistics.push('successful_poll_stages', 1)

            await asyncio.sleep(self.stage_read_period)

    async def _watch_stages_loop(self):
        # FIXME broken
        log = self.log.getChild('watch')
        ts = await self._get_next_timestamp(log, None)
        while True:
            try:
                log.info("watching for stage changes")
                stage_ids = set()
                last_ts = ts
                no_results = True
                async for ev_ts, stage_id in self.yp_client.watch_stages(
                    from_timestamp=last_ts,
                    batch_size=self.batch_size,  # FIXME?
                ):
                    no_results = False
                    log.debug("[%s] modified at %d", stage_id, ev_ts)
                    if stage_id in self.stages_by_id:
                        stage_ids.add(stage_id)
                    last_ts = max(last_ts, ev_ts)

                if no_results:
                    ts = await self._get_next_timestamp(log, ts)

                for stage_id in stage_ids:
                    if stage_id in self.stages_by_id:
                        self.stages_queue.put_nowait(self.stages_by_id[stage_id])
                        log.info("[%s] scheduled update", stage_id)

                self.stages_queue.update_statistic()

            except asyncio.CancelledError:
                break
            except Exception as e:
                log.exception("watch failed: %s", e)
                await asyncio.sleep(0.5)

    async def _update_yp_state(
        self,
        stage_id: str,
        timestamp: int,
        state: Optional[dict] = None,
    ):
        await self.yp_client.save_stage_state(
            object_id=stage_id,
            timestamp=timestamp,
            state=state,
        )

    async def _get_history_events(
        self,
        log: logging.Logger,
        stage: Stage,
    ):
        to_timestamp = int((time.time() - self.history_read_delay) * 1e9)
        log.info("[%s] selecting history since %d till %d",
                 stage.id,
                 stage.last_timestamp,
                 to_timestamp,
                 )

        stage.max_timestamp = await self.yp_client.select_stage_last_timestamp(stage.id, stage.uuid)
        selected = 0
        async for change in self.yp_client.select_stage_history(
            # batch_size=self.batch_size,
            object_id=stage.id,
            object_uuid=stage.uuid,
            from_timestamp=stage.last_timestamp,
            to_timestamp=to_timestamp,
            limit=self.batch_size,
        ):
            selected += 1
            log.debug("[%s] got event %s", stage.id, change)

            input_data = change.InputData(
                stage,
                log=log,
                url_formatter=self.url_formatter,
                paste_client=self.paste_client,
                notification_policy=self.notification_policies_by_id.get(stage.id)
            )

            if change.is_heavy_task:
                message, output_data = await asyncio.get_event_loop().run_in_executor(
                    self.heavy_task_pool,
                    change.process_changes,
                    input_data
                )
            else:
                message, output_data = change.process_changes(input_data)

            state = change.update_stage(input_data)
            if message is not None:
                message.state = state

            if message is not None:
                self.qnotifier_queue.put_nowait(message)
                log.info(
                    "[%s] queued qnotifier message %s from %s at %s",
                    stage.id,
                    message,
                    type(change).__name__,
                    change.str_time
                )
            else:
                log.debug("[%s] event has empty diff", stage.id)

            for infra in output_data.infra_changes:
                self.infra_queue.put_nowait(infra)
                log.debug(
                    "[%s] queued infra event %s for spec revision %d",
                    stage.id,
                    infra.event_kind,
                    infra.revision
                )

            for notification in output_data.notifications:
                self.notify_queue.put_nowait(notification)
                log.debug(
                    "[%s] queued notification action %s for spec revision %d",
                    notification.stage_id,
                    notification.event_kind,
                    notification.revision
                )

        log.info("[%s] selected %d events", stage.id, selected)
        self.statistics.push('history_events_read', selected)

    async def _get_history_events_loop(self):
        log = self.log.getChild('history')
        while True:
            try:
                stage = await self.stages_queue.get()
                await self._get_history_events(log, stage)
                self.statistics.set_absolute('max_update_lag', max(stage.update_lag for stage in self.stages.values()))
                self.statistics.set_absolute('sum_update_lag', sum(stage.update_lag for stage in self.stages.values()))
            except asyncio.CancelledError:
                break
            except Exception as e:
                log.exception('[%s] history fetch failed: %s', stage.id, e)
                self.statistics.push('failed_history_fetch', 1)
            else:
                log.debug('[%s] history request finished', stage.id)
                self.statistics.push('successful_history_fetch', 1)

    @staticmethod
    def log_infra_change(log: logging.Logger, change: InfraChange, event_id: int):
        log.info(
            "[%s] %s event %r in service %r, environment %r",
            change.stage_id,
            change.event_kind.value,
            event_id,
            change.service_id,
            change.environment_id
        )

    async def _process_infra_create(
        self,
        log: logging.Logger,
        client: InfraClient,
        change: InfraChange,
        statistic_name: str,
    ):
        event_id = await client.create_event(
            service_id=change.service_id,
            env_id=change.environment_id,
            author=change.author,
            title=change.name,
            description=change.description,
            meta=change.create_meta(),
            start_time=change.start_time,
        )
        self.statistics.push(statistic_name, 1)
        self.log_infra_change(log, change, event_id)

    async def _process_infra_update(
        self,
        log: logging.Logger,
        client: InfraClient,
        change: InfraChange,
        statistic_name: str,
    ):
        current_events = await client.get_current_events(
            env_id=change.environment_id,
            meta=change.create_meta(
                event_kind=InfraChange.EventKind.STARTED,
                with_revision=False,
            ),
            latest_revision=change.revision,
        )

        log.info(
            "[%s] started events with revision at most %r in service %r, environment %r: %s",
            change.stage_id,
            change.revision,
            change.service_id,
            change.environment_id,
            [event[0] for event in current_events],
        )

        for event_id, event_title, event_meta in current_events:
            await client.update_event(
                event_id,
                event_title=change.update_title(event_title),
                meta=event_meta,
                author=change.author,
                finish_time=change.finish_time,
            )

            self.log_infra_change(log, change, event_id)
            self.statistics.push(statistic_name, 1)

    change_processors = {
        InfraChange.EventChange.CREATE: _process_infra_create,
        InfraChange.EventChange.CLOSE: _process_infra_update,
    }

    async def _process_infra_queue_item(self, log: logging.Logger, client: InfraClient) -> None:
        change = await self.infra_queue.get()

        event_change = change.event_change
        statistic_prefix = 'infra_events_' + event_change.value

        change_processor = self.change_processors[event_change]

        try:
            await change_processor(self, log, client, change, statistic_prefix + 'd')
        except Exception as e:
            self.statistics.push(statistic_prefix + '_errors', 1)
            log.exception("[%s] failed to %s event: %s", change.stage_id, event_change.value, e)
            self.infra_queue.put_nowait(change)
            await asyncio.sleep(self.infra_attempts_delay)

    async def _process_infra_queue_loop(self):
        log = self.log.getChild('infra')
        async with self.infra_client:
            while True:
                try:
                    await self._process_infra_queue_item(log, self.infra_client)
                except asyncio.CancelledError:
                    break

    async def _get_qnotifier_item(
        self,
        log: logging.Logger,
        has_pending_messages: bool,
        timeout: Optional[float] = None,
    ) -> Optional[QnotifierMessage]:
        return await self.qnotifier_queue.get_with_pending(
            has_pending_items=has_pending_messages,
            timeout=timeout,
        )

    def _aggregate_qnotifier_items(
        self,
        log: logging.Logger,
        pending_messages: dict,
    ) -> Iterator[QnotifierMessage]:
        for stage_id, message_list in pending_messages.items():
            message_list.sort(key=lambda msg: msg.timestamp)
            rev_non_dummy_messages = [
                message for message in reversed(message_list)
                if not isinstance(message, DummyQnotifierMessage)
            ]

            plain_text = f'\n---{QnotifierMessage.SEPARATOR}'.join(
                msg.plain_text for msg in message_list
                if msg.plain_text
            )
            html = f'\n<hr/>{QnotifierMessage.SEPARATOR}'.join(
                msg.html for msg in message_list
                if msg.html
            )
            title = rev_non_dummy_messages[0].title if rev_non_dummy_messages else None
            tags = list(set(chain.from_iterable(msg.tags for msg in message_list)))
            attempts = min(msg.attempts for msg in message_list)
            timestamp = message_list[-1].timestamp
            state = message_list[-1].state
            authors = set(chain.from_iterable(msg.authors for msg in message_list))
            change_kinds = set(chain.from_iterable(msg.change_kinds for msg in message_list))
            project_id = rev_non_dummy_messages[0].project_id if rev_non_dummy_messages else None
            revisions = set(chain.from_iterable(msg.revisions for msg in message_list))

            if not rev_non_dummy_messages:  # all messages are dummy changes:
                yield DummyQnotifierMessage(
                    stage_id=stage_id,
                    timestamp=timestamp,
                    attempts=attempts,
                    state=state,
                )
            else:
                log.debug(
                    "[%s] aggregated message: plain_text=%r, html=%r, title=%r, tags=%r",
                    stage_id,
                    plain_text,
                    html,
                    title,
                    tags,
                )

                yield QnotifierMessage(
                    stage_id=stage_id,
                    timestamp=timestamp,
                    title=title,
                    plain_text=plain_text,
                    html=html,
                    tags=tags,
                    attempts=attempts,
                    state=state,
                    authors=authors,
                    project_id=project_id,
                    change_kinds=change_kinds,
                    revisions=revisions,
                )

    async def _send_qnotifier_message(
        self,
        log: logging.Logger,
        message: QnotifierMessage,
    ) -> None:
        try:
            if isinstance(message, DummyQnotifierMessage):
                self.statistics.push('qnotifier_dummy_events_processed', 1)
            else:
                await self.qnotifier_client.post_event(
                    title=message.title,
                    plain_text=message.plain_text,
                    html=f'<html><body>{message.html_style}{message.html}</body></html>',
                    tags=message.tags,
                    headers=message.headers,
                )
                self.statistics.push('qnotifier_messages_sent', 1)
                log.info("[%s] sent message: %s", message.stage_id, message.tags)

            await self._update_yp_state(
                stage_id=message.stage_id,
                timestamp=message.timestamp,
                state=message.state,
            )
            log.info("[%s] updated last_timestamp: %s", message.stage_id, message.timestamp)
        except asyncio.CancelledError:
            return
        except Exception as e:
            message.attempts += 1
            self.statistics.push('qnotifier_messages_failed', 1)
            log.exception(
                "[%s] message send to %r failed (attempt %d/%d): %s",
                message.stage_id,
                message.tags,
                message.attempts,
                self.qnotifier_send_attempts,
                e
            )
            if message.attempts >= self.qnotifier_send_attempts:
                log.info("[%s] message is dropped", message.stage_id)
            else:
                self.qnotifier_queue.put_nowait(message)
                log.debug('[%s] message put back to queue', message.stage_id)

    async def _process_qnotifier_queue_loop(self):
        log = self.log.getChild('qnotifier')
        last_aggregation = time.time()
        pending_messages = defaultdict(list)

        while True:
            msg = await self._get_qnotifier_item(
                log,
                has_pending_messages=bool(pending_messages),
                timeout=self.qnotifier_aggregation_period - (time.time() - last_aggregation),
            )

            if not pending_messages:
                last_aggregation = time.time()

            if msg:
                log.info("[%s] got and queued message for aggregation", msg.stage_id)
                pending_messages[msg.stage_id].append(msg)

            if time.time() - last_aggregation >= self.qnotifier_aggregation_period:
                for message in self._aggregate_qnotifier_items(log, pending_messages):
                    await self._send_qnotifier_message(log, message)

                pending_messages.clear()

            if not pending_messages:
                last_aggregation = time.time()

    async def _send_jns_message(
        self,
        log: logging.Logger,
        notification: Notification,
        jns_action: JNSMessageAction,
    ) -> None:
        try:
            await self.jns_client.send_to_channel(
                target_project=jns_action.project,
                channel=jns_action.channel,
                body=notification.to_json()
            )
            self.statistics.push('jns_message_sent', 1)
            log.info('[%s] action done: [%s]', notification.stage_id, notification.event_kind)
        except asyncio.CancelledError:
            raise
        except Exception as e:
            self.statistics.push('jns_message_errors', 1)
            log.exception('[%s] failed to send jns message: %s', notification.stage_id, e)
            self.notify_queue.put_nowait(notification)
            await asyncio.sleep(self.jns_attempts_delay)

    async def _process_notify_actions_queue_loop(self):
        log = self.log.getChild('notify_action')
        async with self.jns_client:
            while True:
                try:
                    notification = await self.notify_queue.get()
                    if isinstance(notification.action, JNSMessageAction):
                        await self._send_jns_message(log, notification, notification.action)
                    # Ignore other types
                except asyncio.CancelledError:
                    break

    async def run(self):
        loops = {
            asyncio.create_task(self._timestamp_loop()): 'timestamp',
            asyncio.create_task(self._poll_stages_loop()): 'poll_stages',
            # asyncio.create_task(self._watch_stages_loop()): 'watch_stages',
            asyncio.create_task(self._process_infra_queue_loop()): 'process_infra_queue',
            asyncio.create_task(self._process_qnotifier_queue_loop()): 'process_qnotifer_queue',
            asyncio.create_task(self._process_notify_actions_queue_loop()): 'process_notify_actions_queue',
            asyncio.create_task(self.qnotifier_client.ticket_renewal_loop()): 'qnotifier_ticket_renewal_loop',
            asyncio.create_task(self.infra_client.ticket_renewal_loop()): 'infra_ticket_renewal_loop',
        }
        for idx in range(self.history_workers):
            loops[asyncio.create_task(self._get_history_events_loop())] = f'get_history_events_{idx}'

        done, pending = await asyncio.wait(list(loops.keys()), return_when=asyncio.FIRST_EXCEPTION)
        for task in done:
            try:
                task_result = await task
                self.log.info("Task %s finished with %s", loops.get(task), task_result)
            except Exception as e:
                self.log.exception("Task %s failed with %s", loops.get(task), e)

        for task in pending:
            task.cancel()
            self.log.info("Task %s has been canceled", loops.get(task))

        self.log.info('looper stopped')

    async def stop(self):
        pass
