import logging
from dataclasses import dataclass
from typing import List, Optional

import kikimr.public.sdk.python.persqueue.grpc_pq_streaming_api as pqlib
from concurrent.futures import TimeoutError

from src.common.logbroker.client import get_logbroker_client
from src.config import settings

logger = logging.getLogger(__name__)


@dataclass
class MessagesWithCookie:
    messages_data: List[str]
    cookie: int


class LogbrokerConsumer:
    def __init__(self, topic_name: str, consumer_name: str):
        self.topic_name = topic_name
        self.consumer_name = consumer_name

    @classmethod
    async def create(cls, topic_name: str, consumer_name: str):
        self = LogbrokerConsumer(topic_name, consumer_name)
        self._logbroker_client = await get_logbroker_client()
        self.consumer = None
        return self

    async def _create_consumer(self) -> None:
        self.consumer: pqlib.PQStreamingConsumer = await self._logbroker_client.get_consumer(self.consumer_name, self.topic_name)
        start_future = self.consumer.start()
        start_result = start_future.result(timeout=settings.LOGBROKER_TIMEOUT)

        if not start_result.HasField('init'):
            logger.error('Consumer failed to start with error %s', start_result)
            return

    async def read(self) -> Optional[MessagesWithCookie]:
        await self._create_consumer()
        try:
            result: pqlib.ConsumerMessage = self.consumer.next_event().result(timeout=settings.LOGBROKER_TIMEOUT)
        except TimeoutError:
            logger.info('Does not have any unread messages. Stop reading.')
            return None
        if result.type == pqlib.ConsumerMessageType.MSG_COMMIT:
            logger.info('Message committed successfully: %s', result.message)
        elif result.type == pqlib.ConsumerMessageType.MSG_LOCK:
            result.ready_to_read()
            logger.info(
                'Got partition assignment: topic %s, partition %s',
                result.message.lock.topic,
                result.message.lock.partition,
            )
        elif result.type == pqlib.ConsumerMessageType.MSG_RELEASE:
            logger.info(
                'Partition revoked. Topic %s, partition %s',
                result.message.release.topic,
                result.message.release.partition,
            )
        elif result.type == pqlib.ConsumerMessageType.MSG_DATA:
            extracted_messages = await self._process_single_batch(result.message)
            cookie: int = result.message.data.cookie
            logger.info('Result cookie %s', cookie)
            return MessagesWithCookie(messages_data=extracted_messages, cookie=cookie)
        return None

    async def stop(self) -> None:
        self.consumer.stop()
        await self._logbroker_client.stop()

    async def commit(self, cookie: str) -> None:
        self.consumer.commit(cookie)

    async def _process_single_batch(self, consumer_message):
        ret = []
        for batch in consumer_message.data.message_batch:
            for message in batch.message:
                ret.append(message.data)
        return ret
