import logging
import datetime

from collections import defaultdict
from dateutil.parser import parse
from typing import Tuple, Iterable
from sqlalchemy import or_, and_
from sqlalchemy.orm import joinedload, Session

from watcher.celery_app import app
from watcher import enums
from watcher.crud.base import query_objects_by_ids
from watcher.crud.event import update_events_statuses
from watcher.db import (
    Composition,
    CompositionToRole,
    CompositionToScope,
    CompositionParticipants,
    Gap,
    Event,
    Interval,
    Member,
    Problem,
    Role,
    Service,
    Slot,
    Schedule,
    Shift,
    SchedulesGroup,
)
from watcher.logic.holidays import is_weekday
from watcher.logic.timezone import make_localized_datetime, now, today, localize
from watcher.tasks.composition import update_composition
from watcher.tasks.generating_shifts import revision_shift_boundaries
from watcher.tasks.problem import create_problems_for_staff_has_gap_shifts, resolve_shifts_problems
from watcher.tasks.process_delete_members import process_delete_members
from watcher.tasks.shift import finish_shift
from watcher.tasks.people_allocation import start_people_allocation
from watcher.db.base import dbconnect
from .base import lock_task

logger = logging.getLogger(__name__)

DT_BATCH_SIZE = 500


def get_event_type(event: Event) -> str:
    return f'{event.kind}_{event.table}'


