"""
    Middleware для aiohttp для проверки TVM заголовков
"""
from typing import Any, Callable, Optional

from sendr_qstats import Counter, MetricsRegistry, registry


def get_tvm_restrictor(tvm: Any,
                       check_func: Callable = lambda x, y: True,
                       error_response: Optional[Callable] = None,
                       handle_registry: MetricsRegistry = registry.REGISTRY,
                       app_name: Optional[str] = None) -> Callable:
    from collections.abc import Awaitable

    from aiohttp import web

    from sendr_tvm.common.exceptions import TicketParsingException
    from sendr_tvm.qloud_async_tvm import QTVM, TicketCheckResult

    qtvm: Optional[QTVM] = None
    app_name = '' if app_name is None else app_name + '_'

    tvm_usage_counter = Counter(f'{app_name}tvm_usage_counter', labelnames=('status',), registry=handle_registry)
    tvm_usage_by_service = Counter(f'{app_name}tvm_usage_by_service', labelnames=('service',), registry=handle_registry)

    async def middleware(_: web.Application, handler: Callable) -> Callable:
        nonlocal qtvm
        if qtvm is None:
            if isinstance(tvm, dict):
                qtvm = QTVM(**tvm)
            else:
                qtvm = tvm
        ticket_checker = qtvm.ticket_checker()

        async def _handler(request: web.Request) -> web.Response:
            logger = request.get('logger')
            try:
                check_result = ticket_checker.check_headers(request.headers)
                if isinstance(check_result, Awaitable):
                    check_result = await check_result
            except TicketParsingException as exc:
                check_result = TicketCheckResult(None, None)
                logger and logger.error('Ticket parsing error: %s (%s)', exc.message, exc.debug_info)
            if check_result is not None and check_func(request, check_result):
                tvm_usage_counter.labels('tvm').inc()
                tvm_usage_by_service.labels(check_result.src).inc()
                return await handler(request)
            tvm_usage_counter.labels('no_tvm').inc()
            if error_response:
                return error_response()
            return web.Response(
                text='TVM error',
                status=403,
            )

        return _handler

    return middleware
