import logging
import typing
from dataclasses import dataclass

from aiohttp.web import middleware, Request, Response
from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized

from crm.agency_cabinet.common.request_id_utils import USER_ID_VAR
from crm.agency_cabinet.common.server.common.tvm import TvmClient

LOGGER = logging.getLogger('middlewares.tvm')


@dataclass
@middleware
class TvmServiceMiddleware:
    tvm_client: TvmClient
    allowed_clients: typing.List[int]
    development_mode: bool = False

    async def __call__(self, request: Request, handler: typing.Callable[[Request], typing.Awaitable[Response]]) -> Response:
        if self.development_mode and 'No-Check-Service-Ticket' in request.headers:
            return await handler(request)

        if 'X-Ya-Service-Ticket' not in request.headers:
            raise HTTPUnauthorized(reason='Header X-Ya-Service-Ticket not found')

        ticket = request.headers['X-Ya-Service-Ticket']

        parsed_ticket = await self.tvm_client.parse_service_ticket(ticket)
        if parsed_ticket is None:
            raise HTTPUnauthorized(reason='Empty parsed TVM service ticket')
        if parsed_ticket.src not in self.allowed_clients:
            raise HTTPForbidden(reason=f'Unknown source: {parsed_ticket.src}')

        return await handler(request)


@dataclass
@middleware
class TvmUserMiddleware:
    tvm_client: TvmClient
    development_mode: bool = False

    async def __call__(self, request: Request, handler: typing.Callable[[Request], typing.Awaitable[Response]]) -> Response:
        token = None
        try:
            if 'No-Check-User-Ticket' in request.headers:
                request['yandex_uid'] = int(request.headers['No-Check-User-Ticket'])
                token = USER_ID_VAR.set(request['yandex_uid'])
                return await handler(request)

            if 'X-Ya-User-Ticket' not in request.headers:
                raise HTTPUnauthorized(reason='Header X-Ya-User-Ticket not found')

            ticket = request.headers['X-Ya-User-Ticket']
            parsed_ticket = await self.tvm_client.parse_user_ticket(ticket)
            if parsed_ticket is None:
                raise HTTPUnauthorized(reason='X-Ya-User-Ticket is invalid')

            request['yandex_uid'] = parsed_ticket.default_uid
            token = USER_ID_VAR.set(request['yandex_uid'])
            return await handler(request)
        finally:
            if token:
                USER_ID_VAR.reset(token)
