from types import SimpleNamespace
from typing import Optional

from aiohttp import ClientSession, TraceConfig, TraceRequestStartParams, TraceRequestEndParams, TraceRequestExceptionParams

import asyncio
import re
import yarl

from mail.python.theatre.logging.log_adapter import CustomAdapter


_UNISTAT_SIGNAL_NAME_ALLOWED_CHARS_RE = re.compile(r'[^a-zA-Z0-9\.\-/@_]+')


def make_http_logger(logger):
    fields = {
        'host': str,
        'port': int,
        'uri': str,
        'status': int,
        'total_time': float,
    }
    return CustomAdapter(logger, fields)


class ProfiledClientSession(ClientSession):
    def __init__(self, metrics, logger, *args, **kwargs):
        self.logger = make_http_logger(logger)
        self.metrics = metrics
        self.trace_config = TraceConfig()
        self.trace_config.on_request_start.append(self.on_request_start)
        self.trace_config.on_request_end.append(self.on_request_end)
        self.trace_config.on_request_exception.append(self.on_request_exception)
        super(ProfiledClientSession, self).__init__(trace_configs=[self.trace_config], *args, **kwargs)

    async def on_request_start(self, session: ClientSession, trace_ctx: SimpleNamespace, params: TraceRequestStartParams):
        trace_ctx.start = asyncio.get_running_loop().time()

    async def on_request_end(self, session: ClientSession, trace_ctx: SimpleNamespace, params: TraceRequestEndParams):
        first_number_code = params.response.status // 100
        self.metrics.increase_global_meter(normalize_unistat_signal_name(f'http_{params.url.host}_{first_number_code}xx_summ'))
        time_s = asyncio.get_running_loop().time() - trace_ctx.start
        self.metrics.put_in_hist(normalize_unistat_signal_name(f'http_{params.url.host}_ms'), time_s * 1000)

        self.log(url=params.url, status=params.response.status, total_time=time_s)

    async def on_request_exception(self, session: ClientSession, trace_ctx: SimpleNamespace, params: TraceRequestExceptionParams):
        self.metrics.increase_global_meter(normalize_unistat_signal_name(f'http_{params.url.host}_exception_summ'))
        time_s = asyncio.get_running_loop().time() - trace_ctx.start
        self.metrics.put_in_hist(normalize_unistat_signal_name(f'http_{params.url.host}_ms'), time_s * 1000)

        self.log(url=params.url, status=None, total_time=time_s, msg=str(params.exception))

    def log(self, url: yarl.URL, status: Optional[int], total_time: float, msg: str = 'success'):
        if status is None or 5 == status // 100:
            self.logger.error(msg=msg, host=url.host, port=url.port, uri=url, status=status, total_time=total_time)
        else:
            self.logger.info(msg=msg, host=url.host, port=url.port, uri=url, status=status, total_time=total_time)


def normalize_unistat_signal_name(signal_name: str) -> str:
    normalized_name = re.sub(_UNISTAT_SIGNAL_NAME_ALLOWED_CHARS_RE, '_', signal_name)
    normalized_name = '_'.join(re.split('_+', normalized_name))
    if normalized_name.startswith('_'):
        return normalized_name[1:]
    return normalized_name
