import asyncio
import uuid
from typing import ClassVar

import ujson
from kikimr.public.sdk.python.persqueue.grpc_pq_streaming_api import ProducerConfigurator

from mail.payments.payments.storage.logbroker.enums import LogbrokerInstallation
from mail.payments.payments.storage.logbroker.exceptions import (
    ProducerAlreadyWrittenException, ProducerNotRunningException, ProducerWriteException,
    ProducerWriteSeqNoMismatchException
)
from mail.payments.payments.storage.logbroker.factory import LogbrokerFactory
from mail.payments.payments.utils.runnable import Runnable


class BaseProducer(Runnable):
    INSTALLATION: ClassVar[LogbrokerInstallation]
    TOPIC: ClassVar[str]

    def __init__(self, lb_factory: LogbrokerFactory):
        super().__init__()
        self._lb_factory = lb_factory
        self._source_id = uuid.uuid4().hex
        self._producer = None
        self._seq_no = None
        self._seq_no_lock = asyncio.Lock()

        self._logger = lb_factory.logger.clone()
        self._logger.context_push(
            topic=self.TOPIC,
            source_id=self._source_id,
        )

    @property
    def configurator(self) -> ProducerConfigurator:
        return ProducerConfigurator(
            topic=self.TOPIC.encode('ascii'),
            source_id=self.source_id.encode('ascii'),
        )

    @property
    def source_id(self) -> str:
        return self._source_id

    def _clear(self):
        self._producer = None
        self._seq_no = None
        self._seq_no_lock = asyncio.Lock()

    async def _run(self):
        lb_client = await self._lb_factory.get_client(self.INSTALLATION)
        self._producer, self._seq_no = await lb_client.create_producer(self.configurator)

    async def _close(self):
        lb_client = await self._lb_factory.get_client(self.INSTALLATION)
        await lb_client.close_user(self._producer)

    async def get_next_seq_no(self) -> int:
        if not self._running or self._seq_no is None:
            raise ProducerNotRunningException

        async with self._seq_no_lock:
            self._seq_no += 1
            return self._seq_no

    async def write(self, message: bytes) -> None:
        if not self._running or self._producer is None:
            raise ProducerNotRunningException

        with self._logger:
            seq_no = await self.get_next_seq_no()
            write_response = await asyncio.wrap_future(
                self._producer.write(seq_no, message)
            )

            self._logger.context_push(seq_no=seq_no)

            if not write_response.HasField('ack'):
                self._logger.context_push(write_response=write_response)
                self._logger.error('Write response is not ACK')
                raise ProducerWriteException

            if write_response.ack.seq_no != seq_no:
                self._logger.context_push(write_response=write_response)
                self._logger.error('Seq_no mismatch')
                raise ProducerWriteSeqNoMismatchException

            if write_response.ack.already_written:
                raise ProducerAlreadyWrittenException

    async def write_dict(self, data: dict) -> None:
        return await self.write(ujson.dumps(data).encode('ascii'))
