import numpy as np
from scipy import sparse
from sklearn.preprocessing import LabelEncoder

from jafar.estimators.base import GroupFeatureAwareEstimator
from jafar.estimators.knn.base import BaseKNN
from jafar.storages.exceptions import StorageKeyError
from jafar.storages.memory import MemoryStorage
from jafar.utils import get_index_pairs
from jafar.utils.structarrays import DataFrame


class ItemItem(BaseKNN):
    basket_required = True

    def _get_prediction_matrix(self, basket, user_encoder):
        """
        :param basket: input/target frame containing items users already own. Should
                       be previously compressed by `compress_basket` method.
        :param user_encoder: an instance of sklearn's LabelEncoder, obtained from `compress_basket` as well.
        :return: sparse users x items prediction matrix compressed along rows (number of rows
                 if equal to number of used_encoder classes)
        """
        metric = self.reconstruct_item_item_distance_matrix(list(np.int32(np.unique(basket['item']))))
        return np.dot(
            self.sparsify_basket(basket, n_users=len(user_encoder.classes_)),
            metric
        )

    def _get_top_n_predictions_from_frame(self, frame, n):
        # then sort in lexicographic order: first by ratings (ascending), then by users
        idx = np.lexsort((-frame['value'], frame['user']))
        frame = frame[idx]

        result = []
        # then split by user-continuous chunks and leave top-n for each user
        index_pairs = get_index_pairs(frame['user'])

        for pair in index_pairs:
            i, j = pair[0], pair[1]
            user_predictions = frame[i: j][:n]
            result.append(DataFrame.from_dict(dict(
                user=user_predictions['user'],
                item=user_predictions['item'],
                value=user_predictions['value']), self.predictions_dtype))
        result = DataFrame.concatenate(result)
        # discard zero-similarity results
        return result[result['value'] != 0]

    def predict(self, X, basket=None):
        X.assert_has_columns(('user', 'item'))
        if basket is None or len(basket) == 0:
            # if basket is empty, leave the original user-item frame, but set scores to 0
            return X.append_column(np.zeros(len(X), dtype=np.float32), 'value')

        basket, user_encoder = self.compress_basket(X, basket)
        original_user_ids = X['user'].copy()
        compressed_user_ids = user_encoder.transform(X['user']).ravel()
        prediction_matrix = self._get_prediction_matrix(basket, user_encoder)

        if len(X) > 0:
            prediction = np.array(prediction_matrix[np.int32(compressed_user_ids), np.int32(X['item'])]).ravel()
        else:
            prediction = np.array([], dtype=np.float32)
        return DataFrame.from_dict(dict(
            user=original_user_ids,
            item=X['item'],
            value=prediction.astype(np.float32)
        ))

    def predict_top_n(self, X, n, basket=None):
        X.assert_has_columns(('user',))
        assert isinstance(n, int), 'n must be integer'

        if basket is None or len(basket) == 0:
            # if basket is empty, return an empty frame
            return DataFrame.from_structarray(np.empty(0, dtype=self.predictions_dtype))

        basket, user_encoder = self.compress_basket(X, basket)
        original_user_ids = X['user'].copy()
        prediction_matrix = self._get_prediction_matrix(basket, user_encoder).tocoo()
        # select only rows corresponding to requested users

        # make a prediction frame
        predictions = DataFrame.from_dict(dict(
            user=user_encoder.inverse_transform(prediction_matrix.row).ravel(),
            item=prediction_matrix.col,
            value=prediction_matrix.data
        ))
        predictions = predictions[np.in1d(predictions['user'], original_user_ids)]
        return self._get_top_n_predictions_from_frame(predictions, n)

    def reconstruct_item_item_distance_matrix(self, items):
        # reconstruct item-item distance matrix (aka "metric") for a specific set of items
        items_neighbors = self.storage.get_matrix_rows(self.key_for('knn_indices'), items)
        items_distances = self.storage.get_matrix_rows(self.key_for('knn_distances'), items)

        try:
            k = self.storage.get_object(self.key_for('k'))[0]
        except StorageKeyError:
            k = self.k

        row_idx = np.repeat(items, k)
        col_idx = items_neighbors.ravel(order='C')
        value = (1 - items_distances).ravel(order='C')

        # add diagonal entries
        row_idx = np.hstack([row_idx, np.arange(self.n_items)])
        col_idx = np.hstack([col_idx, np.arange(self.n_items)])
        value = np.hstack([value, np.ones(self.n_items)])

        # make a matrix
        return sparse.csr_matrix((value, (row_idx, col_idx)), shape=(self.n_items, self.n_items), dtype=np.float32)

    def sparsify_basket(self, basket, n_users):
        M = self.construct_sparse_matrix(
            basket['user'], basket['item'], np.ones(len(basket)),
            shape=(n_users, self.n_items)
        ).astype('float32')
        item_counts = M.sum(axis=1)
        item_counts += (item_counts == 0)  # to avoid division by zero
        item_counts = 1.0 / item_counts
        item_counts = sparse.csc_matrix(item_counts)
        M = M.multiply(item_counts)
        return M


