import datetime
from typing import List, Dict, Union, Optional, Iterable
from dataclasses import dataclass

from fastapi import BackgroundTasks
from fastapi_utils.cbv import cbv
from fastapi_utils.inferring_router import InferringRouter

from watcher import enums
from watcher.api.schemas.shift import (
    ShiftPutSchema,
    ShiftListSchema,
    SubShiftsPutSchema,
)
from watcher.config import settings
from watcher.crud.schedule import get_schedule_or_404
from watcher.crud.shift import (
    patch_shift,
    query_shifts_by_ids,
    get_shifts_from_parallel_slots,
)
from watcher.crud.staff import query_staff_by_ids
from watcher.crud.slot import get_slots_by_ids
from watcher.crud.base import list_objects_by_model
from watcher.db import Shift, Slot, Staff
from watcher.logic.member import check_and_deprive_member
from watcher.logic.boundaries_revision import recalculated_shift_fields_changes
from watcher.logic.exceptions import (
    BadRequest,
    PermissionDenied,
)
from watcher.logic.permissions import (
    is_superuser,
    is_user_responsible_for_service_or_schedule,
    is_user_in_service,
)
from watcher.logic.service import check_staff_service
from watcher.logic.shift import (
    bind_sequence_shifts,
    get_shift_rating_differences,
    get_prev_and_next_shifts,
    is_full_shift,
    shift_total_points,
    update_ratings_for_subshifts,
    set_shift_approved,
)
from watcher.logic.timezone import localize, now
from watcher.logic.rating import (
    update_ratings_dict,
)
from watcher.tasks.people_allocation import start_people_allocation
from watcher.tasks.shift import finish_shift

from .base import BaseRoute

import logging

logger = logging.getLogger(__name__)
router = InferringRouter()


@dataclass
class ShiftData:
    start: datetime.datetime
    end: datetime.datetime
    slot: Slot


