import asyncio
import logging
import time

from typing import Dict, Optional, Set

from google.protobuf import text_format
from kazoo.client import NoNodeError
from saas.library.python.zk import ZKClient

from saas.tools.devops.lb_dc_checker.dc_checker.checkers.common import DataWriteException, OffsetCheckResult
from saas.tools.devops.lb_dc_checker.dc_checker.checkers.topic import TopicChecker
from saas.tools.devops.lb_dc_checker.dc_checker.config import AppConfig
from saas.tools.devops.lb_dc_checker.dc_checker.service import ServiceWrapper
from saas.tools.devops.lb_dc_checker.proto.result_pb2 import TDcCheckerResult


class ServiceChecker:

    _SAAS_DATACENTERS = ('man', 'sas', 'vla')
    _CHECK_INTERVAL_SECONDS = 30

    def __init__(self, config: AppConfig, service: ServiceWrapper, cluster_checker, zk_client: ZKClient) -> None:
        self._service: ServiceWrapper = service
        self._cluster_checker = cluster_checker

        self._zk_path: str = '{zk_path}/{ns}/{service}/{ctype}'.format(
            zk_path=config.zk_base_path,
            ns=self._service.lb_config.ns_name,
            service=self._service.name,
            ctype=self._service.ctype
        )
        self._zk_client: ZKClient = zk_client

    async def _write_zk_data(self, data: bytes, ensure_path: bool = False) -> None:
        try:
            if ensure_path:
                await self._zk_client.ensure_path(self._zk_path)

            await self._zk_client.set(self._zk_path, data)
        except NoNodeError:
            await self._write_zk_data(data, ensure_path=True)
        except asyncio.CancelledError:
            raise
        except Exception:
            logging.exception(
                'Unable to write data to zk for %s %s, data = %s',
                self._service.name,
                self._service.ctype,
                data
            )
            raise DataWriteException

    # TODO(vbushev): identify the oldest failed dc?
    def _get_unavailable_dc(
        self,
        dc_to_availability_info: Dict[str, Dict[str, bool]],
        check_key: str
    ) -> Optional[str]:
        unavailable_dcs = []
        for dc in self._SAAS_DATACENTERS:
            if not dc_to_availability_info[dc][check_key]:
                unavailable_dcs.append(dc)

        if len(unavailable_dcs) == len(self._SAAS_DATACENTERS):
            logging.warning(
                'All data centers are unavailable for %s %s according to %s, ignoring this check result',
                self._service.name, self._service.ctype, check_key
            )
            return None

        if not unavailable_dcs:
            return None

        dc = unavailable_dcs[0]
        logging.info(
            'Data center %s is unavailable for %s %s according to %s',
            dc, self._service.name, self._service.ctype, check_key
        )
        return dc

    def _update_unavailable_dcs(
        self,
        dc_to_availability_info: Dict[str, Dict[str, bool]],
        check_key: str,
        unavailable_dcs: Set[str]
    ) -> None:
        dc = self._get_unavailable_dc(dc_to_availability_info, check_key)
        if not dc:
            return
        unavailable_dcs.add(dc)

    async def _save_dc_availability_info(self, dc_to_availability_info: Dict[str, Dict[str, bool]]) -> None:
        unavailable_dcs = set()
        self._update_unavailable_dcs(dc_to_availability_info, 'cluster_check', unavailable_dcs)

        if not unavailable_dcs:
            self._update_unavailable_dcs(dc_to_availability_info, 'offset_check', unavailable_dcs)

        result = TDcCheckerResult(
            AvailableDcs=list(set(self._SAAS_DATACENTERS).difference(unavailable_dcs)),
            UnavailableDcs=list(unavailable_dcs),
            UpdatedTs=int(time.time())
        )
        data = text_format.MessageToString(result).encode()
        logging.debug('Check result for %s %s = %s', self._service.name, self._service.ctype, data)

        await self._write_zk_data(data)

    async def _process_current_iteration(self, dc_to_offset_check_result: Dict[str, OffsetCheckResult]) -> None:
        coro_arr = []
        for dc in self._SAAS_DATACENTERS:
            last_offset_check_result = dc_to_offset_check_result.get(dc)
            coro_arr.append(TopicChecker(self._service, dc).check_topics(last_check=last_offset_check_result))

        try:
            offset_check_results = await asyncio.gather(*coro_arr)
        except Exception:
            offset_check_results = [None] * len(self._SAAS_DATACENTERS)

            logging.exception(
                'Unable to make offset checks for service %s %s, only cluster checks will be considered',
                self._service.name,
                self._service.ctype
            )

        dc_to_availability_info = {}
        for dc, offset_check_result in zip(self._SAAS_DATACENTERS, offset_check_results):
            cluster_check_result = self._cluster_checker.get_last_check_result(
                dc=dc,
                endpoint=self._service.lb_config.logbroker_mirror
            )

            if offset_check_result:
                last_offset_check_result = dc_to_offset_check_result.get(dc)
                offset_check_passed = offset_check_result.is_passed(last_check=last_offset_check_result)
            else:
                offset_check_passed = True

            dc_to_availability_info[dc] = {
                'offset_check': offset_check_passed,
                'cluster_check': cluster_check_result.available if cluster_check_result else True
            }

            dc_to_offset_check_result[dc] = offset_check_result

        await self._save_dc_availability_info(dc_to_availability_info)

    async def run(self) -> None:
        logging.info('Running service watcher for service %s %s', self._service.name, self._service.ctype)

        dc_to_check_result = {}
        while True:
            try:
                await self._process_current_iteration(dc_to_check_result)
            except DataWriteException:
                raise
            except asyncio.CancelledError:
                raise
            except Exception:
                logging.exception(
                    'An error occurred during the loop iteration of service %s %s, the result will be ignored',
                    self._service.name,
                    self._service.ctype
                )

            await asyncio.sleep(self._CHECK_INTERVAL_SECONDS)
