from itertools import groupby
from typing import Awaitable, Callable, Dict, Iterable, Set, TypeVar

from sendr_utils import alist

from mail.beagle.beagle.core.actions.base import BaseDBAction
from mail.beagle.beagle.core.actions.unit.get_dependent_uids import GetDependentUIDsUnitAction
from mail.beagle.beagle.core.entities.external_organization import BaseExternalOrganization
from mail.beagle.beagle.core.entities.unit import ExternalKeyType, Unit
from mail.beagle.beagle.core.entities.unit_unit import UnitUnit
from mail.beagle.beagle.core.entities.unit_user import UnitUser
from mail.beagle.beagle.core.entities.user import User

_Stored = TypeVar('_Stored')
_External = TypeVar('_External')


class SyncUnitMembersAction(BaseDBAction):
    def __init__(self,
                 org_id: int,
                 external_organization: BaseExternalOrganization,
                 units: Iterable[Unit],
                 users: Iterable[User],
                 ):
        super().__init__()
        self.org_id: int = org_id
        self.external_organization: BaseExternalOrganization = external_organization
        self.units_by_unit_id: Dict[int, Unit] = {}
        self.units_by_external_key: Dict[ExternalKeyType, Unit] = {}
        for unit in units:
            assert unit.unit_id is not None
            self.units_by_unit_id[unit.unit_id] = self.units_by_external_key[unit.external_key] = unit

    @staticmethod
    async def _sync_one_unit_members(unit: Unit,
                                     stored_members: Iterable[_Stored],
                                     external_members: Set[_External],
                                     stored_to_external: Callable[[_Stored], _External],
                                     create_func: Callable[[Unit, _External], Awaitable[None]],
                                     delete_func: Callable[[_Stored], Awaitable[None]],
                                     ) -> bool:
        changed = False
        for stored in stored_members:
            external = stored_to_external(stored)
            if external in external_members:
                external_members.remove(external)
            else:
                changed = True
                await delete_func(stored)
        for external in external_members:
            changed = True
            await create_func(unit, external)
        return changed

    def _unit_user_to_external(self, unit_user: UnitUser) -> int:
        return unit_user.uid

    async def _create_unit_user(self, unit: Unit, uid: int) -> None:
        assert unit.unit_id is not None
        await self.storage.unit_user.create(UnitUser(
            org_id=self.org_id,
            unit_id=unit.unit_id,
            uid=uid,
        ))

    def _unit_unit_to_external(self, unit_unit: UnitUnit) -> ExternalKeyType:
        return self.units_by_unit_id[unit_unit.unit_id].external_key

    async def _create_unit_unit(self, unit: Unit, child_external_key: ExternalKeyType) -> None:
        child_unit_id = self.units_by_external_key[child_external_key].unit_id
        assert unit.unit_id is not None and child_unit_id is not None
        await self.storage.unit_unit.create(UnitUnit(
            org_id=self.org_id,
            unit_id=child_unit_id,
            parent_unit_id=unit.unit_id,
        ))

    async def sync_unit_users(self) -> Set[int]:
        changed_unit_ids: Set[int] = set()
        external_unit_users = {
            external_key: external_user_uids
            async for external_key, external_user_uids in self.external_organization.get_unit_users()
        }
        stored_unit_users = {
            unit_id: list(unit_users)
            for unit_id, unit_users in groupby(
                await alist(self.storage.unit_user.find(org_id=self.org_id, order_by='unit_id')),
                key=lambda unit_user: unit_user.unit_id,
            )
        }
        for unit in self.units_by_unit_id.values():
            assert unit.unit_id is not None
            changed = await self._sync_one_unit_members(
                unit=unit,
                stored_members=stored_unit_users.get(unit.unit_id, []),
                external_members=external_unit_users.get(unit.external_key, set()),
                stored_to_external=self._unit_user_to_external,
                create_func=self._create_unit_user,
                delete_func=self.storage.unit_user.delete,
            )
            if changed:
                changed_unit_ids.add(unit.unit_id)
        return changed_unit_ids

    async def sync_unit_units(self) -> Set[int]:
        changed_unit_ids: Set[int] = set()
        external_unit_units = {
            external_key: external_unit_units
            async for external_key, external_unit_units in self.external_organization.get_unit_units()
        }
        stored_unit_units = {
            parent_unit_id: list(unit_units)
            for parent_unit_id, unit_units in groupby(
                await alist(self.storage.unit_unit.find(
                    org_id=self.org_id,
                    order_by='parent_unit_id',
                )),
                key=lambda unit_unit: unit_unit.parent_unit_id,
            )
        }
        for unit in self.units_by_unit_id.values():
            assert unit.unit_id is not None
            changed = await self._sync_one_unit_members(
                unit=unit,
                stored_members=stored_unit_units.get(unit.unit_id, []),
                external_members=external_unit_units.get(unit.external_key, set()),
                stored_to_external=self._unit_unit_to_external,
                create_func=self._create_unit_unit,
                delete_func=self.storage.unit_unit.delete,
            )
            if changed:
                changed_unit_ids.add(unit.unit_id)
        return changed_unit_ids

    async def handle(self) -> Set[int]:
        changed_unit_ids: Set[int] = set()
        changed_unit_ids.update(await self.sync_unit_users())
        changed_unit_ids.update(await self.sync_unit_units())
        return await GetDependentUIDsUnitAction(org_id=self.org_id, unit_ids=changed_unit_ids).run()
