from typing import Callable, Dict, Optional, Tuple, List, Iterable
from sqlalchemy.orm import Query, joinedload, load_only
from fastapi import status
from fastapi_utils.cbv import cbv
from fastapi_utils.inferring_router import InferringRouter

from watcher import enums
from watcher.api.routes.base import BaseRoute
from watcher.api.schemas.base import CursorPaginationResponse
from watcher.api.schemas.shift import (
    ShiftListSchema,
    ShiftPatchSchema,
    ShiftRichListSchema,
    ShiftsUploadSchema,
    ShiftUploadSchema,
    ShiftRatingSchema,
    ShiftABCListSchema,
)
from watcher.crud.base import get_object_by_model_or_404, _add_joined_load
from watcher.crud.shift import patch_shift, query_shifts_by_schedule_in_interval
from watcher.db import Shift, Schedule, Slot, Staff
from watcher.logic.exceptions import (
    BadRequest,
    PermissionDenied,
    ShiftRatingDisabled,
)
from watcher.logic.filter import (
    OPERATOR_MAP,
)
from watcher.logic.permissions import (
    is_superuser,
    is_user_responsible_for_service_or_schedule,
    is_user_in_service,
)
from watcher.logic.timezone import now, localize
from watcher.tasks import (
    finish_shift,
    start_people_allocation,
)

from watcher.logic.service import check_staff_service

router = InferringRouter()


