import logging

import kikimr.public.sdk.python.persqueue.auth as pqlib_auth
import kikimr.public.sdk.python.persqueue.errors as pqlib_errors
import kikimr.public.sdk.python.persqueue.grpc_pq_streaming_api as pqlib

logger = logging.getLogger(__name__)


class LBProducer(object):
    _api = None
    _producer = None
    _max_seq_no = None

    _endpoint = None  # type: str
    _token = None  # type: str
    _topic = None  # type: str
    _task_id = None  # type: int

    _timeout = 90

    def __init__(self, endpoint, token, topic, task_id):
        # type: (str, str, str, int) -> None
        self._endpoint = endpoint
        self._token = token
        self._topic = topic
        self._task_id = task_id

    def __enter__(self):
        api = pqlib.PQStreamingAPI(self._endpoint, 2135)
        api_start_future = api.start()

        result = api_start_future.result(timeout=self._timeout)
        logger.info("lb api start future result: %s", result)

        credentials_provider = pqlib_auth.OAuthTokenCredentialsProvider(self._token)
        configurator = pqlib.ProducerConfigurator(self._topic, source_id=__name__)

        producer = api.create_retrying_producer(configurator, credentials_provider)
        producer_start_future = producer.start()
        producer_start_result = producer_start_future.result(timeout=self._timeout)

        if not isinstance(producer_start_result, pqlib_errors.SessionFailureResult):
            if producer_start_result.HasField("init"):
                logger.info("lb producer start future result: %s", producer_start_result)
                self._max_seq_no = producer_start_result.init.max_seq_no
            else:
                raise RuntimeError("Unexpected producer start result from server: {}".format(producer_start_result))
        else:
            raise RuntimeError("Error occurred on start of producer: {}".format(producer_start_result))

        self._api = api
        self._producer = producer

        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self._producer.stop()
        self._api.stop()

    def write(self, data):
        # type: (str) -> None

        if self._producer is None:
            raise RuntimeError("Producer is not initialized!")

        self._max_seq_no += 1
        result = self._producer.write(self._max_seq_no, data).result(timeout=self._timeout)

        if not result.HasField("ack"):
            raise RuntimeError("Message write failed with error: {}".format(result))
