import numpy as np

from jafar.pipelines import ids
from jafar.pipelines.blocks import SingleContextBlock


class ImpressionDiscountBlock(SingleContextBlock):
    """
    Implements impression discounting
    as described in this paper:
    http://mitultiwari.net/docs/papers/impression_discounting_kdd2014.pdf
    """
    discount_model_choices = 'exp', 'logexp', 'inverse'
    # TODO: next values are empirically calculated by hand (this has to be automated)
    defaults = {
        'model': 'inverse',
        'params': (0.0001, 0.0033)  # regression weights: w0, w1
    }
    min_discount = 0.0035

    def __init__(self, discount_model=None, discount_params=None, stat_name='rec_view_count', max_view_count=None):
        super(ImpressionDiscountBlock, self).__init__(
            input_data=[ids.FRAME_KEY_USAGE_STATS, ids.FRAME_KEY_PREDICTIONS],
        )
        if not discount_model and not discount_params:
            # assume defaults
            discount_model = self.defaults['model']
            discount_params = self.defaults['params']
        else:
            assert discount_model in self.discount_model_choices, \
                'discount_model expected one of {}, got {} instead'.format(self.discount_model_choices, discount_model)
        self.discount_model = discount_model
        self.discount_params = discount_params
        self.stat_name = stat_name
        self.max_view_count = max_view_count

    def get_discounts(self, view_counts):
        w0, w1 = self.discount_params
        if self.discount_model == 'exp':
            discounts = np.exp(w1 * view_counts + w0)
        elif self.discount_model == 'logexp':
            discounts = np.exp(w1 * np.log(view_counts) + w0)
        elif self.discount_model == 'inverse':
            discounts = w1 / view_counts + w0
        # beware of accidental zero divisions and log(0)
        # zero view count means no discount, i.e. min possible discount
        discounts[view_counts == 0] = self.min_discount
        return discounts.astype(np.float32)

    def apply(self, context, train):
        if train:
            # TODO: calculate impression discount model parameters here
            return context
        predictions = context.data[ids.FRAME_KEY_PREDICTIONS]
        usage_stats = context.data[ids.FRAME_KEY_USAGE_STATS]

        assert self.stat_name in usage_stats.dtype.names, \
            '"%s" column not found in usage stats, ' \
            'Check the OnlineUsageStats block for "rec_view_count" in counters being collected' % self.stat_name

        key = lambda x: (x['user'], x['item'])
        counts_mapping = {key(stat): stat[self.stat_name] for stat in usage_stats}
        view_counts = np.array([counts_mapping.get(key(prediction), 0) for prediction in predictions], dtype=np.int32)
        discounts = self.get_discounts(view_counts)
        context.data[ids.FRAME_KEY_PREDICTIONS]['value'] *= discounts

        if self.max_view_count is not None:
            context.data[ids.FRAME_KEY_PREDICTIONS] = context.data[ids.FRAME_KEY_PREDICTIONS][view_counts < self.max_view_count]

        return context
