import json
import types
import httplib
import datetime
import requests
import StringIO
import traceback
import functools

import flask

import sandbox.common.types.misc as ctm

from sandbox.common import api
from sandbox.common import config
from sandbox.common import rest as common_rest
from sandbox.common import profiler

from sandbox.yasandbox import context
from sandbox.serviceapi import metrics
from sandbox.yasandbox import controller
from sandbox.yasandbox.database import mapping
from sandbox.serviceapi import constants as sa_consts

from . import request
from . import aggregator
from . import exceptions
from . import middlewares
from . import statistics


try:
    import uwsgi
except ImportError:
    uwsgi = None


__all__ = ("blueprints", "apply_max_concurrent_requests_guard", "Route")


registry = config.Registry()

blueprints = {}  # url_prefix -> `flask.Blueprint`


# variables for requests in progress count
WORKER_ID = None  # id of current worker
SHAREDAREA_LEN = None  # amount of bytes for workers mask
WORKER_BYTE_INDEX = None  # byte index for current worker
WORKER_BYTE_OFFSET = None  # bit index in byte for current worker
WORKER_BYTE_MASK = None  # mask of length 8 with only one 'true' bit at WORKER_BYTE_OFFSET position
INVERTED_WORKER_BYTE_MASK = None  # inverted WORKER_BYTE_MASK mask
LIMIT_MASK = None  # mask of all 'true' workers except current worker


def apply_max_concurrent_requests_guard(bp):
    def before():
        # Batch request has already occupied a worker
        if getattr(flask.g, "is_batch", False):
            flask.request.rejected = False
            flask.request.rejected_in_progress = False
            return

        # Acquire global lock and check for number of concurrent requests.
        # If all workers are busy, drop current request.
        uwsgi.lock()

        if uwsgi.sharedarea_memoryview(sa_consts.UWSGIKey.REQUESTS_IN_PROGRESS)[:SHAREDAREA_LEN] != LIMIT_MASK:
            # Mark worker as "busy", increment request counter
            uwsgi.sharedarea_memoryview(sa_consts.UWSGIKey.REQUESTS_IN_PROGRESS)[WORKER_BYTE_INDEX] = (chr(
                ord(uwsgi.sharedarea_memoryview(sa_consts.UWSGIKey.REQUESTS_IN_PROGRESS)[WORKER_BYTE_INDEX]) |
                WORKER_BYTE_MASK
            ))

            flask.request.rejected = False
            flask.request.rejected_in_progress = False
        else:
            # Don't increment `requests_in_progress`, leave worker as free
            # to allow other workers to process requests while this one is rejecting.
            flask.request.rejected = True

        uwsgi.unlock()

        if flask.request.rejected:
            metrics.rate_inc("rejected_requests")
            return "Too many requests in progress", httplib.SERVICE_UNAVAILABLE

    def teardown_request(exception=None):
        if getattr(flask.g, "is_batch", False):
            return
        if not flask.request.rejected:
            uwsgi.sharedarea_memoryview(sa_consts.UWSGIKey.REQUESTS_IN_PROGRESS)[WORKER_BYTE_INDEX] = (chr(
                ord(uwsgi.sharedarea_memoryview(sa_consts.UWSGIKey.REQUESTS_IN_PROGRESS)[WORKER_BYTE_INDEX]) &
                INVERTED_WORKER_BYTE_MASK
            ))

    if uwsgi:
        WORKER_ID = uwsgi.worker_id() - 1
        SHAREDAREA_LEN = (registry.server.api.workers + 8) / 8
        WORKER_BYTE_INDEX = WORKER_ID / 8
        WORKER_BYTE_OFFSET = WORKER_ID % 8
        WORKER_BYTE_MASK = (1 << WORKER_BYTE_OFFSET)
        INVERTED_WORKER_BYTE_MASK = (255 ^ WORKER_BYTE_MASK)
        LIMIT_MASK = "".join(
            chr(255) if index != WORKER_BYTE_INDEX else chr(INVERTED_WORKER_BYTE_MASK)
            for index in range(SHAREDAREA_LEN)
        )

        uwsgi.lock()
        mask = ord(uwsgi.sharedarea_memoryview(sa_consts.UWSGIKey.REQUESTS_IN_PROGRESS)[WORKER_BYTE_INDEX])

        if WORKER_ID == registry.server.api.workers:
            mask |= (255 ^ (WORKER_BYTE_MASK - 1))
        mask &= (255 & INVERTED_WORKER_BYTE_MASK)

        uwsgi.sharedarea_memoryview(sa_consts.UWSGIKey.REQUESTS_IN_PROGRESS)[WORKER_BYTE_INDEX] = chr(mask)
        uwsgi.unlock()

        bp.before_request(before)
        bp.teardown_request(teardown_request)


