import logging
import numpy as np
from scipy.sparse import csr_matrix
from sklearn.base import BaseEstimator as BaseSklearnEstimator
from sklearn.preprocessing import LabelEncoder

from jafar.storages import make_key
from jafar.storages.memory import MemoryStorage

logger = logging.getLogger(__name__)

# since integers can't be NaN, reserving this value for missing users/items


class BaseEstimator(BaseSklearnEstimator):
    predictions_dtype = [('user', np.int32), ('item', np.int32), ('value', np.float32)]

    def __init__(self, n_users, n_items, storage=None, key_prefix=None):
        self.n_users = n_users
        self.n_items = n_items
        self.storage = storage or MemoryStorage()
        self.key_prefix = key_prefix

    def key_for(self, name):
        if self.key_prefix:
            return make_key(self.key_prefix, 'estimator', self.__class__.__name__.lower(), name)
        else:
            return make_key('estimator', self.__class__.__name__.lower(), name)

    def construct_sparse_matrix(self, users, items, values, shape=None):
        assert len(users) == len(items) == len(values), 'Users/items/values arrays should be of the same length'
        return csr_matrix(
            (values, (np.int32(users), np.int32(items))), shape=shape or (self.n_users, self.n_items)
        )

    def compress_basket(self, X, basket):
        """
        Renumerates user ids to 0..len(unique_users) - 1. Order
        is preserved. This can be useful when converting basket
        to "user X item" sparse matrix to reduce the number of rows.
        """
        basket = basket.copy()
        user_encoder = LabelEncoder().fit(np.union1d(np.unique(X['user']), np.unique(basket['user'])))
        basket['user'] = user_encoder.transform(basket['user'])
        return basket, user_encoder

    def fit(self, X):
        raise NotImplementedError

    def predict(self, X, basket=None):
        raise NotImplementedError

    def predict_top_n(self, X, n, basket=None):
        raise NotImplementedError



class GroupFeatureAwareEstimator(BaseEstimator):
    group_feature_name = None

    def get_group_features_list(self):
        if isinstance(self.group_feature_name, (list, tuple)):
            group_feature_names = list(self.group_feature_name)
        else:
            group_feature_names = [self.group_feature_name]
        return group_feature_names

    def create_group_feature_encoder(self, group_feature_names, dataset):
        group_feature_encoder = LabelEncoder().fit(dataset[group_feature_names])
        self.storage.store(
            self.key_for('{}_encoder_classes'.format('_'.join(group_feature_names))),
            group_feature_encoder.classes_
        )
        return group_feature_encoder

    def load_group_feature_encoder(self, group_feature_name):
        group_feature_encoder = LabelEncoder()
        group_feature_encoder.classes_ = self.storage.get_proxy(
            self.key_for('{}_encoder_classes'.format('_'.join(group_feature_name)))
        )
        return group_feature_encoder

    def get_group_feature_idx(self, group_feature_encoder, group_features=None):
        """
        Converts list of group features to list on integer indices.
        If `group_features` is None, returns all indices.
        """
        if group_features is not None:
            known_group_features = np.intersect1d(
                np.array(group_features, dtype=group_feature_encoder.classes_.dtype),
                group_feature_encoder.classes_
            )
            return group_feature_encoder.transform(known_group_features)
        else:
            return np.arange(len(group_feature_encoder.classes_))
