import logging
from datetime import datetime
from typing import Any, Dict, Iterable, Optional, Tuple

from django.conf import settings
from django.db import models

from staff.budget_position.models import (
    BudgetPosition,
    BudgetPositionAssignment,
    BudgetPositionAssignmentStatus,
    ReplacementType,
    Reward,
)
from staff.departments.models import HRProduct, Geography, Department, RelevanceDate
from staff.lib.db import atomic
from staff.lib.sync_tools.rollupper import Rollupper
from staff.person.models import Staff

from staff.oebs.controllers.mappers import DataMapper
from staff.oebs.models import OebsHeadcountPosition


logger = logging.getLogger(__name__)


class HeadcountPositionsMapper(DataMapper):
    _replacement_type_mapping: Dict[Tuple[bool, bool], str] = {
        (False, True): ReplacementType.HAS_REPLACEMENT,
        (True, False): ReplacementType.BUSY,
        (True, True): ReplacementType.HAS_REPLACEMENT_AND_BUSY,
    }

    def __init__(self, object_dict: Dict):
        super().__init__(object_dict)
        self.mapping = (
            ('code', 'change_registry', lambda _: None),
            ('status', 'status', self._get_status),
            ('state2', 'creates_new_position', self._get_creates_new_position),

            ('main_assignment', 'main_assignment'),
            ('assignment_id', 'oebs_assignment_id'),

            ('bonus_id', 'bonus_id'),
            ('review_id', 'review_id'),

            ('code', 'intranet_status', lambda _: 1),
            ('name', 'name'),
        )

    def __iter__(self):
        for data in super().__iter__():
            yield data

        yield 'change_reason', self._get_change_reason()
        yield 'replacement_type', self._get_replacement_type()

    def _get_change_reason(self) -> str:
        relevance_date = self.object_dict['relevance_date']
        modified_at = self.object_dict['modified_at']

        return f'OEBS-{relevance_date}@{modified_at}'

    @staticmethod
    def _get_status(status: str) -> str:
        return BudgetPositionAssignmentStatus[status].value

    @staticmethod
    def _get_creates_new_position(state2: str) -> bool:
        return state2 == 'NEW'

    def _get_replacement_type(self) -> str:
        has_previous_assignment = bool(self.object_dict['prev_assignment_index'])
        has_next_assignment = bool(self.object_dict['next_assignment_index'])

        return self._replacement_type_mapping.get(
            (has_previous_assignment, has_next_assignment),
            ReplacementType.WITHOUT_REPLACEMENT,
        ).value


class OebsHeadcountPositionException(Exception):
    pass


