from __future__ import absolute_import

import abc
import json
import time
import logging
import datetime
import threading
import collections

import six

try:
    import logbroker.unified_agent.client.python as unified_agent
except ImportError:
    unified_agent = None

from ..types import statistics as ctss

from .. import fs
from .. import api
from .. import rest
from .. import config as common_config
from .. import patterns
from .. import itertools as cit


@six.add_metaclass(abc.ABCMeta)
class SignalHandler(object):
    """
    Base class for signal handlers.
    To create a custom handler, fill types with names of the signals you want to process
    and override `handle()` method.
    """

    types = abc.abstractproperty()

    @abc.abstractmethod
    def handle(self, signal_type, signals):
        pass

    def reset(self):
        """
        Clears handler internal state, e.g. for singleton instance of a handler after process fork
        """
        pass


@six.add_metaclass(patterns.ThreadSafeSingletonMeta)
class Signaler(object):
    """
    Signaler lets you, well, send signals of any type asynchronously from any place.
    After initialization, on the first push() call it fires a daemon thread, which then fetches signals
    (once in a `settings.server.statistics.update_interval` seconds).

    Although a signal can have an arbitrary structure, it follows two restrictions:
    1) it is a dict-like object, and
    2) it has a `type` key, which denotes the signal's type.

    Class usage:

    .. code-block:: python

        import logging

        from sandbox import common


        class YourOwnSignalHandler(common.statistics.SignalHandler):
            types = (common.types.statistics.SignalType.WHATEVER,)

            def handle(cls, signal_type, signals):
                logging.info("%d signals processed in this batch", len(signals))


        common.statistics.Signaler(YourOwnSignalHandler(), logger=logging.getLogger(__file__))
        for _ in xrange(100):
            # as a result, one or more random numbers are logged, depending on Signaler's processing speed
            common.statistics.Signaler().push(dict(
                type=common.types.statistics.SignalType.WHATEVER,
                hitcount=1234,
                omg="wtf",
            ))
    """

    handlers = collections.defaultdict(list)

    TYPE_KEY = "type"
    DEFAULT_UPDATE_INTERVAL = 5  # seconds

    def __init__(self, *handlers, **kws):
        self.logger = kws.get("logger", None) or logging.getLogger(__name__)
        self.update_interval = kws.get("update_interval", self.DEFAULT_UPDATE_INTERVAL)
        self.component = kws.get("component")
        self.enabled = kws.get("enabled")
        if self.enabled is None:
            self.enabled = kws.get("config", common_config.Registry()).common.statistics.enabled
        self.all_signals = kws.get("all_signals")

        self.__signals = collections.defaultdict(list)
        self.__lock = threading.Lock()
        self.__wakeup = threading.Event()
        self.__processed = threading.Event()
        self.__thread = None

        self.register(*handlers)

    def register(self, *handlers):
        """
        Register additional signal handlers on the fly
        """
        for handler in handlers:
            for signal_type in handler.types:
                self.handlers[signal_type].append(handler)

    def __loop(self):
        while True:
            # noinspection PyBroadException
            try:
                self.__wakeup.wait(self.update_interval)
                self.__wakeup.clear()
                if self.__signals:
                    with self.__lock:
                        signals, self.__signals = self.__signals, collections.defaultdict(list)
                    for signal_type in six.iterkeys(signals):
                        handlers = list(self.handlers[signal_type])
                        if signal_type != ctss.ALL_SIGNALS:
                            handlers.extend(self.handlers[ctss.ALL_SIGNALS])
                        if not len(handlers):
                            self.logger.info("No handlers for type %s", signal_type)
                        for handler in handlers:
                            handler.handle(signal_type, signals[signal_type])

            except Exception:
                self.logger.exception("Error sending API usage statistics")

            self.__processed.set()  # signal values are already lost anyway during the above swap

    def push(self, signals):
        """
        Enqueue a signal for processing in a separate thread

        :param signals: list of dicts with `Signaler.TYPE_KEY` key
        """

        if not self.enabled:
            return

        with self.__lock:
            if self.__thread is None or not self.__thread.is_alive():
                self.__thread = threading.Thread(target=self.__loop)
                self.__thread.daemon = True
                self.__thread.start()

            if isinstance(signals, dict):
                signals = [signals]

            if self.all_signals:
                self.__signals[ctss.ALL_SIGNALS].extend(signals)
                return
            for signal in signals:
                signal_type = signal.pop(self.TYPE_KEY, None)
                if signal_type:
                    self.__signals[signal_type].append(signal)
                else:
                    self.logger.error("Unable to process signal with no type %r", signal)

    def wait(self):
        """
        Forcibly handle currently unprocessed signals.
        Useful when timeout hasn't expired yet or is too big, but there are already signals to process
        (say, before the executable is terminated)
        """
        if not self.enabled:
            return

        self.__processed.clear()
        self.__wakeup.set()
        # do not wait indefinitely, since statistics loss is less important than a halted task
        self.__processed.wait(self.update_interval)

    def reset(self):
        self.__signals.clear()
        self.__thread = None
        for handlers in self.handlers.values():
            for handler in handlers:
                handler.reset()


