import asyncio
import logging
from typing import Dict

from infra.deploy_queue_controller.lib.http import StatisticsServer
from infra.deploy_queue_controller.lib.yp_client import YpClient
from infra.deploy_queue_controller.lib.models import Stage, ReleaseRule, Release
from infra.deploy_queue_controller.lib.actions import Actions


class QueueController:
    ITERATION_DELAY = 5.0

    def __init__(
        self,
        yp_client: YpClient,
        batch_size: int,
        username: str,
        statistics: StatisticsServer,
    ) -> None:
        self.batch_size = batch_size
        self.yp_client = yp_client
        self.username = username
        self.statistics = statistics
        self.log = logging.getLogger('queue-ctl')

    async def _poll_stages(self, timestamp: int, batch_size: int) -> Dict[str, Stage]:
        self.log.info('Polling stages')

        batch = self.yp_client.select_stages(
            timestamp=timestamp,
            batch_size=batch_size,
        )
        stages = {}
        async for stage in batch:
            stages[stage.meta.id] = stage

        self.statistics.set('poll_stages', len(stages))
        self.log.info(f"Loaded {len(stages)} stages")
        return stages

    async def _poll_releases(
        self,
        release_rules: Dict[str, ReleaseRule],
        timestamp: int,
        batch_size: int
    ) -> Dict[str, Release]:
        self.log.info('Polling releases')
        release_ids = [
            rule.target_ticket.spec.release_id
            for rule in release_rules.values()
            if not rule.updates_blocked() and rule.target_ticket and rule.target_ticket.spec.release_id
        ]
        batch = self.yp_client.select_releases(
            release_ids=release_ids,
            timestamp=timestamp,
            batch_size=batch_size,
        )
        releases = {release.meta.id: release async for release in batch}

        self.statistics.set('poll_releases', len(releases))
        self.log.info(f"Loaded {len(releases)} releases")
        return releases

    async def _poll_rules(self, timestamp: int, batch_size: int) -> Dict[str, ReleaseRule]:
        self.log.info('Polling release rules')

        batch = self.yp_client.select_release_rules(
            timestamp=timestamp,
            batch_size=batch_size,
        )
        rules = {}
        async for rule_id, rule in batch:
            rules[rule_id] = rule

        self.statistics.set('poll_rules', len(rules))
        self.log.info(f"Loaded {len(rules)} release rules")
        return rules

    def _populate_releases(self, release_rules: Dict[str, ReleaseRule], releases: Dict[str, Release]):
        for rule_id, rule in filter(
            lambda r: not r[1].updates_blocked() and r[1].target_ticket is not None,
            release_rules.items()
        ):
            release_id = rule.target_ticket.spec.release_id
            if release_id not in releases:
                self.log.debug(
                    "skipping release rule %r: no matching release %r for target ticket %r found",
                    rule_id,
                    release_id,
                    rule.target_ticket.meta.id,
                )
                rule.target_ticket = None
            else:
                rule.target_ticket.release = releases[release_id]

    async def _poll_tickets(
        self,
        timestamp: int,
        batch_size: int,
        stages: Dict[str, Stage],
        release_rules: Dict[str, ReleaseRule],
    ) -> None:
        self.log.info('Polling tickets')

        tickets = 0

        batch = self.yp_client.select_tickets(
            timestamp=timestamp,
            batch_size=batch_size,
        )
        async for ticket in batch:
            stage_id = ticket.meta.stage_id
            rule_id = ticket.spec.release_rule_id

            if stage_id not in stages:
                self.log.debug("skipped ticket %r: no matching stage %r found", ticket.meta.id, stage_id)
                continue
            elif rule_id not in release_rules:
                # self.log.debug("skipped ticket %r: no matching release rule %r found", ticket.meta.id, rule_id)
                continue

            ticket.stage = stages[stage_id]
            # self.log.info("adding ticket %r to release rule %r of stage %r", ticket.meta.id, rule_id, stage_id)
            release_rules[rule_id].add_ticket(ticket)

            tickets += 1

        self.statistics.set('poll_tickets', tickets)
        self.log.info(f"Loaded {tickets} tickets")

    def _cancel_outdated_tickets(self, actions: Actions, rules: Dict[str, ReleaseRule]) -> None:
        self.log.info("Canceling superseded tickets")

        for rule in rules.values():
            if not rule.updates_blocked() and not (rule.target_ticket and rule.target_ticket.is_on_hold()):
                actions.cancel_tickets(rule.outdated_tickets)

        self.statistics.push('cancel_tickets', len(actions.cancelled_tickets))
        self.log.info(f"{len(actions.cancelled_tickets)} tickets has been superseded and thus cancelled")

    def _mark_waiting_tickets(self, actions: Actions, rules: Dict[str, ReleaseRule]) -> None:
        self.log.info("Marking tickets as waiting")

        for rule in rules.values():
            actions.wait_for_tickets(rule.waiting_tickets, rule.target_ticket)

        self.statistics.push('wait_tickets', len(actions.waiting_tickets))
        self.log.info(f"{len(actions.waiting_tickets)} are set as waiting")

    def _schedule_tickets(
        self,
        actions: Actions,
        rules: Dict[str, ReleaseRule],
    ) -> None:
        for rule_id, rule in rules.items():
            if rule.updates_blocked() or not rule.target_ticket:
                self.log.debug("Rule %r schedule is blocked: skipping", rule_id)
                continue

            ticket = rule.target_ticket
            stage = ticket.stage
            if ticket.is_in_progress() and not rule.termination_allowed():
                # TODO check if the rule can be nevertheless safely applied to stage
                self.log.debug("Stage %r is in progress, rule %r is ignored", stage.meta.id, rule_id)
                continue

            if ticket.is_on_hold():
                self.log.debug("Ticket %r is on hold", ticket.meta.id)
                continue

            if ticket.is_committed():
                self.log.debug("Ticket %r is already committed", ticket.meta.id)
                continue

            self.log.debug("Committing ticket %r to stage %r", ticket.meta.id, stage.meta.id)
            actions.commit_ticket(ticket)

        self.statistics.push('commit_tickets', len(actions.committed_tickets))
        self.log.info(f"{len(actions.committed_tickets)} are scheduled for commit")

    async def iterate(self):
        self.log.info('Starting iteration')
        ts = await self.yp_client.generate_timestamp()

        try:
            # TODO use watch after initial poll
            stages = await self._poll_stages(ts, self.batch_size)
            release_rules = await self._poll_rules(ts, self.batch_size)
            await self._poll_tickets(
                timestamp=ts,
                batch_size=self.batch_size,
                stages=stages,
                release_rules=release_rules,
            )
            releases = await self._poll_releases(release_rules, ts, self.batch_size)
            self._populate_releases(release_rules, releases)

            actions = Actions(self.yp_client)

            self._cancel_outdated_tickets(actions, release_rules)
            self._mark_waiting_tickets(actions, release_rules)
            self._schedule_tickets(actions, release_rules)

            self.log.debug('Committing new tickets to stages')

            await actions.perform(ts, self.batch_size)
            self.statistics.push("successful_iterations", 1)
        except Exception as e:
            self.log.exception("iteration failed: %s", e)
            self.statistics.push("failed_iterations", 1)

        self.log.info('Iteration done')

    async def run(self):
        while True:
            try:
                await self.iterate()
            except Exception as e:
                self.log.exception('Iteration failed: %s', e)

            await asyncio.sleep(self.ITERATION_DELAY)
