import logging

from typing import Tuple, Optional

from saas.library.python.logbroker import ReadInfoClient
from saas.library.python.logbroker.internal_api.client import LogbrokerTopicInfo

from saas.tools.devops.lb_dc_checker.dc_checker.checkers.common import (
    ShardOffsetCheckResult, OffsetCheckResult, LogbrokerFactory
)
from saas.tools.devops.lb_dc_checker.dc_checker.service import ServiceWrapper


class TopicChecker:

    def __init__(self, service: ServiceWrapper, dc: str) -> None:
        self._service: ServiceWrapper = service
        self._dc: str = dc

        endpoint = self._service.lb_config.logbroker
        self._client: ReadInfoClient = LogbrokerFactory.create_read_info_client(endpoint)

        mirror_endpoint = self._service.lb_config.logbroker_mirror
        dc_endpoint = LogbrokerFactory.create_dc_endpoint(dc, mirror_endpoint.host, mirror_endpoint.port)
        self._mirror_client: ReadInfoClient = LogbrokerFactory.create_read_info_client(dc_endpoint)

    async def _get_mirror_topic_info_by_consumer(
        self,
        topic_path: str,
        consumer_path: str
    ) -> Optional[LogbrokerTopicInfo]:
        result = await self._mirror_client.get_read_info(
            topic_paths=[topic_path],
            consumer_path=consumer_path,
            only_original=True
        )
        return result[0] if result[0].has_read_session() else None

    async def _get_consumer_and_mirror_topic_info(self, topic_path: str) -> Optional[Tuple[str, LogbrokerTopicInfo]]:
        for consumer_id in range(self._service.replicas_cnt):
            consumer_path = f'{self._service.lb_config.mirror_consumers_path}/{consumer_id}'
            result = await self._mirror_client.get_read_info(
                topic_paths=[topic_path],
                consumer_path=consumer_path,
                only_original=True
            )
            if result[0].has_read_session():
                return consumer_path, result[0]

        logging.info(
            'No active read sessions found for topic %s, service=%s %s, dc = %s',
            topic_path,
            self._service.name,
            self._service.ctype,
            self._dc
        )
        return None

    async def _get_mirror_topic_info(
        self,
        shard: str,
        last_check: Optional[ShardOffsetCheckResult] = None
    ) -> Optional[Tuple[str, LogbrokerTopicInfo]]:
        topic_path = f'{self._service.lb_config.mirror_topics_path}/shard-{shard}'

        if last_check:
            mirror_topic_info = await self._get_mirror_topic_info_by_consumer(
                topic_path,
                last_check.mirror_active_consumer
            )
            if mirror_topic_info:
                return last_check.mirror_active_consumer, mirror_topic_info

        result = await self._get_consumer_and_mirror_topic_info(topic_path)
        if result:
            return result

        return None

    async def _get_origin_topic_info(self, shard: str) -> LogbrokerTopicInfo:
        topic_path = f'{self._service.lb_config.topics_path}/shard-{shard}'
        consumer_path = f'{self._service.lb_config.consumers_path}/{self._dc}'

        result = await self._client.get_read_info(
            topic_paths=[topic_path],
            consumer_path=consumer_path,
            only_original=True
        )
        return result[0]

    async def check_topics(self, last_check: Optional[OffsetCheckResult] = None) -> OffsetCheckResult:
        logging.info('Running topic checks for %s %s, dc=%s', self._service.name, self._service.ctype, self._dc)

        shard_to_check_result = {}
        for shard in self._service.shards:
            shard_last_check = last_check.shard_to_check_result.get(shard) if last_check else None
            mirror_result = await self._get_mirror_topic_info(shard, shard_last_check)
            if not mirror_result:
                logging.info(
                    'Unable to get topic info for %s %s, shard = %s, dc = %s, skipping...',
                    self._service.name,
                    self._service.ctype,
                    shard,
                    self._dc
                )
                continue

            info = await self._get_origin_topic_info(shard)
            mirror_consumer, mirror_info = mirror_result

            shard_to_check_result[shard] = ShardOffsetCheckResult(info, mirror_consumer, mirror_info)

        logging.info('Topic checks has finished for %s %s, dc=%s', self._service.name, self._service.ctype, self._dc)
        return OffsetCheckResult(shard_to_check_result)
