from typing import Awaitable, Callable

from aiohttp import web

from sendr_qstats.http.aiohttp import get_stats_middleware
from sendr_tvm.qloud_async_tvm import TicketCheckResult
from sendr_tvm.server.aiohttp import get_tvm_restrictor

from mail.ohio.ohio.api.exceptions import APIException, TVMServiceTicketException
from mail.ohio.ohio.api.handlers.base import BaseHandler
from mail.ohio.ohio.api.schemas.base import fail_response_schema
from mail.ohio.ohio.conf import settings
from mail.ohio.ohio.utils.stats import REGISTRY
from mail.ohio.ohio.utils.tvm import TVM_CONFIG

HandlerType = Callable[[web.Request], Awaitable[web.Response]]

COMMON_ACL = 'common'


@web.middleware
async def middleware_response_formatter(request: web.Request, handler: HandlerType) -> web.Response:
    try:
        response = await handler(request)
    except APIException as exc:
        response = BaseHandler.make_schema_response(
            data=exc,
            schema=fail_response_schema,
            status=exc.code,
        )
    return response


def apply_debug_tvm_user_ticket(request: web.Request, check_result: TicketCheckResult) -> None:
    debug_header = request.headers.get('X-Ya-User-Ticket-Debug', None)
    if debug_header:
        try:
            uid = int(debug_header)
            check_result._user_ticket = {'uid': uid, 'default_uid': uid}
        except ValueError:
            return


def tvm_check_func(request: web.Request, check_result: TicketCheckResult) -> bool:
    if settings.TVM_DEBUG_USER_TICKET:
        apply_debug_tvm_user_ticket(request, check_result)
    request['tvm'] = check_result

    if not check_result.valid or check_result.src not in settings.TVM_ALLOWED_CLIENTS:
        allowed_client = False
    else:
        client = settings.TVM_ALLOWED_CLIENTS[check_result.src]
        client_acl = set(client.get('acl', [COMMON_ACL]))
        acl = set(settings.TVM_ROUTES_ACLS.get(request.match_info.route.name, [COMMON_ACL]))
        allowed_client = bool(acl.intersection(client_acl))

    allowed_path = request.match_info.route.name in settings.TVM_OPEN_PATHS
    return not settings.TVM_CHECK_SERVICE_TICKET or allowed_path or allowed_client


def tvm_service_ticket_error():
    raise TVMServiceTicketException


middleware_stats = get_stats_middleware(handle_registry=REGISTRY)
middleware_tvm = get_tvm_restrictor(TVM_CONFIG, tvm_check_func, tvm_service_ticket_error, handle_registry=REGISTRY)
