import asyncio
import logging

from contextlib import suppress
from typing import List

from saas.library.python.deploy_manager_api import SaasService
from saas.library.python.logbroker import SaaSConfigurationManager
from saas.library.python.zk import ZKClient

from saas.tools.devops.lb_dc_checker.dc_checker.config import AppConfig
from saas.tools.devops.lb_dc_checker.dc_checker.checkers.service import ServiceChecker
from saas.tools.devops.lb_dc_checker.dc_checker.service import ServiceWrapper


class DatacenterChecker:
    _CACHE_UPDATE_SECONDS = 60*60*8
    _ERROR_SLEEP_SECONDS = 10

    def __init__(self, config: AppConfig, cluster_checker, zk_client: ZKClient) -> None:
        self._config: AppConfig = config
        self._cluster_checker = cluster_checker
        self._zk_client: ZKClient = zk_client

    @staticmethod
    async def _load_services() -> List[ServiceWrapper]:
        manager = SaaSConfigurationManager()
        ns_names = await manager.get_namespace_names()

        result = []
        for ns_name in ns_names:
            ns = manager.get_namespace(ns_name)

            async for service in ns.get_services():
                config = await service.get_config()
                if not config.logbroker_mirror:
                    continue

                slots_by_interval = await asyncio.get_event_loop().run_in_executor(
                    None,
                    lambda: SaasService(service.ctype, service.name).slots_by_interval
                )
                if not slots_by_interval:
                    logging.warning(
                        'Unable to get shard & slot info for service %s %s, skipping...',
                        service.name, service.ctype
                    )
                    continue

                shards = []
                replicas_cnt = 0
                for interval in slots_by_interval:
                    shards.append(interval['id'])
                    replicas_cnt = max(replicas_cnt, len([slot for slot in interval['slots'] if not slot['is_sd']]))

                service = ServiceWrapper(
                    service.name,
                    service.ctype,
                    list(shards),
                    replicas_cnt,
                    config,
                )
                logging.debug('Loaded service %s', service)
                result.append(service)

        return result

    async def run(self) -> None:
        while True:
            logging.info('Starting a new iteration of the dc checker')

            try:
                services = await self._load_services()
            except Exception:
                logging.exception('Unable to load services, retrying in %d seconds', self._ERROR_SLEEP_SECONDS)
                await asyncio.sleep(self._ERROR_SLEEP_SECONDS)
                continue
            else:
                logging.info('Loaded %d services', len(services))

            tasks = []
            for service in services:
                service_checker = ServiceChecker(self._config, service, self._cluster_checker, self._zk_client)
                tasks.append(asyncio.create_task(service_checker.run()))

            task_group = asyncio.gather(*tasks)

            try:
                await asyncio.wait_for(task_group, timeout=self._CACHE_UPDATE_SECONDS)
            except asyncio.TimeoutError:
                logging.debug('The task execution was interrupted by cache timeout = %ds', self._CACHE_UPDATE_SECONDS)
            finally:
                task_group.cancel()
                with suppress(asyncio.CancelledError):
                    await task_group
