from typing import Any, Dict, List

from staff.headcounts.headcounts_summary.query_builder.aliases import Aliases
from staff.headcounts.headcounts_summary.query_builder.query_params import RelatedEntity
from staff.headcounts.headcounts_summary.query_builder.query_results import (
    GroupingInstanceInfo,
    Result,
    SummaryResults,
)
from staff.headcounts.headcounts_summary.query_builder.related_entity_chains import RelatedEntityChains


class RowPath:
    def __init__(self, aliases: Aliases, groupings: List[RelatedEntity], row: Dict[str, Any]) -> None:
        self._aliases = aliases
        self._groupings = groupings
        self._row = row
        self._current_grouping_idx = 0

    def has_next_grouping(self) -> bool:
        if self._current_grouping_idx >= len(self._groupings):
            return False

        id_field = self._identity_field(self._groupings[self._current_grouping_idx])
        return self._row[id_field] is not None

    def next(self) -> None:
        self._current_grouping_idx += 1

    @property
    def current_id(self) -> int:
        assert self._current_grouping_idx > 0
        id_field = self._identity_field(self.current_grouping)
        return self.row[id_field]

    def _identity_field(self, grouping: RelatedEntity) -> str:
        return grouping.pk_field_name

    @property
    def current_grouping(self) -> RelatedEntity:
        assert self._current_grouping_idx > 0
        return self._groupings[self._current_grouping_idx - 1]

    @property
    def row(self) -> Dict[str, Any]:
        return self._row


class ResultsMapper:
    def __init__(self, groupings: List[RelatedEntity], aliases: Aliases, paths: RelatedEntityChains) -> None:
        self._groupings = groupings
        self._aliases = aliases
        self._chains = paths
        self.root = Result.default()

    def _create_node(self, node_id: int, row_path: RowPath) -> Result:
        grouping_instance_info = GroupingInstanceInfo.default()
        grouping_instance_info.id = node_id
        return Result.default(
            result_id=node_id,
            grouping=row_path.current_grouping.value,
            grouping_instance_info=grouping_instance_info,
        )

    def map_to_result(self, rows: List[Dict[str, Any]]):
        for row in rows:
            row_path = RowPath(self._aliases, self._groupings, row)

            current_node = self.root
            while row_path.has_next_grouping():
                row_path.next()
                current_node = self._go_through_grouping(current_node, row_path)

            summary_results = SummaryResults.from_dict(row)
            current_node.summary = summary_results
            current_node.summary_without_children = summary_results

    def _go_through_grouping(self, current_node: Result, row_path: RowPath) -> Result:
        result = current_node
        where_to_check = result.next_level_grouping

        for entity_id in self._chains.get_chain(row_path.current_grouping, row_path.current_id):
            if entity_id not in where_to_check:
                where_to_check[entity_id] = self._create_node(entity_id, row_path)

            result = where_to_check[entity_id]
            where_to_check = result.children

        return result
