from datetime import datetime
from sqlalchemy.orm import Session, load_only
from typing import Collection, List, Optional, Iterable, Type, TypeVar, Tuple, Union

from blackbox import XmlBlackbox

from travel.avia.subscriptions.app.api.consts import PASSPORT_AUTH_TYPE, TOKEN_AUTH_TYPE
from travel.avia.subscriptions.app.api.exceptions import (
    InvalidUserAuthType, InvalidTravelVertical,
    EmailNotFound, NoAccess
)
from travel.avia.subscriptions.app.model.db import (
    DBInstanceMixin, Email, TravelVertical,
    UserAuthType, UserAuth, User
)
from travel.avia.subscriptions.app.model.storage import Storage, TBase


def user_auth_type_get_or_fail(
    auth_type_dbs: Storage[UserAuthType],
    auth_type_name: str
) -> UserAuthType:
    user_auth_type = auth_type_dbs.get(
        name=auth_type_name,
    )

    if user_auth_type is None:
        raise InvalidUserAuthType(
            'Unrecognized auth type: {}'.format(auth_type_name)
        )

    return user_auth_type


def travel_vertical_get_or_fail(
    travel_vertical_dbs: Storage[TravelVertical],
    travel_vertical_name: str
) -> TravelVertical:
    travel_vertical = travel_vertical_dbs.get(
        name=travel_vertical_name,
    )
    if not travel_vertical:
        raise InvalidTravelVertical(
            f'Cannot find travel vertical with name {travel_vertical_name}'
        )

    return travel_vertical


def email_get_or_fail(email_dbs: Storage[Email], email: str) -> Union[DBInstanceMixin, Email]:
    email = email_dbs.get(email=email)
    if not email:
        raise EmailNotFound('Unknown email')
    return email


def user_get_or_create(
    *,
    auth_type_name: str,
    auth_value: str,
    email: Email,
    user_auth_type_dbs: Storage[UserAuthType],
    user_auth_dbs: Storage[UserAuth],
    user_dbs: Storage[User],
    user_timezone: str,
    user_name: str,
) -> User:
    user_auth_type = user_auth_type_get_or_fail(
        auth_type_dbs=user_auth_type_dbs,
        auth_type_name=auth_type_name
    )
    user_auth = user_auth_dbs.get_or_create(
        user_auth_type_id=user_auth_type.id,
        auth_value=auth_value
    )
    _, user = user_dbs.upsert(
        where=dict(
            email_id=email.id,
            user_auth_id=user_auth.id,
        ),
        values=dict(
            user_timezone=user_timezone,
            name=user_name,
        ),
    )

    return user


def get_user_auth(
    user_auth_dbs: Storage[UserAuth],
    user_auth_type_dbs: Storage[UserAuthType],
    auth_type_name: str,
    auth_value: str,
    raise_on_auth_type_absent: bool = False
) -> Optional[UserAuth]:
    try:
        user_auth_type = user_auth_type_get_or_fail(
            auth_type_dbs=user_auth_type_dbs,
            auth_type_name=auth_type_name
        )
    except InvalidUserAuthType:
        if raise_on_auth_type_absent:
            raise
        return

    return user_auth_dbs.get(
        user_auth_type_id=user_auth_type.id,
        auth_value=auth_value,
    )


def get_users_by_user_auth(
    session: Session,
    user_dbs: Storage[User],
    user_auth: UserAuth,
    auth_type_name: str,
    email_obj: Email = None
) -> List[User]:
    if auth_type_name == TOKEN_AUTH_TYPE:
        # Для токена всегда существует только один пользователь
        user = user_dbs.get(user_auth_id=user_auth.id, deleted_at=None)
        if email_obj is not None and user.email_id != email_obj.id:
            return []

        # Токен имеет доступ ко всем подтвержденным пользователям текущей почты
        return (session.query(User)
                .filter_by(email_id=user.email_id, deleted_at=None)
                .filter(User.approved_at.isnot(None))
                .all())

    query = (session.query(User)
             .filter_by(user_auth_id=user_auth.id, deleted_at=None)
             .filter(User.approved_at.isnot(None)))

    if email_obj:
        query = query.filter_by(email_id=email_obj.id)

    return query.all()


def get_users_by_credentials(
    *,
    session: Session,
    credentials: Iterable[Tuple[str, str]],
    email: Optional[str] = None,
    email_dbs: Storage[Email],
    user_dbs: Storage[User],
    user_auth_dbs: Storage[UserAuth],
    user_auth_type_dbs: Storage[UserAuthType],
    raise_on_auth_type_absent: bool = False
):
    email_obj = email_dbs.get(email=email) if email else None
    if email_obj is None and email is not None:
        return []

    for auth_type_name, auth_value in credentials:
        user_auth = get_user_auth(
            user_auth_dbs=user_auth_dbs,
            user_auth_type_dbs=user_auth_type_dbs,
            auth_type_name=auth_type_name,
            auth_value=auth_value,
            raise_on_auth_type_absent=raise_on_auth_type_absent
        )
        if user_auth is None:
            continue

        yield from get_users_by_user_auth(
            session, user_dbs, user_auth,
            auth_type_name, email_obj
        )


