#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Logbroker integration

https://logbroker.yandex-team.ru/docs/
"""

import logging
import os
import socket
import time
from concurrent.futures import TimeoutError, FIRST_COMPLETED, ALL_COMPLETED, wait

import kikimr.public.sdk.python.persqueue.auth as auth
import kikimr.public.sdk.python.persqueue.grpc_pq_streaming_api as pqlib
from kikimr.public.sdk.python.persqueue.errors import SessionFailureResult

logger = logging.getLogger(__name__)


def get_default_source_id():
    #  id группы сообщений при записи
    #  группа сообщений привязывется к партиции при первой инициализации
    #  в рамках группы сообщений обеспечивается порядок чтения сообщений
    #  нельзя параллельно писать в сессиях с одниковым source_id
    #  должно быть не более 100000 уникальных групп сообщений на временном окне 14 дней
    #  по дефолту будет HOSTNAME-PID
    return "{}-{}".format(socket.gethostname().split(".")[0], str(os.getpid()))


def start_session(source_id=None):
    """
    Стартует сессию логброкера, не создавая при этом ни читателя, ни писателя.
    Перед вызовом write_message/read_message
    нужно самостоятельно создать писателя или читателя, соответственно.
    """
    return Logbroker(source_id=source_id)


def start_read_session(logbroker_host,
                       read_processor, consumer_name=None
                       ):
    """
    Стартует сессию логброкера и создает читателя
    в read_message_processor передается класс, наследник ReadMessageProcessor, реализующий метод process

    """
    return Logbroker(
        logbroker_host,
        create_consumer=True,
        read_message_processor=read_processor,
        consumer_name=consumer_name,
    )


def start_write_session(logbroker_host,
                        source_id=None, retrying_producer=False
                        ):
    """
        Стартует сессию логброкера и создает писателя
    """
    return Logbroker(
        logbroker_host,
        create_producer=True,
        retrying_producer=retrying_producer,
        source_id=source_id,
    )


class ReadMessageProcessor(object):
    def process(self, messages):
        raise NotImplementedError


class Logbroker(object):
    """
    Базовый класс взаимодействия с Логброкером.
    Может быть использован в режиме контекст-менеджера (рекомендуемый способ).
    При использовании НЕ в режиме контекст-менеджера,
    надо не зыбыть вызвать в конце stop_consumer(), stop_producer(), stop_api().
    Для инициализации рекомендуется изпользовать методы
    start_session(), start_read_session(), start_write_session().
    """

    def __init__(
            self,
            logbroker_host,
            create_producer=False,
            create_consumer=False,
            retrying_producer=False,
            read_message_processor=None,
            consumer_name=None,
            source_id=None,
    ):
        self.host = logbroker_host
        self._parse_config()

        self.api = self._create_api()
        self.credentials_provider = auth.OAuthTokenCredentialsProvider(self.oauth_token.encode('utf-8'))
        self.read_message_processor = read_message_processor
        self.source_id = source_id or get_default_source_id()
        self.consumer = None
        self.producer = None

        if consumer_name:
            self.consumer_name = consumer_name
        if create_consumer:
            self.create_consumer()
        if create_producer:
            self.create_producer(retrying=retrying_producer)

    def _parse_config(self):
        self.port = int(os.environ.get("LB_PORT"))
        self.timeout = int(os.environ.get("LB_TIMEOUT"))
        self.topic = os.environ.get("LB_TOPIC")
        self.consumer_name = os.environ.get("LB_CONSUMER")
        self.msg_batch_size = int(os.environ.get("LB_MESSAGE_BATCH_SIZE"))
        self.oauth_token = os.environ.get("LB_OAUTH_TOKEN")
        self.infly_read_count = int(os.environ.get("LB_READ_INFLY_COUNT"))
        self.balance_partition_now = int(os.environ.get("LB_BALANCE_PARTITION"))

    @staticmethod
    def get_millis():
        return int(round(time.time() * 1000))

    def _create_api(self):
        api = pqlib.PQStreamingAPI(self.host, self.port)
        api_start_future = api.start()
        result = api_start_future.result(timeout=self.timeout)
        logger.debug("Logbroker api started with result: %s", result)

        return api

    def stop_api(self):
        if self.api:
            self.api.stop()

    def create_consumer(self):
        configurator = pqlib.ConsumerConfigurator(
            self.topic, self.consumer_name, max_count=self.msg_batch_size,
            read_infly_count=self.infly_read_count, balance_partition_now=(self.balance_partition_now > 0)
        )

        self.consumer = self.api.create_consumer(
            configurator, credentials_provider=self.credentials_provider
        )

        logger.debug("Starting consumer")
        start_future = self.consumer.start()
        start_result = start_future.result(timeout=self.timeout)
        if not isinstance(start_result, SessionFailureResult):
            if start_result.HasField("init"):
                logger.debug("Consumer start result was: %s", start_result)
            else:
                raise RuntimeError(
                    "Bad consumer start result from server: {}.".format(start_result)
                )
        else:
            raise RuntimeError(
                "Error occurred on start of consumer: {}.".format(start_result)
            )

        logger.debug("Consumer started")

    def stop_consumer(self):
        if self.consumer:
            self.consumer.stop()

    def _recreate_consumer(self):
        if self.consumer:
            self.consumer.stop()
        self.create_consumer()

    def read_message(self):  # noqa: R701 too complex
        if not self.consumer:
            raise RuntimeError("Cannot read message, no consumer has provided")
        if not self.read_message_processor:
            raise RuntimeError(
                "Cannot process message, no read_message_processor has provided"
            )
        try:
            result = self.consumer.next_event().result(timeout=self.timeout)
        except TimeoutError:
            logger.debug("Failed to get any messages from topic: %s", self.topic)
            self._recreate_consumer()
            return list()

        messages = list()
        if result.type == pqlib.ConsumerMessageType.MSG_DATA:
            messages = [
                message
                for batch in result.message.data.message_batch
                for message in batch.message
            ]

            self._process(messages)
            self.consumer.commit(result.message.data.cookie)
        elif result.type == pqlib.ConsumerMessageType.MSG_LOCK:
            # Сервер назначил партицию нашему консьюмеру.
            # Для того чтобы начать получать данные из партиции
            # надо вызвать ready_to_read()
            result.ready_to_read()
            logger.debug(
                "Got partition assignment. Topic %s, partition %s",
                result.message.lock.topic,
                result.message.lock.partition,
            )
        elif result.type == pqlib.ConsumerMessageType.MSG_RELEASE:
            # Сервер забрал партицию. Никаких действий не требуется.
            logger.debug(
                "Partition revoked. Topic %s, partition %s",
                result.message.release.topic,
                result.message.release.partition,
            )
        elif result.type == pqlib.ConsumerMessageType.MSG_COMMIT:
            # Получили ответ на коммит сообщения. Никаких действий не требуется.
            logger.debug("Got ask on commit request")
        elif result.type == pqlib.ConsumerMessageType.MSG_ERROR:
            # Из доки:
            # В случае, если результатом Future consumer.next_event()
            # будет объект с result.type == pqlib.ConsumerMessageType.MSG_ERROR,
            # следует пересоздать консьюмера вызовом create_consumer
            logger.debug("Got result.type == MSG_ERROR. Try to recreate consumer")
            self.create_consumer()

        return messages

    def create_producer(self, retrying=False):
        configurator = pqlib.ProducerConfigurator(self.topic, self.source_id)

        if retrying:
            self.producer = self.api.create_retrying_producer(
                configurator, credentials_provider=self.credentials_provider
            )
        else:
            self.producer = self.api.create_producer(
                configurator, credentials_provider=self.credentials_provider
            )

        logger.debug("Starting Producer")
        start_future = self.producer.start()
        start_result = start_future.result(timeout=self.timeout)

        if not isinstance(start_result, SessionFailureResult):
            if start_result.HasField("init"):
                logger.debug("Producer start result was: %s", start_result)
            else:
                raise RuntimeError(
                    "Unexpected producer start result from server: {}.".format(
                        start_result
                    )
                )
        else:
            raise RuntimeError(
                "Error occurred on start of producer: {}.".format(start_result)
            )

        logger.debug("Producer started")

    def stop_producer(self):
        if self.producer:
            self.producer.stop()

    def write_message(self, message):
        if not self.producer:
            raise RuntimeError("Cannot write message, no producer has provided")
        response = self.producer.write(self.get_millis(), message)
        write_result = response.result(timeout=self.timeout)
        if not write_result.HasField("ack"):
            raise RuntimeError(
                "Message write failed with error {}".format(write_result)
            )

    def write_message_batch(self, messages, max_inflight=10):
        """
        Отличается от метода write_message тем
        что не ждет каждый раз синхронно подтверждения отправки сообщения.
        В терминах логброкера называется inflight-режим.
        Рекомендуемый логброкером объем inflight-сообщений - 8mb
        """
        if not self.producer:
            raise RuntimeError("Cannot write message, no producer has provided")

        messages_written = set()
        messages_inflight = set()

        for m in messages:
            if len(messages_inflight) < max_inflight:
                # спим 1 милисекунду чтобы были гарантированно уникальные оффсеты
                time.sleep(0.001)
                messages_inflight.add(self.producer.write(self.get_millis(), m))
                continue

            # исчерпали лимит inflight-сообщений и ждем пока хотя бы одно из них обработается
            res = wait(
                messages_inflight, timeout=self.timeout, return_when=FIRST_COMPLETED
            )
            # res.done содержит все обработанные сообщения
            messages_written.update(res.done)  # pytype: disable=attribute-error

            messages_inflight = {
                m for m in messages_inflight if m not in messages_written
            }

            messages_inflight.add(self.producer.write(self.get_millis(), m))

        # дождемся пока обработаются все остатки inflight
        res = wait(messages_inflight, timeout=self.timeout, return_when=ALL_COMPLETED)
        messages_written.update(res.done)  # pytype: disable=attribute-error

        if len(messages) != len(messages_written):
            logger.debug("Total messages: %s", len(messages))
            logger.debug("Messages written: %s", len(messages_written))
            raise RuntimeError("There are some unwritten messages")

        logger.debug("Messages written: %s", len(messages_written))

    def _process(self, messages):
        try:
            self.read_message_processor.process(messages)
        except Exception as ex:
            logger.error("Error occured while processing messages. Keeps messages uncommited. Exception: {}".format(ex))
            raise

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):
        self.stop_consumer()
        self.stop_producer()
        self.stop_api()
