import numpy as np
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
from sklearn.preprocessing import LabelEncoder

from jafar.estimators.base import BaseEstimator
from jafar.utils.structarrays import DataFrame


def sigmoid(x):
    """ Sigmoid function as it used in Logit - 1/(1+exp(-x)). In-place calculation """
    x *= -1
    np.exp(x, x)
    x += 1
    np.reciprocal(x, x)
    return x


def logit_item_importance(classifier):
    """ Calculates item importance as class prediction in case of only one item is in features """
    prob = classifier.coef_ + classifier.intercept_[:, np.newaxis]
    prob = sigmoid(prob)
    prob = normalize_classes(prob)
    return prob


def normalize_classes(prob):
    prob /= prob.sum(axis=1).reshape((prob.shape[0], -1))
    return prob


class FeatureImputation(BaseEstimator):
    basket_required = False
    classifier_class = LogisticRegression

    def __init__(self, feature_columns=None, threshold=0.8, classifier_params=None, **kwargs):
        super(FeatureImputation, self).__init__(**kwargs)
        self.feature_columns = feature_columns
        self.threshold = threshold
        self.classifier_params = classifier_params or {'C': 3}

    def fit(self, X, y):
        """
        :param X: user-item interactions frame
        :param y: user features frame that contains column with some categorial feature.
        :return: self
        """
        X.assert_has_columns(('user', 'item', 'value'))
        y.assert_has_columns(['user'] + self.feature_columns)

        self.classes = np.array(range(len(self.feature_columns)))

        # Leave only positive interactions
        X = X[(X['value'] == 1)]
        interactions_matrix = self.construct_sparse_matrix(X['user'], X['item'], X['value'])

        target, mask = self._create_target(y)
        interactions_matrix = interactions_matrix[mask, :]

        classifier = self.classifier_class(**self.classifier_params)
        classifier.fit(interactions_matrix, target)
        # noinspection PyTypeChecker
        self._store_classifier(classifier)
        return self

    def impute(self, X, y):
        mask = np.ones(len(y), dtype=np.bool)
        for column in self.feature_columns:
            mask = mask & np.isnan(y[column])
        if not np.any(mask):  # nothing to impute
            return y
        impute_users = y['user'][mask]
        X.assert_has_columns(('user', 'item'))

        # leave only positive items and users to be imputed
        x_mask = np.in1d(X['user'], impute_users)
        if 'value' in X.dtype.names:
            x_mask &= (X['value'] == 1)
        X = X[x_mask]

        # compress users to reduce sparse matrix size
        user_encoder = LabelEncoder().fit(impute_users)
        impute_users = user_encoder.transform(impute_users)
        users = user_encoder.transform(X['user'])

        n_users = len(user_encoder.classes_)
        interactions_matrix = self.construct_sparse_matrix(
            users=users,
            items=X['item'],
            values=np.ones(len(X)),
            shape=(n_users, self.n_items))[impute_users, :]
        classifier = self._load_classifier()
        impute = classifier.predict_proba(interactions_matrix)
        for idx, column in enumerate(self.feature_columns):
            y[column][mask] = impute[:, idx]
        return y

    @staticmethod
    def calc_item_importance(classifier):
        """ Override this when using classifier other than LogisticRegression """
        return logit_item_importance(classifier)

    def item_importance(self, predictions, feature):
        try:
            idx = self.feature_columns.index(feature)
        except ValueError:
            raise ValueError('%s not in estimator.feature_columns=%r' % (feature, self.feature_columns))
        items = predictions['item']
        importance = self.storage.get_proxy(self.key_for('item_importance'))
        return DataFrame.from_dict({
            'user': predictions['user'],
            'item': items,
            'value': importance[idx, items]
        })

    def _create_target(self, y):
        target = np.full(shape=self.n_users, fill_value=np.nan)
        n_features = len(self.feature_columns)
        if n_features == 1:
            feature = y[self.feature_columns[0]]
            y = y[(feature > self.threshold) | (feature < (1 - self.threshold))]
            # noinspection PyTypeChecker
            target[y['user']] = np.where(y[self.feature_columns[0]] > 0.5, 0.0, 1.0)
        else:
            y_filtered = y[np.max(y[self.feature_columns], axis=1) > self.threshold]
            target[y_filtered['user']] = np.argmax(y_filtered[self.feature_columns], axis=1)
        mask = ~np.isnan(target)
        target = target[mask]
        target_classes = len(np.unique(target))
        assert (n_features == 1 and target_classes == 2) or target_classes == n_features, \
            'Could not fit estimator: fit data has no example with one of categories. ' \
            'Got %d targets for %d feature classes' % (target_classes, n_features)
        return target, mask

    def _store_classifier(self, classifier):
        self.storage.store(self.key_for('item_importance'), self.calc_item_importance(classifier))
        self.storage.store(self.key_for('coef'), classifier.coef_)
        self.storage.store(self.key_for('intercept'), classifier.intercept_)

    def _load_classifier(self):
        classifier = self.classifier_class(**self.classifier_params)
        classifier.intercept_ = self.storage.get_proxy(self.key_for('intercept'))
        classifier.coef_ = self.storage.get_proxy(self.key_for('coef'))
        return classifier


class FeatureImputationCV(FeatureImputation):
    classifier_class = LogisticRegressionCV

    def __init__(self, classifier_params=None, **kwargs):
        classifier_params = classifier_params or {'Cs': np.logspace(1e-4, 1e4, 10)}
        super(FeatureImputationCV, self).__init__(classifier_params=classifier_params, **kwargs)
