import json
import time
import flask
import re
import copy

from infra.swatlib.gevent.geventutil import gevent_idle_iter
from itertools import chain

PREPARE_SIGNAL_NAME = re.compile('[<>_/\.:]')


class Counter(object):
    def __init__(self, name, suffix, v=0):
        self.name = name
        self._v = v
        self._fmt = '{}_{}'.format(name, suffix)

    def inc(self, d=1):
        self._v += d

    def get(self):
        result = self._v
        return result

    def fmt(self):
        return [self._fmt, self._v]

    def __str__(self):
        return str(self.fmt())


class Gauge(object):
    def __init__(self, name, suffix, v=0):
        self.name = name
        self._v = v
        self._fmt = '{}_{}'.format(name, suffix)

    def set(self, v):
        self._v = v

    def fmt(self):
        return [self._fmt, self._v]


class Timer(object):
    __slots__ = ['hgram', '_start']

    def __init__(self, hgram):
        self.hgram = hgram
        self._start = time.time()

    def stop(self):
        self.hgram.observe(time.time() - self._start)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.stop()


class Histogram(object):
    """
    Simple histogram counting observed values.
    """
    DEFAULT_BUCKETS_SEC = (
        .01, .025, .05, .1, .25, .5, .75, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0,
        4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5, 8.0, 8.5, 9.0, 9.5, 10.0, 15.0,
        20.0, 30.0, 50.0, 100.0, 1000.0
    )

    def __init__(self, name, buckets=None):
        self.name = name + '_hgram'
        if not buckets:
            buckets = Histogram.DEFAULT_BUCKETS_SEC
        if len(buckets) > 50:
            raise ValueError('Too many buckets: {}'.format(len(buckets)))
        buckets = [float(b) for b in buckets]
        if buckets != sorted(buckets):
            # This is probably an error on the part of the user,
            # so raise rather than sorting for them.
            raise ValueError('Buckets not in sorted order: {}'.format(buckets))
        self._buckets = [[b, 0] for b in reversed(buckets)]
        if self._buckets[-1][0] != 0:
            self._buckets.append([0, 0])

    def timer(self):
        return Timer(self)

    def observe(self, amount):
        for b in self._buckets:
            if amount >= b[0]:
                b[1] += 1
                break

    def fmt(self):
        return [self.name, [i for i in reversed(self._buckets)]]


class AbsoluteHistogram(object):
    """
    Absolute Histogram counting observed values, export values on finalise.
    """
    DEFAULT_BUCKETS_SEC = (
        .01, .025, .05, .1, .25, .5, .75, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0,
        4.5, 5.0, 5.5, 6.0, 6.5, 7.0, 7.5, 8.0, 8.5, 9.0, 9.5, 10.0, 15.0,
        20.0, 30.0, 50.0, 100.0, 1000.0
    )

    def __init__(self, name, buckets=None):
        self.name = name + '_ahhh'
        if not buckets:
            buckets = Histogram.DEFAULT_BUCKETS_SEC
        if len(buckets) > 50:
            raise ValueError('Too many buckets: {}'.format(len(buckets)))
        buckets = [float(b) for b in buckets]
        if buckets != sorted(buckets):
            # This is probably an error on the part of the user,
            # so raise rather than sorting for them.
            raise ValueError('Buckets not in sorted order: {}'.format(buckets))
        self._initial_buckets = [[b, 0] for b in reversed(buckets)]
        if self._initial_buckets[-1][0] != 0:
            self._initial_buckets.append([0, 0])

        self._buckets = None
        self._tmp_buckets = copy.deepcopy(self._initial_buckets)

    def observe(self, amount):
        for b in self._tmp_buckets:
            if amount >= b[0]:
                b[1] += 1
                break

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.finalize()

    def finalize(self):
        self._buckets = copy.deepcopy(self._tmp_buckets)
        self._tmp_buckets = copy.deepcopy(self._initial_buckets)

    def fmt(self):
        if self._buckets:
            return [self.name, [i for i in reversed(self._buckets)]]