class EventScheduler:
    """
    Группирует события по типу/id сущности (если нужно) и ставит
    таски на дальнейшую обработку
    """
    def __init__(self, db: Session):
        self.events_by_type = defaultdict(list)
        self.db = db
        self.skip_events_ids = set()

    def _split_by_batch(self, data: dict, size: int = 50) -> dict:
        result = {}
        for key, value in data.items():
            result[key] = value
            if len(result) >= size:
                yield result
                result = {}
        if result:
            yield result

    def _get_ids_from_events(self, events: list[Event]) -> dict:
        obj_ids = defaultdict(list)
        for event in events:
            obj_id = event.obj_id
            if not obj_id:
                logger.error(f'Got no obj_id from event: {event.id}')
            else:
                obj_ids[obj_id].append(event.id)
        return obj_ids

    def schedule_insert_holidays_holiday(self, events: list[Event]) -> None:
        """
        Добавление новых праздников/выходных
        """
        holidays_dates = defaultdict(list)
        for event in events:
            date = event.object_data['date']
            if is_weekday(parse(event.object_data['date'])):
                holidays_dates[date].append(event.id)
            else:
                self.skip_events_ids.add(event.id)

        if holidays_dates:
            for batch in self._split_by_batch(data=holidays_dates):
                process_new_and_deleted_holidays.delay(obj_to_event=batch)

    def schedule_delete_holidays_holiday(self, events: list[Event]) -> None:
        """
        Удаление праздников/выходных
        """
        holidays_dates = defaultdict(list)
        for event in events:
            date = event.old_keys['date']
            holidays_dates[date].append(event.id)

        if holidays_dates:
            for batch in self._split_by_batch(data=holidays_dates):
                process_new_and_deleted_holidays.delay(obj_to_event=batch)

    def schedule_update_services_service(self, events: list[Event]) -> None:
        """
        Обрабатываем обновление статуса сервиса
        """
        obj_ids = self._get_ids_from_events(events)
        services = query_objects_by_ids(
            db=self.db, object_ids=obj_ids,
            model=Service,
        )
        to_close = {}
        to_delete = {}
        for service in services:
            service_events = obj_ids[service.id]
            if service.state == enums.ServiceState.closed:
                to_close[service.id] = service_events
            elif service.state == enums.ServiceState.deleted:
                to_delete[service.id] = service_events
            else:
                self.skip_events_ids.update(service_events)

        for obj_to_event, task in (
            (to_close, process_close_services),
            (to_delete, process_delete_services)
        ):
            if obj_to_event:
                for batch in self._split_by_batch(data=obj_to_event):
                    task.delay(obj_to_event=batch)

    def schedule_update_services_servicemember(self, events: list[Event]) -> None:
        obj_ids = self._get_ids_from_events(events)
        members = query_objects_by_ids(
            db=self.db, object_ids=obj_ids,
            model=Member,
        )
        to_delete = defaultdict(dict)
        to_check = defaultdict(dict)
        for member in members:
            member_events = obj_ids[member.id]
            if member.state == enums.MemberState.deprived:
                to_delete[member.service_id][member.id] = member_events
            elif member.state == enums.MemberState.active:
                #  таска проверит не стал ли он активным только что
                to_check[member.service_id][member.id] = member_events
            else:
                self.skip_events_ids.update(member_events)

        for events, task in (
            (to_check, process_update_members),
            (to_delete, process_delete_members)
        ):
            obj_to_event = {}
            for service_id, members in events.items():
                if len(members) > 500:
                    for batch in self._split_by_batch(data=members):
                        task.delay(obj_to_event=batch)
                else:
                    obj_to_event.update(members)

                    if len(obj_to_event) > 500:
                        task.delay(obj_to_event=obj_to_event)
                        obj_to_event = {}

            if obj_to_event:
                task.delay(obj_to_event=obj_to_event)

    def schedule_update_duty_gap(self, events: list[Event]) -> None:
        """
        Обновился gap, либо изменилась длина, либо статус
        """
        obj_ids = self._get_ids_from_events(events)
        if obj_ids:
            for batch in self._split_by_batch(data=obj_ids):
                process_update_gaps.delay(obj_to_event=batch)

    def schedule_insert_duty_gap(self, events: list[Event]) -> None:
        obj_ids = self._get_ids_from_events(events)
        if obj_ids:
            for batch in self._split_by_batch(data=obj_ids):
                process_new_gaps.delay(obj_to_event=batch)

    def schedule_events(self) -> Tuple[set, set]:
        scheduled_events = set()
        for event_type, events in self.events_by_type.items():
            processor = getattr(self, f'schedule_{event_type}', None)
            if not processor:
                logger.warning(f'Unsupported event_type {event_type}')
                self.skip_events_ids.update(obj.id for obj in events)
                continue
            processor(events)
            scheduled_events.update(obj.id for obj in events)

        return scheduled_events, self.skip_events_ids

    def add_insert_services_servicemember(self, event: Event) -> None:
        """
        Если участник сразу активный - обработаем как обновление
        если нет - ничего не делаем, потом обработаем когда активируется
        """
        if event.object_data.get('state') == enums.MemberState.active:
            # добавился новый участник сразу активный
            event.kind = 'update'
            self.add_default_event(event=event)
        else:
            self.skip_events_ids.add(event.id)

    def add_default_event(self, event: Event) -> None:
        self.events_by_type[get_event_type(event)].append(event)

    def add_events(self, events: list[Event]) -> None:
        for event in events:
            getattr(
                self,
                f'add_{get_event_type(event)}',
                getattr(self, 'add_default_event'),
            )(event)


@lock_task(save_metrics=True, send_to_unistat=True)
@dbconnect
def schedule_events(session: Session):
    """
    Берет все не обработанные события и
    обрабатывает их (ставя другие таски)
    при необходимости
    """

    scheduler = EventScheduler(db=session)
    events = session.query(Event).filter(
        Event.source == enums.EventSource.logbroker,
        or_(
            and_(
                Event.state == enums.EventState.new,
                Event.created_at < now() - datetime.timedelta(minutes=10)
            ),
            and_(
                Event.state == enums.EventState.scheduled,
                Event.created_at < now() - datetime.timedelta(hours=2)
            ),
        )
    )

    scheduler.add_events(events)
    scheduled_events, skip_events_ids = scheduler.schedule_events()
    if scheduled_events:
        update_events_statuses(
            db=session,
            query=session.query(Event).filter(~Event.id.in_(skip_events_ids)),
            ids=scheduled_events,
            state=enums.EventState.scheduled
        )
    if skip_events_ids:
        update_events_statuses(db=session, ids=skip_events_ids, state=enums.EventState.processed)


