# -*- coding: utf-8 -*-

from concurrent.futures import TimeoutError as FutureTimeoutError
import logging
import time

from kikimr.public.sdk.python.persqueue.errors import (
    ActorTerminatedException,
    SessionClosedException,
    SessionFailureResult,
)
from kikimr.public.sdk.python.persqueue.grpc_pq_streaming_api import (
    ConsumerConfigurator,
    ConsumerMessageType,
    PQStreamingAPI,
)
from passport.backend.core.logbroker.logbroker_base import (
    create_credentials,
    format_protobuf_safe,
)
from passport.backend.core.logging_utils.helpers import escape_text_for_log
from passport.backend.logbroker_client.core.consumers.exceptions import ExitException
from passport.backend.logbroker_client.core.logbroker.client import (
    BaseLogbrokerClientException,
    TimeoutException,
)
from passport.backend.logbroker_client.core.logbroker.decompress import (
    AutoDecompressor,
    BadArchive,
)
from passport.backend.logbroker_client.core.utils import LogPrefixManager


log = logging.getLogger('logbroker')


class ConnectError(BaseLogbrokerClientException):
    pass


class ProtocolError(BaseLogbrokerClientException):
    pass


class WatchedFutureProperty(object):
    def __init__(self, name):
        self.name = name

    def _warn(self, msg):
        log.warning(msg)

    def watch_set_value(self, old):
        """ Напишем в лог, если future не завершён до создания нового """
        if old is not None:
            if not old.done():
                self._warn(
                    'Future is {!r} while creating new one'.format(old),
                )

    def __set__(self, instance, value):
        self.watch_set_value(instance.__dict__.get(self.name))
        instance.__dict__[self.name] = value


class FuturePoll(object):
    future = WatchedFutureProperty('future')

    def __init__(self, getter, debug_message):
        self.future = None
        self.getter = getter
        self.debug_message = debug_message

    def poll(self, timeout=None):
        if self.future is None:
            log.debug('Creating future: {}'.format(self.debug_message))
            self.future = self.getter()
        result = self.future.result(timeout)
        self.future = None
        return result