class GroupFeatureAwareItemItem(ItemItem, GroupFeatureAwareEstimator):
    group_feature_name = None

    def fit(self, X):
        group_feature_names = self.get_group_features_list()

        X.assert_has_columns(['user', 'item', 'value'] + group_feature_names)
        X = X.drop_null(columns=group_feature_names)

        # leave only positive interactions and relevant columns
        idx = X['value'] == 1
        X = X[['user', 'item', 'value'] + group_feature_names][idx]

        # NOTE: knn_indices are pre-filled with zeros and not MISSING_ID values
        # because they will eventually be converted to sparse matrix (no negative values allowed).
        knn_indices = np.zeros((self.n_items, self.k), dtype=np.int32)
        knn_distances = np.zeros((self.n_items, self.k), dtype=np.float32)

        group_feature_encoder = self.create_group_feature_encoder(group_feature_names, X)
        # for each item, keep track of corresponding category
        # this will be used for groupby purposes, so integers are fine
        group_feature_item_indices = np.zeros(self.n_items, dtype=np.int32)

        for group_feature_value, frame in X.groupby(group_feature_names):
            group_feature_value_idx = group_feature_encoder.transform(
                np.array(tuple(group_feature_value), dtype=group_feature_value.dtype))

            # creating a sub-category estimator here. to be able
            # to do that, we need to encode user and items as labels
            item_encoder = LabelEncoder().fit(frame['item'])
            frame['item'] = item_encoder.transform(frame['item'])
            user_encoder = LabelEncoder().fit(frame['user'])
            frame['user'] = user_encoder.transform(frame['user'])

            estimator = ItemItem(
                n_users=len(user_encoder.classes_),
                n_items=len(item_encoder.classes_),
                storage=MemoryStorage(),
                k=self.k,
                similarity=self.similarity,
                similarity_params=self.similarity_params
            ).fit(frame)
            group_feature_knn_indices = estimator.storage.get_object(estimator.key_for('knn_indices'))
            group_feature_knn_distances = estimator.storage.get_object(estimator.key_for('knn_distances'))

            # decode knn_indices back to original labels
            n_items, n_neighbors = group_feature_knn_indices.shape
            group_feature_knn_indices = item_encoder.inverse_transform(
                group_feature_knn_indices.reshape(-1, 1)
            ).reshape(n_items, n_neighbors)

            # insert sub-category items/distances into overall arrays
            row_idx = item_encoder.inverse_transform(np.arange(estimator.n_items))
            knn_indices[row_idx, :n_neighbors] = group_feature_knn_indices
            knn_distances[row_idx, :n_neighbors] = group_feature_knn_distances

            # keep track of category-item relation
            group_feature_item_indices[row_idx] = group_feature_value_idx

        self.storage.store(self.key_for('knn_indices'), knn_indices)
        self.storage.store(self.key_for('knn_distances'), knn_distances)
        self.storage.store(self.key_for('_'.join(group_feature_names) + '_item_indices'), group_feature_item_indices)
        self.storage.store(self.key_for('k'), [self.k])

        return self

    def predict_top_n_per_group_feature(self, X, n, basket=None, group_feature_values=None):
        """
        :param X: frame containing users to recommend to
        :param n: number of recommendations to return (per user, per category)
        :param basket: frame containing items users already own
        :param group_feature_values: list of group feature values (e.g. categories)
                                     to restrict recommendations to (by default all categories are user)
        """
        group_feature_names = self.get_group_features_list()

        X.assert_has_columns(('user',))
        basket.assert_has_columns(['user', 'item'] + group_feature_names)

        if basket is None or len(basket) == 0:
            return DataFrame.from_structarray(np.empty(0, dtype=self.predictions_dtype))

        if group_feature_values is not None:
            basket = basket[basket[group_feature_names].is_in(
                np.array(group_feature_values, dtype=basket[group_feature_names].dtype)
            )]

        basket, user_encoder = self.compress_basket(X, basket)
        original_user_ids = X['user'].copy()
        prediction_matrix = self._get_prediction_matrix(basket, user_encoder).tocoo()
        # select only rows corresponding to requested users

        group_feature_item_indices = self.storage.get_proxy(
            self.key_for('_'.join(group_feature_names) + '_item_indices')
        )
        # make a prediction frame
        raw_predictions = DataFrame.from_dict(dict(
            user=user_encoder.inverse_transform(prediction_matrix.row).ravel(),
            item=prediction_matrix.col,
            value=prediction_matrix.data,
            # add category index for groupby
            group_feature_idx=group_feature_item_indices[prediction_matrix.col]
        ))
        raw_predictions = raw_predictions[np.in1d(raw_predictions['user'], original_user_ids)]

        # leave top-n items for each user per category
        predictions = [DataFrame.from_structarray(np.empty(0, dtype=self.predictions_dtype))]

        if group_feature_values is None:
            # select only overall top-n predictions
            predictions.append(self._get_top_n_predictions_from_frame(raw_predictions, n))
        else:
            for group_feature_idx, frame in raw_predictions.groupby('group_feature_idx'):
                if len(frame) > 0:
                    predictions.append(self._get_top_n_predictions_from_frame(frame, n))

        return DataFrame.concatenate(predictions)


class CategoryAwareItemItem(GroupFeatureAwareItemItem):
    group_feature_name = 'category'
