import asyncio
import logging
from collections import defaultdict
from dataclasses import asdict
from datetime import date, timedelta
from typing import (
    Callable, Dict, Generator, Iterable,
    Optional, Tuple, ContextManager, Union,
    List, AsyncIterable
)

from blackbox import XmlBlackbox

from sqlalchemy import func
from sqlalchemy.orm import Load, Session
from travel.avia.subscriptions.app.api.exceptions import UnknownPointsError, InvalidPriceChangeSubscription
from travel.avia.subscriptions.app.api.interactor.user_confirm import UserConfirmActor
from travel.avia.subscriptions.app.api.util_db import (
    travel_vertical_get_or_fail, user_get_or_create,
    get_users_by_credentials, delete_subscriptions_factory,
    descriptor_factory
)
from travel.avia.subscriptions.app.lib.dicts import PointKeyResolver
from travel.avia.subscriptions.app.lib.helper import context_manager, deduplication_async, hashed
from travel.avia.subscriptions.app.lib.qkey import (
    structure_from_qkey, qkey_from_params, validate_qkey, QkeyStructure, QkeyValidationError
)
from travel.avia.subscriptions.app.lib.yt_loggers.subscriptions.user_action_log import (
    Language, NationalVersion, SubscriptionType, log_user_subscribed
)
from travel.avia.subscriptions.app.model.db import (
    DBInstanceMixin, PriceChangeSubscription as Subscription,
    UserPriceChangeSubscription as UserSubscription, User, Email,
    TravelVertical, UserAuth, UserAuthType
)
from travel.avia.subscriptions.app.model.schemas import Filter, FilterMinPriceBundle, MinPrice
from travel.avia.subscriptions.app.model.storage import Storage, UpsertAction
from travel.avia.subscriptions.app.settings.urls import avia_search_link

logger = logging.getLogger(__name__)