def support_transfer_encoding(bp):
    def before():
        transfer_encoding = flask.request.headers.get(ctm.HTTPHeader.TRANSFER_ENCODING, None)
        if transfer_encoding == "chunked":
            buf = StringIO.StringIO()
            size = 0
            while True:
                try:
                    chunk = uwsgi.chunked_read()
                    chunk_len = len(chunk)
                    size += chunk_len
                    if size > 16777216:
                        return "Request body too large", 413
                except IOError as ex:
                    context.current.logger.info("IOError %s", ex)
                    return "IOError: {}".format(ex), 400
                if chunk_len > 0:
                    buf.write(chunk)
                else:
                    break
            buf.seek(0)
            flask.request.environ["wsgi.input"] = buf
            flask.request.environ["wsgi.input_terminated"] = True
    bp.before_request(before)


def get_blueprint_by_version(
    version, path_prefix="api", custom_http_error_handler=None, custom_base_error_handler=None
):
    url_prefix = "/{}/v{}".format(path_prefix, version)
    bp = blueprints.get(url_prefix, None)
    if bp is None:
        bp = flask.Blueprint(url_prefix, __name__)
        apply_max_concurrent_requests_guard(bp)
        bp.before_request(before_request)
        support_transfer_encoding(bp)
        bp.after_request(after_request)
        bp.errorhandler(exceptions.HttpError)(custom_http_error_handler or http_error_handler)
        bp.errorhandler(Exception)(custom_base_error_handler or base_error_handler)
        blueprints[url_prefix] = bp
    return bp


class RequestContext(context.Context):
    def __init__(self, req):
        super(RequestContext, self).__init__(req.logger, req.user)
        self.request = req


@middlewares.wrap_temporary_errors
def before_request():
    # Create a new context and replace the base one.
    # For statistics collection purpose, do this even for requests to be rejected
    mapping.base.tls.request = None
    req = request.Request(flask.request, context.current.logger)
    ctx = RequestContext(req)
    context.set_current(ctx)
    if (
        ctm.HTTPHeader.PROFILER in req.req.headers and
        (req.user.super_user or config.Registry().common.installation == ctm.Installation.LOCAL)
    ):
        try:
            sort_field = int(req.req.headers[ctm.HTTPHeader.PROFILER])
            assert -1 <= sort_field <= 2
        except (ValueError, AssertionError):
            return (
                "Header '{}' must contain integer from '-1' to '2'".format(ctm.HTTPHeader.PROFILER),
                httplib.BAD_REQUEST
            )
        req.profiler = profiler.Profiler(sort_field=sort_field)
        req.profiler.enable()
    mapping.base.tls.request = ctx.request

    if config.Registry().server.api.quotas.enabled:
        if not aggregator.Aggregator().check_consumption(req.quota_owner, req.source):
            if uwsgi:
                uwsgi.set_logvar("login", req.quota_owner)
            flask.request.rejected_in_progress = True
            return "Quota exceeded", 429
    if flask.request.headers.get(ctm.HTTPHeader.TRANSFER_ENCODING) == "chunked":
        context.current.logger.info("Request %s with chunked transfer encoding.", req.id[:8])
    flask.request.rejected_in_progress = False

    if uwsgi:
        uwsgi.set_logvar("login", context.current.user.login)


def profiler_response(response, legacy_result=False):
    ctx = context.current
    ctx.request.profiler.disable()
    profile = ctx.request.profiler.dump_to_str()
    if legacy_result:
        result = [response.get("result")]
        legacy_profile = response.get("profile")
    else:
        result = response
        legacy_profile = None

    return flask.current_app.response_class(
        response=json.dumps({
            "profile": profile,
            "legacy_profile": legacy_profile,
            "result": result
        }, indent=4),
        status=requests.codes.I_AM_A_TEAPOT,
        content_type="application/json; charset=utf-8"
    )