@app.task
@dbconnect
def process_update_members(session: Session, obj_to_event: dict[int: Iterable]):
    """
    Обрабатывает событие: добавление участника в сервис
    Добавление участника в сервис определяется изменение статуса у ServiceMember
    с REQUESTED на ACTIVE

    :param obj_to_event
        {member.id: [member_events, ]}
    """

    # выбираем составы по ролям/скопам
    query_compositions = (
        session.query(Member, Role, Composition)
        .filter(Member.id.in_(obj_to_event.keys()))
        .join(Composition, Member.service_id == Composition.service_id)
        .join(CompositionToRole, CompositionToRole.composition_id == Composition.id, isouter=True)
        .join(CompositionToScope, CompositionToScope.composition_id == Composition.id, isouter=True)
        .join(Role, Role.id == Member.role_id)
        .filter(
            Composition.autoupdate.is_(True),
            or_(
                Member.role_id == CompositionToRole.role_id,
                Role.scope_id == CompositionToScope.scope_id,
                Composition.full_service.is_(True),
            ),
        )
    )

    # для исключения лишнего
    query_compositions = (
        query_compositions
        .options(
            joinedload(Composition.excluded_roles),
            joinedload(Composition.excluded_scopes),
            joinedload(Composition.excluded_staff),
            joinedload(Member.role),
            joinedload(Member.staff),
            joinedload(Role.scope),
            joinedload(Composition.participants),
        )
    )

    composition_for_update_set = set()
    for member, role, composition in query_compositions.all():
        if (
            role in composition.excluded_roles
            or role.scope in composition.excluded_scopes
            or member.staff in composition.excluded_staff
            or member.staff in composition.participants
        ):
            continue

        if composition not in composition_for_update_set:
            # создаёт тут соответствующий event и потом запускает
            update_composition.delay(composition_id=composition.id)
        composition_for_update_set.add(composition)

    ids = []
    for event_list in obj_to_event.values():
        ids.extend(event_list)

    update_events_statuses(db=session, ids=ids, state=enums.EventState.processed)


@app.task
@dbconnect
def process_close_services(session: Session, obj_to_event: dict[int: Iterable]):
    """
    Обрабатывает событые закрытия сервиса: все schedule c ScheduleState.active
    переводятся в ScheduleState.disabled
    :param obj_to_event
        {service.id: [close_service_events.id, ]}
    """

    services = query_objects_by_ids(
        db=session, object_ids=obj_to_event.keys(),
        model=Service,
    ).filter(Service.state == enums.ServiceState.closed)
    service_ids = set(service.id for service in services)
    if not service_ids:
        return

    session.query(Schedule).filter(
        Schedule.service_id.in_(service_ids),
        Schedule.state == enums.ScheduleState.active
    ).update(
        {Schedule.state: enums.ScheduleState.disabled},
        synchronize_session=False,
    )

    ids = []
    for event_list in obj_to_event.values():
        ids.extend(event_list)

    update_events_statuses(db=session, ids=ids, state=enums.EventState.processed)


@app.task
@dbconnect
def process_delete_services(session: Session, obj_to_event: dict[int: Iterable]):
    """
    Обрабатывает событые удаления сервиса
    - завершаем активные шифты
    - удаляем связанные расписания
    - перераспределяем людей с текущего дня для групп, в которых остались расписания
    :param obj_to_event
        {service.id: [delete_service_events.id, ]}
    """
    services = query_objects_by_ids(
        db=session, object_ids=obj_to_event.keys(),
        model=Service,
    ).filter(Service.state == enums.ServiceState.deleted).all()
    service_ids = set(service.id for service in services)
    if not service_ids:
        return

    # убираем роль у текущих дежурных (finish_shift)
    shifts = (
        session.query(Shift)
        .join(Schedule, Shift.schedule_id == Schedule.id)
        .filter(
            Schedule.service_id.in_(service_ids),
            Shift.status == enums.ShiftStatus.active,
        )
        .all()
    )
    for shift in shifts:
        finish_shift(shift_id=shift.id)

    schedules_groups = (
        session.query(SchedulesGroup)
        .join(Schedule, Schedule.schedules_group_id == SchedulesGroup.id)
        .filter(Schedule.service_id.in_(service_ids))
        .all()
    )
    schedules_group_ids = set(group.id for group in schedules_groups)

    # удаление расписаний
    session.query(Schedule).filter(Schedule.service_id.in_(service_ids)).delete(synchronize_session=False)

    # перераспределение людей с текущего дня для групп, в которых останутся расписания
    remain_schedules_groups = (
        session.query(SchedulesGroup)
        .join(Schedule, Schedule.schedules_group_id == SchedulesGroup.id)
        .filter(SchedulesGroup.id.in_(schedules_group_ids))
        .all()
    )
    for group in remain_schedules_groups:
        start_people_allocation.delay(schedules_group_id=group.id, start_date=today())

    ids = []
    for event_list in obj_to_event.values():
        ids.extend(event_list)

    update_events_statuses(db=session, ids=ids, state=enums.EventState.processed)


