import functools
import logging
from typing import Collection, Optional

from aiohttp.web import Request, Response, middleware

from smb.common.aiotvm import CheckTicketFails, TvmClient

__all__ = ["Middleware", "tvm_auth"]


class tvm_auth:
    @staticmethod
    def skip(func):
        func.tvm_checker = TvmChecker.skip_check
        return func

    @staticmethod
    def only(whitelist_name: str):
        def wrapped(func):
            func.tvm_checker = functools.partial(
                TvmChecker.check_by_whitelist, whitelist_name=whitelist_name
            )
            return func

        return wrapped


class TvmChecker:
    @classmethod
    async def skip_check(cls, *args, **kwargs):
        pass

    @classmethod
    async def check_by_whitelist(
        cls,
        whitelist_name: str,
        request: Request,
        tvm_client: TvmClient,
        config: dict,
        is_whitelist_required: bool = True,
    ) -> Optional[int]:
        tvm_whitelist = config.get(whitelist_name, None)

        if tvm_whitelist:
            return await cls.auth_with_tvm(tvm_client, request, tvm_whitelist)

        if tvm_whitelist is not None and len(tvm_whitelist) == 0:
            logging.getLogger(__name__).warning(
                f"Tvm whitelist {whitelist_name} is empty. Auth failed."
            )
            return 403

        if not tvm_whitelist and is_whitelist_required:
            logging.getLogger(__name__).warning(
                f"Missed tvm auth config {whitelist_name}. Auth failed."
            )
            return 403

    @staticmethod
    async def auth_with_tvm(
        tvm_client: TvmClient, request: Request, tvm_whitelist: Collection[int]
    ) -> Optional[int]:
        if "X-Ya-Service-Ticket" not in request.headers:
            return 401

        ticket = request.headers["X-Ya-Service-Ticket"]

        try:
            tvm_id = await tvm_client.fetch_service_source_id(ticket=ticket)
        except CheckTicketFails:
            return 403

        if tvm_id not in tvm_whitelist:
            return 403


class Middleware:
    urls_whitelist = ("/ping", "/sensors/")
    tvm_client: TvmClient
    config: dict

    def __new__(cls, *args, **kwargs):
        instance = super().__new__(cls)
        return middleware(instance)

    def __init__(self, tvm_client: TvmClient, config: dict):
        self.tvm_client = tvm_client
        self.config = config

    async def __call__(self, request: Request, handler) -> Response:
        path = request.match_info.get_info().get("path")
        if path in self.urls_whitelist:
            return await handler(request)

        is_whitelist_required = True
        tvm_checker = getattr(request.match_info.handler, "tvm_checker", None)
        if not tvm_checker:
            tvm_checker = functools.partial(
                TvmChecker.check_by_whitelist, whitelist_name="TVM_WHITELIST"
            )
            is_whitelist_required = False

        err_status = await tvm_checker(
            tvm_client=self.tvm_client,
            config=self.config,
            request=request,
            is_whitelist_required=is_whitelist_required,
        )
        if err_status:
            return Response(status=err_status)

        return await handler(request)
