import concurrent.futures
import json
import logging

import kikimr.public.sdk.python.persqueue.grpc_pq_streaming_api as pqlib
import kikimr.public.sdk.python.persqueue.auth as auth
from kikimr.public.sdk.python.persqueue.errors import SessionFailureResult

from stackbot.config import settings


logger = logging.getLogger(__name__)


class Message:
    def __init__(self, _id: int, content: dict):
        self.id = _id
        self.content = content

    def pack(self) -> str:
        return json.dumps(self.content)


class LogbrokerClient:
    MAX_INFLIGHT = settings.LOGBROKER_CLIENT_MAX_INFLIGHT
    WAIT_TIMEOUT = settings.LOGBROKER_CLIENT_WAIT_TIMEOUT

    def __init__(self, endpoint: str, topic: str, producer: str):
        logger.info(f'Starting logbroker producer: {endpoint}, {topic}, {producer}')
        self.api = pqlib.PQStreamingAPI(endpoint, settings.LOGBROKER_PORT)

        api_start_future = self.api.start()
        api_start_future.result(timeout=10)

        credentials_provider = auth.OAuthTokenCredentialsProvider(settings.LOGBROKER_TOKEN)
        configurator = pqlib.ProducerConfigurator(topic, producer)
        self.producer = self.api.create_producer(
            configurator,
            credentials_provider=credentials_provider
        )
        start_future = self.producer.start()
        start_result = start_future.result(timeout=10)

        self.max_seq_no = None

        if not isinstance(start_result, SessionFailureResult):
            if start_result.HasField('init'):
                logger.info(f'Producer start result was: {start_result}')
                self.max_seq_no = start_result.init.max_seq_no
            else:
                raise RuntimeError('Unexpected producer start result from server: {}.'.format(start_result))
        else:
            raise RuntimeError('Error occurred on start of producer: {}.'.format(start_result))
        logger.info(f'Producer started: {endpoint}, {topic}, {producer}')

    def shutdown(self):
        logger.info('Stopping logbroker producer')
        self.producer.stop()
        self.api.stop()
        logger.info('Logbroker producer stopped')

    def write_messages(self, messages: list[Message]) -> set[int]:
        total_count = len(messages)
        written_count = 0
        current_idx = 0
        messages_inflight = list()
        indices_inflight = list()
        processed_indices = set()

        while written_count < total_count:
            while len(messages_inflight) < self.MAX_INFLIGHT and current_idx < total_count:
                self.max_seq_no += 1
                packed_message = messages[current_idx].pack()
                messages_inflight.append(
                    self.producer.write(
                        self.max_seq_no, packed_message, codec=pqlib.WriterCodec.GZIP
                    )
                )
                indices_inflight.append(
                    messages[current_idx].id
                )
                current_idx += 1

            concurrent.futures.wait(
                messages_inflight, timeout=self.WAIT_TIMEOUT, return_when=concurrent.futures.FIRST_COMPLETED
            )
            # Actually futures.wait method returns a ready list of completed futures.
            # But here we use a protocol property - it is strictly consecutive and latter calls can only be replied after
            # earlier ones. So we can just go throw the responses list in order of creation.
            completed_count = 0
            for i, f in enumerate(messages_inflight):
                if not f.done():
                    break

                result = f.result(timeout=0)
                if isinstance(result, SessionFailureResult) or not result.HasField("ack"):
                    logger.error("Exception occurred during message write {}".format(f.exception()))
                    written_count += completed_count
                    logger.info(f'Written {written_count} messages in total')
                    return processed_indices
                else:
                    logger.info("Message written with result: {}".format(result))
                    processed_indices.add(indices_inflight[i])
                    completed_count += 1

            written_count += completed_count
            messages_inflight = messages_inflight[completed_count:]
            indices_inflight = indices_inflight[completed_count:]

        logger.info(f'Written {written_count} messages in total')
        return processed_indices
