import logging

import numpy as np
from flask import current_app as app
from mongoengine import DoesNotExist

from jafar import fast_cache, jafar_mongo
from jafar.pipelines import ids
from jafar.pipelines.blocks import SingleContextBlock
from jafar.pipelines.blocks.estimator import EstimatorBlockMixin
from jafar.pipelines.misc import get_predictions_frame
from jafar.models.yandexphone import GiftSet, YandexDistributedApp, BonusCardResources
from jafar.utils.structarrays import DataFrame

LAUNCHER_LOGIN_ACTION_ID = 'lnchr_login'
LAUNCHER_LOGIN_ACTION_ITEM = 'action:lnchr_login'

logger = logging.getLogger(__name__)


class BaseSelectionBlock(SingleContextBlock):
    """
    Selection block takes target frame (something a pipeline has been asked to evaluate)
    and produces "prediction frame", which, depending on pipeline's mode of operation may
    be one of two things:

     * if pipeline operates in "ranking" mode (i.e. target frame already contains candidate items),
       its goal is to score all of the candidate items and rank them accordingly. prediction frame is
       therefore the same as target frame
     * if pipeline operates in "generation" (top-n) mode (target frame _doesn't_ contain any candidate
       items), prediction frame is going to be composed from candidate recommendations for each user.
    """

    def __init__(self, top_n=None, target_frame=ids.FRAME_KEY_TARGET, predictions_frame=ids.FRAME_KEY_PREDICTIONS):
        self.top_n = top_n
        self.target_frame = target_frame
        self.predictions_frame = predictions_frame
        super(BaseSelectionBlock, self).__init__(
            input_data=[self.target_frame], output_data=[self.predictions_frame], destroyed_data=None
        )

    @staticmethod
    def leave_unique_predictions(predictions):
        """
        Returns unique combinations of `user` and `item` columns.
        """
        _, idx = predictions[['user', 'item']].unique(return_index=True)
        return predictions[idx]

    def apply(self, context, train):
        # first, determine if any other selection block has already been applied
        # NOTE: this logic is mostly temporary, until EstimatorBlocks, EstimatorFeatureBlocks
        # and EstimatorSelectionBlocks are combined in one entity
        first_run = self.predictions_frame not in context.data
        if train:
            # candidate selection makes sense for prediction pass only
            if first_run:
                logger.debug("%s is the first selection block to run in train phase", self.__class__.__name__)
                context.data[self.predictions_frame] = context.data[self.target_frame].copy()
            else:
                logger.debug(
                    "%s frame is already in context and it's train phase, nothing to do here",
                    self.predictions_frame
                )
        else:
            predictions = self.create_prediction_frame(context)
            if not first_run:
                logger.debug(
                    "%s frame is already in context and it's predict phase, combining new predictions "
                    "with existing ones.",
                    self.predictions_frame
                )
                logger.debug(
                    "Filtering predictions frame: %s rows",
                    len(predictions) + len(context.data[self.predictions_frame])
                )
                predictions = self.leave_unique_predictions(DataFrame.concatenate([
                    context.data[self.predictions_frame], predictions
                ]))
                logger.debug("Leaving unique user-item combinations; %s rows left", len(predictions))
            context.data[self.predictions_frame] = predictions

        return context

    def create_prediction_frame(self, context):
        raise NotImplementedError

    @staticmethod
    def get_empty_predictions():
        return np.array([], dtype=[('item', np.int32), ('user', np.int32)])


class DummySelectionBlock(SingleContextBlock):
    """
    dummy selection block selects nothing - just copy data from target frame to predictions in train and CV
    """

    def __init__(self, input_frame=ids.FRAME_KEY_TARGET, output_frame=ids.FRAME_KEY_PREDICTIONS):
        self.input_frame = input_frame
        self.output_frame = output_frame
        super(DummySelectionBlock, self).__init__(
            input_data=[self.input_frame], output_data=[self.output_frame], destroyed_data=None
        )

    def apply(self, context, train):
        context.data[self.output_frame] = context.data[self.input_frame].copy()
        return context