@middlewares.wrap_after_request_temporary_errors
def after_request(r):
    mapping.base.tls.request = None
    statistics.StatRecorder.record(r)
    if flask.request.rejected or flask.request.rejected_in_progress:
        # Request is rejected due to high load, do nothing
        return r

    ctx = context.current
    if ctx.user is None:
        ctx.logger.error("Current user is None, before_request has failed")
        return r

    if config.Registry().server.api.quotas.enabled:
        aggregator.Aggregator().add_delta()

    # Set general service headers
    r.headers[ctm.HTTPHeader.BACKEND_NODE] = registry.this.fqdn
    r.headers[ctm.HTTPHeader.SERVICE_MODE] = controller.Settings.mode().lower()
    r.headers[ctm.HTTPHeader.CURRENT_USER] = ctx.user.login
    r.headers[ctm.HTTPHeader.REQUEST_ID] = flask.request.req_id

    aggregator.Aggregator().update_api_consumption_headers(r.headers, ctx.user.login)

    if ctx.request.is_authenticated and ctx.request.need_reset_cookie:
        r.headers[ctm.HTTPHeader.RESET_SESSION] = "true"

    request_duration = (datetime.datetime.utcnow() - ctx.request.request_started).total_seconds()

    r.headers[ctm.HTTPHeader.REQ_DURATION] = str(request_duration)
    r.headers[ctm.HTTPHeader.REQ_METRICS] = ctx.spans.as_short_string(group_by_name=True)
    r.headers[ctm.HTTPHeader.TASKS_REVISION] = aggregator.Aggregator().revision

    metrics.rate_inc("serviceapi_requests")

    # Don't need to unset context here, it will be done in the `context` plugin
    if ctx.request.profiler is not None:
        return profiler_response(r.response)
    return r


def http_error_handler(ex):
    body = {"reason": ex.message, "error": str(ex.error)}
    return flask.jsonify(body), ex.code


def base_error_handler(ex):
    if context.current.logger is not None:
        context.current.logger.exception("Request error: %s", ex)
    body = {"reason": str(ex), "error": type(ex).__name__, "traceback": traceback.format_exc()}
    return flask.jsonify(body), httplib.INTERNAL_SERVER_ERROR


def register_route(version, path, method, restriction, allow_ro, handler, endpoint, path_prefix="api"):
    """
    Register api handler.

    :param version: api version
    :param path: route path
    :param method: http method
    :param restriction: security scope
    :param allow_ro: allow processing when cluster is read-only
    :param handler: handler function
    :param endpoint: unique route id
    :param path_prefix: route paths prefix
    """
    @functools.wraps(handler)
    def view_func(*args, **kwargs):
        middlewares.check_for_readonly(allow_ro)
        middlewares.check_security_scope(restriction)
        return middlewares.wrap_temporary_errors(handler)(*args, **kwargs)

    bp = get_blueprint_by_version(version, path_prefix=path_prefix)
    bp.add_url_rule(path, methods=[method], view_func=view_func, endpoint=endpoint)


def make_flask_path(path, params):
    """
    Convert swagger-like path to Flask format.
    Example: `/task/{id}` -> `/task/<id>`
    """
    for param in params:
        if param.scope != api.Scope.PATH:
            continue

        swag = "{" + param.__param_name__ + "}"
        flsk = "<" + param.__param_name__ + ">"

        if swag not in path:
            raise RuntimeError("Route doesn't have param '{}': {}".format(swag, path))

        path = path.replace(swag, flsk)

    if "{" in path:
        raise RuntimeError("Route has unknown params: {}".format(path))

    return path


def function_or_classmethod(m):
    if isinstance(m, types.FunctionType):
        return True
    if isinstance(m, types.MethodType) and getattr(m, "__self__", None):
        return True
    return False


class SchemaObjectEncoder(json.JSONEncoder):
    def default(self, obj):
        if hasattr(obj, "__getstate__"):
            return obj.__getstate__()
        return super(SchemaObjectEncoder, self).default(obj)


