# -*- coding: utf-8 -*-
from concurrent.futures import TimeoutError as FutureTimeoutError
import logging
import time

from kikimr.public.sdk.python.persqueue.errors import SessionFailureResult
import kikimr.public.sdk.python.persqueue.grpc_pq_streaming_api as pqlib
from passport.backend.core.logbroker.exceptions import (
    BaseLogbrokerError,
    ConnectionLost,
    ProtocolError,
    TimeoutError,
    TransportError,
)
from passport.backend.core.logbroker.logbroker_base import (
    create_credentials,
    format_protobuf_safe,
)
from passport.backend.core.logging_utils.helpers import (
    escape_text_for_log,
    trim_message,
)


log = logging.getLogger('passport.logbroker.logbroker_writer_client')


class LogbrokerProducer(object):
    def __init__(
        self, host, port, topic, source_id, connect_timeout, write_timeout,
        credentials_config, partition_group=None, extra_fields=None,
        graphite_logger=None, ca_cert=None,
    ):
        """ Лёгкий конструктор - создаёт объекты в памяти,
        не поднимая никаких сетевых соединений """
        self.host = host
        self.port = port
        self.topic = topic
        self.connect_timeout = connect_timeout
        self.write_timeout = write_timeout
        self._graphite_logger = graphite_logger
        self._credentials = create_credentials(credentials_config)
        self._partition_group = partition_group
        self._ca_cert = ca_cert
        configurator_args = dict(
            topic=topic,
            source_id=source_id,
            extra_fields=extra_fields,
        )
        if self._partition_group:
            configurator_args.update(partition_group=self._partition_group)
        self._configurator = pqlib.ProducerConfigurator(**configurator_args)
        self._api = pqlib.PQStreamingAPI(
            self.host, self.port, root_certificates=self._ca_cert,
        )
        self._api_start_future = None
        self._producer_start_future = None
        self._producer = None
        self._max_seq_no = 0
        self.api_started = False
        self.producer_started = False
        self._remote_status = {}

    def _start_api(self):
        if self._api_start_future is None:
            self._api_start_future = self._api.start()
        try:
            result = self._api_start_future.result(timeout=self.connect_timeout)
        except FutureTimeoutError:
            message = 'Timeout starting API at {}:{}{} (timeout {})'.format(
                self.host,
                self.port,
                ' (TLS)' if self._ca_cert else '',
                self.connect_timeout,
            )
            log.error(message)
            raise TimeoutError(message)
        log.info('Api started with result: {}'.format(result))

    def start_api(self):
        """
        Подключиться к API.
        API - обёртка вокруг gRPC и само управляет реконнектами.

        :raises: TimeoutError(TransportError) - при таймауте соединения
        :raises: TransportError - при сетевой ошибке
        :raises: ProtocolError - при ошибке на уровне logbroker-протокола
        """
        self._start_api()
        self.api_started = True

    def _create_producer(self):
        log.debug('Creating producer with config {}'.format(self._configurator))
        if self._producer is None:
            self._producer = self._api.create_producer(
                self._configurator,
                self._credentials,
            )

    @staticmethod
    def _try_format_session_failure(result):
        try:
            if result.description is not None:
                description = format_protobuf_safe(result.description)
            else:
                description = 'None'
            return u'reason={} description={}'.format(result.reason, description)
        except AttributeError:
            return escape_text_for_log(str(result))

    @staticmethod
    def _try_format_write_result(result):
        return format_protobuf_safe(result)

    @staticmethod
    def _parse_init_result(init_result):
        return {
            f: getattr(init_result, f)
            for f in ['max_seq_no', 'session_id', 'partition', 'topic']
            if hasattr(init_result, f)
        }

    @staticmethod
    def _is_critical_start_error(response):
        try:
            description = response.description.error.description
        except AttributeError:
            return False
        description = str(description).lower()
        if 'no topics found' in description:
            return True
        elif 'access denied' in description:
            return True
        else:
            return False

    def _parse_producer_start_result(self, result):
        if isinstance(result, SessionFailureResult):
            message = u'Error starting producer: {}'.format(
                self._try_format_session_failure(result),
            )
            log.error(message)
            if self._is_critical_start_error(result):
                raise ProtocolError(message)
            else:
                raise TransportError(message)
        if not hasattr(result, 'HasField'):
            message = u'Wrong producer start result type: {} {}'.format(result, type(result))
            log.error(message)
            raise ProtocolError(message)
        if not result.HasField('init'):
            message = 'Wrong producer start result: {}'.format(
                self._try_format_write_result(result),
            )
            log.error(message)
            raise ProtocolError(message)
        self._max_seq_no = result.init.max_seq_no
        self._remote_status = self._parse_init_result(result.init)
        log.info('Producer for {}@{}:{} started with result: {}'.format(
            self.topic, self.host, self.port, self._try_format_write_result(result),
        ))

    def _start_producer(self):
        self._create_producer()
        if self._producer_start_future is None:
            self._producer_start_future = self._producer.start()

        try:
            result = self._producer_start_future.result(
                timeout=self.connect_timeout,
            )
            self._producer_start_future = None
        except FutureTimeoutError:
            message = 'Timeout starting producer for {}@{}:{} (timeout {})'.format(
                self.topic, self.host, self.port, self.connect_timeout,
            )
            log.error(message)
            raise TimeoutError(message)

        try:
            self._parse_producer_start_result(result)
        except BaseLogbrokerError:
            self._producer = None
            raise

    def start_producer(self):
        """
        Запустить новую сессию-producer.
        Сессию требуется перезапускать при ConnectionLost.

        :raises: TimeoutError(TransportError) - при таймауте соединения
        :raises: TransportError - при сетевой ошибке
        :raises: ProtocolError - при ошибке на уровне logbroker-протокола
        """
        self._start_producer()
        self.producer_started = True

    def start(self):
        """
        Запуск и API, и сессии-producer'а сразу.

        :raises: TimeoutError(TransportError) - при таймауте соединения
        :raises: TransportError - при сетевой ошибке
        :raises: ProtocolError - при ошибке на уровне logbroker-протокола
        """
        if not self.api_started:
            log.debug('Logbroker API is down, starting')
            self.start_api()
        if not self.producer_started:
            log.debug('Logbroker producer is down, starting')
            self.start_producer()

    def _graphite_log(self, duration, response, status):
        if self._graphite_logger is not None:
            self._graphite_logger.log(
                duration=duration,
                response=response,
                status=status,
            )

    def _check_is_connected(self):
        if self._producer.stop_future.done():
            self.producer_started = False
            result = self._producer.stop_future.result()
            self._producer = None
            raise ConnectionLost(u'Producer stopped: {}'.format(
                self._try_format_session_failure(result),
            ))

    def write(self, message):
        """
        Запись данных.

        :raises: TimeoutError(TransportError) - при таймауте отправки
        :raises: TransportError - при ошибке отправки
        """
        self._max_seq_no += 1
        start_time = time.time()
        status = 'ok'
        response = 'success'
        try:
            self._check_is_connected()
            write_future = self._producer.write(self._max_seq_no, message)
            log.info(
                'Sending logbroker message len={} seq={} remote_status={}'.format(
                    len(message),
                    self._max_seq_no,
                    self._remote_status,
                ),
            )
            write_result = write_future.result(timeout=self.write_timeout)
            if not hasattr(write_result, 'HasField'):
                raise ProtocolError(
                    u'Wrong write result type: {} {}'.format(write_result, type(write_result)),
                )
            if not write_result.HasField('ack'):
                status = 'TransportError'
                response = 'failed'
                message = 'Send error, seq={}: {}'.format(
                    self._max_seq_no,
                    self._try_format_write_result(write_result),
                )
                log.error(message)
                raise TransportError(message)
            log.info('Sent logbroker message len={} seq={} remote_status={} with result: {}'.format(
                len(message),
                self._max_seq_no,
                self._remote_status,
                trim_message(str(self._try_format_write_result(write_result))),
            ))
        except ConnectionLost as err:
            status = 'ConnectionLost'
            response = 'failed'
            message = 'Send error, seq={}: {}'.format(
                self._max_seq_no,
                err,
            )
            log.error(message)
            raise
        except FutureTimeoutError:
            status = 'TimeoutError'
            response = 'timeout'
            message = 'Timeout writing message seq={}'.format(self._max_seq_no)
            log.error(message)
            raise TimeoutError(message)
        except Exception as e:
            status = e.__class__.__name__
            response = 'failed'
            raise
        finally:
            self._graphite_log(
                duration=time.time() - start_time,
                status=status,
                response=response,
            )
