from typing import Any, Dict, Set

from staff.departments.models import Department

from staff.headcounts.headcounts_summary.query_builder import query_results


class DepartmentInfoUpdater:
    _fields = (
        'id',
        'parent_id',
        'instance_class',
        'url',
        'name',
        'name_en',
    )

    def set_departments_info_into_result(self, results: query_results.Result) -> None:
        ids = self._gather_department_instances_id(results)
        departments = self._fetch_departments(ids)
        self._set_departments_in_result(departments, results)

    def _set_departments_in_result(self, departments: Dict[int, Dict[str, Any]], results: query_results.Result) -> None:
        if results.grouping_instance_info:
            if results.grouping_instance_info.id in departments:
                department_info = departments[results.grouping_instance_info.id]
                results.grouping_instance_info = query_results.GroupingInstanceInfo(
                    id=results.grouping_instance_info.id,
                    parent_id=department_info['parent_id'],
                    url=department_info['url'],
                    name=department_info['name'],
                    name_en=department_info['name_en'],
                    instance_class=department_info['instance_class'],
                    chief=None,
                )

        for nested_result in results.next_level_grouping.values():
            self._set_departments_in_result(departments, nested_result)

        for nested_result in results.children.values():
            self._set_departments_in_result(departments, nested_result)

    def _fetch_departments(self, ids: Set[int]) -> Dict[int, Dict[str, Any]]:
        qs = Department.all_types.filter(id__in=ids).values(*self._fields)
        result = {row['id']: row for row in qs}
        return result

    def _gather_department_instances_id(self, results: query_results.Result) -> Set[int]:
        result = set()
        if results.grouping_instance_info:
            result.add(results.grouping_instance_info.id)
        for nested_result in results.next_level_grouping.values():
            result |= self._gather_department_instances_id(nested_result)
        for nested_result in results.children.values():
            result |= self._gather_department_instances_id(nested_result)
        return result