class Registry(object):
    def __init__(self):
        self._counters = {}
        self._gauges = {}
        self._histograms = {}

    def prepare_name(self, name):
        return PREPARE_SIGNAL_NAME.sub('-', name)

    def get_counter(self, name, suffix='axxx'):
        name = self.prepare_name(name)
        key = (name, suffix)
        c = self._counters.get(key)
        if c is None:
            c = Counter(name, suffix)
            self._counters[key] = c
        return c

    def get_gauge(self, name, suffix='axxx'):
        name = self.prepare_name(name)
        key = (name, suffix)
        g = self._gauges.get(key)
        if g is None:
            g = Gauge(name, suffix)
            self._gauges[key] = g
        return g

    def get_histogram(self, name, buckets=None):
        name = self.prepare_name(name)
        h = self._histograms.get(name)
        if h is None:
            h = Histogram(name, buckets)
            self._histograms[name] = h
        return h

    def get_absolute_histogram(self, name, buckets=None):
        name = self.prepare_name(name)
        h = self._histograms.get(name)
        if h is None:
            h = AbsoluteHistogram(name, buckets)
            self._histograms[name] = h
        return h

    def items(self):
        metrics = self._histograms.values(), self._counters.values(), self._gauges.values()
        for metric in chain(*metrics):
            fmt = metric.fmt()
            if fmt is not None:
                yield fmt

    def path(self, *name):
        return PathRegistry(self, "-".join(name))


class PathRegistry(object):
    def __init__(self, registry, path):  # type: (Registry, str) -> None
        self._registry = registry
        self._path = path

    def get_counter(self, name, suffix='axxx'):
        return self._registry.get_counter('{}-{}'.format(self._path, name), suffix)

    def get_histogram(self, name, buckets=None):
        return self._registry.get_histogram('{}-{}'.format(self._path, name), buckets=buckets)

    def get_absolute_histogram(self, name, buckets=None):
        return self._registry.get_absolute_histogram('{}-{}'.format(self._path, name), buckets=buckets)

    def get_gauge(self, name, suffix='axxx'):  # type: (str, str) -> Gauge
        return self._registry.get_gauge('{}-{}'.format(self._path, name), suffix)

    def path(self, *name):
        return PathRegistry(self._registry, "{}-{}".format(self._path, '-'.join(name)))


ROOT_REGISTRY = Registry()


class MetricsExt(object):

    def __init__(self, flask_app=None, registry=None, performance_enabled=True, status_code_enabled=True,
                 export_current_time=True):
        self.app = None
        self._performance_enabled = performance_enabled
        self._status_code_enabled = status_code_enabled
        self._registry = registry or ROOT_REGISTRY
        self._http_service_registry = self._registry.path('services', 'http')
        self._http_statuses_registry = {
            metric_name: self._http_service_registry.path('statuses', metric_name)
            for metric_name in ('1xx', '2xx', '3xx', '404', '4xx', '5xx', 'xxx')
        }
        self._export_current_time = export_current_time

        if flask_app is not None:
            self.init_flask_app(flask_app)

    def init_flask_app(self, app):
        self.app = app

        if self._performance_enabled:
            app.before_request(self._request_performance_init)
            app.teardown_request(self._request_performance_close)

        if self._status_code_enabled:
            app.after_request(self._http_status_counter)

        self.app.add_url_rule('/yasm_stats/', view_func=self.render_yasm_stats)
        self.app.add_url_rule('/yasm_stats_h/', view_func=self.render_yasm_stats_human_readable)

    def render_yasm_stats(self):
        """
        Returns a list with stats for yasm format.
        """
        stats_result = list(gevent_idle_iter(self._registry.items(), idle_period=50))
        if self._export_current_time:
            stats_result.append([
                'current-timestamp_axxx', int(time.time())
            ])
        return flask.Response(response=json.dumps(stats_result), content_type=b'application/json')

    def render_yasm_stats_human_readable(self):
        resp = "<table>"
        line_tpl = "<tr><td>{}</td><td>{}</td><td>{}</td></tr>"
        for i, (stat_key, value) in enumerate(sorted(gevent_idle_iter(self._registry.items(), idle_period=50))):
            resp += line_tpl.format(i, stat_key, value)
        resp += "</table>"
        return resp

    def _request_performance_init(self):
        if flask.request.endpoint is not None:
            target = flask.request.endpoint or 'invalid'
            flask.g._url_perf_timer = self._http_service_registry.get_histogram(target).timer()
            flask.g._total_perf_timer = self._http_service_registry.get_histogram('total').timer()

    @staticmethod
    def _request_performance_close(*args, **kwargs):
        t = getattr(flask.g, '_url_perf_timer', None)
        if t is not None:
            t.stop()
        t = getattr(flask.g, '_total_perf_timer', None)
        if t is not None:
            t.stop()

    def _http_status_counter(self, response):
        status_code = response.status_code
        status_class = status_code // 100

        if status_code == 404:
            metric_name = '404'
        elif 1 <= status_class <= 5:
            metric_name = '{}xx'.format(status_class)
        else:
            metric_name = 'xxx'

        registry = self._http_statuses_registry[metric_name]
        target = flask.request.endpoint or 'invalid'

        registry.get_counter(target).inc()
        registry.get_counter('total').inc()

        return response