@cbv(router)
class SubShiftRoute(BaseRoute):
    model = Shift
    default_ordering = '-start',
    joined_load = ('sub_shifts', 'next', 'prev', 'sub_shifts.staff', 'sub_shifts.slot')
    joined_load_list = ('staff', 'schedule', 'replacement_for', 'slot', )

    @router.put('/{shift_id}')
    def put(self, shift_id: int, sub_shifts: SubShiftsPutSchema, background_tasks: BackgroundTasks) -> List[ShiftListSchema]:
        main_shift = self.get_object(object_id=shift_id)
        sub_shifts = sorted(sub_shifts.sub_shifts, key=lambda x: x.end)
        self._validate_schema(main_shift, sub_shifts)
        self._check_put_permissions(main_shift, sub_shifts)

        # текущий предыдущий и следующий шифты
        parallel_slot_shifts = get_shifts_from_parallel_slots(db=self.session, main_shift=main_shift)
        prev_shift, next_shift = get_prev_and_next_shifts(main_shift, parallel_slot_shifts)
        db_subshifts = {shift.id: shift for shift in main_shift.sub_shifts}

        if not sub_shifts:
            # удаляем все подсмены
            # так же правим ссылки на next/prev
            if db_subshifts:
                self._remove_subshifts(
                    main_shift=main_shift,
                    prev_shift=prev_shift,
                    next_shift=next_shift,
                    db_subshifts=db_subshifts,
                    parallel_slot_shifts=parallel_slot_shifts,
                )
            return [main_shift]

        # если полностью перекрывает - обновляем весь шифт
        # так же правим ссылки на next/prev и удаляем подсмены
        # если были
        if is_full_shift(main_shift, sub_shifts):
            full_shift = sub_shifts[0]
            full_shift.remove_fields('replacement_for_id', 'next_id')
            patch_shift(
                db=self.session, db_obj=main_shift,
                shift=full_shift, author_id=self.current_user.id,
                commit=False,
            )
            self._remove_subshifts(
                main_shift=main_shift,
                prev_shift=prev_shift,
                next_shift=next_shift,
                db_subshifts=db_subshifts,
                parallel_slot_shifts=parallel_slot_shifts,
            )
            return [main_shift]

        # заполняем новыми подшифтами пустые ячейки
        self._set_sub_shift_parameters(main_shift, sub_shifts)
        self._fill_empty_spaces(main_shift, sub_shifts)
        exist_sub_shifts = {shift.id: shift for shift in sub_shifts if shift.id}
        new_sub_shifts = [shift for shift in sub_shifts if not shift.id]
        # берем время с которого нужно начать реаллокацию людей, если нужно
        people_allocation_start_date = self._get_people_allocation_start_time(db_shifts=db_subshifts, sub_shifts=sub_shifts)

        remain_shifts = self._put_exist_shifts(db_subshifts, exist_sub_shifts)
        new_shifts = self._create_new_shifts(main_shift, new_sub_shifts)
        all_subshifts = remain_shifts + new_shifts
        all_shifts = all_subshifts + parallel_slot_shifts
        all_shifts.sort(key=lambda x: (x.start, x.slot_id))

        bind_sequence_shifts(all_shifts)
        all_shifts[0].prev = prev_shift
        all_shifts[-1].next = next_shift

        set_shift_approved(main_shift, approved=True, author_id=self.current_user.id)
        staff_map = {
            staff.id: staff
            for staff in self.session.query(Staff).filter(
                Staff.id.in_(shift.staff_id for shift in all_shifts)
            )
        }
        update_ratings_for_subshifts(
            prev_shift=prev_shift, sub_shifts=all_shifts,
            staff_map=staff_map
        )
        self.session.commit()

        # если основной шифт был активным, то его нужно перевести в планируемые
        # и отозвать роль если нужно
        if main_shift.status == enums.ShiftStatus.active:
            main_shift.status = enums.ShiftStatus.scheduled
            check_and_deprive_member(session=self.session, shift=main_shift)

        if people_allocation_start_date:
            main_shift.schedule.recalculation_in_process = True
            self.session.add(main_shift)
            kwargs = {
                settings.FORCE_TASK_DELAY: True,
                'schedules_group_id': main_shift.schedule.schedules_group_id,
                'start_date': people_allocation_start_date,
                'push_staff': True,
            }
            background_tasks.add_task(
                start_people_allocation.delay,
                **kwargs
            )
        return list(
            list_objects_by_model(
                db=self.session, model=self.model,
                joined_load=self.joined_load_list,
            ).filter(
                self.model.id.in_(obj.id for obj in all_subshifts)
            ).order_by(self.model.start)
        )

    def _remove_subshifts(
        self, main_shift: Shift, next_shift: Optional[Shift],
        prev_shift: Optional[Shift], db_subshifts: Iterable[int],
        parallel_slot_shifts: Iterable[Shift],
    ):
        all_shifts = [main_shift] + parallel_slot_shifts
        all_shifts.sort(key=lambda x: (x.start, x.slot_id))

        bind_sequence_shifts(all_shifts)
        all_shifts[0].prev = prev_shift
        all_shifts[-1].next = next_shift
        self.session.commit()

        shifts_to_remove = [
            (shift.id, shift.status)
            for shift in query_shifts_by_ids(
                db=self.session,
                shift_ids=db_subshifts,
            )
        ]

        for shift_id, status in shifts_to_remove:
            if status == enums.ShiftStatus.active:
                finish_shift(session=self.session, shift_id=shift_id, _lock=False)

        ids_to_remove = [shift[0] for shift in shifts_to_remove]
        self._remove_shifts(ids=ids_to_remove)

    def _create_new_shifts(
        self,
        main_shift: Shift,
        sub_shifts: List[ShiftPutSchema]
    ) -> List[Shift]:
        new_shifts = []

        for shift in sub_shifts:
            db_shift = main_shift.copy(is_primary=main_shift.is_primary)
            shift.status = enums.ShiftStatus.scheduled
            db_shift = patch_shift(
                db=self.session,
                db_obj=db_shift,
                shift=shift,
                author_id=self.current_user.id,
                recalculate_rating=False,
                commit=False,
            )
            new_shifts.append(db_shift)

        self.session.add_all(new_shifts)
        return new_shifts

    def _put_exist_shifts(
        self,
        unchecked_db_subshifts: Dict[int, Shift],
        sub_shifts: Dict[int, ShiftPutSchema],
    ) -> List[Shift]:
        remain_shifts = []

        for shift_id, sub_shift in sub_shifts.items():
            patched_shift = patch_shift(
                db=self.session,
                db_obj=unchecked_db_subshifts.pop(shift_id),
                shift=sub_shift,
                author_id=self.current_user.id,
                recalculate_rating=False,
                commit=False,
            )
            self.session.add(patched_shift)
            remain_shifts.append(patched_shift)

        # в словаре остались только те шифты, которые нужно удалить
        ids_to_remove = [shift.id for shift in unchecked_db_subshifts.values()]
        self._remove_shifts(ids=ids_to_remove)
        return remain_shifts

    def _remove_shifts(self, ids: list[int]) -> None:
        self.session.query(Shift).filter(
            Shift.next_id.in_(ids)
        ).update(
            {Shift.next_id: None},
            synchronize_session=False,
        )
        self.session.query(Shift).filter(
            Shift.id.in_(ids)
        ).delete(synchronize_session=False)

    def _calculate_ratings_difference(
        self,
        db_shifts: Dict[int, Shift],
        sub_shifts: List[ShiftPutSchema]
    ) -> Dict[str, float]:
        rating_differences = {}
        query_watchers = query_staff_by_ids(
            db=self.session,
            staff_ids=[schema.staff_id for schema in sub_shifts if schema.staff_id]
        )
        schema_watchers = {watcher.id: watcher for watcher in query_watchers}

        slots = get_slots_by_ids(
            db=self.session,
            slot_ids=[schema.slot_id for schema in sub_shifts]
        )
        schema_slots = {slot.id: slot for slot in slots}

        for shift in sub_shifts:
            # патчим существующую подсмену
            if shift.id and shift.id in db_shifts:
                update_ratings_dict(
                    rating_differences,
                    get_shift_rating_differences(self.session, db_shifts[shift.id], shift.staff_id)
                )

            # добавляем новую подсмену
            if not shift.id and shift.staff_id:
                orm_shift = ShiftData(
                    start=shift.start,
                    end=shift.end,
                    slot=schema_slots[shift.slot_id]
                )

                update_ratings_dict(
                    rating_differences,
                    {schema_watchers[shift.staff_id].login: shift_total_points(orm_shift)}
                )

        shift_ids_to_delete = set(db_shifts.keys()) - set([schema.id for schema in sub_shifts])

        # нужно уменьшить рейтинг у людей, чьи подсмены удаляют
        update_ratings_dict(rating_differences, {
            db_shifts[shift_id].staff.login: -shift_total_points(db_shifts[shift_id])
            for shift_id in shift_ids_to_delete
            if db_shifts[shift_id].staff_id
        })

        return rating_differences

    def _get_people_allocation_start_time(
        self,
        db_shifts: Dict[int, Shift],
        sub_shifts: List[ShiftPutSchema]
    ) -> Union[datetime.datetime, None]:
        start_date = None

        for shift in sub_shifts:
            if shift.approved and (
                (not shift.id and shift.staff_id) or
                (shift.id and recalculated_shift_fields_changes(db_shifts[shift.id], shift))
            ):
                start_date = shift.start
                break

        if not start_date:
            return

        return localize(start_date)

    def _set_sub_shift_parameters(self, main_shift: Shift, sub_shifts: List[ShiftPutSchema]) -> None:
        """
        Редактирует схемам replacement_for_id
        Если у подшифтов не передан slot_id - он будет равен родительскому slot_id
        """
        for shift in sub_shifts:
            if not shift.slot_id:
                shift.slot_id = main_shift.slot_id

            shift.replacement_for_id = main_shift.id
            shift.approved = True

    def _fill_empty_spaces(self, main_shift: Shift, sub_shifts: List[ShiftPutSchema]) -> None:
        """

        :param main_shift: основной шифт от которого будут создаваться новые подшифты для пробелов
        :param sub_shifts: отсортированные по времени окончания подшифты
        """
        # для новых подшифтов схемы создаются по подобию основного шифта
        template = self._get_schema_from_object(main_shift)

        # добавим фиктивные подшифты в самом начале и конце,
        # для того чтобы в цикле обработать случаи, когда не передали подсмены для начала и конца смены
        start_shift, end_shift = template.copy(), template.copy()
        start_shift.end = main_shift.start
        end_shift.start = main_shift.end

        sub_shifts_copy = [start_shift, ]
        sub_shifts_copy.extend(sub_shifts[:])
        sub_shifts_copy.append(end_shift)

        new_sub_shifts = []
        for i, shift in enumerate(sub_shifts_copy):
            if i + 1 < len(sub_shifts_copy) and shift.end < sub_shifts_copy[i + 1].start:
                schema = template.copy()
                schema.start = shift.end
                schema.end = sub_shifts_copy[i + 1].start
                schema.approved = True
                new_sub_shifts.append(schema)

        # если были пробелы в промежутках, то нужно их добавить
        if new_sub_shifts:
            sub_shifts.extend(new_sub_shifts)
            sub_shifts.sort(key=lambda x: x.start)

    def _get_schema_from_object(self, main_shift: Shift) -> ShiftPutSchema:
        schema = ShiftPutSchema.from_orm(main_shift)
        schema.id = None
        schema.replacement_for_id = main_shift.id

        return schema

    def _validate_schema(self, main_shift: Shift, schemas: List[ShiftPutSchema]) -> None:
        """
        :param main_shift: основная смена
        :param schemas: отсортированные по времени подсмены
        """
        if main_shift.replacement_for_id is not None:
            raise BadRequest(message={
                'ru': 'Нельзя создать подсмены у подсмены',
                'en': 'You can\'t create sub-shifts for a sub-shift',
            })
        if main_shift.end < now():
            raise BadRequest(message={
                'ru': 'Можно редактировать только смены в будущем',
                'en': 'Only future shifts can be edited'
            })
        for i, shift in enumerate(schemas):
            if shift.start < main_shift.start or shift.end > main_shift.end:
                raise BadRequest(message={
                    'ru': 'Подсмены должны лежать в промежутке основной смены',
                    'en': 'Substitutions should lie in the interval of the main shift'
                })

            if i + 1 < len(schemas) and shift.end > schemas[i+1].start:
                raise BadRequest(message={
                    'ru': 'Подсмены не должны пересекаться',
                    'en': 'Sub shifts should not overlap'
                })

        schedule = get_schedule_or_404(self.session, main_shift.schedule_id)
        staff_ids = [shift.staff_id for shift in schemas if shift.staff_id]
        check_staff_service(db=self.session, staff_ids=staff_ids, expected=schedule.service_id)

        self._check_schema_ids_existent(schemas)

    def _check_put_permissions(self, main_shift: Shift, subshifts: List[ShiftPutSchema]) -> None:
        if not is_superuser(staff=self.current_user):
            if not self._taking_subshifts_to_yourself(main_shift, subshifts):
                if not is_user_responsible_for_service_or_schedule(
                    db=self.session,
                    schedule=main_shift.schedule,
                    staff=self.current_user,
                ):
                    raise PermissionDenied(message={
                        'ru': 'Нет разрешения на изменение подсмены',
                        'en': 'No permission to change subshifts',
                    })

    def _taking_subshifts_to_yourself(self, main_shift: Shift, subshifts: List[ShiftPutSchema]):
        db_subshifts = {shift.id: shift for shift in main_shift.sub_shifts}
        user_in_service = is_user_in_service(
            db=self.session,
            staff_id=self.current_user.id,
            service_id=main_shift.schedule.service_id,
        )
        for subshift in subshifts:
            if not subshift.id and subshift.id not in db_subshifts:
                return False
            changed_fields = self.get_changed_fields(db_subshifts[subshift.id], subshift)
            if changed_fields != ['staff_id'] or subshift.staff_id != self.current_user.id or not user_in_service:
                return False
        return True

    def _check_schema_ids_existent(self, schemas: List[ShiftPutSchema]) -> None:
        """ Проверяем все ли переданные id существуют в базе """
        schema_ids = [schema.id for schema in schemas if schema.id]
        shifts = query_shifts_by_ids(self.session, schema_ids).all()
        schema_ids = set(schema_ids)
        shift_ids = set([shift.id for shift in shifts])

        difference_ids = schema_ids.difference(shift_ids)

        if difference_ids:
            raise BadRequest(
                message={
                    'ru': f'Подшифтов с id={difference_ids} не существует',
                    'en': f'Subshifts with id={difference_ids} not exists',
                }
            )