class AggregatingClientSignalHandler(SignalHandler):
    """
    Aggregate signals by certain fields (for example, turn 10 signals with a value of 3
    into a single signal of value 30) and send them via REST API over HTTP.
    Quite useful when you want to collect statistics over a certain period and
    avoid unnecessary service load at the same time
    """

    types = (ctss.SignalType.TASK_OPERATION,)
    DATE_KEY = "date"
    TIMESTAMP_KEY = "timestamp"

    def __init__(self, **kws):
        self.aggregate_by = kws.get("aggregate_by", [])
        self.sum_by = kws.get("sum_by", [])
        self.fixed_args = kws.get("fixed_args", {})
        self.replace_timestamp = kws.get("replace_timestamp", True)

        # when called from task or executor, this instance is guaranteed to have the task's authorization,
        # thanks to _external_auth setup done in bin.executor.execute_command()
        self.api = kws.get("api") or rest.Client(logger=kws.get("logger"))

    def handle(self, signal_type, signals):
        data = {}
        for signal in signals:
            signal.update(self.fixed_args)
            key = tuple(map(signal.get, self.aggregate_by))
            if key not in data:
                data[key] = dict(signal)
                continue

            for field in self.sum_by:
                data[key][field] += signal[field]

        encoder = api.DateTime()
        utcnow = encoder.encode(datetime.datetime.utcnow())
        for value in six.itervalues(data):
            if not self.replace_timestamp:
                timestamp = value.get(self.TIMESTAMP_KEY, value.get(self.DATE_KEY))
                if not timestamp:
                    timestamp = utcnow
                else:
                    timestamp = encoder.encode(timestamp)
            else:
                timestamp = utcnow
            value[self.DATE_KEY] = value[self.TIMESTAMP_KEY] = timestamp
        self.api.statistics[signal_type](list(data.values()))


class ApiCallClientSideHandler(AggregatingClientSignalHandler):
    types = (ctss.SignalType.API_CALL_CLIENT_SIDE,)


class ClientSignalHandler(SignalHandler):
    types = (
        ctss.SignalType.EXCEPTION_STATISTICS,
        ctss.SignalType.TASK_EXCEPTION_STATISTICS,
    )

    def __init__(self, token=None, url=None, task_id=None):
        if token is None:
            token = common_config.Registry().client.auth.oauth_token
        token = fs.read_settings_value_from_file(token) if token else None
        self.task_id = task_id

        self.api = rest.Client(base_url=url, auth=token)
        self.api.DEFAULT_TIMEOUT *= 10
        self.api.MAX_TIMEOUT *= 10
        self.api.reset()
        self.encoder = api.DateTime()

    def handle(self, signal_type, signals):
        for signal in signals:
            if isinstance(signal["timestamp"], datetime.datetime):
                signal["date"] = signal["timestamp"] = self.encoder.encode(signal["timestamp"])

        if self.task_id and signal_type == ctss.SignalType.EXCEPTION_STATISTICS:
            signal_type = ctss.SignalType.TASK_EXCEPTION_STATISTICS
            for signal in signals:
                signal.pop("component")
                signal["task_id"] = self.task_id

        self.api.statistics[signal_type](signals)


class ServerSignalHandler(SignalHandler):
    """
    A server-side signal handler saves signals into an intermediate MongoDB collection.
    """

    def __init__(self, config=None):
        from sandbox.yasandbox.database import mapping
        self.mapping = mapping
        self.config = config
        self.logger = logging.getLogger(__name__)

    types = (ctss.ALL_SIGNALS,)

    @patterns.singleton_property
    def db(self):
        """
        :return: a database object
        :rtype: pymongo.database.Database
        """
        database = (
            self.config.common.statistics.database
            if self.config is not None
            else common_config.Registry().common.statistics.database
        )
        return self.mapping.get_connection()[database]

    def handle(self, signal_type, signals):
        try:
            self.logger.info("Inserting %s signal(s) of type %s", len(signals), signal_type)
            self.db[signal_type].insert_many(signals)
        except Exception as ex:
            self.logger.error("Failed to insert signals.", exc_info=ex)

    def reset(self):
        # reset DB connection object to re-establish it
        del self.db


