import asyncio
import logging
from itertools import chain, islice
from typing import Set

from mail.python.theatre.roles import Cron
from mail.python.theatre.stages.bucket_holder.interactions.db.buckets import BucketsQueries
from mail.python.theatre.stages.bucket_holder.settings.rebalancer import RebalancerSettings
from mail.python.theatre.stages.bucket_holder.typing import BucketId

from .bucket_holder import BucketHolder
from .buckets_processor import BucketsProcessor

log = logging.getLogger(__name__)


class Rebalancer(Cron):
    """Tracks count of running workers and adjust processing and acquired buckets accordingly"""

    def __init__(self,
                 queries: BucketsQueries,
                 holder: BucketHolder,
                 processor: BucketsProcessor,
                 settings: RebalancerSettings):
        self._settings = settings
        self._queries = queries
        self._holder = holder
        self._processor = processor
        super().__init__(job=self.rebalance, **settings.cron.as_dict())

    @property
    def ack_buckets(self) -> Set[BucketId]:
        return set(self._holder.ack_bucket_ids)

    @property
    def proc_buckets(self) -> Set[BucketId]:
        return set(self._processor.processors.keys())

    async def rebalance(self):
        worker_count = await self._queries.active_worker_count(deadline_secs=self._holder.deadline_secs)
        bucket_count = await self._queries.bucket_count()
        target_buckets_cnt = bucket_count // (worker_count + 1) + 1
        self._holder.target_bucket_cnt = target_buckets_cnt
        log.debug(f'Curr bucket count: {len(self.proc_buckets)}, target bucket count: {target_buckets_cnt}')

        target_buckets = set(islice(
            chain(
                # Buckets that are both acknowledged and already in process
                self.ack_buckets & self.proc_buckets,
                # Buckets that are acknowledged but not in process at the moment
                self.ack_buckets - self.proc_buckets,
            ),
            target_buckets_cnt
        ))

        non_ack_buckets = self.proc_buckets - self.ack_buckets
        if non_ack_buckets:
            log.info('Going to stop non-ack buckets asap: %s', non_ack_buckets)
            await asyncio.wait([
                self._processor.detach(bucket_id, wait=False)
                for bucket_id in non_ack_buckets
            ])
            log.info('Done')

        non_target_buckets = self.proc_buckets - non_ack_buckets - target_buckets
        if non_target_buckets:
            log.info('Going to stop non-target buckets gracefully: %s', non_target_buckets)
            await asyncio.wait([
                self._processor.detach(bucket_id, wait=True)
                for bucket_id in non_target_buckets
            ])
            log.info('Done')

        self._holder.proc_bucket_ids = list(self.proc_buckets)

        buckets_to_start = target_buckets - self.proc_buckets
        if buckets_to_start:
            log.info(f'Going to start buckets {buckets_to_start}')
            await asyncio.wait([
                self._processor.attach(bucket_id)
                for bucket_id in buckets_to_start
            ])
        self._holder.proc_bucket_ids = list(self.proc_buckets)
