"""
In jafar's loose terminology a "basket" corresponds to items user already owns
(the name is a derivative from "consumer basket"). Some collaborative filtering-like
estimators require this kind of information to be able to recommend candidate items
(often by some version of nearest neughbor algorithm).

Usually (during offline training and prediction, and online prediction as well) basket
is simply equivalent to "items installed by user". This convention, however, breaks
for cross-validation, when items available at train step become basket at test step.
Since this behavior is highly non-obvious (see https://st.yandex-team.ru/ADVISOR-1012),
the purpose of basket blocks is to make the distinction as explicit as possible.
"""

import logging

from jafar.pipelines import ids
from jafar.pipelines.blocks import SingleContextBlock

logger = logging.getLogger(__name__)


class BaseBasketBlock(SingleContextBlock):

    def __init__(self, input_frame):
        super(BaseBasketBlock, self).__init__(
            input_data=[input_frame], output_data=[ids.FRAME_KEY_BASKET], destroyed_data=None
        )
        self.input_frame = input_frame


class SimpleBasketBlock(BaseBasketBlock):
    """
    This block simply takes data array and uses it as basket
    """
    def __init__(self, input_frame, user_apps_only=True):
        super(SimpleBasketBlock, self).__init__(input_frame=input_frame)
        self.user_apps_only = user_apps_only

    def apply(self, context, train):
        context.data[ids.FRAME_KEY_BASKET] = context.data[self.input_frame]
        if self.user_apps_only and 'is_user_app' in context.data[self.input_frame]:
            mask = context.data[ids.FRAME_KEY_BASKET]['is_user_app']
            context.data[ids.FRAME_KEY_BASKET] = context.data[ids.FRAME_KEY_BASKET][mask]
        return context


class CrossValidationBasketBlock(BaseBasketBlock):
    """
    This block behaves differently depending on whether it is
    train or test step:

     * at train step, data frame is used as basket and is also cached in block's memory
     * at test step, the cached value is retrieved and used as basket as well
    """

    def __init__(self, input_frame):
        super(CrossValidationBasketBlock, self).__init__(input_frame)
        self.baskets = {}

    def apply(self, context, train):
        if train:
            self.baskets[context] = context.data[self.input_frame]
            context.data[ids.FRAME_KEY_BASKET] = context.data[self.input_frame]
        else:
            context.data[ids.FRAME_KEY_BASKET] = self.baskets.get(context)
        return context
