from collections import defaultdict
from typing import (
    Dict, Iterable, Optional, Tuple,
    Callable, ContextManager, AsyncIterable
)
from sqlalchemy.orm import Session

from travel.avia.subscriptions.app.api.util_db import get_users_by_credentials
from travel.avia.subscriptions.app.api.interactor.user_price_change_subscription import (
    UserPriceChangeSubscriptionActor
)
from travel.avia.subscriptions.app.api.interactor.user_promo_subscription import (
    UserPromoSubscriptionActor
)
from travel.avia.subscriptions.app.lib.helper import deduplication_async


SubscriptionListType = Dict[str, Iterable[Dict[str, str]]]


class UserSubscriptionListActor:
    def __init__(
        self,
        email,
        user,
        user_auth,
        user_auth_type,
        session_provider: Callable[[], ContextManager[Session]],
        user_price_change_subscription_actor: UserPriceChangeSubscriptionActor,
        user_promo_subscription_actor: UserPromoSubscriptionActor
    ):
        self.Email = email
        self.User = user
        self.UserAuth = user_auth
        self.UserAuthType = user_auth_type
        self.session_provider = session_provider
        self.actors = {
            'price': user_price_change_subscription_actor,
            'promo': user_promo_subscription_actor,
        }

    async def list(
        self, *,
        credentials: Iterable[Tuple[str, str]],
        source: Optional[str] = None,
        language: Optional[str] = None
    ) -> SubscriptionListType:
        result = defaultdict(list)
        with self.session_provider() as session:
            users = list(get_users_by_credentials(
                session=session,
                credentials=credentials,
                email_dbs=self.Email(session),
                user_dbs=self.User(session),
                user_auth_dbs=self.UserAuth(session),
                user_auth_type_dbs=self.UserAuthType(session)
            ))

            for subscription_type, actor in self.actors.items():
                subscriptions = actor.subscriptions_list_by_users(
                    session=session,
                    users=users
                )
                subscriptions = deduplication_async(
                    seq=self._filter_by(subscriptions, language, source),
                    constraint_func=self._constraint
                )

                async for subscription in subscriptions:
                    result[subscription['email']].append({
                        'subscription_type': subscription_type,
                        'subscription_code': subscription['subscription_code'],
                        'name': subscription['name'],
                        'url': subscription['url'],
                    })

        return result

    @staticmethod
    async def _filter_by(
        subscriptions: AsyncIterable[Dict],
        language: Optional[str] = None,
        source: Optional[str] = None
    ) -> AsyncIterable[Dict]:
        async for s in subscriptions:
            if language is not None and s['language'] != language:
                continue
            if source is not None and s['source'] != source:
                continue

            yield s

    @staticmethod
    def _constraint(subscription):
        return subscription['email'], subscription['subscription_code']
