import logging
from typing import Awaitable, Callable, Optional, Type
from uuid import uuid4

from aiohttp import web

from sendr_qlog import LoggerContext

handler_logger = logging.getLogger('request_logger')

HANDLER_TYPE = Callable[[web.Request], Awaitable[web.Response]]


def get_middleware_logging_adapter(
    logger_context_cls: Optional[Type[LoggerContext]] = None,
) -> Callable[[web.Request, HANDLER_TYPE], Awaitable[web.Response]]:

    if logger_context_cls is not None:
        logger_cls = logger_context_cls
    else:
        logger_cls = LoggerContext

    @web.middleware
    async def middleware_logging_adapter(
        request: web.Request,
        handler: Callable[[web.Request], Awaitable[web.Response]],
    ) -> web.Response:
        req_id = request.headers.get('X-Request-Id') or request.headers.get('X-Req-Id') or uuid4().hex
        request['request-id'] = req_id
        request['logger'] = logger_cls(handler_logger, {'request-id': req_id})
        return await handler(request)

    return middleware_logging_adapter


@web.middleware
async def middleware_debug_logging(
    request: web.Request,
    handler: Callable[[web.Request], Awaitable[web.Response]],
) -> web.Response:
    with request['logger'] as logger:
        logger.context_push(path=request.rel_url.raw_path, url=request.url, has_body=request.has_body)
        return await handler(request)


async def signal_request_id_header(request: web.Request, response: web.Response) -> None:
    request_id = request.get('request-id')
    assert request_id, 'No X-Request-Id in request'
    response.headers['X-Request-Id'] = request_id
