from typing import Awaitable, Callable

from aiohttp import web

from sendr_qstats.http.aiohttp import get_stats_middleware
from sendr_tvm.qloud_async_tvm import TicketCheckResult
from sendr_tvm.server.aiohttp import get_tvm_restrictor

from mail.ipa.ipa.api.exceptions import APIException, TVMServiceTicketException
from mail.ipa.ipa.api.handlers.base import BaseHandler
from mail.ipa.ipa.api.schemas.base import fail_response_schema
from mail.ipa.ipa.conf import settings
from mail.ipa.ipa.utils.stats import REGISTRY
from mail.ipa.ipa.utils.tvm import TVM_CONFIG

HandlerType = Callable[[web.Request], Awaitable[web.StreamResponse]]


@web.middleware
async def middleware_response_formatter(request: web.Request, handler: HandlerType) -> web.StreamResponse:
    try:
        response = await handler(request)
    except APIException as exc:
        response = BaseHandler.make_schema_response(
            data=exc,
            schema=fail_response_schema,
            status=exc.code,
        )
    return response


def tvm_service_ticket_error() -> None:
    raise TVMServiceTicketException


def tvm_check_func(request: web.Request, check_result: TicketCheckResult) -> bool:
    request['tvm'] = check_result
    allowed_client = check_result.valid and check_result.src in settings.TVM_ALLOWED_CLIENTS
    allowed_path = request.match_info.route.name in settings.TVM_OPEN_PATHS
    return not settings.TVM_CHECK_SERVICE_TICKET or allowed_client or allowed_path


middleware_stats = get_stats_middleware(handle_registry=REGISTRY)
middleware_tvm = get_tvm_restrictor(TVM_CONFIG, tvm_check_func, tvm_service_ticket_error, handle_registry=REGISTRY)
