from typing import Optional

from fastapi import Request

from intranet.trip.src.api.schemas import PaginatedResponse
from intranet.trip.src.cache import Cache
from intranet.trip.src.config import settings
from intranet.trip.src.enums import Provider
from intranet.trip.src.exceptions import PermissionDenied
from intranet.trip.src.logic.providers import create_provider_gateway
from intranet.trip.src.unit_of_work import UnitOfWork


if settings.ENABLE_PYSCOPG2:
    def get_unit_of_work(read_only: bool = False):
        async def _get_unit_of_work(request: Request):
            async with request.app.state.db.acquire(read_only=read_only) as conn:
                yield UnitOfWork(conn=conn, redis=request.app.state.redis)
        return _get_unit_of_work
else:
    async def _get_unit_of_work(request: Request):
        async with request.app.state.db.acquire() as conn:
            yield UnitOfWork(conn=conn, redis=request.app.state.redis)

    def get_unit_of_work(read_only: bool = False):
        return _get_unit_of_work


async def get_provider_gateway(request: Request, provider: Provider = None):
    return create_provider_gateway(user=request.state.user, provider=provider)


async def get_cache(request: Request):
    yield Cache(request.app.state.redis)


def only_for_tvm(tvm_service: str):

    async def dependency(request: Request):
        if settings.ENV_TYPE == 'development':
            return

        service_ticket = request.state.service_ticket
        if not service_ticket:
            raise PermissionDenied(log_message='Service ticket is empty')
        if service_ticket.src != settings.tvm_services[tvm_service]:
            raise PermissionDenied(log_message='Service ticket src is not valid')

    return dependency


class Pagination:

    default_page = 1
    default_limit = 20

    def __init__(self, request: Request, page: int = None, limit: int = None):
        self.request = request
        self.page = page or self.default_page
        self.limit = limit or self.default_limit
        self.offset = (self.page - 1) * self.limit

    def _get_next(self, count: int) -> Optional[str]:
        if self.offset + self.limit >= count:
            return
        return str(
            self.request.url.include_query_params(
                limit=self.limit,
                page=self.page + 1,
            )
        )

    def _get_previous(self) -> Optional[str]:
        if self.page == 1:
            return
        return str(
            self.request.url.include_query_params(
                limit=self.limit,
                page=self.page - 1,
            )
        )

    def get_paginated_response(self, data: list, count: int) -> PaginatedResponse:
        return PaginatedResponse(
            data=data,
            count=count,
            page=self.page,
            limit=self.limit,
            next=self._get_next(count),
            previous=self._get_previous(),
        )


class ZeroBasedPagination(Pagination):

    default_page = 0
