import functools
import logging

from dataclasses import dataclass
from typing import Optional, Dict

from saas.library.python.logbroker import ClusterInfoClient, ReadInfoClient, LogbrokerEndpoint
from saas.library.python.logbroker.internal_api.client import LogbrokerTopicInfo


@dataclass(frozen=True)
class ShardOffsetCheckResult:
    origin_topic_info: LogbrokerTopicInfo

    mirror_active_consumer: str
    mirror_topic_info: LogbrokerTopicInfo

    # TODO(vbushev): consider read_time_lag_ms here?
    def is_passed(self, last_check: Optional['ShardOffsetCheckResult'] = None) -> bool:
        if not last_check:
            return True

        partitions_cnt = len(last_check.origin_topic_info.partitions)

        for partition_idx in range(partitions_cnt):
            origin_last_offset = last_check.origin_topic_info.partitions[partition_idx].read_offset
            mirror_offset = self.mirror_topic_info.partitions[partition_idx].read_offset

            result = mirror_offset >= origin_last_offset
            logging.debug(
                'Compared topic offsets for partition %d, result = %s; '
                'Previous topics info: (%s [origin] %s [mirror]); '
                'Current topics info: (%s [origin] %s [mirror]);',
                partition_idx,
                result,
                last_check.origin_topic_info,
                last_check.mirror_topic_info,
                self.origin_topic_info,
                self.mirror_topic_info
            )
            if not result:
                return False
        return True


@dataclass(frozen=True)
class OffsetCheckResult:
    shard_to_check_result: Dict[str, ShardOffsetCheckResult]

    def is_passed(self, last_check: Optional['OffsetCheckResult'] = None) -> bool:
        for shard, shard_check in self.shard_to_check_result.items():
            shard_last_check = last_check.shard_to_check_result.get(shard) if last_check else None
            if not shard_check.is_passed(shard_last_check):
                return False
        return True


class LogbrokerFactory:
    @classmethod
    @functools.lru_cache()
    def create_cluster_info_client(cls, endpoint: LogbrokerEndpoint) -> ClusterInfoClient:
        return ClusterInfoClient(endpoint=endpoint)

    @classmethod
    @functools.lru_cache()
    def create_read_info_client(cls, endpoint: LogbrokerEndpoint) -> ReadInfoClient:
        return ReadInfoClient(endpoint=endpoint)

    @classmethod
    @functools.lru_cache()
    def create_dc_endpoint(cls, dc: str, host: str, port: int) -> LogbrokerEndpoint:
        return LogbrokerEndpoint(f'{dc}.{host}', port)


class DataWriteException(Exception):
    ...
