from sqlalchemy.orm import Session
from sqlalchemy.orm.query import Query
from typing import List, Iterable, Optional

from watcher import enums
from watcher.crud.revision import get_current_revision
from watcher.db import Interval, Revision, Schedule, Slot


def query_intervals_by_revision(db: Session, revision_id: int, query: Optional[Query] = None) -> Query:
    if not query:
        query = db.query(Interval)

    return query.filter(Interval.revision_id == revision_id).order_by(Interval.order)


def query_intervals_with_slots_by_revision(db: Session, revision_id: int) -> Query:
    return (
        db.query(Interval, Slot)
        .outerjoin(Slot, Slot.interval_id == Interval.id)
        .filter(Interval.revision_id == revision_id)
        .order_by(Interval.order)
    )


def query_intervals_by_schedule(db: Session, schedule_id: int, query: Optional[Query] = None) -> Query:
    if not query:
        query = db.query(Interval)

    return query.filter(Interval.schedule_id == schedule_id)


def get_intervals_by_ids(db: Session, interval_ids: Iterable[int]) -> List[Interval]:
    return db.query(Interval).filter(
        Interval.id.in_(interval_ids)
    ).all()


def query_intervals_by_current_revision(db: Session, schedule_id: int) -> Query:
    current_revision = get_current_revision(db=db, schedule_id=schedule_id)
    return query_intervals_by_revision(db=db, revision_id=current_revision.id)


def query_interval_by_schedules(db: Session, schedule_ids: List[int], query: Optional[Query] = None) -> Query:
    if not query:
        query = db.query(Interval)

    return (
        query.join(Revision, Revision.id == Interval.revision_id)
        .filter(Interval.schedule_id.in_(schedule_ids), Revision.state == enums.RevisionState.active)
    )


def query_intervals_by_service(db: Session, service_id: int, query: Optional[Query] = None) -> Query:
    if not query:
        query = db.query(Interval)

    return (
        query
        .join(Schedule, Interval.schedule_id == Schedule.id)
        .filter(Schedule.service_id == service_id)
    )
