from typing import Dict, List

from django.db.models import Q

from staff.departments.models import Department, InstanceClass

from staff.headcounts.headcounts_summary.query_builder import (
    Aliases,
    HierarchyValues,
    QueryParams,
    RelatedEntity,
    RelatedEntityChains,
)


class RelatedEntityChainsBuilder:
    def __init__(self, query_params: QueryParams, aliases: Aliases) -> None:
        self._query_params = query_params
        self._aliases = aliases

    def create_chains(self) -> RelatedEntityChains:
        result = RelatedEntityChains()

        for grouping in self._query_params.groupings:
            roots = self._gather_roots(grouping)
            departments = self._get_parents(roots)
            chains = self._build_chains(departments)
            result.set_chains(grouping, chains)

        return result

    def _gather_roots(self, grouping: RelatedEntity) -> List[HierarchyValues]:
        _, user_filters = self._query_params.filters_pair_for_grouping(grouping)

        if user_filters:
            return self._take_roots_from_user_filters(user_filters)
        else:
            return self._take_roots_from_db(grouping)

    def _take_roots_from_user_filters(self, user_filter: List[HierarchyValues]) -> List[HierarchyValues]:
        sorted_filters = sorted(user_filter, key=lambda values: (values.tree_id, values.lft))
        result = self._merge_hierarchy_values(sorted_filters)
        return result

    def _take_roots_from_db(self, grouping: RelatedEntity) -> List[HierarchyValues]:
        related_entity_to_instance_class = {
            RelatedEntity.department: InstanceClass.DEPARTMENT.value,
            RelatedEntity.value_stream: InstanceClass.VALUESTREAM.value,
            RelatedEntity.geography: InstanceClass.GEOGRAPHY.value,
        }
        qs = (
            Department.all_types
            .filter(instance_class=related_entity_to_instance_class[grouping], parent=None)
            .values_list('tree_id', 'lft', 'rght', 'id')
        )

        return [
            HierarchyValues(tree_id, lft, rght, dep_id)
            for tree_id, lft, rght, dep_id in qs
        ]

    def _get_parents(self, roots: List[HierarchyValues]) -> Dict[int, int]:
        filters = Q()
        for root in roots:
            filters |= Q(lft__gte=root.lft, rght__lte=root.rght, tree_id=root.tree_id)

        return dict(Department.all_types.filter(filters).order_by('lft').values_list('id', 'parent_id'))

    def _build_chains(self, departments_parents: Dict[int, int]) -> Dict[int, List[int]]:
        chains: Dict[int, List[int]] = {
            department_id: [] for department_id, _ in departments_parents.items()
        }

        for department_id, department_parent_id in departments_parents.items():
            parent_chain = chains.get(department_parent_id) or []
            chains[department_id] = [*parent_chain, department_id]
        return chains

    def _merge_hierarchy_values(self, values: List[HierarchyValues]) -> List[HierarchyValues]:
        result: List[HierarchyValues] = []
        for value in values:
            if value.id and (not result or not value.is_child(result[-1])):
                result.append(value)

        return result
