import time
import re
import weakref
from contextlib import contextmanager
from itertools import chain

import six
from six.moves.urllib.parse import urlparse
import flask
import yaml
import ujson
import requests
from urllib3.util import Timeout as TimeoutSauce
from sepelib.core import config
from sepelib.flask.auth.util import login_exempt
from sepelib.flask.h import prep_response
from infra.swatlib.rpc import parse_request
from infra.swatlib import gutils


try:
    import yt.packages.requests as yt_requests
    from yt.packages.urllib3.util import Timeout as YtTimeoutSauce
    yt_requests_imported = True
except ImportError:
    yt_requests_imported = False

PREPARE_SIGNAL_NAME = re.compile(r'[<>_/\.:]')


class Counter(object):
    DEFAULT_SUFFIX = 'summ'  # dmmm - Delta, sum, sum, sum

    def __init__(self, name, suffix=DEFAULT_SUFFIX, v=0):
        self.name = name
        self._v = v
        self._fmt = '{}_{}'.format(name, suffix)

    def clear(self):
        self._v = 0

    def inc(self, d=1):
        self._v += d

    def get(self):
        result = self._v
        return result

    def fmt(self):
        return [self._fmt, self._v]

    @contextmanager
    def count_errors(self, error_types=(Exception,)):
        try:
            yield
        except error_types:
            self.inc()
            raise

    def __str__(self):
        return str(self.fmt())


class Gauge(object):
    DEFAULT_SUFFIX = 'axxx'  # Absolute, Max, Max, Max

    def __init__(self, name, suffix, v=0):
        self.name = name
        self._v = v
        self._fmt = '{}_{}'.format(name, suffix)

    def clear(self):
        self._v = 0

    def inc(self, d=1):
        self._v += d

    def dec(self, d=1):
        self._v -= d

    def set(self, v):
        self._v = v

    def get(self):
        return self._v

    def fmt(self):
        return [self._fmt, self._v]

    def timer(self):
        return GaugeTimer(self)

    def __enter__(self):
        self.inc()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.dec()


class Timer(object):
    __slots__ = ['hgram', '_start']

    def __init__(self, hgram):
        self.hgram = hgram
        self._start = time.time()

    def clear(self):
        self.hgram.clear()

    def start(self):
        self._start = time.time()

    def stop(self):
        self.hgram.observe(time.time() - self._start)

    def __enter__(self):
        self.start()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.stop()


class GaugeTimer(object):
    __slots__ = ['gauge', '_start']

    def __init__(self, gauge):
        self.gauge = gauge
        self._start = time.time()

    def clear(self):
        self.gauge.clear()

    def start(self):
        self._start = time.time()

    def stop(self):
        self.gauge.set(time.time() - self._start)

    def __enter__(self):
        self.start()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.stop()


class Histogram(object):
    """
    Simple histogram counting observed values.
    """
    DEFAULT_BUCKETS_SEC = (
        .01, .025, .05, .1, .25, .5, .75, 1.0, 1.5, 2.0,
        2.5, 3.0, 3.5, 4.0, 4.5, 5.0, 5.5, 6.0, 6.5, 7.0,
        7.5, 8.0, 8.5, 9.0, 9.5, 10.0, 15.0, 20.0, 30.0, 50.0, 100.0,
        200.0, 300.0, 400.0, 500.0, 1000.0
    )

    def __init__(self, name, buckets=None):
        self.name = name + '_hgram'
        if not buckets:
            buckets = Histogram.DEFAULT_BUCKETS_SEC
        if len(buckets) > 50:
            raise ValueError('Too many buckets: {}'.format(len(buckets)))
        buckets = [float(b) for b in buckets]
        if buckets != sorted(buckets):
            # This is probably an error on the part of the user,
            # so raise rather than sorting for them.
            raise ValueError('Buckets not in sorted order: {}'.format(buckets))
        self._buckets = [[b, 0] for b in reversed(buckets)]
        if self._buckets[-1][0] != 0:
            self._buckets.append([0, 0])

    def clear(self):
        self._buckets = [[b[0], 0] for b in self._buckets]

    def timer(self):
        return Timer(self)

    def observe(self, amount):
        for b in self._buckets:
            if amount >= b[0]:
                b[1] += 1
                break

    def fmt(self):
        return [self.name, [i for i in reversed(self._buckets)]]