def authenticate(
    *,
    credentials: Iterable[Tuple[str, str]],
    email_obj: Email,
    user_dbs: Storage[User],
    user_auth_dbs: Storage[UserAuth],
    user_auth_type_dbs: Storage[UserAuthType],
    blackbox: XmlBlackbox,
):
    passport_uid = None

    for auth_type, auth_value in credentials:
        passport_uid = auth_value if auth_type == PASSPORT_AUTH_TYPE else None
        user_auth_type = user_auth_type_dbs.get(name=auth_type)
        if user_auth_type is None:
            continue

        user_auth = user_auth_dbs.get(
            user_auth_type_id=user_auth_type.id,
            auth_value=auth_value,
        )
        if user_auth is None:
            continue

        user = user_dbs.get(
            email_id=email_obj.id,
            user_auth_id=user_auth.id,
        )
        if (user is None or user.approved_at is None
                or user.deleted_at is not None):
            continue

        return

    if passport_uid is not None:
        if email_obj.email in blackbox.list_emails(passport_uid):
            return

    raise NoAccess('Credentials are not allowed for that email')


TDescriptor = TypeVar('TDescriptor')


def delete_subscriptions_factory(
    user_subscription_model: Type[TBase],
    user_subscription_model_attr: str,
    descriptor_cls: Type[TDescriptor]
):
    """
    Логика удаления подписок для текущей схемы базы очень схожа
    для разных типов подписок, поэтому часть с удалением выделена
    в отдельный утильный фабричный метод.
    Удаление происходит по следующему алгоритму:
    1) проверяем, что сущесвтует данный email
    2) пытаемся аутенфицировать пользователя хотя бы по одному credential
    3) с помощью дескриптора находим все подписки с заданными значениями
    (например. для промо - это promo code, для подписок на изменение цен -
    это qkey)
    4) дальше находим все пользовательские подписки с заданными
    подтвержденными неудаленными пользвателями и id-шниками самих
    подписок
    5) с помощью дескриптора находим название удаленной подписки
    и кладем в результат

    :param user_subscription_model: модель пользовательской подписки
    :param user_subscription_model_attr: аттрибут, содержащий связь
    пользовательской подписки с самой подпиской, на деле это id
    :param descriptor_cls: класс дескриптора
    """
    class DeleteUserSubscriptions:
        def __call__(
            self, *,
            session: Session,
            values: List,
            credentials: List[Tuple[str, str]],
            email: Union[str, Email],
            email_dbs: Storage[Email],
            user_dbs: Storage[User],
            user_auth_dbs: Storage[UserAuth],
            user_auth_type_dbs: Storage[UserAuthType],
            blackbox: XmlBlackbox,
            authenticated: bool = False
        ):
            # Найти нужный email и произвести аутенфикацию можно
            # до процедуры удаления, чтобы повторно не делать это
            # для каждого вызова
            if isinstance(email, str):
                email_obj = email_get_or_fail(email_dbs, email)
            else:
                email_obj = email

            if not authenticated:
                authenticate(
                    credentials=credentials,
                    email_obj=email_obj,
                    user_dbs=user_dbs,
                    user_auth_dbs=user_auth_dbs,
                    user_auth_type_dbs=user_auth_type_dbs,
                    blackbox=blackbox
                )

            descriptor = descriptor_cls(session, values)
            now = datetime.utcnow()
            deleted_subscriptions_names = set()
            user_subscriptions = self._get_user_subscriptions(
                session, descriptor.all(), email_obj
            )

            for user_subscription in user_subscriptions:
                user_subscription.deleted_at = now
                attr_value = getattr(
                    user_subscription,
                    user_subscription_model_attr
                )
                deleted_subscriptions_names.add(descriptor[attr_value])

            return deleted_subscriptions_names

        def _get_user_subscriptions(self, sess, subscriptions, email_obj) -> Iterable[TBase]:
            subscriptions_ids = list(self._get_ids(subscriptions))
            users = (sess.query(User)
                     .options(load_only('id'))
                     .filter(User.approved_at.isnot(None))
                     .filter(User.deleted_at.is_(None))
                     .filter_by(email_id=email_obj.id)
                     .all())
            users_ids = list(self._get_ids(users))
            subscription_a = getattr(user_subscription_model, user_subscription_model_attr)
            user_a = user_subscription_model.user_id

            return (sess.query(user_subscription_model)
                    .filter(subscription_a.in_(subscriptions_ids))
                    .filter(user_a.in_(users_ids))
                    .all())

        @staticmethod
        def _get_ids(collection) -> Iterable[int]:
            return (e.id for e in collection)

    return DeleteUserSubscriptions


def descriptor_factory(
    subscription_model: Type[TBase],
    subscription_model_attr: str
):
    """
    Дескриптор для подписки

    :param subscription_model: модель подписки
    :param subscription_model_attr: аттрибут подписки, содержащий значение
    (например. для промо - это promo code, для подписок на изменение цен -
    это qkey)
    """
    class SubscriptionsDescriptor:
        def __init__(self, session: Session, values: Collection[str]):
            self._attr = getattr(subscription_model, subscription_model_attr)
            self._subscriptions = (
                session.query(subscription_model)
                .filter(self._attr.in_(values))
                .all()
            )
            self._dict = self._dict()

        def all(self) -> Collection[TBase]:
            return self._subscriptions

        def _dict(self):
            return {
                promo_s.id: promo_s
                for promo_s in self._subscriptions
            }

        def __getitem__(self, item):
            return getattr(self._dict[item], subscription_model_attr)

    return SubscriptionsDescriptor
