from typing import Awaitable, Callable

from aiohttp import web

from sendr_qstats.http.aiohttp import get_stats_middleware
from sendr_tvm.qloud_async_tvm import TicketCheckResult
from sendr_tvm.server.aiohttp import get_tvm_restrictor

from mail.ciao.ciao.api.exceptions import APIException
from mail.ciao.ciao.api.handlers.base import BaseHandler
from mail.ciao.ciao.api.schemas.base import fail_response_schema
from mail.ciao.ciao.conf import settings
from mail.ciao.ciao.utils.logging import LOGGER
from mail.ciao.ciao.utils.stats import REGISTRY
from mail.ciao.ciao.utils.tvm import TVM_CONFIG

HandlerType = Callable[[web.Request], Awaitable[web.Response]]


@web.middleware
async def middleware_response_formatter(request: web.Request, handler: HandlerType) -> web.Response:
    try:
        response = await handler(request)
    except APIException as exc:
        response = BaseHandler.make_schema_response(
            data=exc,
            schema=fail_response_schema,
            status=exc.code,
        )
    return response


@web.middleware
async def middleware_logging_context(request: web.Request, handler: HandlerType) -> web.Response:
    LOGGER.set(request['logger'])
    return await handler(request)


def tvm_check_func(request: web.Request, check_result: TicketCheckResult) -> bool:
    if settings.TVM_DEBUG_USER_TICKET:
        try:
            uid = int(request.headers['X-Ya-User-Ticket-Debug'])
            check_result._user_ticket = {'uid': uid, 'default_uid': uid}
        except (ValueError, KeyError):
            pass
    request['tvm'] = check_result
    allowed_client = check_result.valid and check_result.src in settings.TVM_ALLOWED_CLIENTS
    allowed_path = request.match_info.route.name in settings.TVM_OPEN_PATHS
    return not settings.TVM_CHECK_SERVICE_TICKET or allowed_client or allowed_path


middleware_tvm_restrictor = get_tvm_restrictor(TVM_CONFIG, tvm_check_func, handle_registry=REGISTRY)

middleware_stats = get_stats_middleware(handle_registry=REGISTRY)
