import logging
from collections import defaultdict
from typing import Optional
from datetime import datetime, timedelta
from sqlalchemy import select, and_, or_, func, extract
from sqlalchemy.dialects.postgresql import insert

from intranet.trip.src.config import settings
from intranet.trip.src.db.tables import (
    trip_table,
    conf_details_table,
    person_table,
    person_trip_table,
    person_conf_details_table,
    person_trip_purpose_table,
    person_trip_document_table,
    travel_details_table,
    person_trip_route_point_table,
)
from intranet.trip.src.enums import PTStatus, Provider
from intranet.trip.src.models import PersonTrip
from intranet.trip.src.db.gateways.base import DBGateway, RecordNotFound
from intranet.trip.src.db.subqueries.trip import (
    get_person_trip_purposes_column,
    get_person_conf_details_column,
    get_travel_details_column,
    get_documents_column,
    get_services_column,
    get_person_trip_route_points_column,
)
from intranet.trip.src.db.subqueries.common import get_person_column, get_person_with_company_column


logger = logging.getLogger(__name__)


class PersonTripGateway(DBGateway):

    table = person_trip_table
    model_class = PersonTrip

    async def get_list(self, trip_id: int) -> list[PersonTrip]:
        query = (
            select([
                person_trip_table,
                get_person_column(correlate_table=person_trip_table),
            ])
            .where(person_trip_table.c.trip_id == trip_id)
        )
        rows = await self._fetchall(query)
        return [PersonTrip(**row) for row in rows]

    async def get_person_trips_for_staff_pull(self) -> list[dict]:
        """
        Отдаем персональные командировки, для которых нужно притянуть данные со Стаффа
        """
        sync_statuses = [PTStatus.draft, PTStatus.new, PTStatus.executing]
        query = (
            select([
                person_trip_table.c.trip_id,
                person_trip_table.c.person_id,
                person_trip_table.c.is_approved,
                person_trip_table.c.is_offline,
                person_trip_table.c.gap_date_from,
                person_trip_table.c.gap_date_to,
                person_trip_table.c.provider,
                person_table.c.login,
                trip_table.c.staff_trip_uuid,
                trip_table.c.issue_travel.label('group_travel_issue'),
                conf_details_table.c.tracker_issue.label('group_conf_issue'),
                travel_details_table.c.tracker_issue.label('travel_issue'),
                person_conf_details_table.c.tracker_issue.label('conf_issue'),
            ])
            .select_from(
                person_trip_table
                .outerjoin(
                    person_table,
                    onclause=person_trip_table.c.person_id == person_table.c.person_id,
                )
                .outerjoin(
                    trip_table,
                    onclause=person_trip_table.c.trip_id == trip_table.c.trip_id,
                )
                .outerjoin(
                    conf_details_table,
                    onclause=conf_details_table.c.trip_id == trip_table.c.trip_id,
                )
                .outerjoin(
                    travel_details_table,
                    onclause=and_(
                        person_trip_table.c.trip_id == travel_details_table.c.trip_id,
                        person_trip_table.c.person_id == travel_details_table.c.person_id,
                    )
                )
                .outerjoin(
                    person_conf_details_table,
                    onclause=and_(
                        person_trip_table.c.trip_id == person_conf_details_table.c.trip_id,
                        person_trip_table.c.person_id == person_conf_details_table.c.person_id,
                    )
                )
            )
            .where(trip_table.c.staff_trip_uuid.isnot(None))
            .where(person_trip_table.c.status.in_(sync_statuses))
            .where(
                or_(
                    person_trip_table.c.is_approved.is_(False),
                    and_(
                        travel_details_table.c.tracker_issue.is_(None),
                        person_conf_details_table.c.tracker_issue.is_(None),
                    ),
                )
            )
            .order_by(trip_table.c.trip_id)
        )
        rows = await self._fetchall(query)
        return [dict(row) for row in rows]

    async def get_person_trip_ids(
            self,
            status: Optional[PTStatus] = None,
            provider: Optional[Provider] = None,
    ) -> list[tuple[int, int]]:
        query = (
            select([
                person_trip_table.c.trip_id,
                person_trip_table.c.person_id,
            ])
            .order_by('trip_id', 'person_id')
        )
        if status is not None:
            query = query.where(person_trip_table.c.status == status)
        if provider is not None:
            query = query.where(person_trip_table.c.provider == provider)
        rows = await self._fetchall(query)
        return [(row['trip_id'], row['person_id']) for row in rows]

    def _get_base_query_for_aeroclub_create(self, select_list: list):
        return (
            select(select_list)
            .select_from(
                person_trip_table
                .outerjoin(trip_table)
            )
            .where(person_trip_table.c.provider == Provider.aeroclub)
            .where(person_trip_table.c.aeroclub_trip_id.is_(None))
            .where(person_trip_table.c.is_offline.is_(False))
            .where(trip_table.c.provider_city_to_id.isnot(None))
            .where(person_trip_table.c.status.not_in([PTStatus.cancelled, PTStatus.closed]))
        )

    async def get_person_trip_ids_for_aeroclub_create(
            self,
            trip_id: int = None,
            person_ids: list[int] = None,
    ) -> list[tuple[int, int]]:
        select_list = [
            person_trip_table.c.trip_id,
            person_trip_table.c.person_id,
        ]
        query = self._get_base_query_for_aeroclub_create(select_list)

        if trip_id is not None:
            query = query.where(person_trip_table.c.trip_id == trip_id)
        if person_ids is not None:
            query = query.where(person_trip_table.c.person_id.in_(person_ids))
        rows = await self._fetchall(query)
        return [(row['trip_id'], row['person_id']) for row in rows]

    async def get_person_trip_for_aeroclub_create(
            self,
            trip_id: int,
            person_id: int,
    ) -> Optional[PersonTrip]:
        select_list = [
            person_trip_table,
            get_person_with_company_column(correlate_table=person_trip_table),
            trip_table.c.provider_city_to_id,
        ]
        query = (
            self._get_base_query_for_aeroclub_create(select_list)
            .where(person_trip_table.c.trip_id == trip_id)
            .where(person_trip_table.c.person_id == person_id)
            .with_for_update(of=person_trip_table)
        )
        row = await self._first(query)
        if row is None:
            return None
        return PersonTrip(**row)

    async def get_person_trips_for_aeroclub_create_count(self) -> int:
        select_list = [func.count(person_trip_table.c.trip_id)]
        query = self._get_base_query_for_aeroclub_create(select_list)
        return await self.conn.scalar(query)

    def _filter_not_authorized_person_trips(self, query):
        return (
            query
            .where(person_trip_table.c.provider == Provider.aeroclub)
            .where(person_trip_table.c.aeroclub_trip_id.isnot(None))
            .where(person_trip_table.c.is_authorized.is_(False))
            .where(person_trip_table.c.status.not_in([PTStatus.cancelled, PTStatus.closed]))
        )

    async def get_not_authorized_person_trips(
            self,
            trip_id: int = None,
            person_id: int = None,
    ) -> list[PersonTrip]:
        query = (
            select([
                person_trip_table,
                get_person_column(correlate_table=person_trip_table),
                get_person_trip_purposes_column(),
                get_travel_details_column(),
                get_person_conf_details_column(),
            ])
            .select_from(
                person_trip_table
                .outerjoin(travel_details_table)
                .outerjoin(person_conf_details_table)
            )
        )
        query = self._filter_not_authorized_person_trips(query)
        if trip_id is not None:
            query = query.where(person_trip_table.c.trip_id == trip_id)
        if person_id is not None:
            query = query.where(person_trip_table.c.person_id == person_id)
        rows = await self._fetchall(query)
        return [PersonTrip(**row) for row in rows]

    async def get_not_authorized_person_trips_count(self) -> int:
        query = select([func.count(person_trip_table.c.trip_id)])
        query = self._filter_not_authorized_person_trips(query)
        return await self.conn.scalar(query)

    async def get_person_trips_without_chat_id(self) -> list[PersonTrip]:
        """
        Отдаем командировки, для которых нужно создать чат
        """
        query = (
            select([
                person_trip_table,
                get_person_column(correlate_table=person_trip_table),
            ])
            .where(person_trip_table.c.chat_id.is_(None))
        )
        rows = await self._fetchall(query)

        return [PersonTrip(**row) for row in rows]

    async def get_person_trips_without_chat_id_count(self) -> int:
        query = (
            select([func.count(person_trip_table.c.trip_id)])
            .where(person_trip_table.c.chat_id.is_(None))
        )
        return await self.conn.scalar(query)

    def _get_where_by_pk(self, trip_id: int, person_id: int):
        return and_(
            person_trip_table.c.trip_id == trip_id,
            person_trip_table.c.person_id == person_id,
        )

    async def get_person_trip(self, trip_id: int, person_id: int) -> PersonTrip:
        return await self._get_one(
            select_list=[
                person_trip_table,
                get_person_column(correlate_table=person_trip_table),
            ],
            where_clause=self._get_where_by_pk(trip_id, person_id),
        )

    async def get_person_trip_by_aeroclub_id(
            self,
            aeroclub_journey_id: int,
            aeroclub_trip_id: int,
    ) -> PersonTrip:
        return await self._get_one(
            select_list=[
                person_trip_table,
                get_person_column(correlate_table=person_trip_table),
                get_travel_details_column(),
                trip_table.c.author_id.label('trip_author_id'),
            ],
            where_clause=and_(
                person_trip_table.c.aeroclub_journey_id == aeroclub_journey_id,
                person_trip_table.c.aeroclub_trip_id == aeroclub_trip_id,
            ),
            select_from=(
                person_trip_table
                .join(trip_table)
            ),
        )

    async def get_person_trip_by_chat_id(self, chat_id: str) -> PersonTrip:
        return await self._get_one(
            select_list=[
                person_trip_table,
                get_person_column(correlate_table=person_trip_table),
            ],
            where_clause=person_trip_table.c.chat_id == chat_id,
        )

    async def get_detailed_person_trip(self, trip_id: int, person_id: int) -> PersonTrip:
        return await self._get_one(
            select_list=[
                person_trip_table,
                get_person_with_company_column(correlate_table=person_trip_table),
                get_person_trip_purposes_column(),
                get_person_conf_details_column(),
                get_travel_details_column(),
                get_documents_column(),
                get_services_column(),
                get_person_trip_route_points_column(),
                get_person_column(correlate_table=person_trip_table, label='manager'),
                trip_table.c.author_id.label('trip_author_id'),
                trip_table.c.city_to,
                trip_table.c.country_to,
                trip_table.c.provider_city_to_id,
            ],
            where_clause=self._get_where_by_pk(trip_id, person_id),
            select_from=(
                person_trip_table
                .join(trip_table)
            ),
        )

    async def get_person_trips_with_purpose_count(
            self,
            person_id: int,
            purpose_id: int,
            month: int,
    ) -> int:
        query = (
            select([func.count(person_trip_table.c.trip_id)])
            .select_from(
                person_trip_table
                .outerjoin(person_trip_purpose_table)
            )
            .where(person_trip_purpose_table.c.purpose_id == purpose_id)
            .where(person_trip_table.c.person_id == person_id)
            .where(person_trip_table.c.status != PTStatus.cancelled)
            .where(extract('month', person_trip_table.c.gap_date_from) == month)
        )
        return await self.conn.scalar(query)

    async def create(self, trip_id: int, person_id: int, **fields):
        fields['trip_id'] = trip_id
        fields['person_id'] = person_id

        query = (
            insert(person_trip_table)
            .values(**fields)
            .on_conflict_do_update(
                constraint=person_trip_table.primary_key,
                set_=fields,
            )
            .returning(person_trip_table.c.trip_id)
        )
        return await self.conn.scalar(query)

    async def bulk_create(self, trip_id: int, values: list[dict]) -> None:
        if not values:
            return
        values = [{'trip_id': trip_id, **item} for item in values]
        await self.conn.execute(person_trip_table.insert().values(values))

    async def update(self, trip_id: int, person_id: int, **fields):
        query = (
            person_trip_table
            .update()
            .where(self._get_where_by_pk(trip_id, person_id))
            .values(**fields)
        )
        await self.conn.execute(query)

    async def get_person_trip_by_issue(self, issue_key: str) -> PersonTrip:
        query = (
            select([
                person_trip_table,
                get_person_column(correlate_table=person_trip_table),
            ])
            .select_from(
                person_trip_table
                .outerjoin(travel_details_table)
                .outerjoin(person_conf_details_table)
            )
            .where(
                or_(
                    travel_details_table.c.tracker_issue == issue_key,
                    person_conf_details_table.c.tracker_issue == issue_key,
                )
            )
        )
        pt_record = await self._first(query)
        if not pt_record:
            msg = 'PersonTrip for issue: {} does not exists'
            raise RecordNotFound(msg.format(issue_key))
        return PersonTrip(**pt_record)

    async def bulk_update(self, trip_id: int, **fields):
        query = (
            person_trip_table
            .update()
            .where(person_trip_table.c.trip_id == trip_id)
            .values(**fields)
        )
        await self.conn.execute(query)

    async def update_conf_details(self, trip_id: int, person_id: int, **fields):
        fields['trip_id'] = trip_id
        fields['person_id'] = person_id
        query = (
            insert(person_conf_details_table)
            .values(**fields)
            .on_conflict_do_update(
                constraint=person_conf_details_table.primary_key,
                set_=fields,
            )
            .returning(person_conf_details_table.c.trip_id)
        )
        return await self.conn.scalar(query)

    async def update_travel_details(self, trip_id: int, person_id: int, **fields):
        fields['trip_id'] = trip_id
        fields['person_id'] = person_id
        query = (
            insert(travel_details_table)
            .values(**fields)
            .on_conflict_do_update(
                constraint=travel_details_table.primary_key,
                set_=fields,
            )
            .returning(travel_details_table.c.trip_id)
        )
        return await self.conn.scalar(query)

    async def add_purposes(self, trip_id: int, person_id: int, purpose_ids: list[int]) -> None:
        values = [
            {
                'purpose_id': purpose_id,
                'trip_id': trip_id,
                'person_id': person_id,
            }
            for purpose_id in purpose_ids
        ]
        await self.conn.execute(person_trip_purpose_table.insert().values(values))

    async def clean_purposes(self, trip_id: int, person_id: int = None):
        query = (
            person_trip_purpose_table.delete()
            .where(person_trip_purpose_table.c.trip_id == trip_id)
        )
        if person_id:
            query = query.where(person_trip_purpose_table.c.person_id == person_id)
        await self.conn.execute(query)

    async def update_purposes(self, trip_id: int, person_id: int, purpose_ids: list[int]) -> None:
        await self.clean_purposes(trip_id, person_id)
        if purpose_ids:
            await self.add_purposes(trip_id, person_id, purpose_ids)

    async def add_documents(
            self,
            trip_id: int,
            person_id: int,
            document_ids: list[int]
    ) -> None:
        values = [
            {
                'document_id': document_id,
                'trip_id': trip_id,
                'person_id': person_id,
            }
            for document_id in document_ids
        ]
        await self.conn.execute(person_trip_document_table.insert().values(values))

    async def clean_documents(self, trip_id: int, person_id: int):
        query = (
            person_trip_document_table
            .delete()
            .where(person_trip_document_table.c.trip_id == trip_id)
            .where(person_trip_document_table.c.person_id == person_id)
        )
        await self.conn.execute(query)

    async def update_person_trip_documents(self, trip_id, person_id, document_ids):
        await self.clean_documents(trip_id, person_id)
        if document_ids:
            await self.add_documents(trip_id, person_id, document_ids)

    async def delete(self, trip_id: int, person_id: int):
        related_tables = (
            person_conf_details_table,
            travel_details_table,
            person_trip_purpose_table,
            person_trip_document_table,
            person_trip_route_point_table,
            person_trip_table,
        )
        for table in related_tables:
            query = (
                table
                .delete()
                .where(table.c.trip_id == trip_id)
                .where(table.c.person_id == person_id)
            )
            await self.conn.execute(query)

    async def get_count_group_by_status(self) -> list[tuple[str, int]]:
        query = (
            select([
                person_trip_table.c.status,
                func.count().label('count_person_trip'),
            ])
            .group_by(person_trip_table.c.status)
        )
        rows = await self._fetchall(query)
        return [(row['status'].value, row['count_person_trip']) for row in rows]

    async def close_completed_person_trips(self):
        """
        Закрываем персональные команировки, которые закончились.
        Закончившейся командировкой считаем ту, которая в статусе executed и когда прошло 3 дня
        с даты ее завершения.
        """
        threshold_date = datetime.now() - timedelta(days=settings.DAYS_BEFORE_CLOSING_TRIPS)
        query = (
            person_trip_table
            .update()
            .where(person_trip_table.c.status == PTStatus.executed)
            .where(person_trip_table.c.gap_date_to <= threshold_date)
            .values(status=PTStatus.closed)
            .returning(person_trip_table.c.trip_id, person_trip_table.c.person_id)
        )
        return await self._fetchall(query)

    async def get_travel_tracker_issue(self, trip_id: int, person_id: int) -> str:
        query = (
            select([travel_details_table.c.tracker_issue])
            .where(travel_details_table.c.trip_id == trip_id)
            .where(travel_details_table.c.person_id == person_id)
        )
        return await self.conn.scalar(query)

    async def create_route_points(self, route_points: list[dict]) -> None:
        await self.conn.execute(person_trip_route_point_table.insert().values(route_points))

    async def update_route_point(
            self,
            trip_id: int,
            person_id: int,
            provider_city_id: str,
            **fields,
    ) -> None:
        query = (
            person_trip_route_point_table
            .update()
            .values(**fields)
            .where(and_(
                person_trip_route_point_table.c.trip_id == trip_id,
                person_trip_route_point_table.c.person_id == person_id,
                person_trip_route_point_table.c.provider_city_id == provider_city_id,
            ))
        )
        await self.conn.execute(query)

    async def get_manager_ids(self, trips_ids: list[int]) -> dict[int, list[int]]:
        query = (
            select([
                person_trip_table.c.trip_id,
                person_trip_table.c.manager_id,
            ])
            .where(person_trip_table.c.trip_id.in_(trips_ids))
            .distinct()
            .order_by(
                person_trip_table.c.trip_id,
                person_trip_table.c.manager_id,
            )
        )

        result = defaultdict(list)
        for item in await self._fetchall(query):
            result[item['trip_id']].append(item['manager_id'])

        return dict(result)

    async def get_manager_ids_for_persons(self, person_ids: list[int]) -> dict[int, list[int]]:
        query = (
            select([
                person_trip_table.c.person_id,
                person_trip_table.c.manager_id,
            ])
            .where(person_trip_table.c.person_id.in_(person_ids))
            .distinct()
            .order_by(
                person_trip_table.c.person_id,
                person_trip_table.c.manager_id,
            )
        )

        result = defaultdict(list)
        for item in await self._fetchall(query):
            result[item['person_id']].append(item['manager_id'])

        return dict(result)
