import datetime
import logging

from collections import defaultdict
from typing import Optional
from sqlalchemy.orm import Session, joinedload

from watcher import enums
from watcher.config import settings
from watcher.crud.gap import get_staff_gaps_for_interval
from watcher.crud.manual_gap import get_staff_manual_gaps_for_schedule_group
from watcher.crud.schedule import set_recalculation
from watcher.crud.schedule_group import (
    query_groups_with_allocation_error,
    schedule_group_has_employed_intervals,
)
from watcher.logic.exceptions import NoShiftsFound
from watcher.logic.problem import resolve_problem
from watcher.logic.people_allocation import (
    get_participants_ratings_from_sequence,
    get_participants_last_shift_ends,
    prepare_people_allocation_start_time,
    find_staff_for_shift,
    fill_shift,
    OnDutyGap,
)
from watcher.logic.timezone import now
from watcher.db.base import dbconnect
from watcher.logic.shift import need_find_duty
from watcher.tasks.generating_shifts import sequence_shifts
from watcher.tasks.sync import notify_staff_duty
from watcher.db import (
    Schedule,
    Interval,
    Slot,
    Composition,
    CompositionParticipants,
    Gap,
    Shift,
    Staff,
    Revision,
    ManualGap,
    Problem,
)
from .base import lock_task


logger = logging.getLogger(__name__)


@lock_task(save_metrics=True, send_to_unistat=True)
@dbconnect
def start_people_allocation_for_groups_with_allocation_error(session: Session):
    """
    Запускаем перераспределение людей для групп расписаний в которых еще не было перераспределения людей или
    в прошлом запуске произошла ошибка
    """
    schedules_group_to_allocate = query_groups_with_allocation_error(session).all()

    schedules_groups = set()
    start_dates = defaultdict(lambda: now())

    for schedules_group, shift in schedules_group_to_allocate:
        schedules_groups.add(schedules_group)

        if schedules_group.last_people_allocation_at:
            start_dates[schedules_group.id] = schedules_group.last_people_allocation_at
        elif not shift.staff_id and not shift.empty:
            start_dates[schedules_group.id] = min(shift.start, start_dates[schedules_group.id])

    for schedules_group in schedules_groups:
        logger.info(f'Starting people allocation for group with error {schedules_group.id}')
        start_date = start_dates[schedules_group.id]
        logger.info(
            f'Scheduling start_people_allocation for group with '
            f'error: {schedules_group.id}, from {start_date}'
        )
        start_people_allocation.delay(
            schedules_group_id=schedules_group.id,
            start_date=start_date,
        )


def get_people_allocation_data(
    session: Session,
    schedules_group_id: int,
    start_date: datetime.datetime
) -> tuple[
    set[Schedule], set[Shift], set[Interval],
    dict[int, dict[int, set[Staff]]], dict[int, set[Staff]],
    dict[int, list[Gap | ManualGap | OnDutyGap]],
    dict[int, Problem]
]:
    query = (
        session.query(Schedule, Staff, Interval, Slot, Shift)
        .join(Interval, Interval.schedule_id == Schedule.id)
        .join(Revision, Revision.id == Interval.revision_id)
        .join(Slot, Slot.interval_id == Interval.id)
        .join(Shift, Shift.slot_id == Slot.id)
        .join(Composition, Composition.id == Slot.composition_id)
        .join(CompositionParticipants, CompositionParticipants.composition_id == Composition.id)
        .join(Staff, Staff.id == CompositionParticipants.staff_id)
        .options(
            joinedload(Shift.next).joinedload(Shift.slot),
            joinedload(Shift.prev),
            joinedload(Shift.replacement_for),
            joinedload(Shift.slot),
            joinedload(Shift.sub_shifts),
            joinedload(Schedule.schedules_group),
        )
        .filter(
            Schedule.schedules_group_id == schedules_group_id,
            Revision.state == enums.RevisionState.active,
            Shift.end >= start_date,
            Shift.status == enums.ShiftStatus.scheduled,
            ~Shift.sub_shifts.any()
        )
        .order_by(Shift.start)
    )
    schedule_groups = query.all()

    if not schedule_groups:
        raise NoShiftsFound

    intervals = set()
    schedules = set()
    shifts = set()
    schedules_participants = defaultdict(lambda: defaultdict(set))
    shifts_participants = defaultdict(set)
    staff_ids = set()
    service_ids = set()

    for schedule, staff, interval, slot, shift in schedule_groups:
        schedules.add(schedule)
        intervals.add(interval)
        shifts.add(shift)
        schedules_participants[schedule.id][slot.id].add(staff)
        shifts_participants[shift.id].add(staff)
        staff_ids.add(staff.id)
        service_ids.add(schedule.service_id)

    start = schedule_groups[0][4].start  # в первой строке Shift.start минимальный
    end = schedule_groups[-1][4].end
    participants_gaps = get_staff_gaps_for_interval(
        db=session, staff_ids=staff_ids, start=start, end=end
    )

    manual_gaps = get_staff_manual_gaps_for_schedule_group(
        db=session, staff_ids=staff_ids, start=start, end=end,
        schedule_group_id=schedules_group_id, service_ids=service_ids
    )
    all_participants_gaps = defaultdict(list, {
        staff_id: participants_gaps[staff_id] + manual_gaps[staff_id]
        for staff_id in set(list(participants_gaps) + list(manual_gaps))
    })
    problems = session.query(Problem).filter(
        Problem.shift_id.in_((obj.id for obj in shifts)),
        Problem.reason == enums.ProblemReason.nobody_on_duty,
        Problem.status == enums.ProblemStatus.new,
    )

    return (
        schedules, shifts, intervals,
        schedules_participants, shifts_participants,
        all_participants_gaps,
        {problem.shift_id: problem for problem in problems}
    )


