from sqlalchemy import and_, func, select

from intranet.trip.src.db.gateways.base import DBGateway, RecordNotFound
from intranet.trip.src.db.subqueries.common import get_person_with_company_column
from intranet.trip.src.db.subqueries.trip import get_person_trip_column
from intranet.trip.src.db.tables import (
    person_trip_table,
    service_provider_table,
    service_table,
    trip_table,
)
from intranet.trip.src.enums import ServiceType
from intranet.trip.src.models import Service, ServiceProvider


class ServiceGateway(DBGateway):

    table = service_table
    pk = 'service_id'
    model_class = Service

    async def get_service(self, service_id: int) -> Service:
        return await self._get_one(
            select_list=[
                service_table,
                get_person_with_company_column(correlate_table=service_table, label='person'),
                trip_table.c.author_id.label('trip_author_id'),
                person_trip_table.c.is_approved.label('is_person_trip_approved'),
                get_person_trip_column(),
            ],
            where_clause=service_table.c.service_id == service_id,
            select_from=(
                service_table
                .join(trip_table, onclause=service_table.c.trip_id == trip_table.c.trip_id)
            ),
        )

    async def get_service_by_aeroclub_id(
            self,
            provider_order_id: int,
            provider_service_id: int,
    ) -> Service:
        return await self._get_one(
            select_list=[
                service_table,
                get_person_with_company_column(correlate_table=service_table, label='person'),
                trip_table.c.author_id.label('trip_author_id'),
                get_person_trip_column(),
            ],
            where_clause=and_(
                service_table.c.provider_service_id == provider_service_id,
                service_table.c.provider_order_id == provider_order_id,
            ),
            select_from=(
                service_table
                .join(trip_table, onclause=service_table.c.trip_id == trip_table.c.trip_id)
            ),
        )

    async def update(self, service_id: int, **fields) -> int:
        query = (
            service_table
            .update()
            .where(service_table.c.service_id == service_id)
            .values(**fields)
            .returning(service_table.c.service_id)
        )
        service_id = await self.conn.scalar(query)
        if not service_id:
            raise RecordNotFound('Service is not found')
        return service_id

    async def bulk_update(self, service_ids: list[int], **fields):
        query = (
            service_table
            .update()
            .where(service_table.c.service_id.in_(service_ids))
            .values(**fields)
        )
        await self.conn.execute(query)

    async def delete(self, service_id: int):
        query = service_table.delete().where(service_table.c.service_id == service_id)
        await self.conn.execute(query)

    async def get_count_group_by_status(self) -> list[tuple[str, int]]:
        query = (
            select([
                service_table.c.status,
                func.count().label('count_service'),
            ])
            .group_by(service_table.c.status)
        )
        rows = await self._fetchall(query)
        return [(row['status'].value, row['count_service']) for row in rows]

    async def get_broken_services_count(self) -> int:
        query = (
            select([func.count(service_table.c.service_id)])
            .where(service_table.c.is_broken.is_(True))
        )
        return await self.conn.scalar(query)


class ServiceProviderGateway(DBGateway):

    table = service_provider_table
    model_class = ServiceProvider

    async def find_provider(
        self,
        search_query: str,
        service_type: ServiceType = None,
    ) -> list[ServiceProvider]:
        query = (
            select([service_provider_table])
            .where(
                service_provider_table.c.code.ilike(f'%{search_query}%')
                | service_provider_table.c.name.ilike(f'%{search_query}%')
                | service_provider_table.c.name_en.ilike(f'%{search_query}%')
            )
        )
        if service_type:
            query = query.where(
                service_provider_table.c.service_type == service_type,
            )
        return [ServiceProvider(**item) for item in await self._fetchall(query)]
