import logging
from time import time

import numpy as np
from flask import current_app as app

from jafar.data_providers.launcher.blacklist import blacklist_caches
from jafar.pipelines import ids
from jafar.pipelines.blocks import SingleContextBlock

logger = logging.getLogger(__name__)


class BannedItemsFilteringBlock(SingleContextBlock):
    """
    This block filters FRAME_KEY_PREDICTIONS against multiple
    sources of forbidden/banned items. For now these include:

     * recently (180 days by default) removed apps, if present
     * recently (180 days by default) disliked apps, if present
     * blacklisted items (jafar.data_providers.launcher.blacklist_caches)
    """

    def __init__(self, blacklist_component, disliked_items_ttl=None, removed_items_ttl=None):
        if disliked_items_ttl is None:
            disliked_items_ttl = app.config['DISLIKED_ITEMS_TTL']
        if removed_items_ttl is None:
            removed_items_ttl = app.config['REMOVED_ITEMS_TTL']
        self.disliked_items_ttl = disliked_items_ttl
        self.removed_items_ttl = removed_items_ttl
        self.blacklist_component = blacklist_component
        super(BannedItemsFilteringBlock, self).__init__(
            input_data=[ids.FRAME_KEY_PREDICTIONS], output_data=None, destroyed_data=None
        )

    def get_disliked_items(self, context):
        disliked_items = context.data[ids.FRAME_KEY_DISLIKED_ITEMS]
        return disliked_items['item'][disliked_items['timestamp'] > time() - self.disliked_items_ttl]

    def get_removed_items(self, context):
        removed_items = context.data[ids.FRAME_KEY_REMOVED_ITEMS]
        return removed_items['item'][removed_items['timestamp'] > time() - self.removed_items_ttl]

    def get_blacklists(self, context):
        """
        Contrary, blacklist returns pre-loaded (and prepared) sets.
        """
        if ids.FRAME_KEY_CLIDS in context.data:
            clids_frame = context.data[ids.FRAME_KEY_CLIDS]
            clids = dict(clids_frame[['name', 'value']])
        else:
            clids = {}
        blacklist_cache = blacklist_caches[self.blacklist_component]
        return blacklist_cache.get_blacklists(country=context.country, clids=clids)

    def apply(self, context, train):
        if train:
            return context

        items_to_filter = []
        predictions = context.data[ids.FRAME_KEY_PREDICTIONS]

        if ids.FRAME_KEY_DISLIKED_ITEMS in context.data:
            items_to_filter.append(self.get_disliked_items(context))
        if ids.FRAME_KEY_REMOVED_ITEMS in context.data:
            items_to_filter.append(self.get_removed_items(context))
        if ids.FRAME_DEFAULT_ITEMS in context.data:
            items_to_filter.append(context.data[ids.FRAME_DEFAULT_ITEMS]['item'])

        items = predictions['item']
        if items_to_filter:
            items_to_filter = np.concatenate(items_to_filter)
            banned_items_idx = np.isin(items, items_to_filter)
        else:
            banned_items_idx = np.zeros_like(items, dtype=np.bool)
        for blacklist in self.get_blacklists(context):
            for feature, filter_function in blacklist.get_filtering_rules():
                if feature not in predictions:
                    continue
                filter_function = np.vectorize(filter_function, otypes=[np.bool])
                banned_items_idx |= filter_function(predictions[feature])
            banned_items_idx &= ~np.isin(items, blacklist.get_excluded_items())

        if np.any(banned_items_idx):
            logger.debug("Filtering out items from blacklist: %s rows", np.sum(banned_items_idx))
            predictions = predictions[~banned_items_idx]
        else:
            logger.debug("Nothing filtered")
        context.data[ids.FRAME_KEY_PREDICTIONS] = predictions
        return context