@lock_task(lock_key=lambda schedules_group_id, *args, **kwargs: 'start_people_allocation_{}'.format(schedules_group_id))
@dbconnect
def start_people_allocation(
    session: Session,
    schedules_group_id: int,
    start_date: Optional[datetime.datetime] = None,
    push_staff: bool = False,
):
    """
    Перераспределение людей внутри группы графиков

    :param schedules_group_id: id группы графиков, для которого нужно провести перераспределение людей
    :param start_date: дата начала перераспределения, если не передать, то будет автоматически определено
    как дата создания самого старого расписания без смен.
    :param push_staff: нужно ли просить стафф обновить данные дежурств у себя

    :return:
    """
    logger.info(f'Processing start_people_allocation for group: {schedules_group_id}, start_date: {start_date}')
    current_now = now()

    if not schedule_group_has_employed_intervals(session, schedules_group_id=schedules_group_id):
        logger.info(f'Для группы смен: {schedules_group_id}, нет ни одной смены с дежурными')
        set_recalculation(
            session=session,
            obj_ids=[
                obj.id for obj in session.query(Schedule).filter(
                    Schedule.schedules_group_id == schedules_group_id
                )],
            target=False,
        )
        return

    start_date = prepare_people_allocation_start_time(session, schedules_group_id, start_date)
    try:
        (
            schedules_to_recalculate,
            shifts_to_recalculate,
            intervals_to_recalculate,
            schedules_participants,
            shifts_participants,
            participants_gaps,
            shifts_problems,
        ) = get_people_allocation_data(
            session=session,
            schedules_group_id=schedules_group_id,
            start_date=start_date,
        )
    except NoShiftsFound:
        logger.warning(f'No shifts for {schedules_group_id} found')
        set_recalculation(
            session=session,
            obj_ids=[
                obj.id for obj in session.query(Schedule).filter(
                    Schedule.schedules_group_id == schedules_group_id
                )],
            target=False,
        )
        return

    set_recalculation(
        session=session,
        obj_ids=(obj.id for obj in schedules_to_recalculate),
        target=True,
    )

    schedule = next(iter(schedules_to_recalculate))
    shifts_sequence = sequence_shifts(
        session=session,
        schedule=schedule,
        shifts=sorted(shifts_to_recalculate, key=lambda x: x.start),
        intervals=intervals_to_recalculate,
        for_group=True,
    )
    participant_ratings = get_participants_ratings_from_sequence(
        db=session,
        schedules_participants=schedules_participants,
        shifts_sequence=shifts_sequence,
        schedules_to_calculate=schedules_to_recalculate,
    )

    participant_last_shift_ends = get_participants_last_shift_ends(
        session=session,
        to_shift=shifts_sequence[0],
        schedules_participants=schedules_participants,
        schedules_to_calculate=schedules_to_recalculate,
    )

    already_found = {}
    nobody_on_duty_shifts = []

    timeout_between_shifts = schedule.schedules_group.timeout_between_shifts
    staff_login_to_id = {
        staff.login: staff.id
        for participants in shifts_participants.values()
        for staff in participants
    }

    # добавляем фиктивный гэп на последнее дежурство до распределения
    # чтобы учесть настройку с N днями перерыва между сменами
    shift_is_primary = True
    for schedule_id in participant_last_shift_ends:
        for staff_login in participant_last_shift_ends[schedule_id]:
            max_datetime = max(
                participant_last_shift_ends[schedule_id][staff_login][shift_is_primary],
                participant_last_shift_ends[schedule_id][staff_login][not shift_is_primary]
            )
            participants_gaps[staff_login_to_id[staff_login]].append(
                OnDutyGap(start=max_datetime, end=max_datetime)
            )

    # учитываем гепы тех кто уже распределен в утвержденные смены
    for shift in shifts_sequence:
        if not shift.empty and shift.approved and shift.staff:
            participants_gaps[shift.staff_id].append(
                OnDutyGap(start=shift.start, end=shift.end)
            )

    # чтобы не назначать одного человека 2 раза подряд, учитывая пустые смены
    previous_shift_end_date = shifts_sequence[0].start
    previous_shifts = []

    for index, shift in enumerate(shifts_sequence):
        shift.predicted_ratings = {
            staff.login: participant_ratings[shift.schedule_id][staff.login]
            for staff in shifts_participants[shift.id]
        }
        if shift.start > previous_shift_end_date:
            for prev_shift in previous_shifts:
                if prev_shift.staff:
                    participant_last_shift_ends[schedule.id][prev_shift.staff.login][shift.is_primary] = shift.start
        if not shift.empty:
            if previous_shift_end_date != shift.end:
                previous_shift_end_date = shift.end
                previous_shifts = []
            previous_shifts.append(shift)

            if not need_find_duty(shift=shift):
                suitable_staff = shift.staff
            else:
                suitable_staff = find_staff_for_shift(
                    shift=shift,
                    shifts_sequence=shifts_sequence[index+1:],
                    possible_participants=schedules_participants[shift.schedule_id][shift.slot_id],
                    participants_gaps=participants_gaps,
                    participant_ratings=participant_ratings[shift.schedule_id],
                    participant_last_shift_ends=participant_last_shift_ends,
                    already_found=already_found,
                    timeout_between_shifts=timeout_between_shifts,
                )
                if not suitable_staff:
                    nobody_on_duty_shifts.append(shift.id)
                else:
                    problem = shifts_problems.get(shift.id)
                    if problem:
                        resolve_problem(problem=problem)
                    if shift.start < current_now + datetime.timedelta(days=settings.STAFF_SYNC_INTERVAL):
                        # в смене выставился человек и он дежурит в пределах месяца
                        # нужно уведомить стафф
                        push_staff = True

            fill_shift(
                shift=shift,
                suitable_staff=suitable_staff,
                participant_ratings=participant_ratings,
                participant_last_shift_ends=participant_last_shift_ends,
            )
        if shift.next_id:
            shift.next.predicted_ratings = {
                staff.login: participant_ratings[shift.next.schedule_id][staff.login]
                for staff in shifts_participants[shift.next_id]
            }

    # TODO: ABC-11146. Мы же можем запустить перераспределение после замены для последних шифтов,
    #  а запишем сюда якобы недавно перераспределяли людей, но по факту там может быть перераспределен только один
    schedule.schedules_group.last_people_allocation_at = now()
    schedule.schedules_group.people_allocation_error = None

    set_recalculation(
        session=session,
        obj_ids=(obj.id for obj in schedules_to_recalculate),
        target=False,
    )

    if nobody_on_duty_shifts:
        #  создадим проблемы для смен, где не нашлось дежурного и у которых
        #  еще нет проблемы такого типа
        shifts = session.query(Shift).filter(
            Shift.id.in_(nobody_on_duty_shifts),
        ).filter(
            ~session.query(Problem.id)
            .filter(
                Problem.shift_id == Shift.id,
                Problem.reason == enums.ProblemReason.nobody_on_duty,
            )
            .exists()
        )
        problems = []
        for shift in shifts:
            logger.info(f'Creating problem nobody_on_duty for shift: {shift.id} from start_people_allocation')
            problems.append({
                'shift_id': shift.id,
                'reason': enums.ProblemReason.nobody_on_duty,
            })
        if problems:
            session.bulk_insert_mappings(Problem, problems)

    if push_staff:
        service_ids = set(schedule.service_id for schedule in schedules_to_recalculate)
        for service_id in service_ids:
            notify_staff_duty.delay(
                service_id=service_id
            )