@cbv(router)
class ShiftRoute(BaseRoute):
    model = Shift
    current: Optional[bool] = None
    joined_load = ('staff', 'schedule', 'replacement_for', 'slot', )
    joined_load_list = ('staff', 'schedule', 'replacement_for', 'slot', )

    @router.get('/all_fields')
    def all_fields(self) -> CursorPaginationResponse[ShiftRichListSchema]:
        self.joined_load_list = self.joined_load_list + (
            'schedule.service',
            'schedule.schedules_group',
        )
        return self.list_objects()

    @router.get('/for_abc')
    def get_shifts_for_abc(self) -> CursorPaginationResponse[ShiftABCListSchema]:
        load_params = ['schedule_id', 'start', 'end', 'is_primary', 'approved', 'staff_id', ]
        query = self.session.query(Shift).options(
            joinedload('staff').load_only('staff_id', 'login', 'uid', ),
            joinedload('schedule').load_only('name', 'slug', 'service_id', ),
            load_only(*load_params),
        )
        filtered_query = self.filter_objects(query=query)
        paginated_query = self.paginate_objects(query=filtered_query)
        return paginated_query

    @router.get('/{shift_id}')
    def retrieve(self, shift_id: int) -> ShiftListSchema:
        return self.get_object(object_id=shift_id)

    @router.get('/{shift_id}/subshifts')
    def retrieve_subshifts(self, shift_id: int) -> list[ShiftListSchema]:
        return _add_joined_load(
            query=self.session.query(Shift).filter(Shift.replacement_for_id == shift_id),
            joined_load=self.joined_load_list,
        ).all()

    @router.get('/{shift_id}/ratings', response_model=list[ShiftRatingSchema])
    def staff_ratings(self, shift_id: int):
        obj = self.get_object(object_id=shift_id)
        if obj.slot.points_per_hour == 0:
            raise ShiftRatingDisabled()

        ratings_sorted = sorted([
            (rating, login)
            for login, rating in obj.predicted_ratings.items()
        ])
        staff_by_login = {
            obj.login: obj
            for obj in self.session.query(Staff).filter(
                Staff.login.in_(obj.predicted_ratings.keys())
            )
        }

        return [{
            'staff': staff_by_login[login],
            'rating': rating,
        } for rating, login in ratings_sorted]

    @staticmethod
    def _get_operator(field_map: Dict[str, str]) -> Tuple[Optional[Callable], Optional[str]]:
        if len(field_map) == 0:
            return None, None

        filter_name, date = list(field_map.items())[0]
        operator = OPERATOR_MAP[filter_name]
        return operator, date

    def filter_objects(self, query: Query) -> Query:
        """
        Фильтруем смены по переданным фильтрам
        """
        if self.current:
            now_datetime = now()
            query = query.filter(
                now_datetime >= Shift.start,
                now_datetime < Shift.end,
                Shift.staff_id.isnot(None)
            )

        query = super().filter_objects(query=query)
        query = query.filter(~Shift.sub_shifts.any())
        return query

    @router.get('/')
    def list(self) -> CursorPaginationResponse[ShiftListSchema]:
        return self.list_objects()

    @router.patch('/{shift_id}')
    def patch(self, shift_id: int, shift: ShiftPatchSchema) -> ShiftListSchema:
        db_obj = self.get_object(object_id=shift_id)
        initial_staff_id = db_obj.staff_id
        self._validate_transition_data(db_obj, shift)
        self._check_patch_permissions(db_obj, shift)
        db_obj = patch_shift(
            db=self.session, db_obj=db_obj, shift=shift,
            author_id=self.current_user.id, recalculate_rating=True,
        )
        self.session.commit()
        if shift.staff_id and shift.staff_id != initial_staff_id:
            # после перестановки дежурного - изменились рейтинги, нужно перераспределить людей
            start_people_allocation.delay(
                schedules_group_id=db_obj.schedule.schedules_group_id,
                start_date=db_obj.end,
                push_staff=True,
            )
        return db_obj

    @router.post('/upload', status_code=status.HTTP_204_NO_CONTENT)
    def upload(self, shifts: ShiftsUploadSchema):
        schedule = get_object_by_model_or_404(
            db=self.session, model=Schedule,
            object_id=shifts.schedule_id,
        )
        self._check_upload_permissions(schedule=schedule)
        sorted_shifts = sorted(shifts.shifts, key=lambda x: x.start)
        existing_shifts = query_shifts_by_schedule_in_interval(
            db=self.session,
            schedule_id=schedule.id,
            start=sorted_shifts[0].start,
            end=sorted_shifts[-1].end
        ).all()
        staff_logins_ids = self._get_staff_ids_to_upload(shifts=sorted_shifts)
        slots_by_ids = self._get_slots_by_shifts(shifts=sorted_shifts)
        self._validate_upload_data(
            shifts=sorted_shifts,
            schedule=schedule,
            existing_shifts=existing_shifts,
            replace=shifts.replace,
            staff_ids=staff_logins_ids.values(),
            slots=slots_by_ids.values(),
        )
        prev_shift = next_shift = None
        if existing_shifts:
            prev_shift = existing_shifts[0].prev
            next_shift = existing_shifts[-1].next
            self._remove_existing_shifts(shifts=existing_shifts)

        self._create_shifts_to_upload(
            shifts=sorted_shifts, schedule_id=schedule.id,
            prev_shift=prev_shift, next_shift=next_shift,
            staff_logins_ids=staff_logins_ids, slots_by_ids=slots_by_ids
        )

    def _validate_transition_data(self, db_obj: Shift, shift: ShiftPatchSchema):
        if shift.staff_id and (shift.empty or shift.empty is None and db_obj.empty):
            raise BadRequest(message={
                'ru': 'В пустых сменах не должно быть дежурных',
                'en': 'There should be no attendants in the empty shifts'
            })

    def _check_patch_permissions(self, db_obj: Shift, shift: ShiftPatchSchema) -> None:
        if is_superuser(staff=self.current_user):
            return
        changed_fields = self.get_changed_fields(db_obj, shift)
        if not changed_fields:
            return

        if changed_fields == ['approved'] and db_obj.staff_id == self.current_user.id:
            if shift.approved or db_obj.start - now() > db_obj.schedule.pin_shifts:
                return

        user_in_service = is_user_in_service(
            db=self.session,
            staff_id=self.current_user.id,
            service_id=db_obj.schedule.service_id,
        )
        if changed_fields == ['staff_id'] and shift.staff_id == self.current_user.id and user_in_service:
            return

        if not is_user_responsible_for_service_or_schedule(
            db=self.session,
            schedule=db_obj.schedule,
            staff=self.current_user,
        ):
            raise PermissionDenied(message={
                'ru': 'Нет разрешения на изменение смены',
                'en': 'No permission to change shift',
            })

    def _check_upload_permissions(self, schedule: Schedule):
        if not is_superuser(staff=self.current_user):
            if not is_user_responsible_for_service_or_schedule(
                db=self.session,
                schedule=schedule,
                staff=self.current_user,
            ):
                raise PermissionDenied(message={
                    'ru': 'Нет разрешения на загрузку смен',
                    'en': 'No permission to upload shifts',
                })

    def _get_staff_ids_to_upload(self, shifts: List[ShiftUploadSchema]) -> Dict[str, int]:
        input_staff_logins = set()
        staff_logins_ids = {}
        for shift in shifts:
            if shift.staff_login:
                input_staff_logins.add(shift.staff_login)
        if input_staff_logins:
            staffs = self.session.query(Staff).filter(Staff.login.in_(input_staff_logins)).all()
            for staff in staffs:
                staff_logins_ids[staff.login] = staff.id
            if len(staffs) < len(input_staff_logins):
                raise BadRequest(message={
                    'ru': 'Присутствуют неизвестные логины',
                    'en': 'Some logins are not found',
                })
        return staff_logins_ids

    def _get_slots_by_shifts(self, shifts: List[ShiftUploadSchema]) -> Dict[int, Slot]:
        slot_ids = {shift.slot_id for shift in shifts if shift.slot_id}
        slots_by_ids = {}
        if slot_ids:
            slots = self.session.query(Slot).filter(Slot.id.in_(slot_ids)).options(joinedload(Slot.interval)).all()
            slots_by_ids = {slot.id: slot for slot in slots}
        return slots_by_ids

    def _validate_upload_data(
        self, shifts: List[ShiftUploadSchema], schedule: Schedule,
        existing_shifts: Optional[List[Shift]], replace: bool, staff_ids: Iterable[int],
        slots: Iterable[Slot]
    ) -> None:
        """
        Проверяем:
        - стаффы находятся в сервисе
        - если смены загружают не в прошлое - у schedule для которого хотят загрузить - recalculate: False
        - если в этот промежуток уже были другие смены - то нужно чтобы был replace=True в запросе
        - если же replace=True - заменяем их: старые удаляем (и завершаем активные если такие есть))
        """
        if staff_ids:
            check_staff_service(db=self.session, staff_ids=staff_ids, expected=schedule.service_id)

        if localize(shifts[-1].end) >= now() and schedule.recalculate:
            raise BadRequest(message={
                'ru': 'Нельзя загрузить будущие смены, т.к. график участвует в пересчётах смен и перераспределении '
                      'людей',
                'en': 'It is impossible to load future shifts, because the schedule is involved in '
                      'the recalculation of shifts and the redistribution of people',
            })
        for shift in shifts:
            if shift.staff_login and shift.empty:
                raise BadRequest(message={
                    'ru': 'У пустой смены не может быть дежурного',
                    'en': 'An empty shift cant have a staff'
                })
            if not (shift.staff_login or shift.empty):
                raise BadRequest(message={
                    'ru': 'У непустой смены должен быть дежурный',
                    'en': 'An non-empty shift must have a staff'
                })

        for slot in slots:
            if slot.interval.schedule_id != schedule.id:
                raise BadRequest(message={
                    'ru': 'Slot не принадлежит расписанию',
                    'en': 'The slot does not belong to the schedule',
                })

        if existing_shifts:
            if not replace:
                raise BadRequest(message={
                    'ru': 'Для перезаписи существующих смен установите параметр replace=True',
                    'en': 'To overwrite existing shifts, set the replace=True parameter'
                })

    def _remove_existing_shifts(
        self, shifts: List[Shift]
    ) -> None:
        """
        :return: prev_shift: Предыдущая смена или None
        :return: next_shift: Следующая смена или None
        """
        to_delete = []
        active_shifts = []

        for shift in shifts:
            to_delete.append(shift.id)
            to_delete.extend(obj.id for obj in shift.sub_shifts)
            if shift.status == enums.ShiftStatus.active:
                active_shifts.append(shift.id)
            for sub_shift in shift.sub_shifts:
                if sub_shift.status == enums.ShiftStatus.active:
                    active_shifts.append(sub_shift.id)

        for shift_id in active_shifts:
            finish_shift(session=self.session, shift_id=shift_id, _lock=False)

        self.session.query(Shift).filter(
            Shift.next_id.in_(to_delete)
        ).update(
            {Shift.next_id: None},
            synchronize_session=False,
        )

        self.session.query(Shift).filter(
            Shift.id.in_(to_delete)
        ).delete(synchronize_session=False)

    def _create_shifts_to_upload(
        self, shifts: List[ShiftUploadSchema], schedule_id: int,
        prev_shift: Optional[Shift], next_shift: Optional[Shift],
        staff_logins_ids: Dict[str, int], slots_by_ids: Dict[int, Slot]
    ):
        new_shifts = []
        for shift in shifts:
            data = shift.dict()
            data['staff_id'] = staff_logins_ids.get(data.pop('staff_login'), None)
            new_shifts.append(Shift(**data, schedule_id=schedule_id))
        for i, shift in enumerate(new_shifts):
            if shift.slot_id and shift.is_primary is None:
                shift.is_primary = slots_by_ids[shift.slot_id].is_primary
            if i + 1 < len(new_shifts):
                shift.next = new_shifts[i + 1]
        if prev_shift:
            new_shifts[0].prev = prev_shift
        if next_shift:
            new_shifts[-1].next = next_shift
        self.session.add_all(new_shifts)
        self.session.commit()
        return new_shifts
