import asyncio
import time
from typing import Awaitable, Callable
from uuid import uuid4

from aiohttp import web
from webargs.aiohttpparser import AIOHTTPParser as BaseAIOHTTPParser

from sendr_qlog import LoggerContext
from sendr_qlog.http.aiohttp import handler_logger
from sendr_qstats.http.aiohttp import get_stats_middleware
from sendr_tvm.qloud_async_tvm import QTVM, TicketCheckResult
from sendr_tvm.server.aiohttp import get_tvm_restrictor

from mail.payments.payments.api.exceptions import APIException, TVMServiceTicketException
from mail.payments.payments.api.handlers.base import BaseHandler
from mail.payments.payments.api.schemas.base import fail_response_schema
from mail.payments.payments.conf import settings
from mail.payments.payments.utils.environment import APPLICATION_ENVIRONMENT
from mail.payments.payments.utils.stats import REGISTRY

HandlerType = Callable[[web.Request], Awaitable[web.Response]]

TVMConfig = QTVM(
    client_name=settings.TVM_CLIENT,
    host=settings.TVM_HOST,
    port=settings.TVM_PORT,
)

middleware_stats = get_stats_middleware(handle_registry=REGISTRY)


async def middleware_logging_adapter(_: web.Application, handler: HandlerType) -> HandlerType:
    async def _handler(request):
        req_id = request.headers.get('X-Request-Id') or request.headers.get('X-Req-Id') or uuid4().hex
        request['request_id'] = req_id
        request['logger'] = LoggerContext(handler_logger, {'request-id': req_id, 'request-url': request.url})
        return await handler(request)

    return _handler


async def measure_time(coro: Awaitable[web.Response], logger: LoggerContext) -> web.Response:
    start = time.perf_counter()
    try:
        return await coro
    finally:
        elapsed = time.perf_counter() - start
        logger.context_push(response_time=elapsed)


@web.middleware
async def middleware_response_formatter(request: web.Request, handler: HandlerType) -> web.Response:
    logger = request['logger']
    logger.context_push(url=request.path)

    try:
        response = await measure_time(handler(request), logger)
        if request.match_info.route.name not in settings.LOG_ACCESS_MUTED_ROUTES:
            logger.context_push(code=response.status)
            logger.info('Request performed')
    except APIException as exc:
        logger.context_push(message=exc.message, code=exc.code)
        logger.exception('An exception occurred while processing request')
        response = BaseHandler.make_response(
            data=exc,
            schema=fail_response_schema,
            status=exc.code,
        )
    except Exception:
        logger.exception('Unhandled exception')
        raise
    except asyncio.CancelledError:
        logger.exception('Request handler is cancelled')
        raise
    return response


@web.middleware
async def middleware_header_cloud(request: web.Request, handler: HandlerType) -> web.Response:
    """
    Middleware adds `X-Cloud` header to response.
    Possible values of the header ['deploy', 'qloud', 'unknown'].
    :param request: web request.
    :param handler: next handler.
    :return: response with set `X-Cloud` header according to the environment the application runs at.
    """
    response: web.Response = await handler(request)
    response.headers['X-Cloud'] = APPLICATION_ENVIRONMENT
    return response


def tvm_service_ticket_error():
    raise TVMServiceTicketException


def apply_debug_tvm_user_ticket(request: web.Request, check_result: TicketCheckResult) -> None:
    debug_header = request.headers.get('X-Ya-User-Ticket-Debug', None)
    if debug_header:
        try:
            uid = dict(settings.TVM_DEBUG_USER_TICKET_VALUES).get(debug_header, None) or int(debug_header)
            check_result._user_ticket = {'uid': uid, 'default_uid': uid}
        except ValueError:
            return


def tvm_check_func(request: web.Request, check_result: TicketCheckResult) -> bool:
    route_name = request.match_info.route.name
    if settings.TVM_DEBUG_USER_TICKET:
        apply_debug_tvm_user_ticket(request, check_result)
    request['tvm'] = check_result

    allowed_client = False

    for client in settings.TVM_ALLOWED_CLIENTS:
        if isinstance(client, tuple):
            client_id, params = client
        else:
            client_id = client
            params = {'acl': ['common']}

        if check_result.src == client_id:
            route_acls = settings.TVM_ROUTES_ACLS.get(route_name, ['common'])
            allowed_client = bool(set(params.get('acl', [])).intersection(route_acls))
            break

    allowed_client = allowed_client and check_result.valid
    allowed_path = route_name in settings.TVM_OPEN_PATHS

    return not settings.TVM_CHECK_SERVICE_TICKET or allowed_client or allowed_path


middleware_tvm = get_tvm_restrictor(TVMConfig, tvm_check_func, tvm_service_ticket_error)


class AIOHTTPParser(BaseAIOHTTPParser):
    def handle_error(self, error, *args, **kwargs):
        raise APIException(code=400, message='Bad Request', params=error.messages)
