import logging

from jafar.estimators.impute import FeatureImputation
from jafar.pipelines import ids
from jafar.pipelines.blocks import SingleContextBlock

logger = logging.getLogger(__name__)


class EstimatorBlockMixin(object):
    """
    Combines estimator-based helper methods.
    """

    def create_estimator(self, context, estimator_class, estimator_params=None):
        return estimator_class(
            n_users=context.n_users,
            n_items=context.n_items,
            storage=context.storage,
            key_prefix=self.key_for(context),
            **(estimator_params or {})
        )


class EstimatorBlock(EstimatorBlockMixin, SingleContextBlock):
    def __init__(self, input_frame, estimator_class, estimator_params=None, *args, **kwargs):
        super(EstimatorBlock, self).__init__(
            input_data=[input_frame], output_data=None, destroyed_data=None, *args, **kwargs
        )
        self.input_frame = input_frame
        self.estimator_class = estimator_class
        self.estimator_params = estimator_params or {}

    def fit(self, context):
        estimator = self.create_estimator(context, self.estimator_class, self.estimator_params)
        estimator.fit(context.data[self.input_frame])
        return estimator

    def apply(self, context, train):
        if train:
            self.fit(context)
        return context


class HybridEstimatorBlock(EstimatorBlock):
    def __init__(self, input_frame, estimator_class,
                 user_features_frame=None, item_features_frame=None, estimator_params=None):
        super(HybridEstimatorBlock, self).__init__(
            input_frame, estimator_class, estimator_params
        )
        self.user_features_frame = user_features_frame
        self.item_features_frame = item_features_frame

    def fit(self, context):
        estimator = self.create_estimator(context, self.estimator_class, self.estimator_params)
        estimator.fit(
            context.data[self.input_frame],
            user_features=context.data.get(self.user_features_frame),
            item_features=context.data.get(self.item_features_frame)
        )
        return estimator


class FeatureImputationEstimatorBlock(EstimatorBlock):
    def __init__(self, input_frame=ids.FRAME_KEY_ADVISOR_MONGO_INSTALLS,
                 estimator_class=FeatureImputation, features_frame=ids.FRAME_KEY_USER_FEATURES,
                 *args, **kwargs):
        self.features_frame = features_frame
        super(FeatureImputationEstimatorBlock, self).__init__(
            input_frame=input_frame, estimator_class=estimator_class,
            *args, **kwargs
        )

    def fit(self, context):
        features_data = context.data[self.features_frame]
        estimator = self.create_estimator(context, self.estimator_class, self.estimator_params)
        estimator.fit(X=context.data[self.input_frame], y=features_data)
        return estimator

    def apply(self, context, train):
        if train:
            estimator = self.fit(context)
        else:
            estimator = self.create_estimator(context, self.estimator_class, self.estimator_params)
        # Do the feature imputation here
        input_data = context.data[self.input_frame]
        features = context.data[self.features_frame]
        context.data[self.features_frame] = estimator.impute(X=input_data, y=features)
        return context
