import numpy as np

from jafar.cross_validation import BlendingKFold
from jafar.pipelines import ids
from jafar.pipelines.blocks import SingleContextBlock
from jafar.pipelines.context import TemporaryPipelineContext
from jafar.storages.memory import MemoryStorage
from jafar.utils.structarrays import DataFrame


class BlendingFeatureBlock(SingleContextBlock):
    """
    the blending block does the following:
    - splits installs/predictions data into N folds
    - applies fit/predicts blocks consequentially to the the data chunks
    - then collect feature from different predictions folds
    """

    def __init__(self, fit_blocks, predict_blocks, features_names,
                 fit_frame, predict_frame, interaction_fields=('user', 'item'),
                 n_folds=3, fit_with_full_data=True, with_basket=True):
        """
        :param fit_blocks: a list of blocks (normally a list of a single block) that fit internal state of estimator
        :param predict_blocks: list of blocks that calculate the feature based on the fitted estimator
        :param features_names: names of the features appended by predict blocks
        :param fit_frame: name of the frame which is used to fit estimator
        :param predict_frame: name of the frame to append the feature to
        :param interaction_fields: common fields of fit and predict frame used as a basis for split
        :param n_folds: number of folds - the more folds the more train data in each folds
        :param fit_with_full_data: False if we do not want to leave the fitted estimator in persistent storage
        :param with_basket: True to override FRAME_KEY_BASKET with train data only
        """

        super(BlendingFeatureBlock, self).__init__(
            input_data=[fit_frame, predict_frame], output_data=None, destroyed_data=None
        )
        self.fit_frame = fit_frame
        self.predict_frame = predict_frame
        self.fit_blocks = fit_blocks
        self.predict_blocks = predict_blocks
        self.features_names = list(features_names)
        self.interaction_fields = list(interaction_fields)
        self.n_folds = n_folds
        self.fit_with_full_data = fit_with_full_data
        self.with_basket = with_basket

    def get_blocks(self):
        return self.fit_blocks + self.predict_blocks

    def blend_indices_generator(self, fit_data, predict_data):
        """
        splitting so that data from fit data is not leaked to predict data
        """
        # find unique user/apps interactions
        interactions = DataFrame.concatenate([fit_data[self.interaction_fields], predict_data[self.interaction_fields]])
        interactions = interactions[self.interaction_fields].unique()

        # generate fit/predict data values (similar to train/predict for cross-validation)
        # TODO: should cross-validation iterator be customizable here?
        for fit_idx, _ in BlendingKFold(interactions, self.n_folds):
            fit_interactions = interactions[fit_idx]
            yield (
                np.in1d(fit_data[self.interaction_fields], fit_interactions),
                ~np.in1d(predict_data[self.interaction_fields], fit_interactions)
            )

    def get_isolated_pipeline(self, blocks, context):
        """
        Takes a bunch of blocks, wraps them inside an.
        """
        from jafar.pipelines.pipeline import Pipeline

        return Pipeline(blocks, name=context.pipeline.name, storage=context.storage)

    def apply(self, context, train):
        if train:
            # keep the initial context and apply fit steps to it in case we want to have fitted estimator
            if self.fit_with_full_data:
                blend_fit_pipeline = self.get_isolated_pipeline(self.fit_blocks, context)
                blend_fit_pipeline.apply_blocks(train=train, initial_context=context)

            # splitting installs/predictions into N folds
            fit_data = context.data[self.fit_frame]
            predict_data = context.data[self.predict_frame]
            # add zero feature to initial contexts
            predict_data = predict_data.append_columns(
                DataFrame.from_dict({feature: np.zeros(len(predict_data), np.float32)
                                     for feature in self.features_names})
            )

            # run blending loop
            blend_pipeline = self.get_isolated_pipeline(self.fit_blocks + self.predict_blocks, context)
            for fit_idx, predict_idx in self.blend_indices_generator(fit_data, predict_data):
                # create temporary context for blend pass based on split data and initial context
                blend_data = context.data.copy()
                blend_data.update({
                    self.fit_frame: fit_data[fit_idx],
                    self.predict_frame: predict_data[predict_idx]
                })

                if self.with_basket:
                    blend_data[ids.FRAME_KEY_BASKET] = blend_data[ids.FRAME_KEY_BASKET][fit_idx]

                # use temporary storage - we do not want persistently store estimators trained on part of the data
                blend_context = TemporaryPipelineContext(
                    context.pipeline, context.country, context.requested_categories,
                    data=blend_data
                )
                # copy existing stuff from the main storage
                blend_context.temporary_storage = MemoryStorage.from_dict(context.storage.to_dict())

                result_context = blend_pipeline.apply_blocks(train=train, initial_context=blend_context)
                # collect fold predictions and assign them to corresponding indices
                for feature_name in self.features_names:
                    predict_data[feature_name][
                        predict_idx
                    ] = result_context.data[self.predict_frame][feature_name]

            # update context data
            context.data[self.predict_frame] = predict_data
            return context
        else:
            # just apply blocks to predict the feature
            blend_predict_pipeline = self.get_isolated_pipeline(self.predict_blocks, context)
            return blend_predict_pipeline.apply_blocks(train=train, initial_context=context)