@app.task
@dbconnect
def process_new_gaps(session: Session, obj_to_event: dict[int: Iterable]):
    query_shifts = (
        session.query(Gap, Shift, Schedule)
        .filter(Gap.id.in_(obj_to_event.keys()))
        .join(Shift, Shift.staff_id == Gap.staff_id)
        .join(Schedule, Shift.schedule_id == Schedule.id)
        .filter(
            Gap.start < Shift.end,
            Gap.end > Shift.start,
            Gap.status == enums.GapStatus.active,
            Shift.end > now(),
            Gap.end - Gap.start >= Schedule.length_of_absences,
        )
    )
    problems = []
    groups_to_relocate = defaultdict(list)

    for gap, shift, schedule in query_shifts.all():
        if not schedule.recalculate or (shift.approved and shift.staff_id is not None):
            logger.info(f'Creating problem staff_has_gap for shift: {shift.id} from process_new_gaps')
            problems.append({
                'shift_id': shift.id,
                'staff_id': shift.staff_id,
                'reason': enums.ProblemReason.staff_has_gap,
                'duty_gap_id': gap.id
            })
        else:
            shift.staff_id = None
            groups_to_relocate[schedule.schedules_group_id].append(localize(shift.start))

    if problems:
        session.bulk_insert_mappings(Problem, problems)

    for group_id, dates in groups_to_relocate.items():
        logger.info(f'Starting people allocation for {group_id} from process_new_gaps')
        start_people_allocation.delay(
            schedules_group_id=group_id,
            start_date=min(dates),
        )

    ids = []
    for event_list in obj_to_event.values():
        ids.extend(event_list)

    update_events_statuses(db=session, ids=ids, state=enums.EventState.processed)


