import json
import logging
import calendar
import datetime as dt

from concurrent import futures

try:
    import kikimr.public.sdk.python.persqueue.grpc_pq_streaming_api as pqlib
    import kikimr.public.sdk.python.persqueue.auth as auth
    import kikimr.public.sdk.python.persqueue.errors as errors
except ImportError:
    # sandbox non-arcadia tests import all modules in services/modules/__init__.py, just ignore for these tests
    pqlib = None
    auth = None
    errors = None

from sandbox.services import base
from sandbox.yasandbox.database import mapping
import sandbox.common.config as common_config
import sandbox.common.tvm as common_tvm

logger = logging.getLogger(__name__)


class LogbrokerSender(object):
    def __init__(self, config):
        self.__fqdn = config.this.fqdn
        cfg = config.server.services.logbroker_publisher
        self.__topic = cfg.tasks_topic_name
        self.__endpoint = cfg.endpoint
        self.__destination_tvm_client_id = cfg.tvm_id
        self.__api = None
        self.__configurator = None
        self.__credentials_provider = None
        self.__producer = None
        self.__max_seq_no = 0

    def initialize(self):
        """
        Method should be called before using of LogbrokerSender
        Re-enterable.
        """
        if self.__max_seq_no:
            return

        host, port = self.__endpoint.split(":")
        self.__api = pqlib.PQStreamingAPI(host, port)
        logger.debug("Starting PqLib")
        api_start_future = self.__api.start()
        result = api_start_future.result(timeout=10)
        logger.debug("Api started with result: {}".format(result))

        self.__credentials_provider = auth.TVMCredentialsProvider(
            common_tvm.TVMClient,
            self.__destination_tvm_client_id
        )

        source_id = self.__fqdn
        self.__configurator = pqlib.ProducerConfigurator(self.__topic, source_id)

        self.__producer = self.__api.create_producer(
            self.__configurator,
            credentials_provider=self.__credentials_provider
        )
        logger.debug("Starting Producer")
        start_future = self.__producer.start()
        start_result = start_future.result(timeout=10)
        if not isinstance(start_result, errors.SessionFailureResult):
            if start_result.HasField("init"):
                logger.debug("Producer start result was: {}".format(start_result))
                self.__max_seq_no = start_result.init.max_seq_no
            else:
                msg = "Unexpected producer start result from server: {}.".format(start_result)
                logger.error(msg)
                raise LogbrokerPublisherException(msg)
        else:
            msg = "Error occurred on start of producer: {}.".format(start_result)
            logger.error(msg)
            raise LogbrokerPublisherException(msg)
        logger.debug("Producer started")

    def stop(self):
        try:
            if self.__api is not None:
                self.__api.stop()
                self.__max_seq_no = 0
        except Exception as ex:
            logger.exception("Error on PqLib stop", ex)

    def send_async(self, event):
        self.__max_seq_no += 1
        data = self.__to_data(event)
        future = self.__producer.write(self.__max_seq_no, json.dumps(data))
        return future

    @staticmethod
    def __to_data(event):
        return {
            "id": str(event.id),
            "created": calendar.timegm(event.created.utctimetuple()),
            "task_id": event.task_id,
            "status": event.status,
            "tags": event.tags,
            "hints": event.hints,
        }


class LogbrokerPublisherException(Exception):
    """
    Dummy class for wrapping publisher internal exceptions
    """


class LogbrokerPublisher(base.SingletonService):
    """
    Sends notifications of entity (e.g. tasks) changes to other systems.
    """

    def __init__(self, sender=None, *args, **kwargs):
        super(LogbrokerPublisher, self).__init__(*args, **kwargs)
        self.__config = common_config.Registry()
        cfg = self.__config.server.services.logbroker_publisher
        self.__endpoint = cfg.endpoint
        self.__batch_size = cfg.batch_size
        self.__event_expires_hours = cfg.event_expires_hours
        self.__run_interval = cfg.run_interval
        self.__rollout_percent = cfg.rollout_percent
        self.__sender = sender or LogbrokerSender(self.__config)
        # used to avoid spam in logs
        self.__inconsistent_events = dict()

    @property
    def tick_interval(self):
        return self.__run_interval

    def tick(self):
        try:
            self.__sender.initialize()

            # cleanup obsolete records
            threshold = dt.datetime.utcnow() - dt.timedelta(hours=self.__event_expires_hours)
            mapping.TaskStatusEvent.objects(created__lt=threshold).delete()

            events = mapping.TaskStatusEvent.objects().order_by("created").limit(self.__batch_size)
            task_ids = map(lambda event: event.task_id, events)
            task_se_ids = dict(mapping.Task.objects(id__in=task_ids).fast_scalar("id", "status_events"))

            messages_inflight = dict()
            for event in events:
                event_id = str(event.id)
                tse_ids = task_se_ids.get(event.task_id)
                if not tse_ids:
                    logger.warning("Task #%s absents for TaskStatusEvent %s", event.task_id, event_id)
                    continue
                if event_id not in tse_ids:
                    # this may mean race condition, when TaskStatusEvent record was already created,
                    # and task status was not yet updated
                    if event_id not in self.__inconsistent_events:
                        self.__inconsistent_events[event_id] = event.task_id
                    continue

                try:
                    if event.task_id % 100 + 1 <= self.__rollout_percent:
                        future = self.__sender.send_async(event)
                    else:
                        future = None
                    messages_inflight[event.id] = future
                except errors.ActorTerminatedException as ex:
                    logger.exception("Sending events to logbroker failed. Stopping API...", ex)
                    self.__sender.stop()
                    break
                except Exception as ex:
                    logger.exception("Sending events to logbroker failed", ex)

            message = "; ".join([("{}: task {}".format(tse, tid)) for tse, tid in self.__inconsistent_events.items()])
            logger.info(
                "TaskStatusEvent is not committed in Task status_events: total %d\n%s",
                len(self.__inconsistent_events),
                message[:500]
            )

            # note that suggested size of inflight messages is 8 MiB
            to_be_awaited = [f for f in messages_inflight.values() if f is not None]
            futures.wait(to_be_awaited, timeout=10, return_when=futures.ALL_COMPLETED)

            sent_tse_ids = set()
            skipped_tse_ids = set()
            for event_id, future in messages_inflight.items():
                if future is None:
                    # scenario of skipped sending, just mark event to delete
                    skipped_tse_ids.add(event_id)
                    continue

                result = future.result(timeout=0)
                if isinstance(result, errors.SessionFailureResult) or not result.HasField("ack"):
                    logger.error("Exception occurred during message write %s", future.exception())
                else:
                    sent_tse_ids.add(event_id)

            logger.info("Successfully sent %d events; skipped %d events", len(sent_tse_ids), len(skipped_tse_ids))
            # remove rows with confirmed delivery to logbroker
            to_delete = sent_tse_ids | skipped_tse_ids
            mapping.TaskStatusEvent.objects(id__in=to_delete).delete()
        finally:
            self.__sender.stop()
