from typing import Iterable
import datetime
from collections import defaultdict

from sqlalchemy.orm import Session, joinedload, Query
from sqlalchemy import func

from watcher.api.schemas.gap_settings import (
    ManualGapSettingsCreateSchema,
    ManualGapSettingsPatchSchema,
)
from watcher.db import (
    ManualGapSettings,
    ManualGapSettingsServices,
    ManualGapSettingsSchedules,
    ManualGap,
    Schedule,
)

from .base import (
    _bulk_insert_refs,
    _set_fields,
    update_many_to_many_for_field,
)


def create_gap_settings(db: Session, staff_id: int, schema: ManualGapSettingsCreateSchema) -> ManualGapSettings:
    obj = ManualGapSettings(
        staff_id=staff_id,
        title=schema.title,
        comment=schema.comment,
        start=schema.start,
        end=schema.end,
        recurrence=schema.recurrence,
        all_services=schema.all_services,
    )
    db.add(obj)
    db.flush()
    if not schema.all_services:
        if schema.services:
            _bulk_insert_refs(
                db=db, obj=obj,
                field_key='service_id',
                to_add=schema.services,
                table=ManualGapSettingsServices,
                related_field='gap_settings_id',
            )
        if schema.schedules:
            _bulk_insert_refs(
                db=db, obj=obj,
                field_key='schedule_id',
                to_add=schema.schedules,
                table=ManualGapSettingsSchedules,
                related_field='gap_settings_id',
            )
    db.commit()
    db.refresh(obj)
    return obj


def patch_gap_settings(db: Session, obj: ManualGapSettings, schema: ManualGapSettingsPatchSchema) -> tuple[ManualGapSettings, bool]:
    data = schema.dict(exclude_unset=True)
    services = data.pop('services', None)
    schedules = data.pop('schedules', None)

    has_changes = _set_fields(obj=obj, update_data=data) or schema.all_services
    has_changes |= services != {service.id for service in obj.services}
    has_changes |= schedules != {schedule.id for schedule in obj.schedules}

    if schema.all_services:
        db.query(ManualGapSettingsServices).filter(
            ManualGapSettingsServices.gap_settings_id == obj.id
        ).delete(synchronize_session=False)
        db.query(ManualGapSettingsSchedules).filter(
            ManualGapSettingsSchedules.gap_settings_id == obj.id
        ).delete(synchronize_session=False)
    else:
        to_update = []
        if services is not None:
            to_update.append((services, ManualGapSettingsServices, 'service_id'))
        if schedules is not None:
            to_update.append((schedules, ManualGapSettingsSchedules, 'schedule_id'))
        for target, table, field_key in to_update:
            current_refs = db.query(table).filter(
                table.gap_settings_id == obj.id
            ).all()
            current = {getattr(ref, field_key) for ref in current_refs}
            update_many_to_many_for_field(
                db=db, obj=obj,
                target=target, current=current,
                table=table, field_key=field_key,
                current_refs=current_refs,
                related_field='gap_settings_id',
            )
    if has_changes:
        db.commit()
        db.refresh(obj)
    return obj, has_changes


def get_staff_manual_gaps_for_schedule_group(
    db: Session,
    start: datetime.datetime,
    end: datetime.datetime,
    schedule_group_id: int,
    staff_ids: Iterable[int],
    service_ids: Iterable[int],
) -> dict[int, list[ManualGap]]:

    query_schedule_in_settings = db.query(ManualGapSettingsSchedules).join(
        Schedule, Schedule.id == ManualGapSettingsSchedules.schedule_id,
    ).filter(
        ManualGapSettingsSchedules.gap_settings_id == ManualGapSettings.id,
        Schedule.schedules_group_id == schedule_group_id
    ).exists()

    query_service_in_settings = db.query(ManualGapSettingsServices).filter(
        ManualGapSettingsServices.gap_settings_id == ManualGapSettings.id,
        ManualGapSettingsServices.service_id.in_(service_ids),
    ).exists()

    query_staff_gaps = (
        db.query(ManualGap)
        .join(ManualGapSettings, ManualGapSettings.id == ManualGap.gap_settings_id)
        .filter(
            ManualGap.is_active.is_(True),
            ManualGap.staff_id.in_(staff_ids),
            start < ManualGap.end, end > ManualGap.start,
            ManualGapSettings.all_services.is_(True) | query_schedule_in_settings | query_service_in_settings,
        )
        .options(
            joinedload(ManualGap.gap_settings).joinedload(ManualGapSettings.schedules),
            joinedload(ManualGap.gap_settings).joinedload(ManualGapSettings.services),
        )
    )

    participants_gaps = defaultdict(list)
    for gap in query_staff_gaps:
        participants_gaps[gap.staff_id].append(gap)
    return participants_gaps


def query_affected_by_schedule_deletion_gap_settings(db: Session, schedule_id: int) -> Query:
    return db.query(ManualGapSettings).join(
        ManualGapSettingsSchedules,
        ManualGapSettingsSchedules.gap_settings_id == ManualGapSettings.id,
        isouter=True
    ).filter(
        ManualGapSettings.is_active.is_(True),
        ManualGapSettings.id.in_(
            db.query(ManualGapSettingsSchedules.gap_settings_id).filter(
                ManualGapSettingsSchedules.schedule_id == schedule_id,
            )
        )
    ).group_by(ManualGapSettings.id).having(
        func.count(ManualGapSettingsSchedules.id) == 1
    )
