import asyncio
from abc import ABC, abstractmethod
from typing import List, Optional, Union

import kikimr.public.sdk.python.persqueue.auth as pqauth
import kikimr.public.sdk.python.persqueue.errors as pqerr
import kikimr.public.sdk.python.persqueue.grpc_pq_streaming_api as pqlib
from tvmauth import TvmClient, TvmToolClientSettings


class LogbrokerError(Exception):
    pass


class LogbrokerReadTimeout(Exception):
    pass


class LogbrokerWrapper(ABC):
    __slots__ = ("_started", "_stopped")

    _started: bool
    _stopped: bool

    def __init__(self):
        self._started = False
        self._stopped = False

    @abstractmethod
    async def start(self):
        if self._started:
            raise RuntimeError("Attempt to call start() twice")
        self._started = True

    @abstractmethod
    async def stop(self):
        if not self._started:
            raise RuntimeError("Attempt to call stop() before start")
        self._stopped = True

    async def __aenter__(self):
        await self.start()
        return self

    async def __aexit__(self, *exc):
        await self.stop()

    def _raise_if_not_started_or_stopped(self):
        if not self._started or self._stopped:
            raise RuntimeError("Attempt to use object before start or after stop")


class TopicWriter(LogbrokerWrapper):
    __slots__ = ("max_seq_no", "producer")

    max_seq_no: Optional[int]
    producer: pqlib.PQStreamingProducer

    START_TIMEOUT = 10
    STOP_TIMEOUT = 10
    WRITE_TIMEOUT = 10

    def __init__(self, producer: pqlib.PQStreamingProducer):
        super().__init__()
        self.producer = producer
        self.max_seq_no = None

    async def start(self):
        await super().start()
        try:
            start_result = await asyncio.wait_for(
                asyncio.wrap_future(self.producer.start()), timeout=self.START_TIMEOUT
            )
        except asyncio.TimeoutError:
            raise LogbrokerError("Producer start timeout")

        if isinstance(
            start_result, pqerr.SessionFailureResult
        ) or not start_result.HasField("init"):
            raise LogbrokerError(f"Producer failed to start with error {start_result}")

        self.max_seq_no = start_result.init.max_seq_no

    async def stop(self):
        await super().stop()
        try:
            await asyncio.wait_for(
                asyncio.wrap_future(self.producer.stop()), timeout=self.STOP_TIMEOUT
            )
        except asyncio.TimeoutError:
            raise LogbrokerError("Producer stop timeout")

    async def write_one(self, message: bytes):
        self._raise_if_not_started_or_stopped()
        self.max_seq_no += 1
        response = self.producer.write(self.max_seq_no, message)
        try:
            result = await asyncio.wait_for(
                asyncio.wrap_future(response), timeout=self.WRITE_TIMEOUT
            )
        except asyncio.TimeoutError:
            raise LogbrokerError("Message write timeout")
        if not result.HasField("ack"):
            raise LogbrokerError(f"Message write failed with error {result}")


