"""
HTTP server boilerplate code.
"""
from __future__ import absolute_import

import copy
import datetime
import logging
import logging.handlers

import six
import yaml
import flask
from gevent import pywsgi
from gevent import socket


if six.PY2:
    from werkzeug.contrib.fixers import ProxyFix
else:
    from werkzeug.middleware.proxy_fix import ProxyFix

from sepelib.core import config
from sepelib.util.log import create_handler_from_config
from sepelib.flask.auth.util import login_exempt
from sepelib.flask.h import prep_response

from infra.swatlib.util.fs import set_close_exec
from infra.swatlib import climit
from infra.swatlib import metrics


log = logging.getLogger('flask')


def setup_access_log_stream(cfg):
    """
    Abuse handy log handler and create stream like object for wsgi to write access log.
    """

    class StreamToHandler(object):
        """
        Fake file-like stream object that redirects writes to a log handler instance.
        """

        def __init__(self, handler):
            self.handler = handler
            self.linebuf = ''

        def isatty(self):
            return False

        def write(self, buf):
            # remove line endings, handler will add them anyway
            buf = buf.rstrip()
            record = logging.makeLogRecord({'msg': buf,
                                            'levelno': logging.DEBUG})
            self.handler.handle(record)

    handler = create_handler_from_config(cfg)
    formatter = logging.Formatter("%(message)s")
    handler.setLevel(logging.DEBUG)
    handler.setFormatter(formatter)

    stream = StreamToHandler(handler)
    return stream


class WSGIHandler(pywsgi.WSGIHandler):
    def format_request(self):
        """
        Overridden to try to log client address from headers.
        """
        now = datetime.datetime.now().replace(microsecond=0)
        length = self.response_length or '-'
        if self.time_finish:
            delta = '%.6f' % (self.time_finish - self.time_start)
        else:
            delta = '-'
        if self.environ is not None and 'HTTP_X_FORWARDED_FOR' in self.environ:
            # Good case: we have X-Forwarded-For
            addr = self.environ['HTTP_X_FORWARDED_FOR'].split(',')
            client_address = addr[0].strip()
        else:
            client_address = self.client_address[0] if isinstance(self.client_address, tuple) else self.client_address
        format_string = '%s - - [%s] "%s" %s %s %s'
        logged_headers = getattr(self.application, 'logged_headers', ())
        if logged_headers:
            format_string += ' %s' % (' '.join('%s' for _ in logged_headers),)
            extra_args = tuple((self.headers is not None and self.headers.get(header)) or '-' for header in logged_headers)
        else:
            extra_args = ()

        return format_string % ((client_address or '-',
                                 now,
                                 getattr(self, 'requestline', ''),
                                 (getattr(self, 'status', None) or '000').split()[0],
                                 length,
                                 delta) + extra_args)