def generate_json_args():
    return {
        "indent": None if flask.request.is_xhr else 2,
        "separators": (",", ":") if flask.request.is_xhr else (", ", ": "),
        "ensure_ascii": False,
        "cls": SchemaObjectEncoder
    }


def wrap_response(rv):
    from sandbox.common.api import Schema

    # Schema objects, dicts and lists of schema objects should be serialized to json.
    if isinstance(rv, flask.Response):
        if isinstance(rv.response, Schema):
            kwargs = generate_json_args()
            rv = flask.current_app.response_class(
                (json.dumps(rv.response, **kwargs).encode("utf8"), "\n"),
                rv.status, rv.headers, "application/json; charset=utf8",
            )
    elif isinstance(rv, (Schema, dict, list)):
        kwargs = generate_json_args()
        rv = flask.current_app.response_class(
            (json.dumps(rv, **kwargs).encode("utf8"), "\n"),
            mimetype="application/json; charset=utf8",
        )

    return rv


def wrap_schemas(handler, parameters, content):
    def view_func(**kwargs):
        arguments = []
        query = {}
        body = None

        def decode(param, value):
            try:
                return param.decode(value and value.strip())
            except (ValueError, TypeError) as ex:
                raise exceptions.BadRequest("Error parsing {}: {}".format(param.__param_name__, ex))

        for param in parameters:
            if param.scope == api.Scope.PATH:
                arguments.append(decode(param, kwargs.pop(param.__param_name__)))
            elif param.scope == api.Scope.BODY:
                try:
                    if content == common_rest.Client.BINARY:
                        body = flask.request.data
                    else:
                        data = json.loads(flask.request.data)
                        body = param.decode(data)
                except ValueError as ex:
                    raise exceptions.BadRequest("Unable to parse input data: {}".format(ex))
            elif param.scope == api.Scope.QUERY:
                raw_values = flask.request.args.getlist(param.__param_name__)

                if len(raw_values) == 0:
                    query[param.__param_name__] = decode(param, None)
                elif len(raw_values) == 1:
                    if raw_values[0] == "":
                        # TODO: Remove after all clients are ready (FEI-9859)
                        # Ignore parameters with empty values
                        query[param.__param_name__] = decode(param, None)
                    else:
                        query[param.__param_name__] = decode(param, raw_values[0])
                else:
                    query[param.__param_name__] = [v for raw_value in raw_values for v in decode(param, raw_value)]

        if query:
            arguments.append(query)
        if body is not None:
            arguments.append(body)

        rv = handler(*arguments)

        # Convert handler return value to a valid flask response
        return wrap_response(rv)

    return view_func


class Route(object):
    api_path = None
    version = None
    path_prefix = "api"

    LIST_QUERY_MAP = {}

    # noinspection PyPep8Naming
    class __metaclass__(type):
        def __new__(mcs, name, bases, namespace):
            base_path = namespace.pop("base_path", False)
            cls = type.__new__(mcs, name, bases, namespace)
            # noinspection PyUnresolvedReferences
            api_path = cls.api_path
            if api_path is None or base_path:
                return cls

            for req in api_path.requests:
                method = req.__name__.upper()
                handler = getattr(cls, method.lower(), None)

                if handler is None:
                    # Method is not implemented yet
                    continue

                if not function_or_classmethod(handler):
                    raise RuntimeError(
                        "{}.{}.{} should be either @staticmethod or @classmethod".format(
                            cls.__module__, cls.__name__, handler.__name__
                        )
                    )

                register_route(
                    version=cls.version,
                    path=make_flask_path(api_path.path, req.__parameters__),
                    method=method,
                    restriction=req.__security__,
                    allow_ro=req.__allow_ro__,
                    handler=wrap_schemas(handler, req.__parameters__, req.__content__),
                    endpoint=req.__operation_id__,
                    path_prefix=cls.path_prefix,
                )

            return cls

    def __new__(cls, api_path):
        return type(cls)(
            cls.__name__,
            (Route,),
            dict(api_path=api_path, base_path=True, version=cls.version, path_prefix=cls.path_prefix)
        )

    @classmethod
    def remap_query(cls, query, save_query=False):
        try:
            return api.remap_query(query, cls.LIST_QUERY_MAP, save_query=save_query)
        except ValueError as ex:
            raise exceptions.BadRequest(ex)
