from sqlalchemy import select, func, distinct, or_

from intranet.trip.src.config import settings
from intranet.trip.src.db.gateways.base import DBGateway, RecordNotFound
from intranet.trip.src.db.subqueries.common import get_company_column
from intranet.trip.src.db.subqueries.person import (
    get_person_passports_column,
    get_person_bonus_cards_column,
    get_ext_person_passports_column,
    get_ext_person_documents_column,
    get_ext_person_bonus_cards_column,
)
from intranet.trip.src.db.tables import (
    person_table,
    person_document_table,
    person_trip_table,
    person_relationship_table,
    person_bonus_card_table,
    company_table,
    ext_person_table,
)
from intranet.trip.src.enums import PersonRole
from intranet.trip.src.lib.utils import paginate
from intranet.trip.src.models import (
    User,
    Person,
    PersonRelationship,
    ExtPerson,
)


class PersonGateway(DBGateway):

    table = person_table
    pk = 'person_id'
    model_class = Person

    async def get_user(self, person_id: int) -> User:
        return await self._get_one(
            select_list=[
                person_table,
                get_company_column(),
            ],
            where_clause=person_table.c.person_id == person_id,
            model_class=User,
        )

    async def get_user_by_uid(self, uid: str) -> User:
        return await self._get_one(
            select_list=[
                person_table,
                get_company_column(),
            ],
            where_clause=person_table.c.uid == uid,
            model_class=User,
            select_from=person_table.outerjoin(company_table),  # company_id is nullable
        )

    async def is_person_exists(self, uid: str) -> bool:
        query = (
            select([person_table.c.person_id])
            .where(person_table.c.uid == uid)
        )
        obj_id = await self.conn.scalar(query)
        return obj_id is not None

    async def get_person(self, person_id: int) -> Person:
        return await self._get_one(
            select_list=[
                person_table,
                get_company_column(),
            ],
            where_clause=person_table.c.person_id == person_id,
        )

    async def search_persons(
        self,
        text: str,
        limit: int,
        holding_id: int,
        person_id: int = None,
    ) -> list[Person]:
        text = text.lower()
        search_fields = (
            person_table.c.login,
            person_table.c.first_name,
            person_table.c.first_name_en,
            person_table.c.last_name,
            person_table.c.last_name_en,
        )
        where_parts = [
            func.lower(field.collate(settings.POSTGRES_COLLATION)).contains(text)
            for field in search_fields
        ]
        query = (
            select([
                person_table,
                company_table,
            ])
            .where(person_table.c.is_dismissed.is_(False))
            .where(person_table.c.company_id == company_table.c.company_id)
            .where(or_(*where_parts))
        )
        if holding_id:
            query = query.where(company_table.c.holding_id == holding_id)
        if person_id:
            query = query.where(person_table.c.person_id == person_id)
        query = query.limit(limit)
        rows = await self._fetchall(query)
        return [Person(**item) for item in rows]

    async def get_person_by_aeroclub_profile(self, profile_id: int) -> Person:
        return await self._get_one(
            select_list=[
                person_table,
                get_company_column(),
            ],
            where_clause=person_table.c.provider_profile_id == profile_id,
        )

    async def get_provider_profile_id_by_uid(self, person_uid: str) -> int:
        query = (
            select([person_table.c.provider_profile_id])
            .where(person_table.c.uid == person_uid)
        )
        return await self.conn.scalar(query)

    async def get_person_ids_for_hub_sync(self, profile_ids: list[int] = None) -> list[int]:
        """возвращает список person_id с заполненными provider_profile_id"""
        query = (
            select([person_table.c.person_id])
            .where(person_table.c.provider_profile_id.isnot(None))
        )
        if profile_ids:
            query = query.where(person_table.c.person_id.in_(profile_ids))

        rows = await self._fetchall(query)
        return [row['person_id'] for row in rows]

    async def get_id_by_login(self, login: str) -> int:
        data = await self.get_ids_by_logins([login])
        if login not in data:
            raise RecordNotFound('Not found')
        return data[login]

    async def get_ids_by_logins(self, logins: list[str]) -> dict[str, int]:
        query = (
            select([
                person_table.c.person_id,
                person_table.c.login,
            ])
            .where(person_table.c.login.in_(logins))
        )
        rows = await self._fetchall(query)
        return {row['login']: row['person_id'] for row in rows}

    async def get_persons(self, person_ids: list[int]) -> list[Person]:
        query = (
            select([
                person_table,
                get_company_column(),
            ])
            .where(person_table.c.person_id.in_(person_ids))
        )
        rows = await self._fetchall(query)
        return [Person(**item) for item in rows]

    async def bulk_create(self, values: list[dict]) -> list[int]:
        if not values:
            return []
        query = person_table.insert().values(values)
        rows = await self._fetchall(query)
        return [row['person_id'] for row in rows]

    async def raise_if_not_exists(self, id_: int):
        query = (
            select([self.table.c[self.pk]])
            .where(self.table.c[self.pk] == id_)
        )
        obj_id = await self.conn.scalar(query)
        if not obj_id:
            raise RecordNotFound('Not found')

    async def get_all_ids_with_person_trip(self) -> set[int]:
        query = (
            select([person_trip_table.c.person_id])
            .distinct()
        )
        data = await self._fetchall(query)
        return {item['person_id'] for item in data}

    async def get_for_ihub_sync(self, person_ids: list[int] = None) -> list[dict]:
        involved_persons_subquery = (
            select([person_trip_table.c.person_id])
            .distinct()
        )
        query = (
            select([
                person_table,
                get_company_column(),
                get_person_passports_column(),
                get_person_bonus_cards_column(),
            ])
            .where(
                person_table.c.person_id.in_(
                    person_ids if person_ids else involved_persons_subquery,
                )
            )
            .where(person_table.c.gender.isnot(None))
        )
        data = await self._fetchall(query)
        return list(data)

    async def get_count_for_ihub_sync(self) -> int:
        query = (
            select([func.count(distinct(person_trip_table.c.person_id))])
            .select_from(person_trip_table)
        )
        return await self.conn.scalar(query)

    async def get_for_staffapi_sync(self, person_id_from: int, person_id_to: int) -> list[Person]:
        """По диапазону id возвращает сотрудников из базы"""
        query = (
            select([person_table])
            .where(person_table.c.person_id >= person_id_from)
            .where(person_table.c.person_id <= person_id_to)
        )
        data = await self._fetchall(query)
        return [Person(**item) for item in data]

    async def update(self, person_id: int, **fields) -> int:
        query = (
            person_table
            .update()
            .where(person_table.c.person_id == person_id)
            .values(**fields)
            .returning(person_table.c.person_id)
        )
        person_id = await self.conn.scalar(query)
        if not person_id:
            raise RecordNotFound('Person is not found')
        return person_id

    async def get_all_ids(self, with_dismissed=False) -> list[int]:
        query = select([person_table.c.person_id])
        if not with_dismissed:
            query = query.where(person_table.c.is_dismissed.is_(False))

        data = await self._fetchall(query)
        return [item['person_id'] for item in data]

    # for IDM
    async def add_role(self, login: str, role_key: str) -> int:
        fields = {role_key: True}
        query = (
            person_table
            .update()
            .where(person_table.c.login == login)
            .values(**fields)
        )
        result = await self.conn.execute(query)
        return result.rowcount

    async def remove_role(self, login: str, role_key: str) -> int:
        fields = {role_key: False}
        query = (
            person_table
            .update()
            .where(person_table.c.login == login)
            .values(**fields)
        )
        result = await self.conn.execute(query)
        return result.rowcount

    async def get_roles(self) -> dict[str, dict[str, bool]]:
        query = (
            select([
                person_table.c.login,
                person_table.c.is_coordinator,
                person_table.c.is_limited_access,
            ])
            .where(person_table.c.is_dismissed.is_(False))
            .where(or_(
                person_table.c.is_coordinator.is_(True),
                person_table.c.is_limited_access.is_(True),
            ))
            .order_by(person_table.c.login)
        )
        rows = await self._fetchall(query)
        return {
            row['login']: {
                'coordinator': row['is_coordinator'],
                'limited_access': row['is_limited_access'],
            } for row in rows
        }


