import logging

from starlette.middleware.base import BaseHTTPMiddleware, Response, RequestResponseEndpoint
from starlette.responses import JSONResponse
from starlette.requests import Request
from starlette.types import ASGIApp

from intranet.library.fastapi_csrf.src.token import check_csrf_token


logger = logging.getLogger(__name__)


SAFE_HTTP_METHODS = {'GET', 'HEAD', 'OPTIONS', 'TRACE'}
DEFAULT_HEADER_NAME = 'X-Csrf-Token'
DEFAULT_TOKEN_LIFETIME = 24 * 60 * 60  # 1 day


class CsrfMiddleware(BaseHTTPMiddleware):
    """
    Middleware для проверки csrf токена

    Для корректной работы нужно наличие атрибута user с полем uid в request.state
    (должна устанавливаться в auth middleware вашего сервиса)

    По умолчанию пытается получить токен из заголовка X-Csrf-Token.
    При желании можно либо заменить заголовок на свой, установив header_name при добавлении
    middleware в приложение, либо получать токен из куки, указанной в cookie_name.
    Если передан cookie_name - токен будет всегда получаться из куки, даже если
    одновременно с этим передается header_name.
    """
    def __init__(
        self,
        app: ASGIApp,
        secret_key: str,
        token_lifetime: int = DEFAULT_TOKEN_LIFETIME,
        cookie_name: str = None,
        header_name: str = DEFAULT_HEADER_NAME,
        **kwargs,
    ):
        self.secret_key = secret_key
        self.token_lifetime = token_lifetime
        self.cookie_name = cookie_name
        self.header_name = header_name
        super().__init__(app, **kwargs)

    def get_token_from_request(self, request: Request) -> str:
        if self.cookie_name is not None:
            return request.cookies.get(self.cookie_name)
        if self.header_name is not None:
            return request.headers.get(self.header_name)
        raise NotImplementedError()

    @staticmethod
    def get_user_from_request(request: Request):
        if hasattr(request.state, 'user') and hasattr(request.state.user, 'uid'):
            return request.state.user
        if 'user' in request.scope and hasattr(request.user, 'uid'):
            return request.user
        return None

    @staticmethod
    def is_url_exempt(request: Request) -> bool:
        for regex in request.app.state.csrf_exempt_endpoints:
            path = request.url.path
            if regex.match(path):
                return True

            alternative_path = path[:-1] if path.endswith('/') else path + '/'
            if regex.match(alternative_path):
                return True
        return False

    def should_validate_token(self, request: Request) -> bool:
        return (
            request.cookies.get('Session_id')
            and request.method not in SAFE_HTTP_METHODS
            and not self.is_url_exempt(request)
        )

    async def dispatch(
        self, request: Request, call_next: RequestResponseEndpoint
    ) -> Response:
        if self.should_validate_token(request):
            user = self.get_user_from_request(request)
            uid = user.uid if user is not None and getattr(user, 'uid', None) else None
            if not check_csrf_token(
                uid=uid,
                yandex_uid=request.cookies.get('yandexuid'),
                token_lifetime=self.token_lifetime,
                secret_key=self.secret_key,
                csrf_token=self.get_token_from_request(request),
            ):
                logger.info('CSRF validation failed for uid %s', uid)
                return JSONResponse(status_code=403, content={'detail': 'Forbidden'})
        return await call_next(request)
