import collections
import functools
import logging
import multiprocessing
import threading

from frozendict import frozendict
from library.python.monlib import metric_registry

StatMessage = collections.namedtuple("StatMessage", ["labels", "method", "args"])

logger = logging.getLogger(__name__)


class MetricProxy:
    def __init__(self, labels, queue, funcs):
        self.labels = labels
        self.queue = queue
        for func in funcs:
            setattr(self, func, self._make_proxy_method(func))

    def _make_proxy_method(self, func):
        return lambda *args: self.queue.put_nowait(StatMessage(self.labels, func, args))


def proxy(method):
    @functools.wraps(method)
    def wrapped(self, labels, *args, **kwargs):
        self.metrics[frozendict(labels)] = getattr(self.registry, method.__name__)(labels, *args, **kwargs)
        return MetricProxy(labels, self.queue, method(self))

    return wrapped


class MultiprocessingMetricRegistry:
    def __init__(self, common_labels=None):
        self.registry = metric_registry.MetricRegistry(common_labels)
        self.queue = multiprocessing.Queue()
        self.metrics = {}
        self.aggregate = set()
        self.running = False
        self.collecting_thread = None

    @proxy
    def rate(self):
        return ["add", "inc"]

    @proxy
    def histogram_rate(self):
        return ["collect"]

    def collect(self):
        while self.running:
            try:
                stat = self.queue.get()
                logger.info(stat)
                for gauge in self.get_metrics(stat.labels):
                    getattr(gauge, stat.method)(*stat.args)
            except:
                logger.exception("Collect error")

    def add_aggregation_label(self, label):
        self.aggregate.add(label)

    def get_metrics(self, labels):
        yield self.metrics[frozendict(labels)]
        for label in labels:
            if label in self.aggregate:
                aggregated_labels = {k: v for k, v in labels.items() if k != label}
                yield self.metrics[frozendict(aggregated_labels)]

    def start(self):
        self.running = True
        self.collecting_thread = threading.Thread(target=self.collect)
        self.collecting_thread.setDaemon(True)
        self.collecting_thread.start()

    def stop(self):
        self.running = False
        self.collecting_thread.join(timeout=10)

    def accept(self, *args, **kwargs):
        self.registry.accept(*args, **kwargs)
