import logging

import numpy as np

from jafar.estimators.impute import FeatureImputation
from jafar.pipelines import ids
from jafar.pipelines.blocks import SingleContextBlock
from jafar.pipelines.blocks.estimator import EstimatorBlockMixin
from jafar.pipelines.misc import extract_feature_names
from jafar.utils.structarrays import get_null_index, DataFrame

logger = logging.getLogger(__name__)


class MultipleFeaturesBlock(SingleContextBlock):
    """
    Adds new columns to the specified dataframe.
    """

    def __init__(self, input_frame, feature_names, *args, **kwargs):
        super(MultipleFeaturesBlock, self).__init__(
            input_data=[input_frame], *args, **kwargs
        )
        self.input_frame = input_frame
        self.feature_names = feature_names


class FeatureBlock(SingleContextBlock):
    """
    Special case with only one feature
    """

    def __init__(self, input_frame, feature_name, *args, **kwargs):
        super(FeatureBlock, self).__init__(
            input_data=[input_frame], *args, **kwargs
        )
        self.input_frame = input_frame
        self.feature_name = feature_name


class ConstantFeatureBlock(FeatureBlock):
    """
    Sets feature to default value if column is missing in predictions frame
    """

    def __init__(self, value=None, dtype=None, *args, **kwargs):
        super(ConstantFeatureBlock, self).__init__(*args, **kwargs)
        self.value = value
        self.dtype = dtype

    def apply(self, context, train):
        frame = context.data[self.input_frame]
        context.data[self.input_frame] = frame.append_column(
            data=np.full(len(frame), self.value, self.dtype),
            name=self.feature_name
        )
        return context


class CustomFeatureBlock(SingleContextBlock):
    """
    Calculates feature with custom function
    """

    def __init__(self, input_frame, feature_name, feature_function):
        self.input_frame = input_frame
        self.feature_name = feature_name
        self.feature_function = feature_function

    def apply(self, context, train):
        frame = context.data[self.input_frame]
        feature = np.array([self.feature_function(item) for item in frame], dtype=np.float32)
        frame = frame.append_column(feature, self.feature_name)
        context.data[self.input_frame] = frame
        return context


class PrecalculatedFeatureBlock(SingleContextBlock):
    """
    use columns from dataset as features
    """

    def __init__(self, input_frame, feature_name, column=None, missing_value=None):
        """
        :param feature_name: feature name
        :param column: column name (if differs from feature name)
        """
        self.input_frame = input_frame
        self.feature_name = feature_name
        self.column = column or feature_name
        self.missing_value = missing_value

    def apply(self, context, train):
        frame = context.data[self.input_frame]
        assert self.column in frame.dtype.names, \
            "Feature '{}' (column '{}') not found in {}".format(self.feature_name, self.column, self.input_frame)
        assert np.issubdtype(frame[self.column].dtype, np.number), \
            "Feature '{}' (column '{}') must be numeric".format(self.feature_name, self.column)

        feature = frame[self.column].astype(np.float32)
        if self.missing_value is not None:
            missing_idx = get_null_index(frame[self.column])
            feature[missing_idx] = self.missing_value

        context.data[self.input_frame] = frame.append_column(feature, self.feature_name)
        return context


class AggregatedFeatureBlock(FeatureBlock):
    """
    Takes multiple feature columns and aggregates them in
    a specified way (can me mean, max, argmax, custom function etc.)
    """

    def __init__(self, input_features, aggregation_function, *args, **kwargs):
        """
        :param input_features: list of features to aggregate
        :param output_feature: resulting feature name
        :param aggregation_function: a function that takes (n x m) frame
            (where m is len(input_features)) as input and outputs a vector
            of length n. Example: lambda frame: frame.max(axis=1)
        """
        super(AggregatedFeatureBlock, self).__init__(*args, **kwargs)
        self.input_features = input_features
        self.aggregation_function = aggregation_function

    def apply(self, context, train):
        frame = context.data[self.input_frame]
        input_features = extract_feature_names(frame, self.input_features)
        context.data[self.input_frame] = frame.append_column(
            self.aggregation_function(frame[input_features]),
            self.feature_name
        )
        return context


