from sqlalchemy import select, and_

from intranet.trip.src.db.gateways.base import DBGateway, RecordNotFound
from intranet.trip.src.db.tables import (
    person_table,
    person_document_table,
    person_bonus_card_table,
    service_provider_table,
)
from intranet.trip.src.enums import DocumentType, ServiceType
from intranet.trip.src.models import (
    Document,
    BonusCard,
)


class PersonDocumentGateway(DBGateway):

    table = person_document_table
    pk = 'document_id'
    model_class = Document

    async def update(self, document_id: int, **fields) -> None:
        query = (
            person_document_table
            .update()
            .where(person_document_table.c.document_id == document_id)
            .values(**fields)
        )
        await self.conn.execute(query)

    async def delete(self, document_id: int) -> None:
        await self.update(document_id=document_id, is_deleted=True)

    async def get_document(self, document_id: int) -> Document:
        return await self._get_one(
            select_list=[person_document_table],
            where_clause=and_(
                person_document_table.c.document_id == document_id,
                person_document_table.c.is_deleted.is_(False),
            ),
        )

    async def get_documents(self, person_id: int) -> list[Document]:
        person_id = await self.conn.scalar(
            select([person_table.c.person_id])
            .where(person_table.c.person_id == person_id)
        )
        if not person_id:
            raise RecordNotFound('Person is not found')
        query = (
            select([
                person_document_table,
            ])
            .where(person_document_table.c.person_id == person_id)
            .where(person_document_table.c.is_deleted.is_(False))
            .order_by('document_id')
        )
        rows = await self._fetchall(query)
        return [Document(**row) for row in rows]

    async def get_all_passports_for_staff_sync(self) -> list[Document]:
        query = (
            select(
                [person_document_table]
            )
            .where(
                person_document_table.c.document_type.in_(
                    [DocumentType.passport, DocumentType.external_passport]
                )
            )
            .where(person_document_table.c.is_deleted.is_(False))
            .order_by(
                person_document_table.c.person_id,
                person_document_table.c.series,
                person_document_table.c.number,
            )
        )
        return [
            Document(**data)
            for data in await self._fetchall(query)
        ]


class PersonBonusCardGateway(DBGateway):

    table = person_bonus_card_table
    pk = 'bonus_card_id'
    model_class = BonusCard

    async def get_bonus_card(self, bonus_card_id: int) -> BonusCard:
        query = (
            select([
                person_bonus_card_table,
                service_provider_table.c.name,
                service_provider_table.c.name_en,
            ])
            .select_from(
                person_bonus_card_table
                .join(service_provider_table)
            )
            .where(person_bonus_card_table.c.bonus_card_id == bonus_card_id)
        )
        record = await self._first(query)
        return BonusCard(**record)

    async def get_person_bonus_cards(
        self,
        person_id: int,
        service_type: ServiceType = None,
    ) -> list[BonusCard]:
        query = (
            select([
                person_bonus_card_table,
                service_provider_table.c.name,
                service_provider_table.c.name_en,
            ])
            .select_from(
                person_bonus_card_table
                .join(service_provider_table)
            )
            .where(person_bonus_card_table.c.person_id == person_id)
        )
        if service_type:
            query = query.where(
                person_bonus_card_table.c.service_provider_type == service_type,
            )
        return [BonusCard(**item) for item in await self._fetchall(query)]