class InstallsFilteringBlock(SingleContextBlock):
    """
    This block cleans up FRAME_KEY_PREDICTIONS by filtering
    out items from FRAME_KEY_ADVISOR_MONGO_INSTALLS (the ones user already have)
    and leaving unique user-item pairs.
    """

    def __init__(self):
        super(InstallsFilteringBlock, self).__init__(
            input_data=[ids.FRAME_KEY_PREDICTIONS], destroyed_data=None
        )

    @staticmethod
    def filter(context, predictions):
        """
        Filters out all installed items including system apps
        """
        installs = context.data[ids.FRAME_KEY_ADVISOR_MONGO_INSTALLS]
        return predictions[~predictions[['user', 'item']].is_in(installs[['user', 'item']])]

    def apply(self, context, train):
        if not train:
            predictions = context.data[ids.FRAME_KEY_PREDICTIONS]
            logger.debug("Filtering predictions frame: %s rows", len(predictions))
            predictions = self.filter(context, predictions)
            logger.debug("Filtering out items which are already in installs frame; %s rows left", len(predictions))
            context.data[ids.FRAME_KEY_PREDICTIONS] = predictions
        return context


class CustomFilteringBlock(SingleContextBlock):
    """
    Filters data with custom function
    """

    def __init__(self, input_data, filter_function, train_only=False):
        super(CustomFilteringBlock, self).__init__(
            input_data=input_data, destroyed_data=None
        )
        self.filter = filter_function
        self.train_only = train_only

    def apply(self, context, train):
        if train or not self.train_only:
            for key in self.input_data:
                data = context.data[key]
                logger.debug("Filtering %s frame: %s rows", key, len(data))
                new_data = self.filter(data)
                assert data.dtype == new_data.dtype, 'Filtering block should not change frame type'
                logger.debug("Filtered %s frame: %s rows left", key, len(new_data))
                context.data[key] = new_data
        return context


class UserSamplingBlock(SingleContextBlock):
    def __init__(self, input_frame, sample):
        super(UserSamplingBlock, self).__init__(input_data=[input_frame])
        assert 0 < sample < 1, 'Sample must be between 0 and 1'
        self.input_frame = input_frame
        self.sample = sample

    def apply(self, context, train):
        if train:
            frame = context.data[self.input_frame]
            users = np.unique(frame['user'])
            sampled_users = np.random.choice(users, int(len(users) * self.sample), replace=False)
            context.data[self.input_frame] = frame[np.in1d(frame['user'], sampled_users)]
            logger.debug(
                "Sampling %s users from %s: %s rows left (out of %s)",
                self.sample, self.input_frame, len(context.data[self.input_frame]), len(frame)
            )
        return context


class ItemCountFilteringBlock(CustomFilteringBlock):
    """
    Filters out users who have less or equal to `item_count` items.
    """

    def __init__(self, input_data, item_count):
        super(ItemCountFilteringBlock, self).__init__(
            input_data=input_data,
            filter_function=self.filter_by_item_count,
            train_only=True
        )
        self.item_count = item_count

    def filter_by_item_count(self, frame):
        logger.debug('Dropping users with less or equal to %s items', self.item_count)
        users, counts = np.unique(frame['user'], return_counts=True)
        idx = np.in1d(frame['user'], users[counts > self.item_count])
        return frame[idx]


class UserFeaturesFilteringBlock(SingleContextBlock):
    def __init__(self, input_frame, user_features):
        super(UserFeaturesFilteringBlock, self).__init__(input_data=[input_frame, ids.FRAME_KEY_USER_FEATURES])
        self.input_frame = input_frame
        self.user_features = user_features

    def apply(self, context, train):
        if train:
            frame = context.data[self.input_frame]
            user_features_frame = context.data[ids.FRAME_KEY_USER_FEATURES]
            all_users_count = len(user_features_frame)
            logger.debug("Filtering only users that have one of features: %s", ','.join(self.user_features))
            mask = np.any(~np.isnan(user_features_frame[self.user_features].to_2d_array()), axis=1)
            users = user_features_frame['user'][mask]
            context.data[self.input_frame] = frame[frame['user'].is_in(users)]
            logger.debug(
                "Leaving %d out of %d users from %s: %s rows left out of %d",
                len(users), all_users_count, self.input_frame, len(context.data[self.input_frame]), len(frame)
            )
        return context


class LeaveTopBlock(SingleContextBlock):
    def __init__(self, top_n, input_frame, feature_name='value'):
        super(LeaveTopBlock, self).__init__()
        self.top_n = top_n
        self.input_frame = input_frame
        self.feature_name = feature_name

    def apply(self, context, train):
        frame = context.data[self.input_frame]
        context.data[self.input_frame] = frame[np.argsort(-frame[self.feature_name])][:self.top_n]
        return context
