from collections import defaultdict
from typing import AsyncIterable, Dict, Iterable, List, Optional, Tuple, Union
from sqlalchemy.orm import Session

from travel.avia.subscriptions.app.model.db import (
    Email, TravelVertical, User, UserPromoSubscription, PromoSubscription
)
from travel.avia.subscriptions.app.api.exceptions import InvalidPromoSubscription
from travel.avia.subscriptions.app.api.util_db import (
    user_get_or_create, travel_vertical_get_or_fail,
    delete_subscriptions_factory, get_users_by_credentials,
    descriptor_factory
)
from travel.avia.subscriptions.app.lib.helper import context_manager, deduplication_async, hashed
from travel.avia.subscriptions.app.lib.yt_loggers.subscriptions.user_action_log import (
    Language, NationalVersion, SubscriptionType, log_user_subscribed
)
from travel.avia.subscriptions.app.model.storage import UpsertAction


class UserPromoSubscriptionActor:
    def __init__(
        self,
        promo_subscription,
        travel_vertical,
        email,
        user_auth_type,
        user_auth,
        user,
        user_promo_subscription,
        user_confirm_actor,
        session_provider,
        blackbox
    ):
        self.PromoSubscription = promo_subscription
        self.TravelVertical = travel_vertical
        self.Email = email
        self.UserAuthType = user_auth_type
        self.UserAuth = user_auth
        self.User = user
        self.UserPromoSubscription = user_promo_subscription
        self.user_confirm_actor = user_confirm_actor
        self.session_provider = session_provider
        self.delete_method = delete_method()
        self.blackbox = blackbox

    async def put(
        self, *,
        email: Union[str, Email],
        name: str,
        credentials: Iterable[Tuple[str, str]],
        promo_subscription_code: str,
        source: str,
        travel_vertical_name: str,
        national_version: str,
        language: str,
        timezone: str,
        travel_vertical: Optional[TravelVertical] = None,
        session: Optional[Session] = None
    ) -> Dict[UpsertAction, list]:
        """
        :raises InvalidAuthType: if provided with unknown auth type
        :raises InvalidEmail: if email is invalid
        :raises InvalidPromoSubscription: if ther's no promo subscription
            with provided code, national version and language
        """
        result = defaultdict(list)
        session_provider = context_manager(session) if session is not None \
            else self.session_provider
        with session_provider() as sess:
            promo_subscription = self._promo_subscription_get_or_fail(
                sess,
                promo_subscription_code,
                national_version,
                language
            )
            if travel_vertical is None:
                travel_vertical = travel_vertical_get_or_fail(
                    travel_vertical_dbs=self.TravelVertical(sess),
                    travel_vertical_name=travel_vertical_name
                )
            if isinstance(email, str):
                email_obj = self.Email(sess).get_or_create(email=email)
            else:
                email_obj = email

            for auth_type_name, auth_value in credentials:
                user = user_get_or_create(
                    auth_type_name=auth_type_name,
                    auth_value=auth_value,
                    email=email_obj,
                    user_auth_type_dbs=self.UserAuthType(sess),
                    user_auth_dbs=self.UserAuth(sess),
                    user_dbs=self.User(sess),
                    user_timezone=timezone,
                    user_name=name,
                )
                action, user_promo_subscription = self.UserPromoSubscription(sess).upsert(
                    where=dict(
                        user_id=user.id,
                        promo_subscription_id=promo_subscription.id,
                    ),
                    values=dict(
                        travel_vertical_id=travel_vertical.id,
                        source=source,
                        deleted_at=None,
                    )
                )

                await self.user_confirm_actor.try_approve(
                    session=sess,
                    user=user,
                    auth_type=auth_type_name,
                    auth_value=auth_value,
                    email_obj=email_obj,
                    national_version=promo_subscription.national_version,
                )

                result[action].append(user_promo_subscription)
        if len(result[UpsertAction.INSERT]) > 0:
            await log_user_subscribed(
                subscription_type=SubscriptionType.PROMO,
                code=promo_subscription_code,
                national_version=NationalVersion(national_version),
                language=Language(language),
                email=hashed(email_obj.email),
                date_range=None,
                filter_params=None,
                travel_vertical=travel_vertical_name,
                source=source,
                pending_passport=None,
                pending_session=None,
                min_price=None,
            )
        return result

    def _promo_subscription_get_or_fail(self, sess, promo_subscription_code, national_version, language):
        promo_subscription = self.PromoSubscription(sess).get(
            code=promo_subscription_code,
            national_version=national_version,
            language=language,
        )
        if not promo_subscription:
            raise InvalidPromoSubscription(
                'Cannot find subscription with code {}, nv {}, lang {}'.format(
                    promo_subscription_code,
                    national_version,
                    language,
                )
            )
        return promo_subscription

    async def get_subscriptions_list(
        self, *,
        credentials: Iterable[Tuple[str, str]],
        email: Optional[str] = None,
        raise_on_auth_type_absent: bool = False
    ) -> List[dict]:
        with self.session_provider() as session:
            users = get_users_by_credentials(
                session=session,
                credentials=credentials,
                email=email,
                email_dbs=self.Email(session),
                user_dbs=self.User(session),
                user_auth_dbs=self.UserAuth(session),
                user_auth_type_dbs=self.UserAuthType(session),
                raise_on_auth_type_absent=raise_on_auth_type_absent
            )
            subscriptions = deduplication_async(
                seq=self.subscriptions_list_by_users(session, users, email),
                constraint_func=self._subscription_constraint
            )
            return [s async for s in subscriptions]

    @staticmethod
    def _subscription_constraint(subscription_dict: Dict):
        return (
            subscription_dict['email'],
            subscription_dict['subscription_code'],
            subscription_dict['national_version'],
            subscription_dict['language'],
        )

    async def subscriptions_list_by_users(
        self,
        session: Session,
        users: Iterable[User],
        email: Optional[str] = None
    ) -> AsyncIterable[Dict]:
        for user in users:
            # Если передали email, то считаем, что его проверили
            # до вызова этой функции
            email_local = email if email is not None else \
                self.Email(session).get(id=user.email_id).email
            user_promo_subscriptions = self.UserPromoSubscription(session).find(
                user_id=user.id,
                deleted_at=None
            )
            for user_promo_subscription in user_promo_subscriptions:
                promo_subscription = self.PromoSubscription(session).get(
                    id=user_promo_subscription.promo_subscription_id
                )

                yield {
                    'email': email_local,
                    'subscription_code': promo_subscription.code,
                    'national_version': promo_subscription.national_version,
                    'language': promo_subscription.language,
                    'source': user_promo_subscription.source,
                    'name': promo_subscription.name,
                    'url': promo_subscription.url,
                }

    async def delete(
        self, *,
        session: Optional[Session] = None,
        subscription_codes: List[str],
        credentials: List[Tuple[str, str]],
        email: str,
        authenticated: bool = False
    ) -> Iterable[str]:
        session_provider = context_manager(session) if session is not None \
            else self.session_provider
        with session_provider() as sess:
            return self.delete_method(
                session=sess,
                values=subscription_codes,
                credentials=credentials,
                email=email,
                email_dbs=self.Email(sess),
                user_dbs=self.User(sess),
                user_auth_dbs=self.UserAuth(sess),
                user_auth_type_dbs=self.UserAuthType(sess),
                blackbox=self.blackbox,
                authenticated=authenticated
            )


delete_method = delete_subscriptions_factory(
    user_subscription_model=UserPromoSubscription,
    user_subscription_model_attr='promo_subscription_id',
    descriptor_cls=descriptor_factory(
        subscription_model=PromoSubscription,
        subscription_model_attr='code'
    )
)
