#!/usr/bin/env python
# -*- coding:utf-8 -*-

import gzip
import logging
import retrying
import queue
import six
import uuid

from concurrent import futures

from kikimr.public.sdk.python.persqueue.grpc_pq_streaming_api import (
    ConsumerMessageType,
    PQStreamingAPI,
    ProducerConfigurator,
    WriterCodec,
)
from kikimr.public.sdk.python.persqueue.auth import TVMCredentialsProvider
from tvmauth import TvmApiClientSettings, TvmClient


class PQReader(object):
    def __init__(self, pq_client, credentials, reader_config, timeout, logger):
        self.pq_client = pq_client
        self.credentials = credentials
        self.reader_config = reader_config
        self.timeout = timeout
        self.logger = logger

        self.consumer = None
        self.event_future = None
        self.done_cookies = []

    def __enter__(self):
        self.consumer = self.pq_client.create_consumer(self.reader_config.consumer_configurator, credentials_provider=self.credentials)
        self.consumer.stop_future.add_done_callback(lambda x: self.consumer_stop_callback(x))
        response = self.consumer.start().result(timeout=self.timeout)
        assert response.init
        self.logger.info("PQ consumer started")

        return self

    def get_messages(self):
        try:
            raw_message = self.__get_message().data
        except futures.TimeoutError:
            return

        for batch in raw_message.message_batch:
            for message in batch.message:
                t = message.data
                yield t

        self.done_cookies.append(raw_message.cookie)

    def clear(self):
        self.done_cookies = []

    def consumer_stop_callback(self, result_future):
        assert result_future.done()
        self.logger.error("PQ consumer stopped: {}".format(result_future.result()))

    def __get_message(self):
        while True:
            if self.event_future is None:
                self.event_future = self.consumer.next_event()

            event = self.event_future.result(timeout=5)
            self.event_future = None
            if event.type == ConsumerMessageType.MSG_DATA:
                return event.message
            self.logger.info("Got COMMIT response")  # debug

    def commit(self):
        if self.done_cookies:
            self.consumer.commit(self.done_cookies)
            self.done_cookies = []

    def __exit__(self, exc_type, exc_val, exc_tb):
        if not self.consumer:
            raise ValueError("Consumer was not started")

        # if exc_type != KeyboardInterrupt:
        self.consumer.stop().result(timeout=10)
        self.logger.info("PQ consumer stopped")


class AsyncPQReader(object):
    def __init__(self, pq_client, credentials, reader_config, timeout, logger):
        self.pq_client = pq_client
        self.credentials = credentials
        self.reader_config = reader_config
        self.timeout = timeout
        self.logger = logger

        self.consumer = None
        self.event_future = None
        self.done_cookies = queue.Queue()
        self.running = True

    def __enter__(self):
        self.consumer = self.pq_client.create_consumer(self.reader_config.consumer_configurator, credentials_provider=self.credentials)
        self.consumer.stop_future.add_done_callback(lambda x: self.consumer_stop_callback(x))
        response = self.consumer.start().result(timeout=self.timeout)
        assert response.init
        self.logger.info("PQ async consumer started")

        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if not self.consumer:
            raise ValueError("Consumer was not started")

        self.consumer.stop().result(timeout=self.timeout)
        self.logger.info("PQ async consumer stopped")

    def loop(self, func):
        while self.running:
            self.commit()
            batch, cookie = self.get_messages()
            if cookie:
                func(batch, cookie)

    def get_messages(self):
        try:
            raw_message = self.__get_message().data
        except futures.TimeoutError:
            self.logger.debug("get_messages timed out")
            return None, None

        batch = [
            message.data
            for batch in raw_message.message_batch
            for message in batch.message
        ]
        return batch, raw_message.cookie

    def consumer_stop_callback(self, result_future):
        assert result_future.done()
        self.logger.error("PQ async consumer stopped: {}".format(result_future.result()))
        self.running = False

    def __get_message(self):
        while True:
            if self.event_future is None:
                self.event_future = self.consumer.next_event()

            event = self.event_future.result(timeout=self.timeout)
            self.event_future = None
            if event.type == ConsumerMessageType.MSG_DATA:
                return event.message

    def commit(self):
        cookies = []
        self.logger.debug("getting cookies: %s", cookies)

        try:
            while True:
                cookies.append(self.done_cookies.get_nowait())
        except queue.Empty:
            pass

        self.logger.debug("Committing cookies: %s", cookies)
        if cookies:
            self.consumer.commit(cookies)

    def stop(self):
        self.running = False