class UserPriceChangeSubscriptionActor:
    def __init__(
        self,
        travel_vertical: Callable[[Session], Storage[TravelVertical]],
        email: Callable[[Session], Storage[Email]],
        user_auth_type: Callable[[Session], Storage[UserAuthType]],
        user_auth: Callable[[Session], Storage[UserAuth]],
        user: Callable[[Session], Storage[User]],
        price_change_subscription: Callable[[Session], Storage[Subscription]],
        user_price_change_subscription: Callable[[Session], Storage[UserSubscription]],
        session_provider: Callable[[], ContextManager[Session]],
        point_key_resolver: PointKeyResolver,
        user_confirm_actor: UserConfirmActor,
        blackbox: XmlBlackbox
    ):
        self.TravelVertical = travel_vertical
        self.Email = email
        self.UserAuthType = user_auth_type
        self.UserAuth = user_auth
        self.User = user
        self.PriceChangeSubscription = price_change_subscription
        self.UserPriceChangeSubscription = user_price_change_subscription
        self.user_confirm_actor = user_confirm_actor
        self.point_key_resolver = point_key_resolver
        self.blackbox = blackbox
        self.delete_method = delete_method()
        self.session_provider = session_provider

    @staticmethod
    def expand_qkey_date_range(
        qkey: Union[str, QkeyStructure],
        date_range: int = 1,
        only_relevant: bool = True
    ) -> Generator[str, None, None]:
        qkey_struct = structure_from_qkey(qkey) if isinstance(qkey, str) else qkey

        for i in range(date_range):
            date_forward = qkey_struct.date_forward.date() + timedelta(days=i)
            if only_relevant and date.today() > date_forward:
                continue
            yield qkey_from_params(
                point_from_key=qkey_struct.point_from_key,
                point_to_key=qkey_struct.point_to_key,
                date_forward=date_forward,
                date_backward=(
                    qkey_struct.date_backward.date() + timedelta(days=i)
                    if qkey_struct.date_backward else None
                ),
                klass=qkey_struct.klass,
                adults=qkey_struct.adults,
                children=qkey_struct.children,
                infants=qkey_struct.infants,
                national_version=qkey_struct.national_version,
            )

    async def put(
        self, *,
        credentials: Iterable[Tuple[str, str]],
        source: str,
        travel_vertical_name: str,
        email: Union[str, Email],
        name: str,
        timezone: str,
        language: str,
        qid: Optional[str] = None,
        qkey: Optional[str] = None,
        filter_: Optional[Filter] = None,
        date_range: int = 1,
        min_price: Optional[MinPrice] = None,
        travel_vertical: Optional[TravelVertical] = None,
        session: Optional[Session] = None
    ) -> Dict[UpsertAction, list]:

        session_provider = context_manager(session) if session is not None \
            else self.session_provider
        try:
            qkey = validate_qkey(qid, qkey)
        except QkeyValidationError as e:
            raise InvalidPriceChangeSubscription(e)
        qkey_struct = structure_from_qkey(qkey)
        national_version = qkey_struct.national_version
        result = defaultdict(list)

        # Если данных в справочнике нет, то будет выкидываться
        # ошибка UnknownPointsError
        await self._get_points_titles(qkey_struct)

        with session_provider() as session:
            if travel_vertical is None:
                travel_vertical = travel_vertical_get_or_fail(
                    travel_vertical_dbs=self.TravelVertical(session),
                    travel_vertical_name=travel_vertical_name
                )
            if isinstance(email, str):
                email_obj = self.Email(session).get_or_create(email=email)
            else:
                email_obj = email

            for auth_type_name, auth_value in credentials:
                user = user_get_or_create(
                    auth_type_name=auth_type_name,
                    auth_value=auth_value,
                    email=email_obj,
                    user_auth_type_dbs=self.UserAuthType(session),
                    user_auth_dbs=self.UserAuth(session),
                    user_dbs=self.User(session),
                    user_timezone=timezone,
                    user_name=name
                )

                # создаем объект подписки и пользовательской подписки
                action, user_subscription = await self._add_subscription(
                    session, qkey, qkey_struct, source, language, user,
                    travel_vertical, filter_, date_range, min_price
                )
                result[action].append(user_subscription)
                for qk in self.expand_qkey_date_range(qkey_struct, date_range):
                    subscription = self.PriceChangeSubscription(session).get(qkey=qk)
                    if subscription is not None:
                        self._add_filter_to_subscription(subscription, filter_)

                await self.user_confirm_actor.try_approve(
                    session=session,
                    user=user,
                    auth_type=auth_type_name,
                    auth_value=auth_value,
                    email_obj=email_obj,
                    national_version=national_version,
                )

            if len(result[UpsertAction.INSERT]) > 0:
                await log_user_subscribed(
                    subscription_type=SubscriptionType.PRICE_CHANGE,
                    code=str(qid),
                    national_version=NationalVersion(national_version),
                    language=Language(language),
                    email=hashed(email_obj.email),
                    date_range=None,
                    filter_params=asdict(filter_) if filter_ else None,
                    travel_vertical=travel_vertical_name,
                    source=source,
                    pending_passport=None,
                    pending_session=None,
                    min_price=asdict(min_price) if min_price else None,
                )
            return result

    async def subscribed_on_direction(
        self, *,
        credentials: Iterable[Tuple[str, str]],
        email: str,
        qid: Optional[str] = None,
        qkey: Optional[str] = None,
    ) -> Dict:
        qkey = validate_qkey(qid, qkey)
        qkey_params = asdict(structure_from_qkey(qkey))
        date_forward = qkey_params.pop('date_forward')
        date_backward = qkey_params.pop('date_backward')
        original_qkey = filter_ = date_range = None
        subscribed = False

        # Только релевантные запросы
        if date.today() > date_forward.date():
            return {
                'subscribed': subscribed,
                'filter': filter_,
                'qkey': original_qkey,
                'date_range': date_range
            }

        with self.session_provider() as session:
            user_ids = {
                user.id for user in get_users_by_credentials(
                    session=session,
                    credentials=credentials,
                    email=email,
                    email_dbs=self.Email(session),
                    user_dbs=self.User(session),
                    user_auth_dbs=self.UserAuth(session),
                    user_auth_type_dbs=self.UserAuthType(session),
                ) if user
            }

            date_range_exp = func.make_interval(0, 0, 0, UserSubscription.date_range)
            # https://docs.sqlalchemy.org/en/13/orm/query.html
            # filter_by: the keyword expressions are extracted from
            # the primary entity of the query, or the last entity
            # that was the target of a call to Query.join().
            query = (session
                     .query(UserSubscription, Subscription)
                     .options(Load(UserSubscription).load_only('date_range', 'applied_filters'),
                              Load(Subscription).load_only('qkey'))
                     .distinct(UserSubscription.price_change_subscription_id)
                     .join(Subscription)
                     .filter_by(**qkey_params)
                     .filter(date_forward - date_range_exp <= Subscription.date_forward)
                     .filter(Subscription.date_forward <= date_forward)
                     .filter(UserSubscription.user_id.in_(user_ids))
                     .order_by(UserSubscription.price_change_subscription_id,
                               UserSubscription.updated_at.desc()))

            # Если указан дата возвращения, то она также входит
            # в date_range, и сдвинута на столько же дней, на сколько
            # сдвинута дата вылета
            if date_backward is not None:
                query = (
                    query.filter(
                        Subscription.date_backward.isnot(None)
                    ).filter(
                        date_backward - date_range_exp <= Subscription.date_backward
                    ).filter(
                        Subscription.date_backward <= date_backward
                    ).filter(
                        Subscription.date_backward - Subscription.date_forward == date_backward - date_forward
                    )
                )
            else:
                query = query.filter(Subscription.date_backward.is_(None))

            subscription_pair = query.first()

            if subscription_pair is not None:
                user_subscription, subscription = subscription_pair
                subscribed = True
                original_qkey = subscription.qkey
                date_range = user_subscription.date_range
                applied_filter = (user_subscription.applied_filters[0]
                                  if user_subscription.applied_filters else None)
                if applied_filter is not None:
                    filter_ = applied_filter.filter_url_postfix

            return {
                'subscribed': subscribed,
                'filter': filter_,
                'qkey': original_qkey,
                'date_range': date_range,
            }

    async def get_subscriptions_list(
        self, *,
        credentials: Iterable[Tuple[str, str]],
        email: Optional[str] = None,
        raise_on_auth_type_absent: bool = True
    ) -> List[dict]:
        with self.session_provider() as session:
            users = get_users_by_credentials(
                session=session,
                credentials=credentials,
                email=email,
                email_dbs=self.Email(session),
                user_dbs=self.User(session),
                user_auth_dbs=self.UserAuth(session),
                user_auth_type_dbs=self.UserAuthType(session),
                raise_on_auth_type_absent=raise_on_auth_type_absent
            )
            subscriptions = deduplication_async(
                seq=self.subscriptions_list_by_users(session, users, email),
                constraint_func=lambda s: s['subscription_code']
            )
            return [s async for s in subscriptions]

    async def subscriptions_list_by_users(
        self,
        session: Session,
        users: Iterable[User],
        email: Optional[str] = None
    ) -> AsyncIterable[Dict]:
        for user in users:
            # Если передали email, то считаем, что его проверили
            # до вызова этой функции
            email_local = email if email is not None else \
                self.Email(session).get(id=user.email_id).email
            user_subscriptions = self.UserPriceChangeSubscription(session).find(
                user_id=user.id,
                deleted_at=None
            )

            for user_subscription in user_subscriptions:
                subscription = self.PriceChangeSubscription(session).get(
                    id=user_subscription.price_change_subscription_id
                )

                point_from, point_to = await self._get_points_title_no_error(subscription)

                if not point_from:
                    point_from = subscription.point_from_key

                if not point_to:
                    point_to = subscription.point_to_key

                date_str = subscription.date_forward.strftime('%d.%m')

                filter_fragment = None
                if len(user_subscription.applied_filters) > 0:
                    # Всегда берем первый фрагмент, потому что
                    # у свежей подписки он должен быть один
                    applied_filter = user_subscription.applied_filters[0]
                    filter_fragment = applied_filter.filter_url_postfix if applied_filter else None

                # TODO: нужно ли добавлять utm?
                url = avia_search_link(
                    qkey_struct=structure_from_qkey(subscription.qkey),
                    language=user_subscription.lang,
                    filter_fragment=filter_fragment
                )

                yield {
                    'email': email_local,
                    'subscription_code': subscription.qkey,
                    'national_version': subscription.national_version,
                    'language': user_subscription.lang,
                    'source': user_subscription.source,
                    'name': f'{point_from}-{point_to}, {date_str}',
                    'url': url,
                }

    async def delete(
        self, *,
        session: Optional[Session] = None,
        subscription_codes: List[str],
        credentials: List[Tuple[str, str]],
        email: str,
        authenticated: bool = False
    ) -> Iterable[str]:
        session_provider = context_manager(session) if session is not None \
            else self.session_provider
        with session_provider() as session:
            return self.delete_method(
                session=session,
                values=subscription_codes,
                credentials=credentials,
                email=email,
                email_dbs=self.Email(session),
                user_dbs=self.User(session),
                user_auth_dbs=self.UserAuth(session),
                user_auth_type_dbs=self.UserAuthType(session),
                blackbox=self.blackbox,
                authenticated=authenticated
            )

    async def _add_subscription(
        self,
        session: Session,
        qkey: str,
        qkey_struct: QkeyStructure,
        source: str,
        language: str,
        user: DBInstanceMixin,
        travel_vertical: DBInstanceMixin,
        filter_: Optional[Filter] = None,
        date_range: int = 1,
        min_price: Optional[MinPrice] = None
    ) -> Tuple[UpsertAction, DBInstanceMixin]:
        subscription_values = asdict(qkey_struct)
        subscription_values['qkey'] = qkey

        _, subscription = self.PriceChangeSubscription(session).upsert(
            where=dict(qkey=qkey),
            values=subscription_values
        )
        action, user_subscription = self.UserPriceChangeSubscription(session).upsert(
            where=dict(
                user_id=user.id,
                price_change_subscription_id=subscription.id,
                lang=language
            ),
            values=dict(
                source=source,
                date_range=date_range,
                travel_vertical_id=travel_vertical.id,
                deleted_at=None
            )
        )

        if filter_ is not None:
            if len(user_subscription.applied_filters) == 0:
                user_subscription.applied_filters.append(filter_)
            else:
                user_subscription.applied_filters[0] = filter_
        else:
            user_subscription.applied_filters = []

        if min_price is not None:
            user_subscription.last_seen_min_price = min_price

        return action, user_subscription

    @staticmethod
    def _add_filter_to_subscription(
        subscription, filter_: Optional[Filter]
    ):
        if filter_ is not None:
            used_filters = [bundle.filter for bundle in subscription.filtered_minprices]
            if filter_ not in used_filters:
                subscription.filtered_minprices.append(
                    FilterMinPriceBundle(filter=filter_)
                )

    async def _get_points_title_no_error(self, points_info: Union[QkeyStructure, Subscription]) -> Tuple[str, str]:
        """
        :return: point from title, point to title
        :raises: UnknownPointsError
        """
        point_from, point_to = await asyncio.gather(
            self.point_key_resolver.resolve(points_info.point_from_key),
            self.point_key_resolver.resolve(points_info.point_to_key)
        )

        return (
            point_from.TitleDefault if point_from else None,
            point_to.TitleDefault if point_to else None,
        )

    async def _get_points_titles(self, points_info: Union[QkeyStructure, Subscription]) -> Tuple[str, str]:
        """
        :return: point from title, point to title
        :raises: UnknownPointsError
        """
        point_from, point_to = await self._get_points_title_no_error(points_info)

        if point_from is None or point_to is None:
            error_msg = 'Subscriptions error: '
            if point_from is None:
                error_msg += f'Point from key: {points_info.point_from_key} not found in dicts. '
            if point_to is None:
                error_msg += f'Point to key: {points_info.point_to_key} not found in dicts'
            logger.error(error_msg)
            raise UnknownPointsError(error_msg)

        return point_from, point_to


delete_method = delete_subscriptions_factory(
    user_subscription_model=UserSubscription,
    user_subscription_model_attr='price_change_subscription_id',
    descriptor_cls=descriptor_factory(
        subscription_model=Subscription,
        subscription_model_attr='qkey'
    )
)
