import logging

import implicit
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import LabelEncoder

from jafar.estimators.knn.item_item import ItemItem, GroupFeatureAwareItemItem
from jafar.pipelines.blocks.mapping import MISSING_ID
from jafar.storages.exceptions import StorageKeyError
from jafar_yt.utils.structarrays import DataFrame

logger = logging.getLogger(__name__)


class ALS(ItemItem):
    basket_required = True

    def __init__(self, n_features=85, regularization=21.5, n_iters=55, alpha=140., save_user_features=False,
                 save_knns=True, value_column='value', *args, **kwargs):
        """
        Estimator based on Alternating Least Squares model from implicit
        and Annoy index for fast prediction and neighbors search.
        Stores 2 Annoy indexes: for neighbors search, for top-N prediction for user.
        :param n_users: total number of users
        :param n_items: total number of items
        :param storage: Storage
        :param key_prefix:
        :param n_features: ALS dimensionality
        :param regularization: lambda coefficient
        :param n_iters: number of ALS iterations
        :param alpha: multiplier of initial values for ALS
        :param k: number of neighbors in `get_neighbors`
        """
        super(ALS, self).__init__(*args, **kwargs)
        self.n_features = int(n_features)
        self.regularization = regularization
        self.n_iters = int(n_iters)
        self.alpha = alpha
        self.save_user_features = save_user_features
        self.save_knns = save_knns
        self.value_column = value_column

    def fit(self, X, y=None):
        """
        Fits ALS model and 2 Annoy indices for fast neighbor search and prediction respectively.
        :param X: structarray, must contain ('user', 'item', value_column) columns
        :param y:
        :return: self
        """
        X.assert_has_columns(('user', 'item', self.value_column))

        X = X[(X['user'] != MISSING_ID) & (X['item'] != MISSING_ID) & (X[self.value_column] > 0)]
        user_features, item_features = self._fit_als(X)

        self.storage.store(self.key_for('item_features'), item_features)
        self.storage.store(self.key_for('regularization'), [self.regularization])
        self.storage.store(self.key_for('alpha'), [self.alpha])
        self.storage.store(self.key_for('k'), [self.k])

        if self.save_user_features:
            self.storage.store(self.key_for('user_features'), user_features)

        if self.save_knns:
            knn_indices, knn_distances = self._calc_neighbors(item_features, self.k)
            self.storage.store(self.key_for('knn_indices'), knn_indices)
            self.storage.store(self.key_for('knn_distances'), knn_distances)

        return self

    def predict_vanilla(self, X, basket=None):
        """
        Calculate preference score for (user, item) pairs
        :param X: structarray, must contain ('user', 'item') columns
        :param basket:
        :return:
        """
        X.assert_has_columns(('user', 'item'))

        user_features = self.get_user_features(X)
        item_features = self.get_item_features(X)

        value = (user_features * item_features).sum(axis=1)

        return DataFrame.from_dict({
            'user': np.int32(X['user']),
            'item': np.int32(X['item']),
            self.value_column: np.float32(value)
        })

    def get_item_features(self, X):
        return self.storage.get_matrix_rows(self.key_for('item_features'), X['item'])

    def get_user_features(self, X):
        dim = len(self.storage.get_matrix_rows(self.key_for('item_features'), 0))
        result = np.empty((len(X), dim), dtype=np.float32)
        for user, idx in X.arggroupby('user'):
            try:
                result[idx] = self.storage.get_matrix_rows(self.key_for('user_features'), user)
            except StorageKeyError:
                result[idx] = self.calc_user_features(X[idx]['item'])
        return result

    def calc_user_features(self, items):
        if len(items) == 0:
            return 0  # ignore ALS feature
        item_features = self.storage.get_proxy(self.key_for('item_features'))
        user_items = self.construct_sparse_matrix(np.zeros_like(items), items, np.ones_like(items),
                                                  shape=(1, max(items) + 1))
        YtY = item_features.T.dot(item_features)
        regularization = self.storage.get_proxy(self.key_for('regularization'))[0]
        return implicit.als.user_factor(item_features, YtY, user_items, 0, regularization, item_features.shape[1])

    def _fit_als(self, X):
        M = self.construct_sparse_matrix(X['user'], X['item'], X[self.value_column]).astype('double')
        M *= self.alpha
        model = implicit.als.AlternatingLeastSquares(factors=self.n_features,
                                                     regularization=self.regularization,
                                                     iterations=self.n_iters)
        model.fit(M)

        return model.item_factors, model.user_factors

    @staticmethod
    def _calc_neighbors(features, k):
        n = features.shape[0]
        knn_indices = np.zeros((n, k), dtype=np.int)
        knn_distances = np.ones((n, k))

        block_size = 1000
        for block_start in xrange(0, n, block_size):
            block_end = min(block_start + block_size, n)
            block_vectors = features[block_start:block_end]
            block_knn = cosine_similarity(block_vectors, features)

            idx = np.argpartition(block_knn, -min(k, block_knn.shape[0]))[:, -k:]
            # weird 2 dimensional numpy behaviour
            index_arr = list(np.ix_(*[range(i) for i in block_knn.shape]))
            index_arr[-1] = idx

            knn_indices[block_start:block_end, 0:idx.shape[1]] = idx
            knn_distances[block_start:block_end, 0:idx.shape[1]] = 1. - block_knn[index_arr]

        return knn_indices, knn_distances