class PQWriter(object):
    def __init__(self, pq_client, topic, credentials, logger, timeout, partition_group):
        self.pq_client = pq_client
        self.topic = topic
        self.credentials = credentials
        self.producer = None
        self.source_id = six.ensure_binary(str(uuid.uuid1()))
        self.seq_num = 0
        self.logger = logger
        self.timeout = timeout
        self.partition_group = partition_group

    def write(self, data, codec=WriterCodec.RAW):
        self.seq_num += 1
        return self.producer.write(self.seq_num, six.ensure_binary(data), codec=codec)

    def __enter__(self):
        producer_config = ProducerConfigurator(
            topic=self.topic,
            source_id=self.source_id,
            partition_group=self.partition_group
        )

        @retrying.retry(stop_max_attempt_number=5)
        def retry_response():
            self.producer = self.pq_client.create_retrying_producer(
                producer_config, credentials_provider=self.credentials
            )
            return self.producer.start().result(timeout=self.timeout)

        response = retry_response()
        init = response.init
        self.seq_num = init.max_seq_no
        self.logger.info("PQ producer started")

        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.producer.stop()

        self.logger.info("PQ API stopped")


class BatchingPQWriter(object):
    def __init__(self, writer, max_batch_size=1 << 20, codec=WriterCodec.GZIP, delimiter=six.b("\n"), timeout=300):
        self.writer = writer
        self.max_batch_size = max_batch_size
        self.codec = codec
        self.delimiter = six.ensure_binary(delimiter)
        self.timeout = timeout

        self.batch = []
        self.current_batch_size = 0

    def write(self, data):
        if not data:
            return

        binary_data = six.ensure_binary(data)
        self.batch.append(binary_data)
        self.current_batch_size += len(binary_data)
        if self.current_batch_size >= self.max_batch_size:
            self.flush()

    def flush(self):
        if not self.batch:
            return

        result = six.ensure_binary(self.delimiter.join(self.batch))

        if self.codec == WriterCodec.RAW:
            pass
        elif self.codec == WriterCodec.GZIP:
            out = six.BytesIO()
            with gzip.GzipFile(fileobj=out, mode="w") as gzip_stream:
                gzip_stream.write(result)
            result = out.getvalue()
        else:
            raise NotImplementedError

        assert self.writer.write(result, self.codec).result(timeout=self.timeout).HasField("ack")

        self.batch = []
        self.current_batch_size = 0

    def __enter__(self):
        self.writer.__enter__()
        return self

    def __exit__(self, *args):
        self.flush()
        self.writer.__exit__(*args)


class PQClient(object):
    def __init__(self, url, port, tvm_id, tvm_secret, logger=logging, timeout=30):
        self.pq_client = PQStreamingAPI(url, port)
        self.credentials = self._get_credentials(tvm_id, tvm_secret)
        self.logger = logger
        self.timeout = timeout

    def __enter__(self):
        self.pq_client.start().result(timeout=self.timeout)
        self.logger.info("PQClient started")
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.pq_client.stop()
        self.logger.info("PQClient stopped")

    def get_reader(self, reader_config, timeout=None):
        return PQReader(
            pq_client=self.pq_client,
            credentials=self.credentials,
            reader_config=reader_config,
            timeout=timeout or self.timeout,
            logger=self.logger,
        )

    def get_async_reader(self, reader_config, timeout=None):
        return AsyncPQReader(
            pq_client=self.pq_client,
            credentials=self.credentials,
            reader_config=reader_config,
            timeout=timeout or self.timeout,
            logger=self.logger,
        )

    def get_writer(self, topic, timeout=None, partition_group=None):
        return PQWriter(
            self.pq_client,
            topic,
            self.credentials,
            logger=self.logger,
            timeout=timeout or self.timeout,
            partition_group=partition_group,
        )

    def _get_credentials(self, id, secret):
        if not secret:
            return None

        tvm_settings = TvmApiClientSettings(self_tvm_id=id, self_secret=secret, dsts={"pq": 2001059})
        tvm_client = TvmClient(tvm_settings)
        return TVMCredentialsProvider(tvm_client, destination_alias="pq")
