import asyncio
from typing import AsyncIterator, ClassVar, Generic, Iterable, Optional, Tuple, TypeVar

from kikimr.public.sdk.python.persqueue.grpc_pq_streaming_api import (
    ConsumerConfigurator, ConsumerMessageType, PQStreamingConsumer
)

from mail.payments.payments.storage.logbroker.enums import LogbrokerInstallation
from mail.payments.payments.storage.logbroker.exceptions import (
    ConsumerMsgReleaseCannotCommitException, ConsumerNotRunningException, ConsumerUnexpectedEventType
)
from mail.payments.payments.storage.logbroker.factory import LogbrokerFactory
from mail.payments.payments.utils.runnable import Runnable

_T = TypeVar('_T')


class BaseConsumer(Runnable, Generic[_T]):
    INSTALLATION: ClassVar[LogbrokerInstallation]
    CONSUMER: ClassVar[str]
    TOPICS: ClassVar[Iterable[str]]

    def __init__(self, lb_factory: LogbrokerFactory):
        super().__init__()
        self._lb_factory = lb_factory
        self._consumer: Optional[PQStreamingConsumer] = None

        self._logger = lb_factory.logger.clone()
        self._logger.context_push(
            consumer=self.CONSUMER,
            topics=self.TOPICS,
        )

    @property
    def configurator(self) -> ConsumerConfigurator:
        return ConsumerConfigurator(
            topics=[topic.encode('ascii') for topic in self.TOPICS],
            client_id=self.CONSUMER.encode('ascii'),
            read_only_local=False,
            use_client_locks=True,
        )

    def _clear(self) -> None:
        self._consumer = None

    async def _run(self) -> None:
        lb_client = await self._lb_factory.get_client(self.INSTALLATION)
        self._consumer = await lb_client.create_consumer(self.configurator)

    async def _close(self) -> None:
        lb_client = await self._lb_factory.get_client(self.INSTALLATION)
        if self._consumer is not None:
            await lb_client.close_user(self._consumer)

    def commit(self, *cookies: int) -> None:
        if self._consumer is None:
            raise ConsumerNotRunningException
        self._consumer.commit(cookies)  # Fire and forget

    def parse_data(self, data: bytes) -> Iterable[_T]:
        raise NotImplementedError

    async def read_event(self) -> Tuple[Iterable[bytes], int]:
        """
        Looks for first MSG_DATA message, skipping the rest. Returns message data
        """

        if self._consumer is None:
            raise ConsumerNotRunningException

        while True:
            with self._logger:
                next_event_response = await asyncio.wrap_future(
                    self._consumer.next_event()
                )
                event_type = next_event_response.type
                self._logger.context_push(event_type=event_type)
                self._logger.info('Read next event')

                if event_type == ConsumerMessageType.MSG_DATA:
                    data = next_event_response.message.data
                    return (
                        message.data
                        for batch in data.message_batch
                        for message in batch.message
                    ), data.cookie

                elif event_type == ConsumerMessageType.MSG_LOCK:
                    next_event_response.ready_to_read()  # Fire and forget
                    lock = next_event_response.message.lock
                    self._logger.context_push(
                        topic=lock.topic,
                        partition=lock.partition,
                    )
                    self._logger.info('Locked. Getting next event')

                elif event_type == ConsumerMessageType.MSG_COMMIT:
                    self._logger.info('Skipping event')

                elif event_type == ConsumerMessageType.MSG_RELEASE:
                    if not next_event_response.message.release.can_commit:
                        raise ConsumerMsgReleaseCannotCommitException
                    self._logger.info('Skipping event')

                elif event_type == ConsumerMessageType.MSG_ERROR:
                    self._logger.context_push(
                        code=next_event_response.message.error.code,
                        description=next_event_response.message.error.description,
                    )
                    self._logger.info('Skipping error event')

                else:
                    self._logger.error('Unexpected event type')
                    raise ConsumerUnexpectedEventType

    async def read(self) -> AsyncIterator[_T]:
        while True:
            messages, cookie = await self.read_event()
            for message in messages:
                for data in self.parse_data(message):
                    yield data
            self.commit(cookie)
