import datetime
from typing import List, Optional, Union, Iterable

from sqlalchemy import or_, and_, column
from sqlalchemy.orm import joinedload, Session
from sqlalchemy.orm.query import Query

from watcher import enums
from watcher.api.schemas.shift import ShiftPatchSchema, ShiftPutSchema
from watcher.config import settings
from watcher.crud.base import (
    get_object_by_model,
    get_object_by_model_or_404,
    patch_object,
)
from watcher.db import (
    Shift, Interval, Slot,
    Schedule, Revision, Service,
)
from watcher.logic.timezone import now


def patch_shift(
    db: Session,
    db_obj: Shift,
    shift: Union[ShiftPatchSchema, ShiftPutSchema],
    author_id: int,
    recalculate_rating: Optional[bool] = False,
    commit: bool = True,
) -> Shift:
    from watcher.logic.shift import set_shift_empty, set_shift_approved, set_shift_staff
    update_data = shift.dict(exclude_unset=True)
    if 'empty' in update_data:
        set_shift_empty(db_obj, shift.empty)
    if 'staff_id' in update_data:
        if shift.staff_id != db_obj.staff_id:
            # При смене staff_id, approved взводится безусловно
            shift.approved = True
            update_data['approved'] = True
        set_shift_staff(db, db_obj, shift.staff_id, recalculate_rating=recalculate_rating)
    if 'approved' in update_data:
        set_shift_approved(db_obj, shift.approved, author_id=author_id)

    db_obj = patch_object(db=db, obj=db_obj, schema=shift, commit=commit)

    return db_obj


def query_shifts_by_ids(db: Session, shift_ids: Iterable[int], query: Optional[Query] = None) -> Query:
    if not query:
        query = db.query(Shift)

    return query.filter(Shift.id.in_(shift_ids))


def query_shift_by_id(db: Session, shift_id: int, joined: List[str], query: Optional[Query] = None) -> Query:
    query = query_shifts_by_ids(db, [shift_id], query)

    joinedload_params = [joinedload(getattr(Shift, param)) for param in joined]
    if joinedload_params:
        query = query.options(*joinedload_params)

    return query


def get_shift_by_id(db: Session, shift_id: int) -> Shift:
    return get_object_by_model(db=db, model=Shift, object_id=shift_id)


def get_shift_or_404(db: Session, shift_id: int) -> Shift:
    return get_object_by_model_or_404(db=db, model=Shift, object_id=shift_id)


def get_last_shift_by_schedule(db: Session, schedule_id: int) -> Union[Shift, None]:
    """
    Находим и возвращаем последний шифт у графика.
    Последним шифтом считаем тот, у которого в поле next не указан следующий шифт.
    Если таких шифтов несколько (например, по какой-то причине не произошло еще формирование последовательности),
    то выберем из них тот, который является самым поздним по старту, если таких несколько - берем подсмену.
    Проверяем, что последовательность не сломана - в query не должно быть более одного шифта
    Если последовательность сломана - возвращаем None
    """
    query = (
        query_shifts_without_next(db, schedule_id)
        .filter(~Shift.sub_shifts.any())
        .order_by(Shift.start.desc(), Shift.replacement_for_id.is_(None))
    )

    # если вернулось более одного объекта - последовательность уже сломана
    if query.count() > 1:
        return

    return query.first()


def get_first_shift_by_schedule(
    db: Session, schedule_id: int, query: Optional[Query] = None
) -> Optional[tuple[Shift, bool]]:
    """
    Находим и возвращаем первый шифт в последовательности среди текущих (если не задан квери).
    Если последовательность шифтов еще не задана вернём None.

    valid: bool возвращаем False  когда нашли несколько первых шифтов
    """
    if not query:
        query = query_current_shifts_by_schedule(db=db, schedule_id=schedule_id)

    result_query = (
        query
        .options(
            joinedload(Shift.prev),
            joinedload(Shift.replacement_for),
        )
        .filter(
            ~Shift.sub_shifts.any(),
            or_(
                ~Shift.prev.has(),
                ~Shift.prev.has(Shift.id.in_({t[0] for t in query.values(column('id'))}))
            )
        )
    )

    result = result_query.all()
    first_shift = None
    valid = True
    if result:
        result = sorted(result, key=lambda x: (bool(x.prev), x.start))

        first_shift = result[0]
        shifts_without_prev = [shift for shift in result if not shift.prev]
        if len(shifts_without_prev) > 1:
            # Для графика найдено несколько шифтов без prev
            valid = False
            first_shift = shifts_without_prev[0]

    if first_shift and first_shift.replacement_for_id is not None:
        first_shift = first_shift.replacement_for

    return first_shift, valid