class ItemsStoreBlock(MultipleFeaturesBlock):
    """
    Stores item_features_frame in storage
    """

    def __init__(self, input_frame=ids.FRAME_KEY_ITEM_FEATURES, *args, **kwargs):
        """
        :param input_frame: preloaded frame with item features
        """
        super(ItemsStoreBlock, self).__init__(input_frame, *args, **kwargs)

    def apply(self, context, train):
        if not train:
            return context

        key = self.key_for(context, 'item_features')
        item_features_frame = context.data[self.input_frame]
        feature_names = extract_feature_names(item_features_frame, self.feature_names)
        for feature_name in feature_names:
            assert feature_name in item_features_frame.dtype.names, \
                '"%s" is not preloaded to item features frame' % feature_name
        names = [name for name in item_features_frame.columns if name in feature_names]
        # because items are in range(N), features can be saved in simple array
        # so that index of array is item ID
        features = item_features_frame[names][np.argsort(item_features_frame['item'])]
        context.storage.store(key, features)
        return context


class ItemFeaturesBlock(MultipleFeaturesBlock):
    """
    Appends item feature columns to input frame.
    Features should be stored by ItemsStoreBlock
    """

    def apply(self, context, train):
        key = self.key_for(context, 'item_features')
        item_features_frame = DataFrame(context.storage.get_proxy(key))
        feature_names = extract_feature_names(item_features_frame, self.feature_names)
        frame = context.data[self.input_frame]
        items = frame['item']

        # bizarre place to avoid huge copy at training and slowdown at runtime
        # [feature_names][items] way consumes less memory, but slow, so we use it at training only
        context.data[self.input_frame] = frame.append_columns(item_features_frame[feature_names][items] if train else
                                                              item_features_frame[items][feature_names])
        return context


class DetectUserCategoriesBlock(FeatureBlock):
    """
    Set requested_categories in context basing on basket items if requested_categories is None
    Does nothing otherwise.
    """

    def __init__(self, input_frame=ids.FRAME_KEY_BASKET, feature_name='category', *args, **kwargs):
        super(DetectUserCategoriesBlock, self).__init__(input_frame, feature_name, *args, **kwargs)

    def apply(self, context, train):
        if train:
            return context

        if context.requested_categories:
            return context
        frame = context.data[self.input_frame]
        users_categories = np.unique(frame[self.feature_name])
        context.requested_categories = list(set(users_categories) | set(context.default_categories or []))
        return context


class FeatureProductBlock(FeatureBlock):
    def __init__(self, input_features, resulting_feature, input_frame=ids.FRAME_KEY_PREDICTIONS):
        super(FeatureProductBlock, self).__init__(input_frame=input_frame, feature_name=resulting_feature)
        self.input_features = input_features

    def apply(self, context, train):
        array = context.data[self.input_frame]

        data = reduce(np.multiply, (array[feature] for feature in self.input_features))
        context.data[self.input_frame] = array.append_column(data, self.feature_name)
        return context


class EstimatorFeatureBlock(EstimatorBlockMixin, FeatureBlock):
    """
    Calculates the feature using previously fitted estimator.

    NOTE: if the estimator is itself some kind of ML model, it
    might be better to use `BlendingFeatureBlock` instead.
    """

    def __init__(self, estimator_class, input_frame, feature_name, estimator_result_name='value', estimator_params=None,
                 *args, **kwargs):
        super(EstimatorFeatureBlock, self).__init__(input_frame, feature_name, *args, **kwargs)
        self.estimator_class = estimator_class
        self.estimator_params = estimator_params or {}
        self.estimator_result_name = estimator_result_name

        # determine whether the block needs a basket frame (depending on specific estimator)
        if not hasattr(estimator_class, 'basket_required'):
            raise ValueError("Estimator {} doesn't have `basket_required` attribute".format(estimator_class))
        if estimator_class.basket_required:
            self.input_data.append(ids.FRAME_KEY_BASKET)

    def estimate_predictions(self, context, estimator, predictions, basket=None):
        return estimator.predict(
            predictions, basket=basket
        )

    def apply(self, context, train):
        logger.debug("Calculating '%s' feature for %s", self.feature_name, context)
        estimator = self.create_estimator(context, self.estimator_class, self.estimator_params)

        # grab required frames
        frame = context.data[self.input_frame]
        if self.estimator_class.basket_required:
            basket = context.data[ids.FRAME_KEY_BASKET]
        else:
            basket = None

        estimated = self.estimate_predictions(context, estimator, frame, basket)

        # check order of users/items hasn't changed
        assert np.array_equal(estimated['user'], frame['user'])
        assert np.array_equal(estimated['item'], frame['item'])
        context.data[self.input_frame] = frame.append_column(estimated[self.estimator_result_name], self.feature_name)
        return context


