# coding=utf-8
import logging
import json
import time

logger = logging.getLogger(__name__)

MESSAGE_FORMAT_JSON = 'MESSAGE_FORMAT_JSON'
MESSAGE_FORMAT_TSKV = 'MESSAGE_FORMAT_TSKV'


def write_to_logbroker(topic, endpoint, port, source, oauth, messages, file, message_format=MESSAGE_FORMAT_TSKV):
    import kikimr.public.sdk.python.persqueue.auth as pqlib_auth
    lb = LogBroker(
        cred_provider=pqlib_auth.OAuthTokenCredentialsProvider(oauth),
        endpoint=endpoint,
        port=port,
        topic=topic,
        source_id=source if source else "SomeSandboxTask",
        file=file
    )
    lb.send(messages, message_format=message_format)


class LogBroker:
    def __init__(
        self, cred_provider,
        endpoint, port, topic, source_id,
        file,
        throttle_pause=0.25
    ):
        import kikimr.public.sdk.python.persqueue.grpc_pq_streaming_api as pqlib
        import kikimr.public.sdk.python.persqueue.errors as pqerrors

        logger.info("Spawning a Logbroker client %s", source_id)
        logging.getLogger('kikimr.public.sdk.python').setLevel(logging.INFO)

        self._source_id = source_id
        self._throttle_pause = throttle_pause

        logger.debug('Logbroker connector %s instance to %s', source_id, endpoint)
        self.api = pqlib.PQStreamingAPI(endpoint, port)
        api_start_future = self.api.start()
        result = api_start_future.result(timeout=30)
        logger.debug('Logbroker API %s started: %s', source_id, result)

        configurator = pqlib.ProducerConfigurator(topic, source_id, extra_fields={
            "file": file
        })
        self.producer = self.api.create_producer(
            producer_configurator=configurator, credentials_provider=cred_provider
        )

        logger.debug('Starting Logbroker producer %s', source_id)
        start_future = self.producer.start()
        start_result = start_future.result(timeout=30)
        if (isinstance(start_result, pqerrors.SessionFailureResult)
            or not start_result.HasField('init')):
            raise RuntimeError(
                'Logbroker producer %s (%s @ %s) start failure: %s' % (
                    source_id, topic, endpoint, start_result))
        self.max_seq = start_result.init.max_seq_no + 1
        logger.debug(
            'Logbroker producer %s started: %s', source_id, start_result)

    def quote(self, value):
        return value.replace('=', '\\=').replace('\\', '\\\\')

    def serialize(self, value):
        if isinstance(value, str):
            value = self.quote(value)
        elif isinstance(value, list) or isinstance(value, tuple):
            value = ','.join(map(lambda x: self.quote(x.replace(',', '\\,')), value))
        return str(value)

    def as_tskv(self, row):
        return '\t'.join(['%s=%s' % (self.serialize(k), self.serialize(v)) for k, v in row.items()])

    def as_json(self, row):
        new_row={}
        for k in row:
            if isinstance(row[k], tuple) or isinstance(row[k], list):
                new_row[k] = ','.join(map(lambda x: str(x), row[k]))
            else:
                new_row[k] = row[k]
        return json.dumps(new_row, separators=(',', ':'))

    def send(self, messages, start_seq=None, message_format=MESSAGE_FORMAT_TSKV):
        if not messages:
            logger.info('No messages to be sent by %s.' % self._source_id)
            return
        seq = start_seq or self.max_seq
        sid = self._source_id
        logger.debug('Start %s seq: %s', sid, seq)
        logger.debug('Message: %s', messages)

        write_responses = []
        for message in messages:
            processed_message = message
            if message_format == MESSAGE_FORMAT_TSKV:
                processed_message = self.as_tskv(message)

            if message_format == MESSAGE_FORMAT_JSON:
                processed_message = self.as_json(message)

            logger.debug('Writing %s', processed_message)
            response = self.producer.write(seq, processed_message)
            write_responses.append(response)
            seq += 1
            if seq % 1000 == 0:
                logger.debug('%s sending, seq = %i', sid, seq)
                logger.debug(json.dumps(message))
                time.sleep(self._throttle_pause)
        logger.debug('%s messages enqueued by %s.', len(write_responses), sid)

        acks_count = 0
        for r in write_responses:
            write_result = r.result(timeout=30)
            if write_result.HasField('ack'):
                acks_count += 1

        logger.debug('%s: %s messages confirmed out of %i (%i%%).' % (
            sid, acks_count, len(messages), 100 * acks_count // len(messages)))
        logger.debug('%s done sending messages', sid)
        if acks_count < len(messages):
            raise RuntimeError(
                'Some messages have not been sent by %s (see above)!' % sid)

    def stop(self):
        self.producer.stop()
