from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert

from intranet.trip.src.db.gateways.base import DBGateway, RecordNotFound
from intranet.trip.src.db.tables import company_table, company_domain_table
from intranet.trip.src.models import Company


class CompanyGateway(DBGateway):

    table = company_table
    pk = 'company_id'
    out_of_company_id = 5  # Вне организаций
    model_class = Company

    async def get_company(self, company_id: int) -> Company:
        return await self._get_one(
            select_list=[company_table],
            where_clause=company_table.c.company_id == company_id,
        )

    async def delete(self, company_id: int) -> None:
        query = company_table.delete().where(company_table.c.company_id == company_id)
        await self.conn.execute(query)

    async def get_all_company_ids(self) -> list[int]:
        query = select([company_table.c.company_id])
        rows = await self._fetchall(query)
        return [row['company_id'] for row in rows]

    async def get_for_staffapi_sync(self, first_id: int, last_id: int) -> list[Company]:
        query = (
            select([
                company_table.c.company_id,
                company_table.c.name
            ])
            .where(company_table.c.company_id >= first_id)
            .where(company_table.c.company_id <= last_id)
        )
        rows = await self._fetchall(query)
        return [Company(**row) for row in rows]

    async def update(self, company_id: int, **fields: dict) -> int:
        company_id = await super().update(company_id, **fields)
        if not company_id:
            raise RecordNotFound('Company is not found')
        return company_id

    async def get_companies_by_domain(self, domain: str) -> list[Company]:
        query = (
            select([
                company_table.c.company_id,
                company_table.c.name,
            ])
            .select_from(
                company_table
                .join(company_domain_table)
            )
            .where(company_domain_table.c.company_id == company_table.c.company_id)
            .where(company_domain_table.c.domain == domain)
        )
        rows = await self._fetchall(query)
        return [Company(**row) for row in rows]

    async def add_company_domain(self, company_id: int, domain: str) -> None:
        query = (
            insert(company_domain_table)
            .values(company_id=company_id, domain=domain)
        )
        await self.conn.scalar(query)