class OnlineTrendingSelectionBlock(SingleContextBlock):
    """
    Loads trending apps from mongo collection
    """

    def __init__(self, trending_by, trend_days, top_n, average_daily_counts_threshold=None,
                 input_frame=ids.FRAME_KEY_TARGET, output_frame=ids.FRAME_KEY_PREDICTIONS):

        trend_days_choice = app.config['TRENDING_INTERVALS_IN_DAYS']
        assert trend_days in trend_days_choice, 'Trend days count must be one of {}'.format(trend_days_choice)
        self.trend_field = app.config['TRENDING_MONGO_FIELD'].format(trend_days)

        trending_by_choice = app.config['TRENDING_COUNT_STATS']
        assert trending_by in trending_by_choice, \
            'Trending event must be one of {}'.format(trending_by_choice)
        self.mongo_collection = app.config['TRENDING_MONGO_COLLECTION'].format(trending_by)

        if average_daily_counts_threshold is not None:
            self.average_daily_counts_threshold = average_daily_counts_threshold
        elif trending_by == 'launches':
            # ignore apps with less than 100 average launches a day
            self.average_daily_counts_threshold = 100.0
        else:
            # ignore apps with less than 10 average installs a day
            self.average_daily_counts_threshold = 10.0

        self.top_n = top_n
        self.input_frame = input_frame
        self.output_frame = output_frame
        super(OnlineTrendingSelectionBlock, self).__init__(
            input_data=[self.input_frame], output_data=[self.output_frame], destroyed_data=None
        )

    def apply(self, context, train):
        if not train:
            @fast_cache.memoize(5 * 60)
            def get_trending_data(mongo_collection, mean_counts_threshold, trend_field, top_n):
                trending_apps = jafar_mongo.db[mongo_collection].find({
                    'average_daily_counts': {'$gt': mean_counts_threshold}}).sort(trend_field, -1).limit(top_n)
                return DataFrame.from_structarray(np.array(
                    [(item['item'],) for item in trending_apps],
                    dtype=[('item', np.object)])
                )

            target = context.data[self.input_frame]
            trending_data = get_trending_data(
                self.mongo_collection, self.average_daily_counts_threshold, self.trend_field, self.top_n)
            context.data[self.output_frame] = get_predictions_frame(target, trending_data)
        return context


class EstimatorSelectionBlock(EstimatorBlockMixin, BaseSelectionBlock):
    def get_estimator_predictions(self, context, estimator, target, basket):
        return estimator.predict_top_n(target, self.top_n, basket=basket)

    def __init__(self, estimator_class, top_n, estimator_params=None, drop_value=True,
                 target_frame=ids.FRAME_KEY_TARGET, predictions_frame=ids.FRAME_KEY_PREDICTIONS,
                 *args, **kwargs):
        self.estimator_class = estimator_class
        self.estimator_params = estimator_params or {}
        self.target_frame = target_frame
        self.predictions_frame = predictions_frame
        self.top_n = top_n
        self.drop_value = drop_value

        required_frames = [self.target_frame]
        if self.estimator_class.basket_required:
            required_frames.append(ids.FRAME_KEY_BASKET)

        # NOTE: super() is called for parent class to override input_data requirements
        super(BaseSelectionBlock, self).__init__(
            input_data=required_frames, output_data=[self.predictions_frame], destroyed_data=None,
            *args, **kwargs
        )

    def create_prediction_frame(self, context):
        target_frame = context.data[self.target_frame]
        basket = context.data[ids.FRAME_KEY_BASKET] if self.estimator_class.basket_required else None

        _, idx = target_frame['user'].unique(return_index=True)
        target = target_frame[idx]
        estimator = self.create_estimator(context, self.estimator_class, self.estimator_params)
        predictions = self.get_estimator_predictions(context, estimator, target, basket)
        assert isinstance(predictions, DataFrame), "%r not returning Dataframe" % self.estimator_class
        if self.drop_value:
            predictions = predictions.drop_columns(['value'])
        return predictions


class CategoryAwareEstimatorSelectionBlock(EstimatorSelectionBlock):
    def get_estimator_predictions(self, context, estimator, target, basket):
        kwargs = dict(
            X=target,
            n=self.top_n,
            basket=basket
        )
        # NOTE: category-aware estimators support list of specific categories to recommend from
        if context.requested_categories is not None:
            kwargs.update(group_feature_values=[(category,) for category in context.requested_categories])

        return estimator.predict_top_n_per_group_feature(**kwargs)


