import asyncio
import logging
import collections
from typing import Set, Dict, Iterable, Any, AsyncIterator, List

from infra.yp_drp import dru
from infra.yp_drp.http import StatisticsServer
from infra.yp_drp.yp_client import YpClient
from infra.yp_drp.podutil import Resource


async def idle_iter(
    iterator: Iterable[Any],
    idle_period: int = 100,
) -> AsyncIterator[Any]:
    for item_id, item in enumerate(iterator, 1):
        if item_id % idle_period == 0:
            await asyncio.sleep(0)
        yield item


class YpPoller:
    ITERATION_DELAY = 5.0

    def __init__(
        self,
        yp_client: YpClient,
        batch_size: int,
        username: str,
        statistics: StatisticsServer,
    ) -> None:
        super().__init__()
        self.batch_size = batch_size
        self.yp_client = yp_client
        self.username = username
        self.log = logging.getLogger('yp-poller')
        self.statistics = statistics

    async def _poll_resources(
        self,
        resources: Dict[str, Resource],
        pod_sets: Dict[str, Set[str]],
        timestamp: int,
    ) -> None:
        batch = self.yp_client.select_resources(
            timestamp=timestamp,
            batch_size=self.batch_size,
        )
        total = 0
        expired = 0
        async for resource in batch:
            r = resources[resource.meta.id]
            r.resource = resource
            r.expired = resource.meta.pod_set_id not in pod_sets
            r.managed = resource.deploy_engine == 'drp'  # TODO take managed value from config, for testing purposes
            total += 1
            if not r.expired:
                pod_sets[resource.meta.pod_set_id].add(resource.meta.id)
            else:
                expired += 1

        self.statistics.push('poll_resources', 1)
        self.statistics.set('polled_resources', total)
        self.statistics.set('expired_resources', expired)

    async def _poll_pods(
        self,
        resources: Dict[str, Resource],
        pod_sets: Dict[str, Set[str]],
        expired: List[Any],
        timestamp: int,
    ) -> List[Any]:
        pods = []
        batch = self.yp_client.select_pods(
            timestamp=timestamp,
            batch_size=self.batch_size,
        )
        async for pod in batch:
            for resource_id in pod_sets.get(pod.meta.pod_set_id, []):
                r = resources[resource_id]
                r.pods.append(pod)
            for resource in pod.spec_dynamic_resources:
                if resource.get('id') not in resources:
                    expired.append(pod)
                    break
            pods.append(pod)
        self.statistics.push('poll_pods', 1)
        self.statistics.set('polled_pods', len(pods))
        self.statistics.set('pods_with_expired', len(expired))
        return pods

    async def _poll_pod_sets(self, pod_sets: Dict[str, Set[str]], timestamp: int) -> None:
        batch = self.yp_client.list_pod_sets(
            timestamp=timestamp,
            batch_size=self.batch_size,
        )
        async for pod_set in batch:
            pod_sets[pod_set] = set()

        self.statistics.push('poll_pod_sets', 1)
        self.statistics.set('polled_pod_sets', len(pod_sets))

    async def _iterate(self):
        self.log.info('Starting iteration')
        updater = dru.DynamicResourceUpdater(self.yp_client)

        d = collections.defaultdict(Resource)
        pod_sets = {}
        pods_with_expired = []
        ts = await self.yp_client.generate_timestamp()
        self.log.info('Polling pod_sets')
        await self._poll_pod_sets(pod_sets, ts)
        self.log.info(f"Loaded {len(pod_sets)} pod_sets")
        self.log.info('Polling resources')
        await self._poll_resources(d, pod_sets, ts)
        self.log.info(f"Loaded {len(d)} resources")
        self.log.info('Polling pods')
        pods = await self._poll_pods(d, pod_sets, pods_with_expired, ts)
        self.log.info(f"Loaded {len(pods)} pods")

        self.log.info(f"Found {len(pods_with_expired)} pods with nonexistent resources, will cleanup")
        await updater.drop_old_resources(pods_with_expired, d, self.batch_size)

        self.log.debug(f'Updating resource statuses for {len(pods)} pods')
        async for pod in idle_iter(pods, idle_period=1000):
            await updater.sync_pod_status(pod)

        self.log.debug(f'Synchronizing resource statuses for {len(d)} resources from {len(pods)} pods')
        new_statuses = {}
        async for pod in idle_iter(pods, idle_period=1000):
            await updater.count_resource_statuses(pod, new_statuses)

        async for resource_id, r in idle_iter(filter(lambda r: r[1].managed, d.items())):
            await updater.sync_resource_status(
                resource_id=resource_id,
                resource_status=r.resource.status,
                statuses=new_statuses.get(resource_id, {}),
                update_window=r.resource.spec.get('update_window', 0),
                pods_total=len(r.pods),
            )

        self.log.debug(f'Updating allocations for {len(d)} resources')
        async for resource_id, r in idle_iter(filter(lambda r: r[1].managed, d.items())):
            try:
                await updater.update_allocations(resource_id, r.resource.spec, r.pods, r.expired)
            except Exception as e:
                self.log.exception('Could not process DR %r: %s', resource_id, e)

        self.statistics.set('resource_removals', len(updater.scheduled_resource_removals))
        self.statistics.set('resource_status_updates', len(updater.scheduled_resource_status_updates))
        self.statistics.set('pod_status_updates', len(updater.scheduled_pod_status_updates))
        self.statistics.set('pod_spec_updates', len(updater.scheduled_pod_spec_updates))  # can contain pod dups
        await updater.commit(self.batch_size)

        self.log.info('Iteration done')

    async def run(self):
        while True:
            try:
                await self._iterate()
            except Exception as e:
                self.log.exception('Iteration failed: %s', e)
                self.statistics.push('failed_iterations', 1)
            else:
                self.statistics.push('successful_iterations', 1)
            await asyncio.sleep(self.ITERATION_DELAY)