class TopicReader(LogbrokerWrapper):
    __slots__ = ("consumer", "cookies", "event_fut", "last_unacknowledged_cookie")

    consumer: pqlib.PQStreamingConsumer
    cookies: List[int]
    event_fut: Optional[asyncio.Future]
    last_unacknowledged_cookie: Optional[int]

    START_TIMEOUT = 10
    STOP_TIMEOUT = 10
    READ_TIMEOUT = 10
    ACK_TIMEOUT = 10

    def __init__(self, consumer: pqlib.PQStreamingConsumer):
        super().__init__()
        self.consumer = consumer
        self.cookies = []
        self.event_fut = None
        self.last_unacknowledged_cookie = None

    async def start(self):
        await super().start()
        try:
            start_result = await asyncio.wait_for(
                asyncio.wrap_future(self.consumer.start()), timeout=self.START_TIMEOUT
            )
        except asyncio.TimeoutError:
            raise LogbrokerError("Consumer start timeout")

        if isinstance(
            start_result, pqerr.SessionFailureResult
        ) or not start_result.HasField("init"):
            raise LogbrokerError(f"Consumer failed to start with error {start_result}")

    async def stop(self):
        await super().stop()
        try:
            await asyncio.wait_for(
                asyncio.wrap_future(self.consumer.stop()), timeout=self.STOP_TIMEOUT
            )
        except asyncio.TimeoutError:
            raise LogbrokerError("Consumer stop timeout")

    async def read_batch(
        self, message_count_threshold: int, read_timeout: Optional[int] = None
    ):
        self._raise_if_not_started_or_stopped()
        message_count = 0
        while message_count < message_count_threshold:
            event = await self._fetch_event(
                read_timeout if read_timeout is not None else self.READ_TIMEOUT
            )
            if event is None:
                raise LogbrokerReadTimeout
            elif event.type == pqlib.ConsumerMessageType.MSG_DATA:
                self.cookies.append(event.message.data.cookie)
                for batch in event.message.data.message_batch:
                    for message in batch.message:
                        yield message.data
                    message_count += len(batch.message)
            elif event.type == pqlib.ConsumerMessageType.MSG_COMMIT:
                if (
                    self.last_unacknowledged_cookie is not None
                    and event.message.commit.cookie[-1]
                    == self.last_unacknowledged_cookie
                ):
                    self.last_unacknowledged_cookie = None
            elif event.type == pqlib.ConsumerMessageType.MSG_ERROR:
                raise LogbrokerError(f"Read failed with error {event.message}")
            else:
                raise LogbrokerError(f"Event type not supported: {event.type}")

    def commit(self):
        self._raise_if_not_started_or_stopped()
        if self.cookies:
            self.consumer.commit(self.cookies)
            self.last_unacknowledged_cookie = self.cookies[-1]
            self.cookies = []

    async def finish_reading(self):
        self._raise_if_not_started_or_stopped()
        self.consumer.reads_done()
        if self.last_unacknowledged_cookie is not None:
            last_cookie = None
            while last_cookie != self.last_unacknowledged_cookie:
                event = await self._fetch_event(self.ACK_TIMEOUT)
                if event is None:
                    raise LogbrokerError("Commit ack wait timeout")
                if event.type == pqlib.ConsumerMessageType.MSG_DATA:
                    continue
                elif event.type == pqlib.ConsumerMessageType.MSG_COMMIT:
                    last_cookie = event.message.commit.cookie[-1]
                elif event.type == pqlib.ConsumerMessageType.MSG_ERROR:
                    raise LogbrokerError(f"Read failed with error {event}")
                else:
                    raise LogbrokerError(f"Event type not supported: {event.type}")
            self.last_unacknowledged_cookie = None

    async def _fetch_event(self, timeout: int) -> Optional[pqlib.ConsumerMessage]:
        if self.event_fut is None:
            self.event_fut = asyncio.wrap_future(self.consumer.next_event())

        try:
            event = await asyncio.wait_for(
                asyncio.shield(self.event_fut), timeout=timeout
            )
        except asyncio.TimeoutError:
            return None

        self.event_fut = None
        return event


class LogbrokerClient(LogbrokerWrapper):
    __slots__ = ("api", "cred_provider", "default_source_id")

    api: pqlib.PQStreamingAPI
    cred_provider: pqauth.TVMCredentialsProvider
    default_source_id: bytes

    START_TIMEOUT = 10

    def __init__(
        self,
        host: str,
        port: int,
        default_source_id: bytes,
        tvm_destination: Union[str, int],
        tvm_self_alias: str = "self",
        tvm_port: int = None,
    ):
        super().__init__()
        self.api = pqlib.PQStreamingAPI(host, port)
        self.default_source_id = default_source_id

        tvm_client = TvmClient(
            TvmToolClientSettings(self_alias=tvm_self_alias, port=tvm_port)
        )
        cred_kwargs = (
            dict(destination_client_id=tvm_destination)
            if isinstance(tvm_destination, int)
            else dict(destination_alias=str(tvm_destination))
        )
        self.cred_provider = pqauth.TVMCredentialsProvider(
            tvm_client=tvm_client, **cred_kwargs
        )

    async def start(self):
        await super().start()
        try:
            start_result = await asyncio.wait_for(
                asyncio.wrap_future(self.api.start()), timeout=self.START_TIMEOUT
            )
        except asyncio.TimeoutError:
            raise LogbrokerError("Api start timeout")
        if not start_result:
            raise LogbrokerError("Api start failed")

    async def stop(self):
        await super().stop()
        self.api.stop()

    async def close(self):
        await self.stop()

    def create_reader(self, topic: str, consumer_id: str, **kwargs) -> TopicReader:
        self._raise_if_not_started_or_stopped()
        consumer = self.api.create_consumer(
            pqlib.ConsumerConfigurator(topic, consumer_id, **kwargs),
            credentials_provider=self.cred_provider,
        )
        return TopicReader(consumer)

    def create_writer(self, topic: str, source_id: Optional[str] = None) -> TopicWriter:
        self._raise_if_not_started_or_stopped()
        producer = self.api.create_producer(
            pqlib.ProducerConfigurator(
                topic, source_id if source_id is not None else self.default_source_id
            ),
            credentials_provider=self.cred_provider,
        )
        return TopicWriter(producer)