class UASignalHandler(SignalHandler):
    """
    A signal handler sending signals via Unified Agent.
    'signals_group' defines the UA topic that will be used for signals delivery.
    """
    MAX_CHUNK_SIZE = 1000

    types = (ctss.ALL_SIGNALS,)

    class CustomEncoder(json.JSONEncoder):
        def default(self, o):
            if isinstance(o, datetime.datetime):
                return o.isoformat() + "Z"
            return super(UASignalHandler.CustomEncoder, self).default(o)

    def __init__(self, signals_group):
        self._signals_group = signals_group
        self._config = common_config.Registry()
        self._ua_uri = self._config.common.unified_agent.get(signals_group, {}).get("uri")
        self._logger = logging.getLogger(__name__)
        self._ua_client = None
        self._ua_session = None

    @property
    def ua_session(self):
        if self._ua_session is not None:
            return self._ua_session
        if self._ua_client is None:
            self._ua_client = unified_agent.Client(self._ua_uri, log_level=logging.ERROR)
        self._ua_session = self._ua_client.create_session()
        return self._ua_session

    @ua_session.deleter
    def ua_session(self):
        if self._ua_session is not None:
            try:
                self._ua_session.close()
            except Exception as ex:
                self._logger.warning("Error while closing unified agent session: %s", ex)
            self._ua_session = None

    def handle(self, signal_type, signals):
        try:
            if not unified_agent:
                self._logger.warning("Cannot send signals, Unified Agent is not available")
                return
            if not self._ua_uri:
                self._logger.warning("Cannot send signals, Unified Agent socket not defined")
                return
            self._logger.info("Sending statistics")
            for chunk in cit.chunker(signals, self.MAX_CHUNK_SIZE):
                self._logger.info("Sending %s signal(s) of type %s", len(chunk), signal_type)
                self.ua_session.send(json.dumps({signal_type: chunk}, cls=self.CustomEncoder) + "\n", time.time())
        except Exception as ex:
            self._logger.error("Failed to insert signals.", exc_info=ex)
            del self.ua_session

    def reset(self):
        self._ua_client = None
        self._ua_session = None


class UASignalHandlerServiceApiCalls(UASignalHandler):
    types = (
        ctss.SignalType.API_CALL,
    )

    def __init__(self):
        super(UASignalHandlerServiceApiCalls, self).__init__("serviceapi_calls_statistics")


class UASignalHandlerResourceAudit(UASignalHandler):
    types = (
        ctss.YTSignalType.RESOURCE_AUDIT,
    )

    def __init__(self):
        super(UASignalHandlerResourceAudit, self).__init__("resource_audit_statistics")


class UASignalHandlerServiceApi(UASignalHandler):
    """
    Handles all signals that come via http-requests
    """

    types = tuple(set(ctss.SignalType) - set(UASignalHandlerServiceApiCalls.types))

    def __init__(self):
        super(UASignalHandlerServiceApi, self).__init__("serviceapi_statistics")


class UASignalHandlerForProxy(UASignalHandler):
    types = (
        ctss.SignalType.EXCEPTION_STATISTICS,
        ctss.SignalType.MDS_API_CALL,
        ctss.SignalType.PROXY_API_CALL,
        ctss.SignalType.PROXY_REQUESTS_DYNAMIC_NETS,
    )

    def __init__(self):
        super(UASignalHandlerForProxy, self).__init__("proxy_statistics")


class SignalHandlerForAgentR(ClientSignalHandler):
    types = (
        ctss.SignalType.EXCEPTION_STATISTICS,
        ctss.SignalType.MDS_API_CALL,
        ctss.SignalType.RESOURCE_REGISTRATION,
        ctss.SignalType.RESOURCE_SYNC,
        ctss.SignalType.RESOURCES_SYNC_TO_MDS_DELAY,
        ctss.SignalType.TASK_HARDWARE_METRICS,
        ctss.SignalType.COREDUMPS,
    )

    # TODO: use UASignalHandler(signals_group="agentr_statistics") SANDBOX-9127
    # def __init__(self):
    #     super(SignalHandlerForAgentR, self).__init__("agentr_statistics")


class SignalHandlerInternalServiceApi(SignalHandler):
    types = (ctss.ALL_SIGNALS,)

    def __init__(self):
        from sandbox.serviceapi.mules import signaler
        self.signaler = signaler

    def handle(self, signal_type, signals):
        self.signaler.send_msg(signals)
