# coding: utf-8
import contextlib
from six.moves import http_client as httplib
import logging
import time
import six
from gevent import threadpool

import ujson
import flask
import inject
import google.protobuf.message
import google.protobuf.json_format

from infra.swatlib.rpc import parse_request, exceptions, authentication
from infra.swatlib import pbutil, climit, jsonschemautil


if six.PY3:
    PROTOBUF_CONTENT_TYPE = 'application/x-protobuf'
    JSON_CONTENT_TYPE = 'application/json'
    JSON_CONTENT_TYPES = (JSON_CONTENT_TYPE,)
    PROTOBUF_CONTENT_TYPES = (PROTOBUF_CONTENT_TYPE, 'application/octet-stream',)
    MIME_MATCHES = JSON_CONTENT_TYPES + PROTOBUF_CONTENT_TYPES
else:
    PROTOBUF_CONTENT_TYPE = b'application/x-protobuf'
    JSON_CONTENT_TYPE = b'application/json'
    JSON_CONTENT_TYPES = (JSON_CONTENT_TYPE,)
    PROTOBUF_CONTENT_TYPES = (PROTOBUF_CONTENT_TYPE, b'application/octet-stream',)
    MIME_MATCHES = JSON_CONTENT_TYPES + PROTOBUF_CONTENT_TYPES


def make_response(protobuf_object, status_msg, accept_mimetypes, status=httplib.OK, headers=None,
                  pb_json_printer_cls=google.protobuf.json_format._Printer, json_preserve_field_names=False):
    """
    Transforms protobuf message to flask response
    according to accept_types (from flask request) setting provided HTTP status.

    :param protobuf_object: protobuf message to serialize and put in HTTP body
    :param accept_mimetypes: accept_mimetypes attribute of flask request
    :param status_msg: protobuf message class to be returned in case of error.
                       Must have fields code, status and message.
    :param status: HTTP status to set
    :param headers: headers to set
    :param bool json_preserve_field_names: use the original protobuf field names
                                           in JSON rather than make them camelCase.

    :return: flask Response object
    :rtype: flask.Response
    """
    match = accept_mimetypes.best_match(MIME_MATCHES, default=JSON_CONTENT_TYPE)
    if match in PROTOBUF_CONTENT_TYPES:
        body = protobuf_object.SerializeToString()
        content_type = PROTOBUF_CONTENT_TYPE
    else:
        # If we don't use including_default_value_fields empty lists
        # won't be present in result javascript, which can unexpected and will require
        # complicated handling on receiving side.
        try:
            d = pbutil.pb_to_jsondict(protobuf_object, including_default_value_fields=True,
                                      pb_json_printer_cls=pb_json_printer_cls,
                                      preserve_field_names=json_preserve_field_names)
        except Exception as e:
            d = pbutil.pb_to_jsondict(status_msg(code=httplib.INTERNAL_SERVER_ERROR,
                                                 status="Failure",
                                                 message=six.text_type(e)),
                                      pb_json_printer_cls=pb_json_printer_cls,
                                      preserve_field_names=json_preserve_field_names)
        body = ujson.dumps(d)  # Use ujson to serialize response - it's faster
        content_type = JSON_CONTENT_TYPE
    return flask.Response(body, status=status, content_type=content_type, headers=headers)


def make_authentication_request(flask_request):
    user_ip = flask_request.access_route[0] if flask_request.access_route else flask_request.remote_addr
    return authentication.Request(flask_request.url,
                                  flask_request.host,
                                  user_ip,
                                  flask_request.headers.get('Authorization'),
                                  flask_request.cookies.get('Session_id'),
                                  flask_request.headers.get('X-Ya-Service-Ticket'))


def parse_flask_request(protobuf_request, flask_request, status_msg, log, validate_schema=False):
    method = flask_request.method
    json_preserve_field_names = getattr(flask_request, 'json_preserve_field_names', False)
    if method == 'GET':
        try:
            parse_request.init_from_args(protobuf_request, flask_request.args, validate_schema=validate_schema)
        except Exception as e:
            log.info('Failed to parse request to "%s"', flask_request.path, exc_info=True)
            return make_response(
                status_msg(
                    status="Failure",
                    code=httplib.BAD_REQUEST,
                    message=six.text_type(e)
                ),
                accept_mimetypes=flask_request.accept_mimetypes,
                status_msg=status_msg,
                status=httplib.BAD_REQUEST,
                json_preserve_field_names=json_preserve_field_names,
            )
        return
    elif method == 'POST':
        try:
            data = flask_request.get_data()
            parse_request.init_from_content(protobuf_request,
                                            data,
                                            flask_request.content_type, validate_schema=validate_schema,
                                            request_form=flask_request.form)
        except Exception as e:
            log.info('Failed to parse request to "%s"', flask_request.path, exc_info=True)
            return make_response(
                status_msg(
                    status="Failure",
                    code=httplib.BAD_REQUEST,
                    message=six.text_type(e)
                ),
                accept_mimetypes=flask_request.accept_mimetypes,
                status_msg=status_msg,
                status=httplib.BAD_REQUEST,
                json_preserve_field_names=json_preserve_field_names,
            )
        return
    error = status_msg(status="Failure", message='Method not supported', code=httplib.BAD_REQUEST)
    return make_response(error, accept_mimetypes=flask_request.accept_mimetypes,
                         status_msg=status_msg, status=httplib.BAD_REQUEST,
                         json_preserve_field_names=json_preserve_field_names)


