# -*- coding: utf-8 -*-
from __future__ import unicode_literals

import json

from kikimr.public.sdk.python.persqueue.auth import OAuthTokenCredentialsProvider
from kikimr.public.sdk.python.persqueue.grpc_pq_streaming_api import PQStreamingAPI
from kikimr.public.sdk.python.persqueue.grpc_pq_streaming_api import ProducerConfigurator
from concurrent import futures


from travel.cpa.lib.errors import ErrorType, ProcessError
from travel.cpa.lib.lib_logging import get_logger


LOG = get_logger(__name__)
TIMEOUT_SECONDS = 60


class LBWriter(object):
    def __init__(self, url, port, source_id, token):
        self.url = url
        self.port = port
        self.source_id = source_id
        self.token = token

    def write(self, topic, data):
        api = PQStreamingAPI(self.url, self.port)
        simple_writer_config = ProducerConfigurator(topic=topic, source_id=self.source_id)
        provider = None
        if self.token is not None:
            provider = OAuthTokenCredentialsProvider(self.token)
        try:
            f = api.start()
            f.result(TIMEOUT_SECONDS)

            writer = api.create_retrying_producer(simple_writer_config, credentials_provider=provider)

            response = writer.start_future.result(timeout=TIMEOUT_SECONDS)
            if not response.HasField('init'):
                raise Exception('Failed to initialize logbroker connection')
            offset = response.init.max_seq_no + 1
            LOG.debug('Writing from %s to %s with offset %d', self.source_id, topic, offset)
            data = (self._get_message_bytes(message) for message in data)
            fs = [writer.write(offset + i, rec) for i, rec in enumerate(data)]
            for f in fs:
                write_result = f.result(timeout=TIMEOUT_SECONDS)
                if not write_result.HasField('ack'):
                    raise Exception('Failed to write {}'.format(f))
            LOG.debug('Write to LB complete')
        except futures.TimeoutError:
            LOG.exception('LB timeout')
            raise ProcessError(ErrorType.ET_LB_TIMEOUT)
        finally:
            api.stop()
            LOG.debug('API stopped')

    @staticmethod
    def _get_message_bytes(message):
        if isinstance(message, list):
            return b'\n'.join(json.dumps(rec).encode('utf8') for rec in message)
        return json.dumps(message).encode('utf8')
