import asyncio
import logging
import time
from typing import Union

from aiohttp.web import Request, Response, json_response, middleware
from smb.common.sensors import MetricGroup

__all__ = [
    "SolomonLoggingHandler",
    "handler",
    "log_request",
    "response_time_middleware",
    "rps_middleware",
]


@middleware
async def rps_middleware(request: Request, handler) -> Response:
    transport = request.transport

    try:
        return await handler(request)
    except asyncio.CancelledError:
        if transport.is_closing():
            log_request("rps", request, "closed_by_client", 1)
        raise


@middleware
async def response_time_middleware(request: Request, handler) -> Response:
    start = time.time()
    transport = request.transport

    try:
        return await handler(request)
    except asyncio.CancelledError:
        if transport.is_closing():
            end = time.time()
            log_request(
                "response_time", request, "closed_by_client", (end - start) * 1000
            )
        raise


def log_request(
    metric_group_name: str,
    request: Request,
    response_code: Union[str, int],
    value: Union[int, float],
):
    lasagna = request.app["lasagna"]

    info = request.match_info.get_info()
    path = info.get("formatter", info.get("path"))
    labels = {"path": path, "response_code": response_code}

    lasagna.sensors.take(metric_group_name, **labels).add(value)


async def handler(request: Request) -> Response:
    lasagna = request.app["lasagna"]

    data = lasagna.sensors.serialize()
    return json_response(data)


class SolomonLoggingHandler(logging.Handler):
    ignored_loggers = (
        "aiohttp.access",
        "aiohttp.client",
        "aiohttp.internal",
        "aiohttp.server",
        "aiohttp.web",
        "aiohttp.websocket",
    )

    metric_group: MetricGroup

    def __init__(self, metric_group: MetricGroup, *args, **kwargs):
        self.metric_group = metric_group
        super().__init__(*args, **kwargs)

    def emit(self, record: logging.LogRecord):
        name = record.name

        if name in self.ignored_loggers:
            return

        self.metric_group.take(name=name, level=record.levelname).add(1)
