import logging

import numpy as np

from jafar.pipelines import ids
from jafar.pipelines.blocks import SingleContextBlock
from jafar.utils import add_implicit_negatives
from jafar.utils.structarrays import DataFrame

logger = logging.getLogger(__name__)

DEFAULT_IMPLICIT_NEGATIVE_FACTOR = 1


class BaseTargetBlock(SingleContextBlock):
    """
    Target block creates the main `target` frame which is going to serve
    as an input to a classifier. Hence, a target column/variable (for now, binary)
    is required to be able to train the classifier.

    Target blocks can use different kinds of user-item interaction frames as input
    """

    def __init__(self, input_frame, target_frame):
        super(BaseTargetBlock, self).__init__(
            input_data=[input_frame], output_data=[target_frame], destroyed_data=None
        )
        self.input_frame = input_frame
        self.target_frame = target_frame

    def apply(self, context, train):
        if not train:
            # no target for prediction pass
            return context

        context.data[self.target_frame] = self.create_target_frame(context)
        return context

    def create_target_frame(self, context):
        raise NotImplementedError


class ImplicitNegativeTargetBlockMixin(object):
    """
    This block requires a frame which contains only positive user-item
    interactions (for example, items users install). In order to complete
    the target frame it samples "implicit negative" targets (for a given
    user, selects items this user have _no_ interactions with).
    """

    def __init__(self, input_frame, target_frame=ids.FRAME_KEY_TARGET,
                 implicit_negative_factor=DEFAULT_IMPLICIT_NEGATIVE_FACTOR):
        """
        :param implicit_negative_factor: proportion of implicit negative samples to choose.
        default value is 1 (choose exactly the same amount as true positive items)
        """
        super(ImplicitNegativeTargetBlockMixin, self).__init__(input_frame=input_frame, target_frame=target_frame)
        self.implicit_negative_factor = implicit_negative_factor

    def create_target_frame(self, context):
        target_frame = context.data[self.input_frame].copy()
        target_frame = add_implicit_negatives(
            target_frame,
            n_items=context.n_items,
            factor=self.implicit_negative_factor,
            max_attempts=1000
        )
        return target_frame


class ImplicitNegativeTargetBlock(ImplicitNegativeTargetBlockMixin, BaseTargetBlock):
    pass


class ConversionTargetBlockMixin(object):
    """
    This block specifically operates on "conversions" dataset. It doesn't require
    implicit negative sampling because conversion dataset already contains negative
    targets, some of which may be deduplicated and discarded.
    """

    def __init__(self, input_frame, target_frame, deduplicated=False):
        super(ConversionTargetBlockMixin, self).__init__(input_frame=input_frame, target_frame=target_frame)
        self.deduplicated = deduplicated

    @staticmethod
    def deduplicate_target(frame):
        """
        leave one row for each user-item pair
        'value': 0 - item was recommended, but not installed; 1 - recommended and installed
        """

        # leave just user-item-value fields since all other information we'll be lost after deduplication
        frame = frame[['user', 'item', 'value']]

        def deduplicate_generator(frame_):
            for idx, group in frame_.groupby(['user', 'item']):
                yield idx[0], idx[1], 1.0 if np.sum(group['value']) > 0.0 else 0.0
        deduplicated = zip(*deduplicate_generator(frame))
        return DataFrame.from_dict(
            {'user': deduplicated[0], 'item': deduplicated[1], 'value': deduplicated[2]},
            dtype=frame.dtype
        )

    def create_target_frame(self, context):
        target_frame = context.data[self.input_frame].copy()
        if self.deduplicated:
            target_frame = self.deduplicate_target(target_frame)
        logger.info('Positive targets count: {}'.format(np.sum(target_frame['value'])))
        return target_frame


class ConversionTargetBlock(ConversionTargetBlockMixin, BaseTargetBlock):
    def __init__(self, input_frame, target_frame=ids.FRAME_KEY_TARGET, deduplicated=False):
        ConversionTargetBlockMixin.__init__(self, input_frame=input_frame, target_frame=target_frame, deduplicated=deduplicated)
