import asyncio
import logging
from typing import List, Generator, TypeVar

import kikimr.public.sdk.python.persqueue.grpc_pq_streaming_api as pqlib
from kikimr.public.sdk.python.persqueue.auth import OAuthTokenCredentialsProvider
from kikimr.public.sdk.python.persqueue.errors import SessionFailureResult

from saas.library.python.logbroker.common import LogbrokerEndpoint

TMessage = TypeVar('TMessage')


class LogbrokerReader:
    def __init__(self, endpoint: LogbrokerEndpoint, topics: List[str], consumer: str, oauth_token: str) -> None:
        self._endpoint = endpoint
        self._topics = topics
        self._consumer = consumer
        self._oauth_token = oauth_token

    async def _init_api(self) -> pqlib.PQStreamingAPI:
        logging.info('Initializing PQStreamingAPI')

        api = pqlib.PQStreamingAPI(self._endpoint.host, self._endpoint.port)

        start_future = asyncio.wrap_future(api.start())
        result = await asyncio.wait_for(start_future, timeout=10)
        if not result:
            raise PQAPIInitException

        logging.info('PQStreamingAPI has been initialized')
        return api

    async def _init_consumer(self) -> pqlib.PQStreamingConsumer:
        api = await self._init_api()

        configurator = pqlib.ConsumerConfigurator(self._topics, self._consumer)
        credentials_provider = OAuthTokenCredentialsProvider(self._oauth_token)

        consumer = api.create_consumer(configurator, credentials_provider=credentials_provider)

        start_future = asyncio.wrap_future(consumer.start())
        start_result = await asyncio.wait_for(start_future, timeout=10)

        if isinstance(start_result, SessionFailureResult) or not start_result.HasField('init'):
            raise ConsumerInitException(f'Failed to start the consumer with error {start_result}')

        logging.info('The consumer has started with result %s', start_result)
        return consumer

    async def read(self, timeout=None) -> Generator[TMessage, None, None]:
        consumer = await self._init_consumer()

        while True:
            next_event = asyncio.wrap_future(consumer.next_event())
            result = await asyncio.wait_for(next_event, timeout=timeout)

            if result.type == pqlib.ConsumerMessageType.MSG_DATA:
                for batch in result.message.data.message_batch:
                    for message in batch.message:
                        yield message
                consumer.commit(result.message.data.cookie)
            elif result.type == pqlib.ConsumerMessageType.MSG_LOCK:
                result.ready_to_read()
                logging.info('Partition %s for assigned for topic %s,',
                             result.message.lock.partition,
                             result.message.lock.topic)
            elif result.type == pqlib.ConsumerMessageType.MSG_RELEASE:
                logging.info('Partition %s was revoked for topic %s',
                             result.message.release.partition,
                             result.message.release.topic)
            elif result.type == pqlib.ConsumerMessageType.MSG_ERROR:
                raise ConsumerRecreateRequiredException


class PQAPIInitException(Exception):
    ...


class ConsumerInitException(Exception):
    ...


class ConsumerRecreateRequiredException(Exception):
    ...
