import asyncio
import functools
import logging
import math
import os
import time

from typing import Tuple, List, Iterable, Dict, Sequence

from saas.library.python.deploy_manager_api import SaasService
from saas.tools.devops.fban.config import FBanConfig, ServiceSettings
from saas.tools.devops.fban.fusion import FusionInstance
from saas.tools.devops.fban.utils import setup_logging
from saas.tools.ssm.modules.nanny_yp_api import SaaSNannyYpWorkflow


class ServiceChecker:
    _SAAS_LOCATIONS = ['MAN', 'SAS', 'VLA']

    def __init__(self, config: FBanConfig, service_settings: ServiceSettings) -> None:
        self._config: FBanConfig = config
        self._service: SaasService = SaasService(service_settings.ctype, service_settings.service)
        self._service_settings: ServiceSettings = service_settings
        self._yp_client: SaaSNannyYpWorkflow = SaaSNannyYpWorkflow(self._service_settings.nanny_name)

    @staticmethod
    def _split_instance(instance: str) -> Tuple[str, int]:
        host, port = instance.split(':')[0], int(instance.split(':')[1])
        return host, port

    def _dump_fusions(self, shard_to_fusions: Dict[str, List[FusionInstance]]) -> None:
        logging.debug(
            '[%s %s] Fusions:\n%s',
            self._service.name, self._service.ctype,
            '\n'.join([
                f'Shard {shard}:\n' + '\n'.join(map(lambda f: f'{f} {f.info}', fusions))
                for shard, fusions in shard_to_fusions.items()
            ])
        )

    def _print_action_results(
            self,
            action: str,
            changed: List[FusionInstance],
            total_processed: int
    ) -> None:
        if changed:
            logging.warning(
                '[%s %s] These instances were candidates for %s:\n'
                '%s\n'
                'Successfully %sned: %d instances',

                self._service.name, self._service.ctype, action,
                '\n'.join([f'{fusion}: {fusion.info.action_result}' for fusion in changed]),
                action, total_processed
            )

    async def _unban_ready_instances(self, fusions: Iterable[FusionInstance]) -> int:
        changed = []
        total_processed = 0

        for fusion in fusions:
            if fusion.info.state == FusionInstance.State.ALIVE \
                    and not fusion.info.search_enabled \
                    and fusion.info.receive_lag is not None \
                    and fusion.info.receive_lag < self._service_settings.unban_seconds:

                logging.debug('+ UNBAN attempt: %s', fusion)

                if self._config.dry_run or await fusion.unban():
                    fusion.info.search_enabled = True
                    fusion.info.action_result = FusionInstance.ActionResult.UNBAN_OK

                    total_processed += 1
                else:
                    fusion.info.action_result = FusionInstance.ActionResult.UNBAN_FAIL

                changed.append(fusion)

        self._print_action_results('unban', changed, total_processed)
        return total_processed

    async def _ban_lagging_instances(self, fusions: Iterable[FusionInstance], remaining_slots_cnt: int) -> int:
        sorted_fusions = sorted(
            filter(lambda f: f.info.search_enabled and f.info.receive_lag is not None, fusions),
            key=lambda f: f.info.receive_lag,
            reverse=True
        )

        changed = []
        total_processed = 0

        for fusion in sorted_fusions:
            if fusion.info.receive_lag < self._service_settings.ban_seconds or total_processed >= remaining_slots_cnt:
                break

            logging.debug('+ BAN attempt: %s', fusion)

            if self._config.dry_run or await fusion.ban():
                fusion.info.search_enabled = False
                fusion.info.action_result = FusionInstance.ActionResult.BAN_OK

                total_processed += 1
            else:
                fusion.info.action_result = FusionInstance.ActionResult.BAN_FAIL

            changed.append(fusion)

        self._print_action_results('ban', changed, total_processed)
        return total_processed

    @staticmethod
    def _get_banned_cnt(fusions: Sequence[FusionInstance]) -> int:
        return len([fusion for fusion in fusions if not fusion.info.search_enabled])

    def _get_max_banned_cnt(self, fusions: Sequence[FusionInstance]) -> int:
        if self._service_settings.perc:
            return int(math.ceil(len(fusions) * self._service_settings.perc / 100.0))
        return self._service_settings.cnt

    # NOTE(vbushev): advanced optimizations can be made here if we parallelize instance processing
    async def _make_decisions(self, shard_to_fusions: Dict[str, List[FusionInstance]], locations: List[str]) -> None:
        self._dump_fusions(shard_to_fusions)

        fusions = functools.reduce(lambda a, b: a + b, shard_to_fusions.values())
        await self._unban_ready_instances(fusions)

        for shard, fusions in shard_to_fusions.items():
            banned_cnt = self._get_banned_cnt(fusions)
            max_banned_cnt = self._get_max_banned_cnt(fusions)

            if banned_cnt >= max_banned_cnt:
                logging.warning(
                    '[%s %s] Max limit for shard %s reached, banned/unavailable: %d, max to ban: %d [%s]',
                    self._service.name, self._service.ctype,
                    shard, banned_cnt, max_banned_cnt,
                    ','.join(locations)
                )
                continue

            remaining_slots = max_banned_cnt - banned_cnt
            curr_banned_cnt = await self._ban_lagging_instances(fusions, remaining_slots)

            logging.info(
                '[%s %s] Total/current/remaining banned slots: %d/%d/%d in shard %s [%s]',
                self._service.name, self._service.ctype,
                max_banned_cnt, banned_cnt + curr_banned_cnt, remaining_slots - curr_banned_cnt, shard,
                ','.join(locations)
            )

        self._dump_fusions(shard_to_fusions)

    async def _process_instance(self, instance: Tuple[str, str]) -> Tuple[str, FusionInstance]:
        shard, instance = instance
        host, port = self._split_instance(instance)

        fusion = FusionInstance(self._service, self._service_settings, host, port)
        await fusion.load_fusion_info()

        return shard, fusion

    async def _process_locations(self, locations: List[str]) -> None:
        logging.info(
            'Started checks for %s %s in %s',
            self._service.name, self._service.ctype, ','.join(locations)
        )

        # NOTE(vbushev):
        #  according to flame graphs we spend much time to read request data & parse json,
        #  we can use grpc API if needed
        instances = self._yp_client.get_pods_by_shards(locations)  # blocking call
        if not instances:
            raise RuntimeError(f'No instances resolved for {self._service.name} {self._service.ctype}')

        fusions = await asyncio.gather(*[self._process_instance(instance) for instance in instances])

        shard_to_fusions = {}
        for shard, fusion in fusions:
            shard_to_fusions.setdefault(shard, []).append(fusion)

        await self._make_decisions(shard_to_fusions, locations)

    async def _process_iteration(self) -> None:
        logging.info('Check iteration for %s %s...', self._service.name, self._service.ctype)

        if self._service.per_dc_search:  # blocking call
            await asyncio.gather(*[self._process_locations([location]) for location in self._SAAS_LOCATIONS])
        else:
            await self._process_locations(self._SAAS_LOCATIONS)

        logging.warning('Check iteration for %s %s... OK', self._service.name, self._service.ctype)

    async def run(self) -> None:
        logging.debug('Starting fban loop for %s %s...', self._service.name, self._service.ctype)

        while True:
            start_ts = int(time.time())

            try:
                await self._process_iteration()
            except Exception:
                logging.exception('Unable to process iteration for %s %s', self._service.name, self._service.ctype)
            finally:
                finish_ts = int(time.time())

            diff = finish_ts - start_ts
            if diff > self._service_settings.check_interval:
                logging.warning(
                    '[%s %s] Attention! '
                    'FBan is unable to process the iteration within %d seconds, the execution took %d seconds',
                    self._service.name, self._service.ctype,
                    self._service_settings.check_interval, diff
                )
            else:
                logging.debug(
                    '[%s %s] The execution took %d / %d seconds',
                    self._service.name, self._service.ctype,
                    diff, self._service_settings.check_interval
                )

            sleep_time = max(0, self._service_settings.check_interval - diff)
            await asyncio.sleep(sleep_time)


async def _run_service_checks(config: FBanConfig, service_settings: ServiceSettings) -> None:
    checker = ServiceChecker(config, service_settings)
    setup_logging(config, service_settings=service_settings)

    logging.debug('[%s %s] Starting service checker...', service_settings.service, service_settings.ctype)

    try:
        await checker.run()
    except Exception:
        logging.exception(
            '[%s %s] An error occurred during checker.run, FBAN is stopped',
            service_settings.service, service_settings.ctype,
        )
        os._exit(1)  # causes BrokenProcessPool and stops the parent as well


def run_service_checks(args: Tuple[FBanConfig, ServiceSettings]) -> None:
    config, service_settings = args
    asyncio.run(_run_service_checks(config, service_settings))