def query_all_shifts_by_schedule(db: Session, schedule_id: int, query: Optional[Query] = None) -> Query:
    if not query:
        query = db.query(Shift)
    return query.filter(Shift.schedule_id == schedule_id).order_by(Shift.start, Shift.replacement_for_id.isnot(None))


def query_all_shifts_by_schedule_with_staff(db: Session, schedule_id: int, query: Optional[Query] = None) -> Query:
    return (
        query_all_shifts_by_schedule(db=db, schedule_id=schedule_id, query=query)
        .filter(Shift.staff_id.isnot(None))
    )


def query_all_shifts_by_service(db: Session, service_id: int, query: Optional[Query] = None) -> Query:
    if not query:
        query = db.query(Shift)

    return (
        query
        .join(Schedule, Shift.schedule_id == Schedule.id)
        .filter(Schedule.service_id == service_id)
    )


def query_main_shifts_by_slots(db: Session, slot_id: int) -> Query:
    return db.query(Shift).filter(
        Shift.slot_id == slot_id,
        Shift.replacement_for_id.is_(None),
    )


def query_schedule_main_shifts(db: Session, schedule_id: int) -> Query:
    query = query_all_shifts_by_schedule(db, schedule_id)
    return query.filter(Shift.replacement_for_id.is_(None)).order_by(Shift.start)


def query_all_approved_shifts_by_schedule(db: Session, schedule_id: int, query: Optional[Query] = None) -> Query:
    if not query:
        query = db.query(Shift)

    query_all_shifts = query_all_shifts_by_schedule(db=db, schedule_id=schedule_id, query=query)
    return query_all_shifts.filter(Shift.approved.is_(True))


def query_shifts_without_next(db: Session, schedule_id: int, query: Optional[Query] = None) -> Query:
    if not query:
        query = db.query(Shift)

    return query_all_shifts_by_schedule(db, schedule_id, query).filter(Shift.next_id.is_(None))


def query_shifts_with_intervals_by_schedule(db: Session, schedule_id: int) -> Query:
    return (
        db.query(Shift, Interval)
        .join(Slot, Slot.id == Shift.slot_id)
        .join(Interval, Interval.id == Slot.interval_id)
        .filter(Shift.schedule_id == schedule_id)
        .order_by(Shift.start)
    )


def query_need_approve_shifts_by_schedules(db: Session, schedules_ids: List[int]) -> Query:
    recently_disapproved_threshold = now() - datetime.timedelta(
        seconds=settings.DONT_APPROVE_DISAPPROVED_SHIFTS_FOR)
    return (
        db.query(Shift)
        .join(Schedule, Schedule.id == Shift.schedule_id)
        .filter(
            Schedule.id.in_(schedules_ids),
            ~Shift.approved, Shift.start < now() + Schedule.pin_shifts,
            or_(
                Shift.approved_removed_at.is_(None),
                Shift.approved_removed_at <= recently_disapproved_threshold,
            )
        )
        .options(
            joinedload(Shift.sub_shifts).joinedload(Shift.schedule),
            joinedload(Shift.schedule),
        )
    )


def query_need_start_shifts(db: Session, future_time_shift: bool = True) -> Query:
    return (
        db.query(Shift.id)
        .join(Schedule, Schedule.id == Shift.schedule_id)
        .filter(
            Schedule.state == enums.ScheduleState.active,
            Shift.status == enums.ShiftStatus.scheduled,
            Shift.start < now() + settings.SHIFT_START_TIMEDELTA * (1 if future_time_shift else -1),
            Shift.end > now(),
            ~Shift.sub_shifts.any(),
        )
    )


def query_need_finish_shifts(db: Session) -> Query:
    return (
        db.query(Shift.id)
        .filter(
            Shift.status == enums.ShiftStatus.active,
            Shift.end < now() - settings.SHIFT_FINISH_TIMEDELTA,
        )
    )


