import logging
from datetime import datetime, timedelta
from typing import Optional

from sqlalchemy import select, func, or_, and_, exists
from sqlalchemy.dialects.postgresql import insert

from intranet.trip.src.api.schemas import TripFilter
from intranet.trip.src.config import settings
from intranet.trip.src.db.tables import (
    trip_table,
    trip_purpose_table,
    company_table,
    conf_details_table,
    person_relationship_table,
    person_table,
    person_trip_table,
    person_conf_details_table,
    person_document_table,
    person_trip_purpose_table,
    person_trip_document_table,
    travel_details_table,
    trip_route_point_table,
)
from intranet.trip.src.models import Trip
from intranet.trip.src.models.person import User
from intranet.trip.src.db.gateways.base import DBGateway, RecordNotFound
from intranet.trip.src.db.subqueries.trip import (
    get_conf_details_column,
    get_person_trips_column,
    get_person_trips_with_details_column,
    get_person_with_company_column,
    get_trip_purposes_column,
    get_trip_route_points_column,
)
from intranet.trip.src.db.subqueries.common import get_person_column

from intranet.trip.src.enums import TripStatus, PTStatus, PersonRole, RelationType


logger = logging.getLogger(__name__)


class TripGateway(DBGateway):

    table = trip_table
    pk = 'trip_id'
    model_class = Trip

    async def get_list_and_count(self, fltr: TripFilter, user: User) -> tuple[list[Trip], int]:
        person_id = user.person_id if not user.is_coordinator and user.is_limited_access else None
        query = (
            select([
                trip_table,
                get_trip_purposes_column(),
                get_person_trips_with_details_column(person_id=person_id),
                get_person_column(correlate_table=trip_table, label='author', person_id=person_id),
                get_conf_details_column(),
                get_trip_route_points_column(),
            ])
            .select_from(trip_table.outerjoin(conf_details_table))
            .order_by(trip_table.c.trip_id.desc())
        )

        if (
            not user.is_coordinator
            or fltr.relation_type == RelationType.all_subordinates
            or fltr.relation_type == RelationType.direct_subordinates
        ):
            subordinate_ids = (
                select([person_relationship_table.c.dependant_id])
                .where(person_relationship_table.c.owner_id == user.person_id)
                .where(person_relationship_table.c.role == PersonRole.chief)
            )
            if fltr.relation_type == RelationType.direct_subordinates:
                subordinate_ids = subordinate_ids.where(
                    person_relationship_table.c.is_direct.is_(True)
                )
            trip_ids = (
                select([person_trip_table.c.trip_id])
                .where(
                    or_(
                        person_trip_table.c.person_id == user.person_id,
                        person_trip_table.c.person_id.in_(subordinate_ids),
                    )
                )
            )
            query = query.where(
                or_(
                    trip_table.c.trip_id.in_(trip_ids),
                    trip_table.c.author_id == user.person_id,
                )
            )

        if fltr.holding_id is not None:
            trip_ids = (
                select([person_trip_table.c.trip_id])
                .select_from(
                    trip_table
                    .join(
                        person_trip_table,
                        trip_table.c.trip_id == person_trip_table.c.trip_id,
                    )
                    .join(
                        person_table,
                        person_trip_table.c.person_id == person_table.c.person_id,
                    )
                    .join(
                        company_table,
                        person_table.c.company_id == company_table.c.company_id,
                    )
                )
                .where(company_table.c.holding_id == fltr.holding_id)
            )
            query = query.where(trip_table.c.trip_id.in_(trip_ids))

        if (
            fltr.person_id is not None
            or fltr.manager_id is not None
            or fltr.relation_type == RelationType.personal
        ):
            trip_ids = select([person_trip_table.c.trip_id])
            if fltr.person_id is not None:
                trip_ids = trip_ids.where(person_trip_table.c.person_id == fltr.person_id)
            if fltr.manager_id is not None:
                trip_ids = trip_ids.where(person_trip_table.c.manager_id == fltr.manager_id)
            if fltr.relation_type == RelationType.personal:
                trip_ids = trip_ids.where(person_trip_table.c.person_id == user.person_id)
            query = query.where(trip_table.c.trip_id.in_(trip_ids))

        if fltr.date_from__gte is not None:
            query = query.where(func.date(trip_table.c.date_from) >= fltr.date_from__gte)

        if fltr.date_from__lte is not None:
            query = query.where(func.date(trip_table.c.date_from) <= fltr.date_from__lte)

        if fltr.provider_city_to_id is not None:
            city_filtration = [
                trip_route_point_table.c.provider_city_id == fltr.provider_city_id,
            ]
            try:
                city_int_id = int(fltr.provider_city_id)
            except ValueError:
                pass
            else:
                city_filtration.append(
                    trip_route_point_table.c.aeroclub_city_id == city_int_id
                )
            route_point_subquery = (
                select([
                    trip_route_point_table.c.point_id
                ])
                .where(
                    and_(
                        or_(*city_filtration),
                        trip_route_point_table.c.trip_id == trip_table.c.trip_id,
                        trip_route_point_table.c.is_start_city.is_(False),
                    )
                )
            )
            query = query.where(
                or_(
                    *city_filtration,
                    exists(route_point_subquery),
                )
            )

        if fltr.tracker_issue is not None:
            trip_ids_for_travel_details = (
                select([travel_details_table.c.trip_id])
                .where(travel_details_table.c.tracker_issue == fltr.tracker_issue)
            )
            trip_ids_for_person_conf_details = (
                select([person_conf_details_table.c.trip_id])
                .where(person_conf_details_table.c.tracker_issue == fltr.tracker_issue)
            )
            query = query.where(
                or_(
                    trip_table.c.issue_travel == fltr.tracker_issue,
                    conf_details_table.c.tracker_issue == fltr.tracker_issue,
                    trip_table.c.trip_id.in_(trip_ids_for_travel_details),
                    trip_table.c.trip_id.in_(trip_ids_for_person_conf_details),
                )
            )

        count_query = select([func.count()]).select_from(query.subquery())

        count = await self.conn.scalar(count_query)
        if fltr.limit:
            query = query.limit(fltr.limit)
        if fltr.offset:
            query = query.offset(fltr.offset)
        rows = await self._fetchall(query)
        return [Trip(**item) for item in rows], count

    async def _get_one_by_id(self, trip_id: int, select_list: list) -> dict:
        query = (
            select(select_list)
            .where(trip_table.c.trip_id == trip_id)
        )
        trip_record = await self._first(query)
        if not trip_record:
            msg = 'Trip with id: {} does not exists'
            raise RecordNotFound(msg.format(trip_id))
        return trip_record

    async def get_trip(self, trip_id: int) -> Trip:
        return await self._get_one(
            select_list=[
                trip_table,
                get_trip_route_points_column(),
            ],
            where_clause=trip_table.c.trip_id == trip_id,
        )

    async def get_detailed_trip(
        self,
        trip_id: int,
        person_id: int = None,
        manager_id: int = None,
    ) -> Trip:
        """
        Get Trip entity with all nested entities (person-trips, persons, companies etc.) filtered by
            person_id or manager_id
        if they both (person_id, manager_id) are set at the same call, then use filter:
            person_trip.person_id == person_id OR person_trip.manager_id == manager_id

        :param trip_id: required value
        :param person_id: if not None then filter person_trip by person_trip.person_id
        :param manager_id: if not None then filter person_trip by person_trip.manager_id
        :return:
        """
        return await self._get_one(
            select_list=[
                trip_table,
                get_trip_purposes_column(),
                get_person_trips_with_details_column(person_id=person_id, manager_id=manager_id),
                get_person_with_company_column(correlate_table=trip_table, label='author'),
                get_conf_details_column(),
                get_trip_route_points_column(),
            ],
            where_clause=trip_table.c.trip_id == trip_id,
        )

    async def get_holding_ids(self, trip_id: int) -> list[int]:
        """
        Получить все holding_id из person_trip, или None (если их нет),
        """
        query = (
            select([company_table.c.holding_id])
            .select_from(
                trip_table
                .join(
                    person_trip_table,
                    trip_table.c.trip_id == person_trip_table.c.trip_id,
                )
                .join(
                    person_table,
                    person_trip_table.c.person_id == person_table.c.person_id,
                )
                .join(
                    company_table,
                    person_table.c.company_id == company_table.c.company_id,
                )
            )
            .where(trip_table.c.trip_id == trip_id)
        )
        rows = await self._fetchall(query)
        return [item['holding_id'] for item in rows]

    async def get_author_id(self, trip_id: int) -> int:
        query = (
            select([trip_table.c.author_id])
            .where(trip_table.c.trip_id == trip_id)
        )
        return await self.conn.scalar(query)

    async def get_trip_ids_for_staff_push(self) -> list[int]:
        query = (
            select([
                trip_table.c.trip_id,
            ])
            .where(trip_table.c.staff_trip_uuid.is_(None))
            .where(trip_table.c.status == TripStatus.new)
        )
        rows = await self._fetchall(query)
        return [item['trip_id'] for item in rows]

    async def get_trip_for_staff_push(self, trip_id: int) -> Optional[Trip]:
        """
        Отдаем командировки, которые нужно создать/отредактировать на Стаффе
        """
        query = (
            select([
                trip_table,
                get_trip_purposes_column(),
                get_person_trips_column(),
                get_person_column(correlate_table=trip_table, label='author'),
                get_conf_details_column(),
                get_trip_route_points_column(),
            ])
            .where(trip_table.c.trip_id == trip_id)
        )

        row = await self._first(query.with_for_update(of=trip_table))
        if row is None:
            return None
        return Trip(**row)

    async def get_trips_for_staff_push_count(self) -> int:
        query = (
            select([func.count(trip_table.c.trip_id)])
            .where(trip_table.c.staff_trip_uuid.is_(None))
            .where(trip_table.c.status == TripStatus.new)
        )
        return await self.conn.scalar(query)

    async def get_id_by_trip_uuid(self, staff_trip_uuid: str) -> int:
        query = (
            select([trip_table.c.trip_id])
            .where(trip_table.c.staff_trip_uuid == staff_trip_uuid)
        )
        return await self.conn.scalar(query)

    async def update(self, trip_id: int, **fields) -> int:
        query = (
            trip_table
            .update()
            .where(trip_table.c.trip_id == trip_id)
            .values(**fields)
            .returning(trip_table.c.trip_id)
        )
        trip_id = await self.conn.scalar(query)
        if not trip_id:
            raise RecordNotFound('Trip is not found')
        return trip_id

    async def add_purposes(self, trip_id: int, purpose_ids: list[int]) -> None:
        values = [{'purpose_id': purpose_id, 'trip_id': trip_id} for purpose_id in purpose_ids]
        await self.conn.execute(trip_purpose_table.insert().values(values))

    async def clean_purposes(self, trip_id: int):
        query = (
            trip_purpose_table.delete()
            .where(trip_purpose_table.c.trip_id == trip_id)
        )
        await self.conn.execute(query)

    async def create_conf_details(self, trip_id: int, **fields) -> None:
        fields['trip_id'] = trip_id
        await self.conn.execute(conf_details_table.insert().values(**fields))

    async def update_conf_details(self, trip_id: int, **fields) -> None:
        query = (
            conf_details_table.update()
            .where(conf_details_table.c.trip_id == trip_id)
            .values(**fields)
        )
        await self.conn.execute(query)

    async def create_route_points(self, route_points: list[dict]) -> None:
        await self.conn.execute(trip_route_point_table.insert().values(route_points))

    async def delete(self, trip_id: int):
        related_tables = (
            person_conf_details_table,
            travel_details_table,
            person_trip_purpose_table,
            person_trip_document_table,
            person_trip_table,
            conf_details_table,
            trip_purpose_table,
            trip_table,
        )
        for table in related_tables:
            query = table.delete().where(table.c.trip_id == trip_id)
            await self.conn.execute(query)

    async def add_document(self, trip_id: int, document_id: int, person_id: int) -> None:
        trip_id = await self.conn.scalar(
            select([trip_table.c.trip_id])
            .where(trip_table.c.trip_id == trip_id)
        )
        if not trip_id:
            raise RecordNotFound

        person_id = await self.conn.scalar(
            select([person_document_table.c.person_id])
            .where(person_document_table.c.document_id == document_id)
            .where(person_document_table.c.person_id == person_id)
        )
        if not person_id:
            raise RecordNotFound

        query = (
            insert(person_trip_document_table)
            .values(
                trip_id=trip_id,
                document_id=document_id,
                person_id=person_id,
            )
            .on_conflict_do_nothing(constraint=person_trip_document_table.primary_key)
        )
        await self.conn.execute(query)

    async def remove_document(self, trip_id: int, document_id: int, person_id: int) -> None:
        query = (
            person_trip_document_table
            .delete()
            .where(person_trip_document_table.c.trip_id == trip_id)
            .where(person_trip_document_table.c.document_id == document_id)
            .where(person_trip_document_table.c.person_id == person_id)
        )
        await self.conn.execute(query)

    async def close_completed_trips(self):
        """
        Закрываем групповые команировки, которые закончились.
        Закончившейся командировкой считаем ту, у которой завершились все персональные,
        либо, если нет неотмененных персональных, то закрываем, если прошло 3 дня с даты завершения.
        """
        open_statuses = [PTStatus.draft, PTStatus.new, PTStatus.executing, PTStatus.executed]
        threshold_date = datetime.now() - timedelta(days=settings.DAYS_BEFORE_CLOSING_TRIPS)
        query = (
            trip_table
            .update()
            .where(trip_table.c.status == TripStatus.new)
            .where(
                or_(
                    # Trip, у которых нет активных PersonTrip, но есть закрытые
                    and_(
                        trip_table.c.trip_id.notin_(
                            select([person_trip_table.c.trip_id])
                            .where(person_trip_table.c.status.in_(open_statuses))
                        ),
                        trip_table.c.trip_id.in_(
                            select([person_trip_table.c.trip_id])
                            .where(person_trip_table.c.status == PTStatus.closed)
                        ),
                    ),
                    # Trip, у которых есть только отмененные, или вообще нет PersonTrip
                    # и прошло 3 дня с даты окончания
                    and_(
                        trip_table.c.trip_id.notin_(
                            select([person_trip_table.c.trip_id])
                            .where(person_trip_table.c.status != PTStatus.cancelled)
                        ),
                        trip_table.c.date_to <= threshold_date,
                    ),
                )
            )
            .values(status=TripStatus.closed)
            .returning(trip_table.c.trip_id)
        )
        return await self._fetchall(query)
