import asyncio
import contextlib
from logging import Logger
from typing import Any, Dict, List, Optional, Union

import kikimr.public.sdk.python.persqueue.auth as auth
import kikimr.public.sdk.python.persqueue.grpc_pq_streaming_api as pqlib
from kikimr.public.sdk.python.persqueue.errors import SessionFailureResult
from kikimr.yndx.api.protos.persqueue_pb2 import ReadResponse


class LogbrokerException(Exception):
    pass


class Logbroker:
    """ Клиент Logbroker

    Нужен для инициализации API, и для последующего создания читателя или писателя

    .. code::

        async with Logbroker("localhost", 2135) as api:
            async with Consumer(api, ...) as consumer:
                consumer.messages()
    """
    _host: str
    _port: int
    _api: Optional[pqlib.PQStreamingAPI]

    def __init__(self, host: str, port: int):
        self._host = host
        self._port = port
        self._api = None

    async def __aenter__(self) -> pqlib.PQStreamingAPI:
        self._api = pqlib.PQStreamingAPI(self._host, self._port)
        if not await asyncio.wrap_future(self._api.start()):
            self._api = None
            raise LogbrokerException("Client not started")
        return self._api

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        if self._api:
            self._api.stop()
            self._api = None


class Consumer:
    """Читатель сообщений из Logbroker

    Позволяет отдельно читать и подтверждать (commit) сообщения.
    Работает в двух режимах:

    1. Ручное подтверждение чтения

    .. code::
        credentials_provider = auth.TVMCredentialsProvider(
            tvm_client=TVMClient(TVM_CONFIG),
            destination_client_id=settings.LOGBROKER_LBKX_TVM_ID
        )
        async with Consumer(api, credentials_provider, consumer="consumer", topic="topic", logger=logger) as consumer:
            messages1 = consumer.messages()
            # do some stuff with messages
            messages2 = consumer.messages()
            consumer.commit()
            # all messages committed here

    2.  Подтверждение после выхода из менеджера контекста

    .. code::

        async with Consumer(api, credentials_provider, consumer="consumer", topic="topic", logger=logger) as consumer:
            async with consumer.begin()
                messages1 = consumer.messages()
                # do some stuff with messages
                messages2 = consumer.messages()
            # all messages committed here
    """
    _api: pqlib.PQStreamingAPI
    _consumer: str
    _topic: str
    _consumer_client: Optional[pqlib.PQStreamingConsumer]
    _credentials_provider: Optional[auth.CredentialsProvider]
    _logger: Logger
    _options: Dict[str, Any]
    _cookies: List[str]

    def __init__(
        self,
        api: pqlib.PQStreamingAPI,
        credentials_provider: auth.CredentialsProvider,
        *,
        consumer: str,
        topic: str,
        logger: Logger,
        **kwargs: Any
    ):
        self._api = api
        self._credentials_provider = credentials_provider
        self._consumer = consumer
        self._topic = topic
        self._consumer_client = None
        self._logger = logger
        self._cookies = []
        self._options = kwargs

    async def __aenter__(self):
        await self._create()
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        if self._consumer_client:
            self._consumer_client.stop()
            self._consumer_client = None

    async def messages(self, only_data: bool = True) -> Union[List[bytes], List[ReadResponse.BatchedData.MessageData]]:
        while True:
            assert self._consumer_client is not None
            next_event = await asyncio.wrap_future(self._consumer_client.next_event())
            event_type = next_event.type
            if event_type == pqlib.ConsumerMessageType.MSG_DATA:
                data = next_event.message.data
                messages = [
                    message.data if only_data else message
                    for batch in data.message_batch for message in batch.message
                ]
                self._cookies.append(data.cookie)
                return messages
            elif event_type == pqlib.ConsumerMessageType.MSG_LOCK:
                next_event.ready_to_read()
                lock = next_event.message.lock
                self._logger.debug('Locked. Getting next event. [topic=%s; partition=%s]', lock.topic, lock.partition)
            elif event_type == pqlib.ConsumerMessageType.MSG_COMMIT:
                self._logger.debug('Commit. Getting next event')
            elif event_type == pqlib.ConsumerMessageType.MSG_RELEASE:
                self._logger.debug('Release. Getting next event')
            elif event_type == pqlib.ConsumerMessageType.MSG_ERROR:
                self._logger.error("Got MSG_ERROR '%s'. Try to recreate consumer", next_event.message.error.description)
                await self._create()
            else:
                raise Exception(f'Unexpected event type: {event_type}')

    async def commit(self):
        self._consumer_client.commit(self._cookies)
        self._cookies = []

    @contextlib.asynccontextmanager
    async def begin(self):
        yield self
        await self.commit()

    async def _create(self):
        configurator = pqlib.ConsumerConfigurator(
            self._topic,
            self._consumer,
            **self._options,
        )
        self._consumer_client = self._api.create_consumer(configurator, credentials_provider=self._credentials_provider)
        start_future = await asyncio.wrap_future(self._consumer_client.start())
        if isinstance(start_future, SessionFailureResult):
            self._consumer_client = None
            raise LogbrokerException(f'Session error: {start_future.reason}; {start_future.description}')


class Producer:
    """Писатель в Logbroker

     .. code::
        credentials_provider = auth.TVMCredentialsProvider(
            tvm_client=TVMClient(TVM_CONFIG),
            destination_client_id=settings.LOGBROKER_LBKX_TVM_ID
        )

        async with Producer(api, credentials_provider, topic="topic", source_id="source_id", logger=logger) as producer:
            await producer.write('hello world')
    """
    _api: pqlib.PQStreamingAPI
    _credentials_provider: auth.CredentialsProvider
    _topic: str
    _source_id: bytes
    _producer_client: Optional[pqlib.PQStreamingProducer]
    _logger: Logger
    _options: Dict[str, Any]
    _max_seq_no: int

    def __init__(
        self,
        api: pqlib.PQStreamingAPI,
        credentials_provider: auth.CredentialsProvider,
        *,
        topic: str,
        source_id: bytes,
        logger: Logger,
        **kwargs: Any
    ) -> None:
        self._api = api
        self._credentials_provider = credentials_provider
        self._source_id = source_id
        self._topic = topic
        self._logger = logger
        self._options = kwargs
        self._producer_client = None

    async def __aenter__(self):
        await self._create()
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        if self._producer_client:
            self._producer_client.stop()
            self._producer_client = None

    async def _create(self):
        configurator = pqlib.ProducerConfigurator(
            self._topic,
            self._source_id,
            **self._options,
        )
        self._producer_client = self._api.create_retrying_producer(
            configurator,
            credentials_provider=self._credentials_provider
        )
        start_result = await asyncio.wrap_future(self._producer_client.start())
        if isinstance(start_result, SessionFailureResult):
            self._producer_client = None
            raise LogbrokerException(f"Producer failed to start with error {start_result}")

        self._max_seq_no = start_result.init.max_seq_no

    async def write(self, data: bytes) -> None:
        assert self._producer_client
        self._max_seq_no += 1
        result = await asyncio.wrap_future(self._producer_client.write(self._max_seq_no, data))
        if not result.HasField("ack"):
            raise LogbrokerException(f"Message write failed with error {result}")