def query_current_shifts_by_schedule(db: Session, schedule_id: int, query: Optional[Query] = None) -> Query:
    if not query:
        query = db.query(Shift)

    return (
        query_all_shifts_by_schedule(db=db, schedule_id=schedule_id, query=query)
        .filter(Shift.status == enums.ShiftStatus.active)
    )


def query_current_shifts_by_service(db: Session, service_id: int, query: Optional[Query] = None) -> Query:
    if not query:
        query = db.query(Shift)

    return (
        query_all_shifts_by_service(db=db, service_id=service_id, query=query)
        .filter(Shift.status == enums.ShiftStatus.active)
    )


def query_not_completed_shifts_by_members(db: Session, composition_id: int, staff_ids: Iterable[int]) -> Query:
    return (
        db.query(Shift)
        .options(joinedload(Shift.schedule))
        .join(Schedule, Schedule.id == Shift.schedule_id)
        .join(Interval, Interval.schedule_id == Schedule.id)
        .join(Slot, Slot.interval_id == Interval.id)
        .filter(Slot.composition_id == composition_id)
        .filter(Shift.status.in_((enums.ShiftStatus.active, enums.ShiftStatus.scheduled)))
        .filter(Shift.staff_id.in_(staff_ids))
    )


def query_shifts_by_revision(db: Session, revision_id: int, query: Optional[Query] = None) -> Query:
    if not query:
        query = db.query(Shift)

    return (
        query
        .join(Slot, Slot.id == Shift.slot_id)
        .join(Interval, Interval.id == Slot.interval_id)
        .join(Revision, Revision.id == Interval.revision_id)
        .filter(Revision.id == revision_id)
    )


def query_shifts_by_schedule_in_interval(
    db: Session, schedule_id: int, start: datetime.datetime, end: datetime.datetime,
    query: Optional[Query] = None
) -> Query:
    if not query:
        query = db.query(Shift)
    return (
        query_all_shifts_by_schedule(db=db, schedule_id=schedule_id, query=query)
        .filter(start < Shift.end, end > Shift.start)
        .options(joinedload(Shift.prev), joinedload(Shift.next))
    )


def get_shifts_with_same_role(db: Session, shift: Shift) -> List[Shift]:
    """
    Находит все смены в сервисе с такой же ролью и дежурным, как у переданной,
    либо идущие параллельно, либо начинающиеся впритык после переданной
    """
    return (
        db.query(Shift.id)
        .join(Slot, Slot.id == Shift.slot_id)
        .join(Schedule, Schedule.id == Shift.schedule_id)
        .join(Service, Service.id == Schedule.service_id)
        .filter(
            Service.id==shift.schedule.service_id,
            Shift.staff_id==shift.staff_id,
            Slot.role_on_duty_id==shift.slot.role_on_duty_id,
            Shift.id!=shift.id,
            or_(
                Shift.start==shift.end,  # следующая смена у того же человека
                and_(  # параллельная смена
                    Shift.status == enums.ShiftStatus.active,
                    Shift.end > shift.end,
                ),
                and_(  # подсмена (такая ситуация может быть только при редактировании подсмен)
                    Shift.replacement_for_id == shift.id,
                    Shift.status == enums.ShiftStatus.scheduled,
                    Shift.end > now(),
                    Shift.start <= now(),
                )
            )
        )
    ).all()


def get_simultaneous_shifts(db: Session, shift: Shift, target_is_primary: bool) -> list[Shift]:
    return db.query(Shift).join(
        Slot, Slot.id == Shift.slot_id,
    ).filter(
        Shift.id != shift.id,
        Slot.interval == shift.slot.interval,
        Shift.start == shift.start, Shift.end == shift.end,
        Shift.is_primary.is_(target_is_primary),
    ).all()


def get_shifts_from_parallel_slots(db: Session, main_shift: Shift):
    return (
        db.query(Shift)
        .options(
            joinedload(Shift.prev),
            joinedload(Shift.next),
        )
        .filter(
            Shift.schedule_id == main_shift.schedule_id,
            Shift.id != main_shift.id,
            Shift.replacement_for_id != main_shift.id,
            Shift.start >= main_shift.start, Shift.end <= main_shift.end,
        )
    ).all()
