from itertools import groupby
from typing import List, Tuple

from django.db.models.query import QuerySet

from staff.departments.models import InstanceClass
from staff.departments.tree_lib.tree_builder import TreeBuilder
from staff.departments.tree_lib.abstract_entity_info import AbstractEntityInfo
from staff.lib.models.mptt import filter_by_heirarchy


class Pager(object):
    MAX_ROWS_COUNT = 100

    def __init__(self, info_provider: AbstractEntityInfo, url=None, skip=0):
        self.info_provider = info_provider
        self.url = url
        self.skip = skip
        self._department = None

    @property
    def department(self) -> dict:
        if self.url and self._department is None:
            self._department = self.info_provider.departments_query().get(url=self.url)
        return self._department

    @property
    def entity_dep_field(self) -> str:
        if not self.url:
            return 'department'

        if self.department.get('instance_class') == InstanceClass.GEOGRAPHY.value:
            return 'geography'

        if self.department.get('instance_class') == InstanceClass.VALUESTREAM.value:
            return self.info_provider.value_stream_field_name

        return 'department'

    def _filter_by_url(self, qs):
        if not self.url:
            return qs
        qs = filter_by_heirarchy(
            qs,
            mptt_objects=[self.department],
            by_children=True,
            include_self=True,
            filter_prefix=self.entity_dep_field + '__',
        )
        return qs

    def _entities_quantity_qs(self):
        entity_dep_field_name = f'{self.entity_dep_field}_id'
        qs = self.info_provider.entities_quantity_by_department_query(entity_dep_field_name)
        entities_qs = (
            qs.values_list(self.entity_dep_field + '_id', 'qty')
        )

        entities_qs_filtered = self._filter_by_url(entities_qs)

        return entities_qs_filtered[self.skip:self.skip + self.MAX_ROWS_COUNT].iterator()

    def _get_departments_page(self) -> Tuple[List[int], bool]:
        entities_qs = self._entities_quantity_qs()
        count = 0
        departments_ids = []
        has_next = False

        for dep_id, qty in entities_qs:
            count += qty
            departments_ids.append(dep_id)

            if count >= self.MAX_ROWS_COUNT:
                try:
                    next(entities_qs)
                    has_next = True
                except StopIteration:
                    pass
                break

        return departments_ids, has_next

    def _get_departments(self, departments_ids: List[int], entities_map: dict) -> List:
        departments = self.info_provider.departments_query().filter(id__in=departments_ids)

        return TreeBuilder(self.info_provider).get_for_info_list(departments, entities_map)

    def count(self) -> int:
        entities_qs = self.info_provider.total_entities_count_query()
        entities_count = self._filter_by_url(entities_qs).count()
        return entities_count

    def get_aggregate_qs(self, fields: List[str] = None) -> QuerySet:
        entities_qs = self.info_provider.total_entities_aggregate_query(fields)
        entities_qs = self._filter_by_url(entities_qs)
        return entities_qs

    def _get_departments_entities(self, departments_ids: List[int]) -> dict:
        order_fields = [self.entity_dep_field] + self.info_provider.order_entities_by_fields()

        entities_qs = (
            self.info_provider.full_entities_query()
            .filter(**{self.entity_dep_field + '_id__in': departments_ids})
            .order_by(*order_fields)
        )

        return {
            dep_id: list(entities)
            for dep_id, entities in
            groupby(self.info_provider.fill_list(entities_qs), lambda p: p[self.entity_dep_field + '_id'])
        }

    def get_grouped_entities(self) -> Tuple[list, int]:
        departments_ids, has_next = self._get_departments_page()
        entities_map = self._get_departments_entities(departments_ids)
        departments = self._get_departments(departments_ids, entities_map)
        continuation_token = self.skip + len(departments) if has_next else None

        return departments, continuation_token
