from __future__ import absolute_import

import abc
import logging
import datetime
import threading
import collections

from .types import misc as ctm
from .types import statistics as ctss

from . import api
from . import rest
from . import utils
from . import config


class SignalHandler(object):
    """
    Base class for signal handlers.
    To create a custom handler, fill types with names of the signals you want to process
    and override `handle()` method.
    """

    __metaclass__ = abc.ABCMeta

    types = abc.abstractproperty()

    @abc.abstractmethod
    def handle(self, signal_type, signals):
        pass


class Signaler(object):
    """
    Signaler lets you, well, send signals of any type asynchronously from any place.
    After initialization, on the first push() call it fires a daemon thread, which then fetches signals
    (once in a `settings.server.statistics.update_interval` seconds).

    Although a signal can have an arbitrary structure, it follows two restrictions:
    1) it is a dict-like object, and
    2) it has a `type` key, which denotes the signal's type.

    Class usage:

    .. code-block:: python

        import logging

        from sandbox import common


        class YourOwnSignalHandler(common.statistics.SignalHandler):
            types = (common.types.statistics.SignalType.WHATEVER,)

            def handle(cls, signal_type, signals):
                logging.info("%d signals processed in this batch", len(signals))


        common.statistics.Signaler(YourOwnSignalHandler(), logger=logging.getLogger(__file__))
        for _ in xrange(100):
            # as a result, one or more random numbers are logged, depending on Signaler's processing speed
            common.statistics.Signaler().push(dict(
                type=common.types.statistics.SignalType.WHATEVER,
                hitcount=1234,
                omg="wtf",
            ))
    """

    __metaclass__ = utils.ThreadSafeSingletonMeta

    __signals = collections.defaultdict(list)
    __thread = None
    __lock = threading.Lock()
    __wakeup = threading.Event()
    __processed = threading.Event()

    handlers = collections.defaultdict(list)

    TYPE_KEY = "type"
    DEFAULT_UPDATE_INTERVAL = 5  # seconds

    def __init__(self, *handlers, **kws):
        self.logger = kws.get("logger", None) or logging.getLogger(__name__)
        self.update_interval = kws.get("update_interval", self.DEFAULT_UPDATE_INTERVAL)
        self.register(*handlers)

    def register(self, *handlers):
        """
        Register additional signal handlers on the fly
        """
        for handler in handlers:
            for signal_type in handler.types:
                self.handlers[signal_type].append(handler)

    def __loop(self):
        while True:
            # noinspection PyBroadException
            try:
                self.__wakeup.wait(self.update_interval)
                self.__wakeup.clear()
                if self.__signals:
                    with self.__lock:
                        signals, self.__signals = self.__signals, collections.defaultdict(list)
                    for signal_type in signals.iterkeys():
                        for handler in utils.chain(
                            self.handlers[signal_type],
                            self.handlers[ctss.ALL_SIGNALS]
                        ):
                            handler.handle(signal_type, signals[signal_type])

            except Exception:
                self.logger.exception("Error sending API usage statistics")

            self.__processed.set()  # signal values are already lost anyway during the above swap

    def push(self, signals):
        """
        Enqueue a signal for processing in a separate thread

        :param signals: list of dicts with `Signaler.TYPE_KEY` key
        """

        with self.__lock:
            if self.__thread is None:
                self.__thread = threading.Thread(target=self.__loop)
                self.__thread.daemon = True
                self.__thread.start()

            if isinstance(signals, dict):
                signals = [signals]
            for signal in signals:
                signal_type = signal.pop(self.TYPE_KEY, None)
                if signal_type:
                    self.__signals[signal_type].append(signal)
                else:
                    self.logger.error("Unable to process signal with no type %r", signal)

    def wait(self):
        """
        Forcibly handle currently unprocessed signals.
        Useful when timeout hasn't expired yet or is too big, but there are already signals to process
        (say, before the executable is terminated)
        """

        self.__processed.clear()
        self.__wakeup.set()
        # do not wait indefinitely, since statistics loss is less important than a halted task
        self.__processed.wait(self.update_interval)


class AggregatingClientSignalHandler(SignalHandler):
    """
    Aggregate signals by certain fields (for example, turn 10 signals with a value of 3
    into a single signal of value 30) and send them via REST API over HTTP.
    Quite useful when you want to collect statistics over a certain period and
    avoid unnecessary service load at the same time
    """

    types = (ctss.SignalType.TASK_OPERATION,)
    DATE_KEY = "date"
    TIMESTAMP_KEY = "timestamp"

    def __init__(self, **kws):
        self.aggregate_by = kws.get("aggregate_by", [])
        self.sum_by = kws.get("sum_by", [])
        self.fixed_args = kws.get("fixed_args", {})

        # when called from task or executor, this instance is guaranteed to have the task's authorization,
        # thanks to _external_auth setup done in bin.executor.execute_command()
        self.api = kws.get("api") or rest.Client(logger=kws.get("logger"))

    def handle(self, signal_type, signals):
        data = {}
        for signal in signals:
            signal.update(self.fixed_args)
            key = tuple(map(signal.get, self.aggregate_by))
            if key not in data:
                data[key] = dict(signal)
                continue

            for field in self.sum_by:
                data[key][field] += signal[field]

        utcnow = api.DateTime().encode(datetime.datetime.utcnow())
        for value in data.itervalues():
            value[self.DATE_KEY] = value[self.TIMESTAMP_KEY] = utcnow
        self.api.statistics[signal_type](data.values())


class ClientSignalHandler(SignalHandler):
    def __init__(self):
        token = config.Registry().client.auth.oauth_token
        token = utils.read_settings_value_from_file(token) if token else None
        self.api = rest.Client(auth=token)
        self.api.DEFAULT_TIMEOUT *= 10
        self.api.MAX_TIMEOUT *= 10
        self.api.reset()
        self.encoder = api.DateTime()

    types = (ctss.SignalType.RESOURCE_REGISTRATION, ctss.SignalType.EXCEPTION_STATISTICS,)

    def handle(self, signal_type, signals):
        if config.Registry().common.installation not in ctm.Installation.Group.LOCAL:
            for signal in signals:
                signal["date"] = signal["timestamp"] = self.encoder.encode(signal["timestamp"])
            self.api.statistics[signal_type](signals)


class ServerSignalHandler(SignalHandler):
    """
    A server-side signal handler saves signals into an intermediate MongoDB collection.
    """

    def __init__(self):
        from sandbox.yasandbox.database import mapping
        self.mapping = mapping

    types = (ctss.ALL_SIGNALS,)

    @utils.singleton_property
    def db(self):
        """
        :return: a database object
        :rtype: pymongo.database.Database
        """
        return self.mapping.ensure_connection().rw.connection[
            config.Registry().server.services.statistics_processor.database
        ]

    def handle(self, signal_type, signals):
        self.db[signal_type].insert_many(signals)