class Proxy(object):
    def __init__(self, name, suffix, getter):
        self._getter = getter
        self._fmt = '{}_{}'.format(name, suffix)

    def fmt(self):
        v = self._getter()
        return [self._fmt, v]


class Registry(object):
    def __init__(self):
        self._counters = {}
        self._gauges = {}
        self._histograms = {}
        self._proxies = weakref.WeakValueDictionary()

    def clear(self):
        for c in six.itervalues(self._counters):
            c.clear()
        for g in six.itervalues(self._gauges):
            g.clear()
        for h in six.itervalues(self._histograms):
            h.clear()
        self._proxies.clear()

    def prepare_name(self, name, tags=None):
        prepared_name = PREPARE_SIGNAL_NAME.sub('-', name)
        if tags:
            tags_string = ""
            for key, value in tags.iteritems():
                tags_string += key + '=' + value + ';'
            prepared_name = tags_string + prepared_name
        return prepared_name

    def get_counter(self, name, suffix=Counter.DEFAULT_SUFFIX, tags=None):
        name = self.prepare_name(name, tags)
        key = (name, suffix)
        c = self._counters.get(key)
        if c is None:
            c = Counter(name, suffix)
            self._counters[key] = c
        return c

    def get_gauge(self, name, suffix=Gauge.DEFAULT_SUFFIX, tags=None):
        name = self.prepare_name(name, tags)
        key = (name, suffix)
        g = self._gauges.get(key)
        if g is None:
            g = Gauge(name, suffix)
            self._gauges[key] = g
        return g

    def get_summable_gauge(self, name, tags=None):
        return self.get_gauge(name, 'ammx', tags)

    def get_histogram(self, name, buckets=None, tags=None):
        name = self.prepare_name(name, tags)
        h = self._histograms.get(name)
        if h is None:
            h = Histogram(name, buckets)
            self._histograms[name] = h
        return h

    def get_proxy(self, name, suffix, getter, tags=None):
        name = self.prepare_name(name, tags)
        p = self._proxies.get(name)
        if p is None:
            p = Proxy(name, suffix, getter)
            self._proxies[name] = p
        return p

    def items(self, filter_keys=None):
        metrics = (
            list(six.itervalues(self._histograms)),
            list(six.itervalues(self._counters)),
            list(six.itervalues(self._gauges)),
            list(six.itervalues(self._proxies)),
        )
        for metric in chain(*metrics):
            fmt = metric.fmt()
            if fmt is not None and (filter_keys is None or filter_keys in fmt[0]):
                yield fmt

    def path(self, *name):
        return PathRegistry(self, "-".join(name))


class PathRegistry(object):
    def __init__(self, registry, path):  # type: (Registry, str) -> None
        self._registry = registry
        self._path = path

    def get_counter(self, name, suffix=Counter.DEFAULT_SUFFIX):
        return self._registry.get_counter('{}-{}'.format(self._path, name), suffix)

    def get_gauge(self, name, suffix=Gauge.DEFAULT_SUFFIX):  # type: (str, str) -> Gauge
        return self._registry.get_gauge('{}-{}'.format(self._path, name), suffix)

    def get_summable_gauge(self, name):
        return self._registry.get_summable_gauge('{}-{}'.format(self._path, name))

    def get_histogram(self, name, buckets=None):
        return self._registry.get_histogram('{}-{}'.format(self._path, name), buckets=buckets)

    def get_proxy(self, name, suffix, getter):
        return self._registry.get_proxy('{}-{}'.format(self._path, name), suffix, getter)

    def path(self, *name):
        return PathRegistry(self._registry, "{}-{}".format(self._path, '-'.join(name)))