def authenticate_request(flask_request, authenticator, status_msg, log):
    """
    Helper which calls authenticator and handles errors.

    Returns a tuple of (AuthSubject, Response). If response if not None,
    then it should be returned to the user.

    NOTE: moved from blueprint for ease of testing.

    :param flask_request: flask request object
    :param authenticator: authenticator
    :param log: logger
    """
    auth_request = make_authentication_request(flask_request)
    json_preserve_field_names = getattr(flask_request, 'json_preserve_field_names', False)
    try:
        return authenticator.authenticate_request(auth_request), None
    except exceptions.RpcError as e:
        error = status_msg(status="Failure", code=e.status, message=e.message if six.PY2 else str(e))
        if e.redirect_url:
            error.redirect_url = e.redirect_url
        return None, make_response(error, accept_mimetypes=flask_request.accept_mimetypes,
                                   status_msg=status_msg, status=e.status,
                                   json_preserve_field_names=json_preserve_field_names)
    except Exception as e:
        log.info('Failed to authenticate request to "%s"', flask_request.path, exc_info=True)
        return None, make_response(
            status_msg(
                status="Failure",
                code=httplib.INTERNAL_SERVER_ERROR,
                message=six.text_type(e)
            ),
            accept_mimetypes=flask_request.accept_mimetypes,
            status_msg=status_msg,
            status=httplib.INTERNAL_SERVER_ERROR,
            json_preserve_field_names=json_preserve_field_names)