class GroupFeatureAwareALS(ALS, GroupFeatureAwareItemItem):
    def fit(self, X):
        X.assert_has_columns(('user', 'item', self.value_column))

        X = X[(X['user'] != MISSING_ID) & (X['item'] != MISSING_ID) & (X[self.value_column] > 0)]
        _, item_vecs = self._fit_als(X)
        # NOTE: knn_indices are pre-filled with zeros and not MISSING_ID values
        # because they will eventually be converted to sparse matrix (no negative values allowed).
        knn_indices = np.zeros((self.n_items, self.k), dtype=np.int32)
        knn_distances = np.zeros((self.n_items, self.k), dtype=np.float32)

        group_feature_names = self.get_group_features_list()
        group_feature_encoder = self.create_group_feature_encoder(group_feature_names, X)
        group_feature_item_indices = np.zeros(self.n_items, dtype=np.int32)
        group_items = DataFrame(np.unique(X[group_feature_names + ['item']]))

        for group_feature_value, frame in group_items.groupby(group_feature_names):
            group_feature_value_idx = group_feature_encoder.transform(
                np.array(tuple(group_feature_value), dtype=group_feature_value.dtype))

            grouped_vectors = item_vecs[frame['item']]
            item_encoder = LabelEncoder().fit(frame['item'])
            frame['item'] = item_encoder.transform(frame['item'])

            group_knn_indices, group_knn_distances = self._calc_neighbors(grouped_vectors, self.k)

            # decode knn_indices back to original labels
            n_items, n_neighbors = group_knn_indices.shape
            group_knn_indices = item_encoder.inverse_transform(
                group_knn_indices.reshape(-1, 1)
            ).reshape(n_items, n_neighbors)

            # insert sub-category items/distances into overall arrays
            row_idx = item_encoder.inverse_transform(np.arange(n_items))
            knn_indices[row_idx, :n_neighbors] = group_knn_indices
            knn_distances[row_idx, :n_neighbors] = group_knn_distances

            # keep track of category-item relation
            group_feature_item_indices[row_idx] = group_feature_value_idx

        self.storage.store(self.key_for('knn_indices'), knn_indices)
        self.storage.store(self.key_for('knn_distances'), knn_distances)
        self.storage.store(self.key_for('_'.join(group_feature_names) + '_item_indices'), group_feature_item_indices)
        self.storage.store(self.key_for('k'), [self.k])

        return self


class CategoryAwareALS(GroupFeatureAwareALS):
    group_feature_name = 'category'
    base_estimator = ALS