class MetricsExt(object):

    def __init__(self, flask_app=None, cfg=None, registry=None):
        self._config = cfg or {}
        self._performance_enabled = self._config.get('performance_enabled', True)
        self._status_code_enabled = self._config.get('status_code_enabled', True)
        self._export_current_timestamp = self._config.get('export_current_timestamp', True)
        self._yasm_stat_url = self._config.get('yasm_stat_url', '/yasm_stats/')

        self._registry = registry or ROOT_REGISTRY
        self._http_service_registry = self._registry.path('services', 'http')

        self._http_statuses_registry = {
            metric_name: self._http_service_registry.path('statuses', metric_name)
            for metric_name in ('1xx', '2xx', '3xx', '404', '429', '4xx', '5xx', 'xxx')
        }

        self.app = None
        self._destructive_targets = set()
        if flask_app is not None:
            self.init_flask_app(flask_app)

    def init_flask_app(self, app):
        self.app = app
        for target, handler in six.iteritems(self.app.view_functions):
            if getattr(handler, 'is_destructive', False):
                self._destructive_targets.add(target)

        if self._performance_enabled:
            app.before_request(self._request_performance_init)
            app.teardown_request(self._request_performance_close)

        if self._status_code_enabled:
            app.after_request(self._http_status_counter)

        self.app.add_url_rule(self._yasm_stat_url, view_func=self.render_yasm_stats)

    @login_exempt
    def render_yasm_stats(self):
        """
        Returns a list with stats for yasm format.
        """
        pretty = 'pretty' in flask.request.args
        filter_keys = flask.request.args.get('filter', None)
        stats_result = list(gutils.idle_iter(self._registry.items(filter_keys), idle_period=50))

        if self._export_current_timestamp:
            stats_result.append([
                'current-timestamp_axxx', int(time.time())
            ])

        if pretty:
            return prep_response(yaml.safe_dump({s[0]: s[1] for s in stats_result}), fmt='txt')

        return flask.Response(response=ujson.dumps(stats_result), content_type=b'application/json')

    def _request_performance_init(self):
        if flask.request.endpoint is not None:
            target = flask.request.endpoint or 'invalid'
            self._http_service_registry.get_counter('{}-{}'.format(target, 'count')).inc()
            if target in self._destructive_targets:
                self._http_service_registry.get_counter('destructive-count').inc()
            sent_at = parse_request.parse_x_start_time_header(flask.request)
            if sent_at is not None:
                wait = time.time() - sent_at
                self._http_service_registry.get_histogram('wait-timer').observe(wait)

            flask.g._url_perf_timer = self._http_service_registry.get_histogram('{}-timer'.format(target)).timer()
            flask.g._total_perf_timer = self._http_service_registry.get_histogram('total-timer').timer()

    @staticmethod
    def _request_performance_close(*args, **kwargs):
        t = getattr(flask.g, '_url_perf_timer', None)
        if t is not None:
            t.stop()
        t = getattr(flask.g, '_total_perf_timer', None)
        if t is not None:
            t.stop()

    def _http_status_counter(self, response):
        status_code = response.status_code
        status_class = status_code // 100

        if status_code == 404:
            metric_name = '404'
        elif status_code == 429:
            metric_name = '429'
        elif 1 <= status_class <= 5:
            metric_name = '{}xx'.format(status_class)
        else:
            metric_name = 'xxx'

        registry = self._http_statuses_registry[metric_name]
        target = flask.request.endpoint or 'invalid'

        registry.get_counter('{}-count'.format(target)).inc()
        registry.get_counter('total-count').inc()

        return response


def _get_status_type(status_code):
    if status_code // 100 == 2:
        return '2xx'
    elif status_code // 100 == 3:
        return '3xx'
    elif status_code == 404:
        return '404'
    elif status_code == 429:
        return '429'
    elif status_code // 100 == 4:
        return '4xx'
    elif status_code // 100 == 5:
        return '5xx'
    else:
        return 'other'


