from collections import defaultdict
from itertools import chain
from typing import AsyncIterable, ClassVar, Dict, List, Set, Tuple

from mail.beagle.beagle.core.entities.enums import OrganizationType
from mail.beagle.beagle.core.entities.external_organization import BaseExternalOrganization
from mail.beagle.beagle.core.entities.unit import ExternalKeyType, Unit
from mail.beagle.beagle.core.entities.user import User
from mail.beagle.beagle.interactions.directory import DirectoryClient
from mail.beagle.beagle.interactions.directory.entities import (
    DirectoryDepartment, DirectoryGroup, DirectoryObjectType, DirectoryUser
)
from mail.beagle.beagle.utils.graph import expand_acyclic_graph


class DirectoryOrganization(BaseExternalOrganization):
    TYPE = OrganizationType.DIRECTORY
    DEPARTMENT_TYPE: ClassVar[str] = 'department'
    GROUP_TYPE: ClassVar[str] = 'group'

    def __init__(self, org_id: int, client: DirectoryClient):
        self.org_id: int = org_id
        self.client: DirectoryClient = client

        self._groups_fetched: bool = False
        self._group_units: List[Unit]
        self._departments_fetched: bool = False
        self._department_units: List[Unit]
        self._graph_fetched: bool = False
        self._graph: Dict[ExternalKeyType, Set[ExternalKeyType]] = defaultdict(set)

        self._users_fetched: bool = False
        self._users: List[User]
        self._group_users: Dict[ExternalKeyType, Set[int]] = defaultdict(set)
        self._department_users: Dict[ExternalKeyType, Set[int]] = defaultdict(set)

    @classmethod
    def group_key(cls, group_id: int) -> ExternalKeyType:
        return Unit.get_external_key(cls.GROUP_TYPE, str(group_id))

    @classmethod
    def department_key(cls, department_id: int) -> ExternalKeyType:
        return Unit.get_external_key(cls.DEPARTMENT_TYPE, str(department_id))

    def _unit_from_department(self, department: DirectoryDepartment) -> Unit:
        return Unit(
            org_id=self.org_id,
            external_id=str(department.department_id),
            external_type=self.DEPARTMENT_TYPE,
            name=department.name,
            uid=department.uid,
            username=department.label,
        )

    def _unit_from_group(self, group: DirectoryGroup) -> Unit:
        return Unit(
            org_id=self.org_id,
            external_id=str(group.group_id),
            external_type=self.GROUP_TYPE,
            name=group.name,
            uid=group.uid,
            username=group.label,
        )

    def _user_from_directory_user(self, directory_user: DirectoryUser) -> User:
        return User(
            org_id=self.org_id,
            uid=directory_user.user_id,
            username=directory_user.local_part,
            first_name=directory_user.first_name,
            last_name=directory_user.last_name,
        )

    async def _fetch_departments(self) -> None:
        if self._departments_fetched:
            return
        self._department_units = []
        async for department in self.client.get_departments(self.org_id):
            unit = self._unit_from_department(department)
            self._department_units.append(unit)
            if department.parent_id is not None:
                parent_key = self.department_key(department.parent_id)
                self._graph[parent_key].add(unit.external_key)
        self._departments_fetched = True

    async def _fetch_groups(self) -> None:
        if self._groups_fetched:
            return
        self._group_units = []
        async for group in self.client.get_groups(self.org_id):
            unit = self._unit_from_group(group)
            self._group_units.append(unit)
            async for member_type, member_id in self.client.get_group_members(self.org_id, group.group_id):
                if member_type == DirectoryObjectType.DEPARTMENT:
                    member_key = self.department_key(member_id)
                elif member_type == DirectoryObjectType.GROUP:
                    member_key = self.group_key(member_id)
                else:
                    continue
                self._graph[unit.external_key].add(member_key)
        self._groups_fetched = True

    async def _fetch_graph(self) -> None:
        if self._graph_fetched:
            return
        await self._fetch_departments()
        await self._fetch_groups()
        self._graph = expand_acyclic_graph(self._graph)
        self._graph_fetched = True

    async def _fetch_users(self):
        if self._users_fetched:
            return
        self._users = []
        async for directory_user in self.client.get_users(self.org_id):
            user = self._user_from_directory_user(directory_user)
            self._users.append(user)
            for department_id in directory_user.departments:
                self._department_users[self.department_key(department_id)].add(user.uid)
            for group_id in directory_user.groups:
                self._group_users[self.group_key(group_id)].add(user.uid)
        self._users_fetched = True

    async def get_revision(self) -> int:
        return await self.client.get_organization_revision(self.org_id)

    async def get_units(self) -> AsyncIterable[Unit]:
        await self._fetch_departments()
        for department_unit in self._department_units:
            yield department_unit
        await self._fetch_groups()
        for group_unit in self._group_units:
            yield group_unit

    async def get_unit_units(self) -> AsyncIterable[Tuple[ExternalKeyType, Set[ExternalKeyType]]]:
        await self._fetch_graph()
        for parent, children in self._graph.items():
            yield parent, children.copy()

    async def get_users(self) -> AsyncIterable[User]:
        await self._fetch_users()
        for user in self._users:
            yield user

    async def get_unit_users(self) -> AsyncIterable[Tuple[ExternalKeyType, Set[int]]]:
        await self._fetch_users()
        for external_key, uids in chain(self._department_users.items(), self._group_users.items()):
            yield external_key, uids.copy()
