from typing import Awaitable, Callable

from aiohttp import web

from sendr_qstats.http.aiohttp import get_stats_middleware
from sendr_tvm.qloud_async_tvm import QTVM, TicketCheckResult
from sendr_tvm.server.aiohttp import get_tvm_restrictor

from mail.beagle.beagle.api.exceptions import APIException, TVMServiceTicketException
from mail.beagle.beagle.api.handlers.base import BaseHandler
from mail.beagle.beagle.api.schemas.base import fail_response_schema
from mail.beagle.beagle.conf import settings
from mail.beagle.beagle.utils.stats import REGISTRY

HandlerType = Callable[[web.Request], Awaitable[web.Response]]

TVMConfig = QTVM(
    client_name=settings.TVM_CLIENT,
    host=settings.TVM_HOST,
    port=settings.TVM_PORT,
)


@web.middleware
async def middleware_response_formatter(request: web.Request, handler: HandlerType) -> web.Response:
    try:
        response = await handler(request)
    except APIException as exc:
        response = BaseHandler.make_schema_response(
            data=exc,
            schema=fail_response_schema,
            status=exc.code,
        )
    return response


def apply_debug_tvm_user_ticket(request: web.Request, check_result: TicketCheckResult) -> None:
    debug_header = request.headers.get('X-Ya-User-Ticket-Debug', None)
    if debug_header:
        try:
            uid = int(debug_header)
            check_result._user_ticket = {'uid': uid, 'default_uid': uid}
        except ValueError:
            return


def tvm_check_func(request: web.Request, check_result: TicketCheckResult) -> bool:
    if settings.TVM_DEBUG_USER_TICKET:
        apply_debug_tvm_user_ticket(request, check_result)
    request['tvm'] = check_result
    return True


def tvm_service_ticket_error():
    raise TVMServiceTicketException


middleware_stats = get_stats_middleware(handle_registry=REGISTRY)
middleware_tvm = get_tvm_restrictor(TVMConfig, tvm_check_func, tvm_service_ticket_error, handle_registry=REGISTRY)
