from typing import Iterable, Optional, Set

from mail.beagle.beagle.core.actions.base import BaseDBAction
from mail.beagle.beagle.core.entities.not_fetched import NotFetchedType


class GetDependentUIDsUnitAction(BaseDBAction):
    transact = True

    def __init__(self,
                 org_id: int,
                 unit_id: Optional[int] = None,
                 unit_ids: Optional[Iterable[int]] = None,
                 ):
        super().__init__()
        self.org_id: int = org_id
        assert (unit_id is None) != (unit_ids is None), 'Exactly one of unit_id, unit_ids must not be None'
        self.unit_ids: Set[int]
        if unit_id is not None:
            self.unit_ids = {unit_id}
        elif unit_ids is not None:
            self.unit_ids = set(unit_ids)

    async def handle(self) -> Set[int]:
        unit_ids = {
            unit_unit.parent_unit_id
            async for unit_unit in self.storage.unit_unit.find(
                org_id=self.org_id,
                unit_ids=list(self.unit_ids),
            )
        }
        unit_ids.update(self.unit_ids)

        mail_list_uids = set()
        async for unit_subscription in self.storage.unit_subscription.find(org_id=self.org_id, unit_ids=list(unit_ids)):
            assert not isinstance(unit_subscription.mail_list, NotFetchedType)
            mail_list_uids.add(unit_subscription.mail_list.uid)

        return mail_list_uids
