import datetime

from typing import Optional, Iterable
from collections import defaultdict

from sqlalchemy.orm import Session, Query

from watcher.db import Gap
from watcher import enums


def query_true_gaps(db: Session, query: Optional[Query] = None) -> Query:
    if not query:
        query = db.query(Gap)

    return query.filter(
        Gap.status == enums.GapStatus.active,
        ~Gap.work_in_absence,
    )


def query_staffs_gaps(db: Session, staff_ids: Iterable[int]) -> Query:
    return query_true_gaps(
        db=db,
        query=db.query(Gap).filter(Gap.staff_id.in_(staff_ids))
    ).order_by(Gap.start)


def get_staff_gaps_for_interval(
    db: Session, start: datetime.datetime,
    end: datetime.datetime, staff_ids: Iterable[int]
) -> dict[int, list[Gap]]:
    participants_gaps = defaultdict(list)
    query_staff_gaps = query_staffs_gaps(
        db=db, staff_ids=staff_ids,
    ).filter(start < Gap.end, end > Gap.start)

    for gap in query_staff_gaps:
        participants_gaps[gap.staff_id].append(gap)
    return participants_gaps