class CapacityLimiter(object):
    TOO_MANY_REQUESTS_STATUS = '429 Too Many Requests'
    TOO_MANY_REQUESTS_HEADERS = [
        ('Content-Type', 'application/json')
    ]
    TOO_MANY_REQUESTS_BODY = (b'{'
                              b'"error": "TOO_MANY_REQUESTS", '
                              b'"message": "Global concurrent requests limit reached"'
                              b'}',)

    LOCAL_IPS = ('::1', '127.0.0.1')
    NODE_SRC_IP_SUFFIX = 'badc:ab1e'

    def __init__(self, app, max_in_flight=1000, headers=None, registry=None, exempted_paths=None):
        self.app = app
        self.l = climit.CLimit(max_in_flight)
        self.headers = self.TOO_MANY_REQUESTS_HEADERS[:]
        self.exempted_paths = exempted_paths or []
        if headers:
            self.headers.extend(headers)
        if registry is not None:
            self._too_many_requests_counter = registry.get_counter('too-many-requests')
            self._in_flight_gauge = registry.get_summable_gauge('in-flight-requests')
            self._exempted_requests_counter = registry.get_counter('in-flight-limit-exempted-requests')
        else:
            self._too_many_requests_counter = None
            self._in_flight_gauge = None
            self._exempted_requests_counter = None

    def _is_limit_applicable(self, environ):
        # https://st.yandex-team.ru/SPI-23526
        # We *must not* respond 429 to yasm-agent even if server is overloaded!
        # Otherwise we won't see metrics from this instance.
        remote_addr = environ.get('REMOTE_ADDR')
        if remote_addr in self.LOCAL_IPS:
            # yasm-agent makes requests from ::1 in case of instances without network isolation
            return False
        if remote_addr and remote_addr.endswith(self.NODE_SRC_IP_SUFFIX):
            # yasm-agent makes requests from IP which has suffix BADC:AB1E if network isolation enabled
            return False
        path_info = environ.get('PATH_INFO')
        if path_info and path_info.strip('/') in self.exempted_paths:
            return False
        return True

    def __call__(self, environ, start_response):
        limit_applicable = self._is_limit_applicable(environ)
        if not limit_applicable and self._exempted_requests_counter is not None:
            self._exempted_requests_counter.inc()
        if limit_applicable and not self.l.add():
            if self._too_many_requests_counter is not None:
                self._too_many_requests_counter.inc()
            start_response(self.TOO_MANY_REQUESTS_STATUS, self.headers)
            return self.TOO_MANY_REQUESTS_BODY
        if self._in_flight_gauge is not None:
            self._in_flight_gauge.inc(1)
        try:
            # Flask can return an iterable which will call user code upon .next()
            # We do not want that (because otherwise) requests accounting will not work.
            # But there is a catch - streaming responses will be accumulated in memory.
            # To work around this we need to put limiting logic in WSGIHandler.process_result.
            # Leaving that for future improvements
            return list(self.app(environ, start_response))
        finally:
            if self._in_flight_gauge is not None:
                self._in_flight_gauge.inc(-1)
            if limit_applicable:
                self.l.done()


class WebServer(object):
    """
    An HTTP server for Flask applications.
    Implemented using gevent's WSGIServer.
    """
    _DEFAULT_BACKLOG = 1024

    @classmethod
    def _create_server_socket(cls, host, port, backlog=None):
        """
        Try to create ipv6 socket, falling back to ipv4.
        Start listening on it.

        :returns: listening socket
        """
        if host == '0.0.0.0':
            host = ''
        try:
            listener = socket.socket(socket.AF_INET6)
        except EnvironmentError as e:
            log.warn("failed to create ipv6 socket: {0}".format(e.strerror))
            log.warn("falling back to ipv4 only")
            listener = socket.socket()
        listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        # mark the socket fd as non-inheritable
        set_close_exec(listener.fileno())
        try:
            listener.bind((host, port))
        except socket.error:
            # it's useful to log what host and port we failed to bind
            log.error("failed to bind '{}' on port {}".format(host, port))
            raise
        listener.listen(backlog if backlog is not None else cls._DEFAULT_BACKLOG)
        return listener

    def __init__(self, cfg, app, version, logstream=None, certfile=None, metrics_registry=None):
        """
        :type cfg: dict
        :type app: flask.Flask
        :type version: basestring
        :param logstream: file-like stream object
        """
        self.cfg = cfg
        self.version = version

        # === setup http ===
        webcfg = cfg['web']
        http_cfg = webcfg['http']
        self.app = app
        num_proxies = webcfg.get('num_proxies', 0)
        if num_proxies:
            self.app.wsgi_app = ProxyFix(self.app.wsgi_app, num_proxies=num_proxies)
        max_in_flight = http_cfg.get('max_in_flight', 0)
        if max_in_flight:
            self.app.wsgi_app = CapacityLimiter(self.app.wsgi_app, max_in_flight=max_in_flight)
        host, port = http_cfg['host'], http_cfg['port']
        if webcfg.get('access_log'):
            logstream = setup_access_log_stream(webcfg['access_log'])
        listener = self._create_server_socket(host, port, backlog=http_cfg.get('backlog'))
        kwargs = {
            'application': self.app,
            'log': logstream,
            'handler_class': WSGIHandler,
        }
        if certfile:
            kwargs['certfile'] = certfile
        self.wsgi = pywsgi.WSGIServer(listener, **kwargs)

        # add some performance hooks
        metrics_ext = metrics.MetricsExt(registry=metrics_registry,
                                         cfg=webcfg.get('metrics', {}) or config.get_value('metrics', {}))
        metrics_ext.init_flask_app(self.app)

        # register custom urls
        self._register_urls()

    def run(self):
        log.info("starting flask interface on: '{0[0]}':{0[1]}".format(self.wsgi.address))
        self.wsgi.serve_forever()

    def stop(self):
        log.info('stopping flask service')
        self.wsgi.stop(timeout=1)

    def start(self):
        log.info("starting flask interface on: '{0[0]}':{0[1]}".format(self.wsgi.address))
        self.wsgi.start()

    def _register_urls(self):
        # functions part
        self.app.add_url_rule('/ping', view_func=self.ping)
        self.app.add_url_rule('/version', view_func=self.render_version)

    @classmethod
    @login_exempt
    def ping(cls):
        return flask.Response('', status=200)

    @login_exempt
    def render_version(self):
        return prep_response({'version': self.version}, fmt='txt')

    @login_exempt
    def render_config(self):
        output_config = self.cfg
        hidden_options = config.get_value("web.http.hidden_config_options", default=None, config=self.cfg)

        if hidden_options:
            undefined = object()
            output_config = copy.deepcopy(output_config)

            for hidden_option in hidden_options:
                if config.get_value(hidden_option, default=undefined, config=output_config) is not undefined:
                    config.set_value(hidden_option, "HIDDEN", config=output_config)

        obj = yaml.safe_dump(output_config, default_flow_style=False)
        return prep_response(obj, fmt='txt')