class PersonRelationshipGateway(DBGateway):

    table = person_relationship_table
    model_class = PersonRelationship

    async def drop_all_chiefs(self):
        query = (
            person_relationship_table
            .delete()
            .where(person_relationship_table.c.role == PersonRole.chief)
        )
        await self.conn.execute(query)

    async def bulk_create_chiefs(self, chiefs_values: list[tuple[int, int, bool]]):
        if not chiefs_values:
            return

        for chunk in paginate(chiefs_values, by=20000):
            values = [
                {
                    'owner_id': chief_id,
                    'dependant_id': dependant_id,
                    'is_direct': is_direct,
                    'role': PersonRole.chief,
                }
                for chief_id, dependant_id, is_direct in chunk
            ]
            await self.conn.execute(person_relationship_table.insert().values(values))

    async def get_relationships(
        self,
        roles: list[PersonRole],
        person_ids: list[int],
    ) -> list[PersonRelationship]:
        query = (
            select([person_relationship_table])
            .where(person_relationship_table.c.role.in_(roles))
            .where(person_relationship_table.c.dependant_id.in_(person_ids))
            .order_by(person_relationship_table.c.owner_id)
        )
        rows = await self._fetchall(query)
        return [PersonRelationship(**row) for row in rows]


class ExtPersonGateway(DBGateway):

    table = ext_person_table
    model_class = ExtPerson
    pk = 'ext_person_id'

    async def get_ext_persons(self, person_id: int) -> list[ExtPerson]:
        """
        Сами ext_persons + их документы + их бонусные карты
        """
        query = (
            select([
                ext_person_table,
                get_ext_person_documents_column(),
                get_ext_person_bonus_cards_column(),
            ])
            .where(ext_person_table.c.person_id == person_id)
        )
        return [ExtPerson(**item) for item in await self._fetchall(query)]

    async def get_for_ihub_sync(self, ext_person_ids: list[int]) -> list[dict]:
        query = (
            select([
                ext_person_table,
                get_ext_person_passports_column(),
                get_ext_person_bonus_cards_column(),
            ])
            .where(person_table.c.person_id.in_(ext_person_ids))
            .where(person_table.c.gender.isnot(None))
        )
        data = await self._fetchall(query)
        return list(data)

    async def get_by_id(self, ext_person_id: int) -> ExtPerson:
        return await self._get_one(
            select_list=[
                ext_person_table,
                get_ext_person_documents_column(),
                get_ext_person_bonus_cards_column(),
            ],
            where_clause=ext_person_table.c.ext_person_id == ext_person_id,
        )

    async def update(self, ext_person_id: int, **fields) -> int:
        query = (
            ext_person_table
            .update()
            .where(ext_person_table.c.ext_person_id == ext_person_id)
            .values(**fields)
            .returning(ext_person_table.c.ext_person_id)
        )
        ext_person_id = await self.conn.scalar(query)
        if not ext_person_id:
            raise RecordNotFound('ExtPerson is not found')
        return ext_person_id

    async def delete_documents(self, ext_person_id: int) -> None:
        query = (
            person_document_table
            .delete()
            .where(person_document_table.c.ext_person_id == ext_person_id)
        )
        await self.conn.execute(query)

    async def delete_bonus_cards(self, ext_person_id: int) -> None:
        query = (
            person_bonus_card_table
            .delete()
            .where(person_bonus_card_table.c.ext_person_id == ext_person_id)
        )
        await self.conn.execute(query)
