from itertools import groupby
from typing import List, Tuple

from staff.departments.tree_lib import Pager
from staff.departments.tree_lib.abstract_entity_info import AbstractEntityInfo


class LeafPager(Pager):
    def __init__(self, info_provider: AbstractEntityInfo, url=None, skip=0):
        super().__init__(info_provider, url, skip)

    def _entities_quantity_qs(self):
        entity_dep_field_name = f'{self.entity_dep_field}_id'
        qs = self.info_provider.entities_quantity_by_department_query(entity_dep_field_name)
        entities_qs = qs.values_list(self.entity_dep_field + '_id', 'qty')
        entities_qs_filtered = self._filter_by_url(entities_qs)

        return entities_qs_filtered.iterator()

    def _get_departments_page_by_leaf(self) -> Tuple[List[int], int, int]:
        entities_qs = self._entities_quantity_qs()
        count = 0
        departments_ids = []
        start_token = 0
        continuation_token = None

        for dep_id, qty in entities_qs:
            processed = count
            count += qty

            if count <= self.skip:
                continue

            departments_ids.append(dep_id)

            if self.skip > processed:
                start_token = self.skip - processed

            if count >= self.skip + self.MAX_ROWS_COUNT:
                continuation_token = self.skip + self.MAX_ROWS_COUNT
                break

        return departments_ids, start_token, continuation_token

    def _get_departments_entities_paged(self, departments_ids: List[int], start_token: int) -> dict:
        tree_order = [f'{self.entity_dep_field}__tree_id', f'{self.entity_dep_field}__lft']
        order_fields = tree_order + self.info_provider.order_entities_by_fields()

        entities_qs = (
            self.info_provider.full_entities_query()
            .filter(**{self.entity_dep_field + '_id__in': departments_ids})
            .order_by(*order_fields)
        )[start_token:start_token + self.MAX_ROWS_COUNT]

        return {
            dep_id: list(entities)
            for dep_id, entities in
            groupby(list(self.info_provider.fill_list(entities_qs)), lambda p: p[self.entity_dep_field + '_id'])
        }

    def get_grouped_entities(self) -> Tuple[list, int]:
        departments_ids, start_token, continuation_token = self._get_departments_page_by_leaf()
        entities_map = self._get_departments_entities_paged(departments_ids, start_token)
        departments = self._get_departments(departments_ids, entities_map)
        return departments, continuation_token