class HybridEstimatorFeatureBlock(EstimatorFeatureBlock):
    def __init__(self, estimator_class, input_frame, feature_name,
                 user_features_frame=None, item_features_frame=None,
                 *args, **kwargs):
        super(HybridEstimatorFeatureBlock, self).__init__(
            estimator_class, input_frame, feature_name, *args, **kwargs
        )
        self.user_features_frame = user_features_frame
        self.item_features_frame = item_features_frame

    def estimate_predictions(self, context, estimator, predictions, basket=None):
        return estimator.predict(
            predictions,
            user_features=context.data.get(self.user_features_frame),
            item_features=context.data.get(self.item_features_frame),
            basket=basket
        )


class ImputeItemImportanceBlock(EstimatorFeatureBlock):
    def __init__(self, item_feature_name, user_feature_name, estimator_class=FeatureImputation,
                 input_frame=ids.FRAME_KEY_PREDICTIONS, *args, **kwargs):
        super(ImputeItemImportanceBlock, self).__init__(
            estimator_class=estimator_class, input_frame=input_frame,
            feature_name=item_feature_name, *args, **kwargs
        )
        self.user_feature_name = user_feature_name

    def estimate_predictions(self, context, estimator, predictions, basket=None):
        return estimator.item_importance(predictions, self.user_feature_name)


class ALSExplanationFeatureBlock(EstimatorFeatureBlock):
    def __init__(self, *args, **kwargs):
        super(ALSExplanationFeatureBlock, self).__init__(estimator_result_name='similar_to', *args, **kwargs)

    def estimate_predictions(self, context, estimator, predictions, basket=None):
        return estimator.get_explanation(predictions, basket=basket)


class DropFeaturesBlock(MultipleFeaturesBlock):
    """
    Drops specified columns from the frame.
    """

    def apply(self, context, train):
        frame = context.data[self.input_frame]
        feature_names = extract_feature_names(frame, self.feature_names)
        context.data[self.input_frame] = frame.drop_columns(feature_names)
        return context


class ALSVanillaFeatureBlock(EstimatorFeatureBlock):
    def estimate_predictions(self, context, estimator, predictions, basket=None):
        return estimator.predict_vanilla(
            predictions, basket=basket
        )


class ALSEmbeddingFeatureBlock(EstimatorFeatureBlock):
    def get_embeddings(self, estimator, frame):
        raise NotImplementedError

    def apply(self, context, train):
        logger.debug("Calculating multiple '%s' features for %s", self.feature_name, context)
        estimator = self.create_estimator(context, self.estimator_class, self.estimator_params)
        embeddings = self.get_embeddings(estimator, context.data[self.input_frame])
        embeddings = DataFrame.from_dict({'{}_{}'.format(self.feature_name, number): data
                                          for number, data in enumerate(embeddings.T)})

        frame = context.data[self.input_frame]
        context.data[self.input_frame] = frame.append_columns(embeddings, replace=True)
        return context


class ItemALSEmbeddingFeatureBlock(ALSEmbeddingFeatureBlock):
    def get_embeddings(self, estimator, frame):
        return estimator.get_item_features(frame)


class UserALSEmbeddingFeatureBlock(ALSEmbeddingFeatureBlock):
    def get_embeddings(self, estimator, frame):
        return estimator.get_user_features(frame)


class MeanByUserFeatureBlock(MultipleFeaturesBlock):
    def apply(self, context, train):
        frame = context.data[self.input_frame]
        extracted_features = extract_feature_names(frame, self.feature_names)
        frame_slice = frame[extracted_features]
        mean_frame = DataFrame(np.empty(len(frame), dtype=frame[extracted_features].dtype))

        for key, idx in frame.arggroupby('user'):
            mean_frame[idx] = tuple(frame_slice[idx].mean(axis=0))

        mean_frame.rename_columns({column: 'mean_' + column for column in mean_frame.columns})

        frame = frame.append_columns(mean_frame, replace=False)
        context.data[self.input_frame] = frame
        return context
