from sepelib.metrics.registry import metrics_inventory
from requests import Session


class InstrumentedSession(Session):
    """
    Instruments all requests with metrics:
        * {general}: timer for all requests
        * {total_error}: meter for error responses
        * {total_success}: meter for success responses (not 4xx or 5xx)
        * {total_exception}: meter for exceptions (timeouts, connection refused)
        * {method}: timer for particular HTTP requests
        * {method}_{status_type}: meter for method-code pair (e.g. GET_5xx)
    Should be used as drop-in replacement for ``requests.Session``
    """
    @staticmethod
    def _get_status_type(status_code):
        if status_code // 100 == 2:
            return '2xx'
        elif status_code // 100 == 3:
            return '3xx'
        elif status_code == 404:
            return '404'
        elif status_code // 100 == 4:
            return '4xx'
        elif status_code // 100 == 5:
            return '5xx'
        else:
            return 'other'

    def __init__(self, name, trust_env=False):
        # With default trust_env every request will try to load
        # settings for proxy and auth from .netrc and env variables.
        # We don't need this, but leave as an option for users.
        self.trust_env = trust_env
        self._metrics = metrics_inventory.get_metrics(name)
        self._name = name
        self._general_timer = self._metrics.timer('general')
        self._total_error = self._metrics.counter('total_error')
        self._total_success = self._metrics.counter('total_success')
        self._total_exception = self._metrics.counter('total_exception')
        super(InstrumentedSession, self).__init__()

    @property
    def name(self):
        return self._name

    def request(self, method, url, **kwargs):
        method_time_ctx = self._metrics.timer(method.upper()).time()
        general_time_ctx = self._general_timer.time()
        try:
            response = super(InstrumentedSession, self).request(method, url, **kwargs)
        except Exception:
            self._total_exception.inc()
            raise
        else:
            if 400 <= response.status_code < 600:
                self._total_error.inc()
            else:
                self._total_success.inc()
        finally:
            general_time_ctx.stop()
            method_time_ctx.stop()
        return response

    def get_metrics(self):
        return self._metrics.dump_metrics()