class NativeLogbrokerConsumer(object):
    """ Враппер для нативного LB-клиента (python logbroker SDK) """
    BEACON_THRESHOLD_SEC = 30

    def __init__(
        self,
        host,
        port,
        ca_cert,
        client_id,
        credentials_config,
        topic,
        use_client_locks,
        partition_group,
        decompress,
        connect_timeout,
        max_count,
        is_interrupted_callback,
    ):
        self._host = host
        self._port = port
        self._ca_cert = ca_cert
        self._topic = topic
        self._connect_timeout = connect_timeout
        self._is_interrupted_callback = is_interrupted_callback
        self._credentials = create_credentials(credentials_config)
        self._use_client_locks = use_client_locks
        self._partition_group = partition_group
        if self._partition_group and not self._use_client_locks:
            raise ValueError('Explicit partition is supported only with client locks')
        configurator_args = dict(
            topics=topic,
            client_id=client_id,
            use_client_locks=use_client_locks,
            max_count=max_count,
        )
        if self._partition_group:
            configurator_args.update(partition_groups=[self._partition_group])
        log.debug('Created logbroker client. host={} port={} session args={}'.format(
            self._host, self._port, configurator_args,
        ))
        self._configurator = ConsumerConfigurator(**configurator_args)
        self._consumer = None
        self._decompressor = AutoDecompressor() if decompress else None
        self._api = PQStreamingAPI(
            host=self._host,
            port=self._port,
            root_certificates=self._ca_cert,
        )
        self._received_count = 0
        self._last_received_ts = 0
        self._last_beacon_ts = time.time()
        self._last_heartbeat_ts = time.time()
        self._session_id = None
        self._remote_partition_state = None
        self.api_started = False
        self.consumer_started = False

        self._consumer_start_poll = FuturePoll(
            lambda: self._consumer.start(),
            'consumer start',
        )
        self._api_start_poll = FuturePoll(
            lambda: self._api.start(),
            'api start',
        )
        self._next_event_poll = FuturePoll(
            lambda: self._consumer.next_event(),
            'next event',
        )

    def _start_api(self):
        dbg_addr = '{}:{}'.format(self._host, self._port)
        try:
            result = self._api_start_poll.poll(self._connect_timeout)
        except FutureTimeoutError:
            raise TimeoutException('Timeout starting API {}'.format(dbg_addr))
        log.debug('PQStreamingAPI {} started with result: {}'.format(dbg_addr, result))

    def _create_consumer(self):
        if not self._consumer:
            self._consumer = self._api.create_consumer(
                self._configurator,
                self._credentials,
            )

    @staticmethod
    def _try_format_session_failure(result):
        try:
            if result.description is not None:
                description = format_protobuf_safe(result.description)
            else:
                description = 'None'
            return u'reason={} description={}'.format(result.reason, description)
        except AttributeError:
            return escape_text_for_log(str(result))

    @staticmethod
    def _try_format_read_result(result):
        return format_protobuf_safe(result)

    def _parse_consumer_start_result(self, result):
        dbg_addr = '{}@{}:{}'.format(self._topic, self._host, self._port)
        if isinstance(result, SessionFailureResult):
            raise ConnectError(
                u'LB consumer {} failed to start with error {}'.format(
                    dbg_addr,
                    self._try_format_session_failure(result),
                ),
            )
        if not hasattr(result, 'HasField'):
            message = 'Wrong consumer start result type: {}'.format(result, type(result))
            log.error(message)
            raise ProtocolError(message)
        if not result.HasField('init'):
            message = u'Wrong consumer start result: {}'.format(
                self._try_format_read_result(result),
            )
            log.error(message)
            raise ProtocolError(message)
        self._session_id = result.init.session_id
        log.debug(
            u'LB consumer {} started. Result was: {}'.format(
                dbg_addr,
                self._try_format_read_result(result),
            ),
        )

    def _start_consumer(self):
        self._create_consumer()

        try:
            result = self._consumer_start_poll.poll(self._connect_timeout)
        except FutureTimeoutError:
            raise TimeoutException('Timeout starting consumer')

        try:
            self._parse_consumer_start_result(result)
        except BaseLogbrokerClientException:
            self._consumer = None
            raise

    def start_api(self):
        self._start_api()
        self.api_started = True

    def start_consumer(self):
        self._start_consumer()
        self.consumer_started = True

    def start(self):
        if not self.api_started:
            self.start_api()
        if not self.consumer_started:
            self.start_consumer()

    def stop(self):
        if self.consumer_started:
            self._consumer.reads_done()
            stop_future = self._consumer.stop()
            try:
                result = stop_future.result(timeout=self._connect_timeout)
                log.info(u'LB consumer stop resulted in {}'.format(
                    self._try_format_session_failure(result),
                ))
            except FutureTimeoutError:
                log.warning('LB consumer stop resulted in TimeoutError')
            self._consumer = None
            self.consumer_started = False
        if self.api_started:
            self._api.stop()
            self.api_started = False

    def __enter__(self):
        self.start()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.stop()

    def _check_consumer_is_running(self):
        if self._consumer.stop_future.done():
            self.consumer_started = False
            result = self._consumer.stop_future.result()
            self._consumer = None
            raise ConnectError(u'Consumer stopped: {}'.format(
                self._try_format_session_failure(result),
            ))

    @property
    def details(self):
        info = {
            'host': self._host,
            'port': self._port,
            'topic': self._topic,
            'n_received': self._received_count,
            'last_heartbeat_ago_s': time.time() - self._last_heartbeat_ts,
        }
        if self._last_received_ts:
            info['last_msg_ago_s'] = time.time() - self._last_received_ts
        if self._session_id:
            info['session_id'] = self._session_id
        if self._partition_group is not None:
            info['partition_group'] = self._partition_group
        if self._remote_partition_state is not None:
            info['remote_state'] = self._remote_partition_state

        return info

    def _heartbeat(self):
        self._last_heartbeat_ts = time.time()

    def _beacon(self):
        cur_time = time.time()
        if cur_time - self._last_beacon_ts >= self.BEACON_THRESHOLD_SEC:
            log.debug('Consumer is alive: {}'.format(
                ' '.join('{}={}'.format(k, v) for k, v in self.details.items()))
            )
            self._last_beacon_ts = cur_time

    def _assert_not_interrupted(self):
        if self._is_interrupted_callback():
            raise ExitException()

    def await_event(self):
        # Цикл с таймаутами нужен для проверки условий выхода/дисконнекта
        # и отправки heartbeat в лог
        while True:
            self._check_consumer_is_running()
            self._beacon()
            self._heartbeat()
            self._assert_not_interrupted()
            try:
                return self._next_event_poll.poll(timeout=1)
            except FutureTimeoutError:
                continue
            finally:
                # Эта лишняя проверка нужна, чтобы выкинуть верный класс
                # исключения, когда в ожидании next event future
                # произошёл разрыв сессии. В документации к LB SDK нет
                # гарантии определённого поведения клиента в такой ситуации
                self._check_consumer_is_running()

    def await_data(self):
        while True:
            self._assert_not_interrupted()
            result = self.await_event()
            if not hasattr(result, 'type'):
                raise ProtocolError(u'Wrong result type: {} {}'.format(result, type(result)))
            if result.type == ConsumerMessageType.MSG_DATA:
                return result
            elif result.type == ConsumerMessageType.MSG_COMMIT:
                continue
            elif result.type == ConsumerMessageType.MSG_ERROR:
                raise ProtocolError(u'Message read error {}'.format(
                    self._try_format_read_result(result.message),
                ))
            elif self._use_client_locks and result.type == ConsumerMessageType.MSG_LOCK:
                log.debug('Received partition lock {}'.format(result.message.lock))
                self._remote_partition_state = {
                    'state': 'ASSIGNED',
                    'topic': result.message.lock.topic,
                    'partition': result.message.lock.partition,
                }
                try:
                    result.ready_to_read()
                except (ActorTerminatedException, SessionClosedException) as err:
                    raise ConnectError(
                        'Failed to confirm partition lock: {} {}'.format(
                            err.__class__, err,
                        ),
                    )
            elif self._use_client_locks and result.type == ConsumerMessageType.MSG_RELEASE:
                log.debug('Received partition revoke {}'.format(result.message.release))
                self._remote_partition_state = {
                    'state': 'RELEASED',
                    'old_topic': result.message.release.topic,
                    'old_partition': result.message.release.partition,
                }
            else:
                raise ProtocolError(
                    u'Wrong message type {}. MSG_DATA/MSG_COMMIT/MSG_ERROR expected ({})'.format(
                        result.type,
                        self._try_format_read_result(result.message),
                    ),
                )

    def decompress(self, data):
        if self._decompressor is not None:
            try:
                return self._decompressor.decompress(data)
            except BadArchive as err:
                raise ProtocolError('Failed to decompress data: {}'.format(err))
        else:
            return data

    def handle_message(self, message_callback):
        self._assert_not_interrupted()
        data_response = self.await_data().message.data
        request_ids = []
        log.debug(
            'Received data response, cookie={}, n_batches={}'.format(
                data_response.cookie,
                len(data_response.message_batch),
            ),
        )
        for batch in data_response.message_batch:
            log.debug(
                'Received message batch, cookie={}, topic={}, partition={}, '
                'n_messsages={}'.format(
                    data_response.cookie,
                    batch.topic,
                    batch.partition,
                    len(batch.message),
                ),
            )
            for message in batch.message:
                request_id = LogPrefixManager.new_id()
                request_ids.append(request_id)
                with LogPrefixManager.system_context(
                    data_response.cookie,
                    request_id,
                ):
                    log.debug(
                        'Received message, topic={} partition={} offset={} '
                        'cookie={} seq_no={} create_time={} write_time={} '
                        'len={}'.format(
                            batch.topic,
                            batch.partition,
                            message.offset,
                            data_response.cookie,
                            message.meta.seq_no,
                            message.meta.create_time_ms,
                            message.meta.write_time_ms,
                            len(message.data),
                        ),
                    )
                    self._received_count += 1
                    self._last_received_ts = time.time()
                    decompressed_data = self.decompress(message.data)
                    self._heartbeat()
                    message_callback(
                        topic=self._topic,
                        partition=batch.partition,
                        message=message,
                        decompressed_data=decompressed_data,
                    )
        with LogPrefixManager.system_context('|'.join(request_ids)):
            log.debug('Commit cookie {}'.format(data_response.cookie))
            self._consumer.commit(data_response.cookie)