class InstrumentedSession(requests.Session):
    def __init__(self, client_name, registry=None, connection_timeout=None,
                 http_codes_to_exclude_from_total_error=frozenset()):
        registry = ROOT_REGISTRY or registry
        self._registry = registry.path('clients', client_name)
        if connection_timeout is None:
            self._connection_timeout = config.get_value('instrumented_session.default_connection_timeout', None)
        else:
            self._connection_timeout = connection_timeout
        self._http_codes_to_exclude_from_total_error = http_codes_to_exclude_from_total_error
        self._general_time_histogram = self._registry.get_histogram('general-timer')
        self._total_error = self._registry.get_counter('total_error')
        self._total_success = self._registry.get_counter('total_success')
        self._total_exception = self._registry.get_counter('total_exception')
        self._connection_timeout_exception = self._registry.get_counter('connection_timeout_exception')
        super(InstrumentedSession, self).__init__()

    def request(self, method, url, **kwargs):
        method_timer = self._registry.get_histogram('{}-timer'.format(method.upper())).timer()
        general_timer = self._general_time_histogram.timer()
        elapsed_time_histogram = self._registry.get_histogram('{}-elapsed'.format(method.upper()))
        timeout = kwargs.pop('timeout', None)
        if self._connection_timeout is not None and not isinstance(timeout, (tuple, TimeoutSauce)):
            timeout = (self._connection_timeout, timeout)
        with method_timer, general_timer:
            try:
                response = super(InstrumentedSession, self).request(method, url, timeout=timeout, **kwargs)
            except requests.ConnectTimeout:
                self._connection_timeout_exception.inc()
                self._total_exception.inc()
                raise
            except Exception:
                self._total_exception.inc()
                raise
            else:
                elapsed_time_histogram.observe(response.elapsed.total_seconds())
                if 400 <= response.status_code < 600:
                    if response.status_code not in self._http_codes_to_exclude_from_total_error:
                        self._total_error.inc()
                    st = _get_status_type(response.status_code)
                    self._registry.get_counter('error_{}'.format(st)).inc()
                else:
                    self._total_success.inc()
            return response


if yt_requests_imported:
    class YpInstrumentedSession(yt_requests.Session):
        def __init__(self, registry=None, connection_timeout=None):
            registry = ROOT_REGISTRY or registry
            self._registry = registry.path('clients', 'yp')
            if connection_timeout is None:
                self._connection_timeout = config.get_value('instrumented_session.default_connection_timeout', None)
            else:
                self._connection_timeout = connection_timeout
            self._general_time_histogram = self._registry.get_histogram('general-timer')
            self._total_error = self._registry.get_counter('total_error')
            self._total_success = self._registry.get_counter('total_success')
            self._total_exception = self._registry.get_counter('total_exception')
            self._connection_timeout_exception = self._registry.get_counter('connection_timeout_exception')
            super(YpInstrumentedSession, self).__init__()

        def request(self, method, url, **kwargs):
            parsed_url = urlparse(url)
            path_parts = '-'.join([p for p in parsed_url.path.split('/') if p])
            path_timer = self._registry.get_histogram('{}-timer'.format(path_parts)).timer()
            elapsed_time_histogram = self._registry.get_histogram('{}-elapsed-timer'.format(path_parts))
            general_timer = self._general_time_histogram.timer()
            timeout = kwargs.pop('timeout', None)
            if self._connection_timeout is not None and not isinstance(timeout, (tuple, YtTimeoutSauce)):
                timeout = (self._connection_timeout, timeout)
            with general_timer, path_timer:
                try:
                    response = super(YpInstrumentedSession, self).request(method, url, timeout=timeout, **kwargs)
                except yt_requests.ConnectTimeout:
                    self._connection_timeout_exception.inc()
                    self._total_exception.inc()
                    raise
                except Exception:
                    self._total_exception.inc()
                    raise
                else:
                    elapsed_time_histogram.observe(response.elapsed.total_seconds())
                    if 400 <= response.status_code < 600:
                        self._total_error.inc()
                        st = _get_status_type(response.status_code)
                        self._registry.get_counter('error_{}'.format(st)).inc()
                    else:
                        self._total_success.inc()
                return response
else:
    YpInstrumentedSession = None


ROOT_REGISTRY = Registry()