class LocalWebServer(object):
    """
    A small and simple web server that intends to be run in a separate system thread in
    gevent-patched environment. Can be useful to provide a responsive HTTP server
    (for serving /yasm_stats, /ping, etc) in a CPU-heavy process.

    It uses gevent-safe pywsgi.WSGIServer and does not use logging module to avoid
    errors that occur when different system threads with their own gevent hubs
    try to "use" each other's hub-local gevent objects like locks.

    Please keep it that way.
    """
    DEFAULT_BACKLOG = 128

    @classmethod
    def _create_server_socket(cls, host, port):
        """
        Try to create ipv6 socket, falling back to ipv4.
        Start listening on it.

        :returns: listening socket
        """
        if host == '0.0.0.0':
            host = ''
        try:
            listener = socket.socket(socket.AF_INET6)
        except EnvironmentError:
            listener = socket.socket()
        listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        # mark the socket fd as non-inheritable
        set_close_exec(listener.fileno())
        listener.bind((host, port))
        listener.listen(cls.DEFAULT_BACKLOG)
        return listener

    def __init__(self, cfg, app, metrics_registry, version):
        """
        :type cfg: dict
        :type app: flask.Flask
        :type metrics_registry: infra.swatlib.metrics.Registry
        """
        self.version = version
        self.app = app

        webcfg = cfg['web']
        host = webcfg['http'].get('host', '::')
        port = webcfg['http']['port']
        listener = self._create_server_socket(host, port)
        self.wsgi = pywsgi.WSGIServer(
            listener=listener,
            application=self.app,
            handler_class=WSGIHandler,
            log=None)
        self.app.add_url_rule('/ping', view_func=self.ping)
        # add export metrics to yasm
        metrics.MetricsExt(app, cfg=webcfg.get('metrics', {}) or config.get_value('metrics', {}), registry=metrics_registry)

    def run(self):
        self.wsgi.serve_forever()

    def stop(self):
        self.wsgi.stop(timeout=1)

    def start(self):
        self.wsgi.start()

    @classmethod
    def ping(cls):
        return flask.Response('', status=200)