class OebsHeadcountPositionRollupper(Rollupper):
    model = OebsHeadcountPosition
    queryset = OebsHeadcountPosition.objects.order_by('code', 'assignment_index')
    data_mapper_class = HeadcountPositionsMapper
    link_field_name = 'dis_budget_position_assignment'
    key_field_name = 'id'
    create_absent = True

    # TODO: vacancy
    rollup_rel_fields = (
        ('code', {'budget_position_id': '_get_budget_position_id'}),
        ('current_login', {'person_id': '_get_person_id'}),
        ('department_id', {'department_id': '_get_department_id'}),
        ('hr_product_id', {'value_stream_id': '_get_value_stream_id'}),
        ('geography_oebs_code', {'geography_id': '_get_geography_id'}),
        ('reward_id', {'reward_id': '_get_reward_id'}),
    )

    _non_oebs_data_fields = {'change_reason', 'change_registry'}
    _budget_position_mapping: Optional[Dict[int, int]] = None
    _person_mapping: Optional[Dict[str, int]] = None
    _department_mapping: Optional[Dict[int, int]] = None
    _value_stream_mapping: Optional[Dict[int, int]] = None
    _geography_mapping: Optional[Dict[str, int]] = None
    _reward_mapping: Optional[Dict[int, int]] = None

    def create_dis_instance(self, oebs_instance: OebsHeadcountPosition) -> BudgetPositionAssignment:
        if oebs_instance.has_errors:
            logger.info('Zombie assignment %s', oebs_instance.id)
            raise OebsHeadcountPositionException(f'Zombie assignment {oebs_instance.id}')

        previous_assignment_query = BudgetPositionAssignment.objects.filter(
            budget_position__code=oebs_instance.code,
            next_assignment=None,
        )

        return BudgetPositionAssignment(previous_assignment=previous_assignment_query.first())

    def run_rollup(self, object_id=None, dry_run=False):
        logger.info('Rolling up BudgetPosition headcount...')
        if not dry_run:
            self._ensure_budget_positions_exists()

        logger.info('Linking BudgetPositionAssignment`s created by registry...')
        if not dry_run:
            self._unlink_occupied_budget_position_assignments_without_oebs_assignment_id()
            self._link_new_budget_position_assignments()

        logger.info('Rolling up BudgetPositionAssignment...')
        super().run_rollup(object_id, dry_run)

        logger.info('Deactivating BudgetPositionAssignment...')
        if not dry_run:
            change_reason = f'Not in OEBS @{datetime.now()}'
            removed_ids = list(
                BudgetPositionAssignment.objects
                .exclude(intranet_status=0)
                .filter(oebs_headcount_position=None)
                .values_list('pk', flat=True)
            )

            with atomic():
                for assignment in BudgetPositionAssignment.objects.select_for_update().filter(id__in=removed_ids):
                    assignment.intranet_status = 0
                    assignment.change_reason = change_reason
                    assignment.save(update_fields=['intranet_status', 'version', 'change_reason', 'modified_at'])

            relevance_date = OebsHeadcountPosition.objects.all().values_list('relevance_date', flat=True).first()
            RelevanceDate.objects.update_or_create(
                model_name=BudgetPositionAssignment.__name__,
                defaults={
                    'relevance_date': relevance_date,
                    'updated_entities': self.logger.rolled,
                    'failed_entities': self.logger.errors,
                    'skipped_entities': self.logger.skipped,
                },
            )

    def generate_update_values(
        self,
        oebs_instance: models.Model,
        dis_instance: models.Model,
    ) -> Iterable[Tuple[str, Any]]:
        updated_values = list(super().generate_update_values(oebs_instance, dis_instance))

        if any(field_name not in self._non_oebs_data_fields for field_name, _ in updated_values):
            return updated_values

        return []

    def generic_rollup(
        self,
        oebs_instance: OebsHeadcountPosition,
        dis_instance: BudgetPositionAssignment,
        field_name: str,
        dry_run: bool,
        **kwargs,
    ) -> bool:
        actually_updated = False
        oebs_value = getattr(oebs_instance, field_name)

        for destination_field, mapping_function in kwargs.items():
            result_value = getattr(self, mapping_function)(oebs_value, oebs_instance)
            current_value = getattr(dis_instance, destination_field)

            if result_value != current_value:
                actually_updated = True
                setattr(dis_instance, destination_field, result_value)

        return actually_updated

    @staticmethod
    def _ensure_budget_positions_exists() -> None:
        oebs_headcount = OebsHeadcountPosition.objects.filter(next_assignment_index=None)
        new_budget_position_data = dict(oebs_headcount.values_list('code', 'headcount'))
        current_budget_position_data = dict(BudgetPosition.objects.all().values_list('code', 'headcount'))

        new_budget_positions = set(new_budget_position_data) - set(current_budget_position_data)
        logger.info('%s new BudgetPosition`s', len(new_budget_positions))

        if new_budget_positions:
            # TODO: Verify update signal being sent
            positions_to_create = [
                BudgetPosition(code=code, headcount=new_budget_position_data[code])
                for code in new_budget_positions
            ]
            BudgetPosition.objects.bulk_create(positions_to_create)

        current_budget_positions = set(new_budget_position_data) & set(current_budget_position_data)
        positions_to_update = {
            code: new_budget_position_data[code]
            for code in current_budget_positions
            if new_budget_position_data[code] != current_budget_position_data[code]
        }
        logger.info('%s BudgetPosition`s has new headcount', len(positions_to_update))

        for code, headcount in positions_to_update.items():
            budget_position = BudgetPosition.objects.get(code=code)
            budget_position.headcount = headcount
            budget_position.save(update_fields=['headcount'])

    def _get_budget_position_id(self, position_code: int, _: OebsHeadcountPosition) -> int:
        if self._budget_position_mapping is None:
            self._budget_position_mapping = dict(BudgetPosition.objects.values_list('code', 'id'))

        return self._budget_position_mapping[position_code]

    def _get_person_id(self, login: Optional[str], _: OebsHeadcountPosition) -> Optional[int]:
        if login is None:
            return None

        if self._person_mapping is None:
            self._person_mapping = dict(Staff.objects.values_list('login', 'id'))

        person_id = self._person_mapping.get(login, None)

        if person_id is None:
            logger.info('Person %s not found. Probably not yet hired.', login)

        return person_id

    def _get_value_stream_id(self, hr_product_id: int, _: OebsHeadcountPosition) -> Optional[int]:
        if self._value_stream_mapping is None:
            self._value_stream_mapping = dict(HRProduct.objects.values_list('id', 'value_stream_id'))

        value_stream_id = self._value_stream_mapping.get(hr_product_id, None)

        if value_stream_id is None:
            logger.error('Product %s not found', hr_product_id)

        return value_stream_id

    def _get_geography_id(self, oebs_geography_code: str, _: OebsHeadcountPosition) -> Optional[int]:
        if self._geography_mapping is None:
            self._geography_mapping = dict(Geography.objects.values_list('oebs_code', 'department_instance_id'))

        geography_id = self._geography_mapping.get(oebs_geography_code, None)

        if geography_id is None:
            logger.error('Geography %s not found', oebs_geography_code)

        return geography_id

    def _get_department_id(self, oebs_department_id: int, _: OebsHeadcountPosition) -> Optional[int]:
        if self._department_mapping is None:
            self._department_mapping = dict(Department.objects.values_list('id', 'id'))

        department_id = self._department_mapping.get(oebs_department_id, None)

        if department_id is None:
            logger.info('Department %s not found', oebs_department_id)
            department_id = settings.NON_EXISTING_OEBS_DEPARTMENT_ID

        return department_id

    def _get_reward_id(self, oebs_reward_id: int, _: OebsHeadcountPosition) -> int:
        if self._reward_mapping is None:
            self._reward_mapping = dict(Reward.objects.values_list('oebs_instance__pk', 'id'))

        return self._reward_mapping[oebs_reward_id]

    @staticmethod
    def _unlink_occupied_budget_position_assignments_without_oebs_assignment_id():
        invalid_assignments = OebsHeadcountPosition.objects.filter(
            assignment_id=None,
            next_assignment_index__isnull=False,
        )

        if invalid_assignments.exists():
            logger.error(
                'Budget position assignment without OEBS assignmentID can only be for the last assignment on position.',
            )
            raise OebsHeadcountPositionException('Invalid OebsHeadcountPosition found')

        missing_links = (
            OebsHeadcountPosition.objects
            .filter(dis_budget_position_assignment=None)
            .values_list('code', flat=True)
        )

        affected_assignments = OebsHeadcountPosition.objects.filter(assignment_id=None, code__in=missing_links)
        links_dropped = affected_assignments.update(dis_budget_position_assignment=None)
        logger.info('%s links dropped', links_dropped)

    def _link_new_budget_position_assignments(self):
        not_linked_query = OebsHeadcountPosition.objects.filter(dis_budget_position_assignment=None)
        logger.info('%s new budget position assignments', not_linked_query.count())

        last_linked = dict(
            OebsHeadcountPosition.objects
            .exclude(dis_budget_position_assignment=None)
            .values('code')
            .annotate(max=models.Max('assignment_index'))
            .values_list('code', 'max')
        )

        dis_assignment_links = dict(BudgetPositionAssignment.objects.all().values_list('pk', 'next_assignment__id'))
        last_dis_assignments_query = BudgetPositionAssignment.objects.filter(next_assignment=None)
        last_dis_assignments = {
            assignment['budget_position__code']: assignment
            for assignment in last_dis_assignments_query.values('budget_position__code', 'pk', 'person__login')
        }

        current_code = None
        current_dis_assignment = None

        for not_linked in not_linked_query.order_by('code', 'assignment_index'):
            if not_linked.code in last_linked and last_linked[not_linked.code] > not_linked.assignment_index:
                logger.info('Zombie assignment detected for %s %s', not_linked.code, not_linked.current_login)
                not_linked.has_errors = True
                not_linked.save(update_fields=['has_errors'])
            elif current_code != not_linked.code:
                current_code = not_linked.code
                current_dis_assignment = self._set_dis_assignment_id_for_first_unlinked(
                    not_linked,
                    dis_assignment_links,
                    last_dis_assignments,
                )
            else:
                current_dis_assignment = self._set_dis_assignment_id_for_next_in_chain(
                    not_linked,
                    current_dis_assignment,
                    dis_assignment_links,
                )

    @staticmethod
    def _set_dis_assignment_id_for_first_unlinked(
        not_linked: OebsHeadcountPosition,
        dis_assignment_links: Dict[int, Optional[int]],
        last_dis_assignments: Dict[int, Dict[str, Any]],
    ) -> Optional[int]:
        current_dis_assignment = None

        if not_linked.prev_assignment_index is None:
            last_dis_assignment = last_dis_assignments.get(not_linked.code)
            if last_dis_assignment is None:
                logger.info('Brand new assignment for %s %s', not_linked.code, not_linked.current_login)
            elif last_dis_assignment['person__login'] is None:
                logger.info(
                    'Existing assignment for %s: %s %s',
                    not_linked.code,
                    last_dis_assignment['pk'],
                    not_linked.current_login,
                )
                not_linked.dis_budget_position_assignment_id = last_dis_assignment['pk']
                not_linked.save(update_fields=['dis_budget_position_assignment_id'])
            else:
                logger.info(
                    'New assignment for %s after %s %s',
                    not_linked.code,
                    last_dis_assignment['pk'],
                    not_linked.current_login,
                )
        else:
            prev_assignment = OebsHeadcountPosition.objects.get(
                code=not_linked.code,
                assignment_index=not_linked.prev_assignment_index,
            )
            current_dis_assignment = dis_assignment_links.get(prev_assignment.dis_budget_position_assignment_id)
            logger.info(
                'Existing assignment for %s: %s %s',
                not_linked.code,
                current_dis_assignment,
                not_linked.current_login,
            )
            not_linked.dis_budget_position_assignment_id = current_dis_assignment
            not_linked.save(update_fields=['dis_budget_position_assignment_id'])

        return current_dis_assignment

    @staticmethod
    def _set_dis_assignment_id_for_next_in_chain(
        not_linked: OebsHeadcountPosition,
        current_dis_assignment: Optional[int],
        dis_assignment_links: Dict[int, Optional[int]],
    ) -> Optional[int]:
        if current_dis_assignment is None:
            logger.info('New assignment for %s %s far ahead of dis', not_linked.code, not_linked.current_login)
        else:
            current_dis_assignment = dis_assignment_links.get(current_dis_assignment)
            logger.info(
                'Existing assignment in chain for %s: %s %s',
                not_linked.code,
                current_dis_assignment,
                not_linked.current_login,
            )
            not_linked.dis_budget_position_assignment_id = current_dis_assignment
            not_linked.save(update_fields=['dis_budget_position_assignment_id'])

        return current_dis_assignment