class BasketNeighborSelectionBlock(EstimatorSelectionBlock):
    """
    This block selects candidate items by simply
    taking all neighbors of basket items.

    Since not every estimator can be used as
    nearest-neighbor searcher, the choices are
    limited (but not restricted) to (currenly)
    subclasses of BaseKNN estimator.
    """

    # noinspection PyUnusedLocal
    def __init__(self, estimator_class, estimator_params=None, *args, **kwargs):
        assert hasattr(estimator_class, 'get_neighbors'), \
            "Estimator {} doesn't support neighbor selection".format(estimator_class)
        super(BasketNeighborSelectionBlock, self).__init__(
            estimator_class, estimator_params=None, top_n=None,  # number of neighbors is determined by estimator
            *args, **kwargs
        )

    def apply(self, context, train):
        if train:
            # target frame is copied because of the same reasons as in BaseSelectionBlock
            context.data[self.predictions_frame] = context.data[self.target_frame].copy()
            return context
        context.data[self.predictions_frame] = self.create_prediction_frame(context)
        return context

    def create_prediction_frame(self, context):
        basket = context.data[ids.FRAME_KEY_BASKET]
        # restrict items for neighbor searching by requested categories
        if context.requested_categories is not None:
            basket = basket[np.in1d(basket['category'], context.requested_categories)]

        estimator = self.create_estimator(context, self.estimator_class)
        predictions = estimator.get_neighbors(basket)
        # NOTE: predictions will containg special 'similar_to' reference column
        predictions = self.leave_unique_predictions(predictions)
        return predictions


class HybridEstimatorSelectionBlock(EstimatorSelectionBlock):
    def __init__(self, estimator_class, estimator_params=None, top_n=None,
                 user_features_frame=None, item_features_frame=None, *args, **kwargs):
        super(HybridEstimatorSelectionBlock, self).__init__(
            estimator_class, estimator_params, top_n, *args, **kwargs
        )
        self.user_features_frame = user_features_frame
        self.item_features_frame = item_features_frame

    def get_estimator_predictions(self, context, estimator, target, basket):
        return estimator.predict_top_n(
            X=target,
            n=self.top_n,
            user_features=context.data.get(self.user_features_frame),
            item_features=context.data.get(self.item_features_frame),
            basket=basket
        )


def is_yandex_phone(user_features):
    if len(user_features) == 0:
        return False
    device_model = user_features['device_model'][0]
    device_manufacturer = user_features['device_manufacturer'][0]
    return (device_manufacturer, device_model) in app.config['YPHONE_DEVICE_MODELS']


class AvailableGiftsSelectionBlock(BaseSelectionBlock):
    """Ignore top_n, select all not activated gifts"""

    def create_prediction_frame(self, context):
        target_frame = context.data[self.target_frame]
        assert 'passport_uid' in target_frame, 'No passport_uid in target frame'

        passport_uid = target_frame['passport_uid'][0]
        device_id = target_frame['user'][0]

        dtype = [('user', np.object), ('id', np.object), ('item', np.object), ('code', np.object)]

        if not is_yandex_phone(context.data[ids.FRAME_KEY_USER_FEATURES]):
            return DataFrame(np.array([], dtype=dtype))

        if not passport_uid:
            result = [(device_id, LAUNCHER_LOGIN_ACTION_ID, LAUNCHER_LOGIN_ACTION_ITEM, '')]
            return DataFrame(np.array(result, dtype=dtype))

        result = []
        gift_set = GiftSet.objects(passport_uid=passport_uid)

        if not gift_set:
            return DataFrame(np.array(result, dtype=dtype))

        for given_gift in gift_set[0].gifts:
            if given_gift.activated:
                continue
            try:
                gift = given_gift.gift
                result.append((device_id,
                               gift.id,
                               gift.package_name,
                               given_gift.promocode or ''))
            except DoesNotExist:
                logger.warning('No gift wit id "%s", skip', given_gift.gift)

        return DataFrame(np.array(result, dtype=dtype))


class YandexDistributedAppSelectionBlock(BaseSelectionBlock):
    """ Selects all app marked as show_in_feed """

    def create_prediction_frame(self, context):
        dtype = [('user', np.object), ('item', np.object), ('offer_id', np.object)]

        @fast_cache.memoize(5 * 60)
        def get_distributed_apps():
            apps_with_resources = set(BonusCardResources.objects.distinct('package_name'))
            apps = []
            for app in YandexDistributedApp.objects(show_in_feed=True):
                if app.package_name in apps_with_resources:
                    apps.append((user, app.package_name, app.offer_id))
            return apps

        if not is_yandex_phone(context.data[ids.FRAME_KEY_USER_FEATURES]):
            return DataFrame(np.array([], dtype=dtype))

        target_frame = context.data[self.target_frame]
        user = target_frame['user'][0]

        prediction = get_distributed_apps()
        return DataFrame(np.array(prediction, dtype=dtype))