@app.task
@dbconnect
def process_update_gaps(session: Session, obj_to_event: dict[int: Iterable]):
    """
    Тут может изменится длина/даты или статус гэпа стать неактивным.
    Если гэп стал неактивный, нужно найти проблемы, которые к нему относились и пометить решеными.
    Если гэп изменились длина/даты:
        - проверяем разрешились ли какие-то проблемы
        - запускаем проверку для стаффов из шифтов create_problems_for_staff_has_gap_shifts,
            тк не знаем какие конкретно были даты/длина раньше
        - занулить данные
    """
    # попытка снять все проблемы с шифтов
    query_gaps_for_resolve = (
        session.query(Shift)
        .join(Problem, Problem.shift_id == Shift.id)
        .join(Gap, Gap.id == Problem.duty_gap_id)
        .filter(
            Gap.id.in_(obj_to_event.keys()),
            Shift.status != enums.ShiftStatus.completed,    # незавершённые
            Problem.status != enums.ProblemStatus.resolved,   # с активными проблемами
        )
    )

    schedule_ids = set(shift.schedule_id for shift in query_gaps_for_resolve.all())
    for schedule_id in schedule_ids:
        logger.info(f'Sending resolve_shifts_problems for {schedule_id}')
        resolve_shifts_problems.delay(schedule_id=schedule_id)

    # проверка nobody_on_duty проблем
    # возьмем все смены, где нет дежурного и удаленный геп с ними пересекается
    query_nobody_on_duty_for_resolve = (
        session.query(Schedule, Shift)
        .join(Interval, Interval.schedule_id == Schedule.id)
        .join(Slot, Slot.interval_id == Interval.id)
        .join(CompositionParticipants, CompositionParticipants.composition_id == Slot.composition_id)
        .join(Gap, Gap.staff_id == CompositionParticipants.staff_id)
        .join(Shift, Shift.schedule_id == Schedule.id)
        .filter(
            Gap.id.in_(obj_to_event.keys()),
            Gap.status == enums.GapStatus.deleted,
            Gap.start < Shift.end,
            Gap.end > Shift.start,
            Shift.staff_id.is_(None),
            Shift.empty.is_(False),
            Shift.status != enums.ShiftStatus.completed,
            ~Shift.sub_shifts.any(),
        )
    )
    groups_to_relocate = defaultdict(list)
    for schedule, shift in query_nobody_on_duty_for_resolve.all():
        groups_to_relocate[schedule.schedules_group_id].append(shift.start)

    for group_id, dates in groups_to_relocate.items():
        logger.info(f'Starting people allocation for {group_id} from process_update_gaps')
        start_people_allocation.delay(
            schedules_group_id=group_id,
            start_date=min(dates),
        )

    # если у стаффа не будет ни одного шифта, то нам ничего проверять на проблемы не нужно
    query_shift_active_gap = (
        session.query(Shift)
        .join(Gap, Shift.staff_id == Gap.staff_id)
        .filter(
            Gap.id.in_(obj_to_event.keys()),
            Gap.status == enums.GapStatus.active,
            Shift.status != enums.ShiftStatus.completed,
            Gap.start < Shift.end,
            Gap.end > Shift.start,
            ~session.query(Problem.id)
            .filter(
                Problem.shift_id == Shift.id,
                Problem.reason == enums.ProblemReason.nobody_on_duty,
            )
            .exists()
        )
    )

    # создание проблем с пересчётом
    possible_shifts = query_shift_active_gap.all()
    if len(possible_shifts) > 0:
        staff_ids = set([shift.staff_id for shift in possible_shifts])
        logger.info(f'Sending create_problems_for_staff_has_gap_shifts for {staff_ids}')
        create_problems_for_staff_has_gap_shifts.delay(
            staff_ids=list(staff_ids),
        )

    ids = []
    for event_list in obj_to_event.values():
        ids.extend(event_list)

    update_events_statuses(db=session, ids=ids, state=enums.EventState.processed)


@app.task
@dbconnect
def process_new_and_deleted_holidays(session: Session, obj_to_event: dict[int: Iterable]):
    now_date = now()
    schedule_holiday = defaultdict(list)
    for date_string in obj_to_event:
        logger.info(f'Processing holiday {date_string}')
        date = parse(date_string)
        holiday_start = make_localized_datetime(date)
        holiday_end = make_localized_datetime(date + datetime.timedelta(days=1))
        # прошедшие праздники не имеет смысла пересчитывать
        if holiday_end < now_date:
            continue

        shifts = (
            session.query(Shift)
            .join(Slot, Slot.id == Shift.slot_id)
            .join(Interval, Interval.id == Slot.interval_id)
            .filter(
                Shift.status != enums.ShiftStatus.completed,
                or_(
                    Interval.unexpected_holidays == enums.IntervalUnexpectedHolidays.remove,
                    Interval.weekend_behaviour == enums.IntervalWeekendsBehaviour.extend,
                ),
                holiday_start < Shift.end,
                holiday_end > Shift.start,
                Shift.empty.is_(False),
                Shift.approved.is_(False),
            )
        ).all()

        schedule_ids = set(shift.schedule_id for shift in shifts)
        for schedule_id in schedule_ids:
            schedule_holiday[schedule_id].append(date_string)

    for schedule_id, holidays in schedule_holiday.items():
        # найдем самый ранний из дней и запустим пересчет с него

        date_from = sorted(holidays)[0]
        logger.info(f'Sending revision_shift_boundaries for schedule: {schedule_id}, from: {date_from}, from_holiday=True')

        revision_shift_boundaries.delay(
            schedule_id=schedule_id,
            date_from=sorted(holidays)[0],
            from_holiday=True,
        )

    event_ids = []
    for events in obj_to_event.values():
        event_ids.extend(events)
    update_events_statuses(db=session, ids=event_ids, state=enums.EventState.processed)