class HttpRpcBlueprint(flask.Blueprint):
    """
    Blueprint which provides utilities for protobuf-over-http RPC services.

    :type _authenticator: authentication.IRpcAuthenticator
    """
    _authenticator = inject.attr(authentication.IRpcAuthenticator)

    _default_methods = ('GET', 'POST')
    # We register all methods in flask so that we can manually check actual method
    # and return proper API response.
    # Not quite sure about HEAD and OPTIONS.
    ALL_METHODS = ('GET', 'POST', 'PUT', 'DELETE', 'PATCH')

    def __init__(self, name, import_name, url_prefix, status_msg,
                 validate_schema=False, infer_response_schema=False, serialize_resp_threads_count=None,
                 maybe_reject_outstaff_user=None):
        super(HttpRpcBlueprint, self).__init__(name, import_name, url_prefix=url_prefix)
        self.log = logging.getLogger(name)
        self.route('/', endpoint='schema', methods=('GET',))(self._schema_ctrl)
        self.route('/schemas/<method_name>/', endpoint='method_schema', methods=('GET',))(self._method_schema_ctrl)
        self.rpc_methods = []
        self.schema_bodies_lookup = {}
        self.validate_schema = validate_schema
        self.infer_response_schema = infer_response_schema
        self.status_msg = status_msg
        self.BUSY_ERROR = status_msg(status="Failure",
                                     code=429,
                                     message="Server is too busy",
                                     reason="Concurrent requests limit reached")
        self.BAD_METHOD_ERROR = status_msg(status="Failure",
                                           message='Method not supported',
                                           code=httplib.BAD_REQUEST)
        self.tp = threadpool.ThreadPool(maxsize=serialize_resp_threads_count) if serialize_resp_threads_count else None
        self._maybe_reject_outstaff_user = maybe_reject_outstaff_user

    def get_authenticator(self):
        return self._authenticator

    @classmethod
    def _get_pb_json_printer_cls(cls):
        return google.protobuf.json_format._Printer

    def _schema_ctrl(self):
        return flask.Response(ujson.dumps(self.rpc_methods), status=200, mimetype=JSON_CONTENT_TYPE)

    def _method_schema_ctrl(self, method_name):
        b = self.schema_bodies_lookup.get(method_name)
        if b:
            return flask.Response(b, status=200, mimetype=JSON_CONTENT_TYPE)

        d = pbutil.pb_to_jsondict(
            self.status_msg(code=httplib.NOT_FOUND,
                            status="Failure",
                            message="Method '{}' not found".format(method_name)),
            preserve_field_names=self.json_preserve_field_names(flask.request),
            pb_json_printer_cls=self._get_pb_json_printer_cls()
        )
        return flask.Response(ujson.dumps(d), status=404, mimetype=JSON_CONTENT_TYPE)

    @classmethod
    @contextlib.contextmanager
    def translate_errors(cls):
        yield

    def json_preserve_field_names(self, flask_request):
        return False

    def call_user_handler(self, handler, accept_mimetypes, protobuf_request, auth_subject, status_msg,
                          log=None, json_preserve_field_names=False, request_kwargs=None):
        """
        Calls user RPC handler and catches exceptions, transforming them into RPC response.

        :param handler: user handler to call
        :param accept_mimetypes: accept_mimetypes attribute of flask request
        :param protobuf_request: protobuf request
        :param status_msg: protobuf message class to be returned in case of error.
                           Must have fields code, status and message.
        :param auth_subject: authentication subject information if authentication enabled or None
        :param log: logger to report exceptions to
        :param bool json_preserve_field_names: use the original protobuf field names
                                               in JSON rather than make them camelCase.
        :return: flask response object
        """
        pb_json_printer_cls = self._get_pb_json_printer_cls()
        try:
            with self.translate_errors():
                request_kwargs = request_kwargs or {}
                response = handler(protobuf_request, auth_subject, **request_kwargs)
        except exceptions.ConflictError as e:
            status_pb = status_msg(status="Failure", code=e.status, message=six.text_type(e))
            for c in e.conflicts:
                status_pb.objects.add(object_type=c['object_type'],
                                      object_id=c['object_id'],
                                      status=c['status'])
            return make_response(status_pb,
                                 accept_mimetypes=accept_mimetypes, status_msg=status_msg, status=e.status,
                                 pb_json_printer_cls=pb_json_printer_cls)
        except exceptions.RpcError as e:
            if e.status == httplib.INTERNAL_SERVER_ERROR:
                log.info('Request processing failed.', exc_info=True)
            status_pb = status_msg(status="Failure", code=e.status, message=six.text_type(e))
            return make_response(status_pb,
                                 accept_mimetypes=accept_mimetypes, status_msg=status_msg, status=e.status,
                                 pb_json_printer_cls=pb_json_printer_cls,
                                 json_preserve_field_names=json_preserve_field_names)
        except Exception as e:
            log.exception('Request processing unexpectedly failed.')
            status_pb = status_msg(status="Failure", code=httplib.INTERNAL_SERVER_ERROR, message=six.text_type(e))
            return make_response(status_pb,
                                 accept_mimetypes=accept_mimetypes, status_msg=status_msg, status=httplib.INTERNAL_SERVER_ERROR,
                                 pb_json_printer_cls=pb_json_printer_cls,
                                 json_preserve_field_names=json_preserve_field_names)

        if isinstance(response, google.protobuf.message.Message):
            if self.tp is None:
                return make_response(response, accept_mimetypes=accept_mimetypes, status_msg=status_msg,
                                     pb_json_printer_cls=pb_json_printer_cls,
                                     json_preserve_field_names=json_preserve_field_names)
            return self.tp.apply(make_response, (response, status_msg, accept_mimetypes),
                                 dict(pb_json_printer_cls=pb_json_printer_cls,
                                      json_preserve_field_names=json_preserve_field_names))
        elif isinstance(response, flask.Response):
            return response
        else:
            raise RuntimeError('Response is neither protobuf message nor flask.Response')

    def _maybe_write_extended_log(self, started_at, finished_at, protobuf_request, auth_subject, method_name,
                                  is_method_destructive=False, sent_at=None):
        """
        :type started_at: float
        :type finished_at: float
        :param protobuf_request: Protobuf message
        :type auth_subject: infra.swatlib.rpc.authentication.AuthSubject
        :type method_name: six.text_type
        :type is_method_destructive: bool
        :type sent_at: float | None
        """
        pass

    def _maybe_reject_method(self, method):
        """
        :type method: six.text_type
        :rtype: Optional[flask.Response]
        """
        return None

    def _maybe_reject_user(self, auth_subject):
        """
        :type auth_subject: infra.swatlib.rpc.authentication.AuthSubject
        :rtype: Optional[flask.Response]
        """
        return None

    def method(self, method_name, request_type, response_type,
               allow_http_methods=_default_methods,
               need_authentication=True, max_in_flight=None, is_destructive=False):
        """
        Decorator to register provided function as RPC handler.

        Upon invocation user function will be called as
        >> function(protobuf_request, auth_subject)
        Where:
            * protobuf_request - protobuf message object, which was instantiated from :param request_type:
            * auth_subject - Object holding authentication info, particular login or
                             None (if authentication was disabled).
        Function should either:
            * return protobuf response, which must be of type :param response_type:
            * raise one of RpcErrors, defined in .exceptions.py

        Either way response or exception will be serialized in JSON/Protobuf depending on
        Accept: HTTP header. Default content type for response is JSON.

        :param method_name: RPC method name.
        :param request_type: Protobuf message, which holds incoming parameters.
        :param response_type: Protobuf message, which will contain response.
        :param allow_http_methods: A list of allow HTTP methods. POST and GET by default.
        :param need_authentication: If we need user request to be authenticated.
        :param max_in_flight: Number of maximum concurrent requests in flight.
        :param is_destructive: bool, need for logging
        """
        if self.validate_schema:
            request_schema = jsonschemautil.infer_schema(request_type.DESCRIPTOR)
        else:
            request_schema = {}
        if self.infer_response_schema:
            response_schema = jsonschemautil.infer_schema(response_type.DESCRIPTOR)
        else:
            response_schema = {}

        self.rpc_methods.append((
            method_name,
            request_schema,
            response_schema
        ))
        self.schema_bodies_lookup[method_name] = ujson.dumps({
            'request_schema': request_schema,
            'response_schema': response_schema
        })

        def real_method(fun):
            if max_in_flight:
                l = climit.CLimit(max_in_flight)
            else:
                l = None

            def handle_http_request(**kwargs):
                flask_request = flask.request
                flask_request.json_preserve_field_names = self.json_preserve_field_names(flask_request)

                error_response = self._maybe_reject_method(method_name)
                if error_response is not None:
                    return error_response

                if l is not None:
                    if not l.add():
                        return make_response(self.BUSY_ERROR,
                                             accept_mimetypes=flask_request.accept_mimetypes,
                                             status_msg=self.status_msg,
                                             status=429,
                                             json_preserve_field_names=flask_request.json_preserve_field_names)
                try:
                    started = time.time()
                    if flask_request.method not in allow_http_methods:
                        return make_response(self.BAD_METHOD_ERROR,
                                             accept_mimetypes=flask_request.accept_mimetypes,
                                             status_msg=self.status_msg,
                                             status=httplib.BAD_REQUEST,
                                             json_preserve_field_names=flask_request.json_preserve_field_names)

                    # Initialize empty protobuf object
                    protobuf_request = request_type()
                    # Parse request
                    error_response = parse_flask_request(protobuf_request, flask_request, self.status_msg, self.log,
                                                         validate_schema=self.validate_schema)
                    if error_response is not None:
                        return error_response
                    # Authenticate if needed
                    if need_authentication:
                        auth_subject, error_response = authenticate_request(
                            flask_request, self.get_authenticator(), self.status_msg, self.log)
                        if error_response is not None:
                            return error_response

                        error_response = self._maybe_reject_user(auth_subject)
                        if error_response is not None:
                            return error_response

                        if self._maybe_reject_outstaff_user:
                            error_response = self._maybe_reject_outstaff_user(auth_subject, fun, self.status_msg)
                            if error_response is not None:
                                return error_response
                    else:
                        auth_subject = None
                    # Call user handler
                    r = self.call_user_handler(fun, flask_request.accept_mimetypes, protobuf_request,
                                               auth_subject, self.status_msg, log=self.log,
                                               json_preserve_field_names=flask_request.json_preserve_field_names,
                                               request_kwargs=kwargs)
                    finished = time.time()
                    try:
                        sent_at = parse_request.parse_x_start_time_header(flask_request)
                        self._maybe_write_extended_log(
                            started, finished, protobuf_request, auth_subject, method_name,
                            is_method_destructive=is_destructive, sent_at=sent_at)
                    except Exception:
                        self.log.exception('Failed to write extended access log')
                    return r
                finally:
                    if l is not None:
                        l.done()

            url = '/' + method_name + '/'
            # Disable legacy authentication mechanism, we have our own
            handle_http_request.need_auth = False
            handle_http_request.handler = fun
            handle_http_request.bp = self
            handle_http_request.response_msg = response_type
            handle_http_request.status_msg = self.status_msg
            handle_http_request.is_destructive = is_destructive
            # Register flask handler
            return self.route(url, endpoint=method_name, methods=self.ALL_METHODS)(handle_http_request)

        return real_method
