import numpy as np
from scipy import sparse

from jafar import estimators
from jafar.tests import JafarTestCase
from jafar.utils.structarrays import DataFrame

n_users = 10
n_items = 20
size = 20

X = DataFrame.from_dict(dict(
    user=[5, 6, 9, 7, 8, 4, 1, 9, 5, 8, 1, 1, 5, 3, 3, 9, 8, 1, 7, 5],
    item=[8, 5, 4, 3, 5, 3, 8, 9, 0, 3, 6, 3, 8, 2, 1, 6, 4, 7, 6, 3],
    timestamp=[1472896072] * 10 + [1473159038] * 10,
    value=np.ones(size)
))

locality_data = DataFrame.from_dict(dict(
    region=[1] * 20,
    item=range(20),
    score=[0.09, 0.42, 0.74, 0.17, 0.73, 0.05, 0.84, 0.52, 1., 0.65,
           0.2, 0.07, 0.88, 0.05, 0.83, 0.89, 0.61, 0.73, 0.87, 0.55]
))

# for testing 'predict'
test_user = 0
basket_items = [1, 2, 3]
test_items = [4, 5, 6, 7, 8]
X_test = DataFrame.from_dict(dict(user=[test_user] * len(test_items), item=test_items, city_region=[1] * len(test_items)))
basket = DataFrame.from_dict(dict(user=[test_user] * len(basket_items), item=basket_items))


class EstimatorStorageMixin(object):
    """
    Verifying that estimators can store
    parameters in storages and access them
    correctly.
    """

    def test_popular(self):
        est = estimators.Popular(n_users=n_users, n_items=n_items, storage=self.storage)
        est.fit(X)
        for name in ('items', 'scores', 'popularity'):
            try:
                value = est.storage.get_proxy(est.key_for(name))
            except Exception as e:
                self.fail("Could not get proxy for {}: {}".format(name, e))
            try:
                len(value)
                value[:2]
                value[2:]
                value[0]
            except Exception as e:
                self.fail('{} proxy does not behave like a valid list: {}'.format(name, e))

    def test_itemitem(self):
        est = estimators.ItemItem(n_users=n_users, n_items=n_items, storage=self.storage)
        est.fit(X)
        for name in ('knn_distances', 'knn_indices'):
            try:
                value = est.storage.get_matrix_rows(est.key_for(name), [1, 2, 3])
            except Exception as e:
                self.fail("Failed to call `get_matrix_rows`: {}".format(e))
            self.assertIsInstance(value, np.ndarray)


class MemoryEstimatorStorageTestCase(EstimatorStorageMixin, JafarTestCase):
    @property
    def storage(self):
        return self.memory_storage


class MemmapEstimatorStorageTestCase(EstimatorStorageMixin, JafarTestCase):
    @property
    def storage(self):
        return self.memmap_storage


class KnnTestCase(JafarTestCase):
    """
    Verify custom KNN algorithm implementation
    """

    def test_knn_sparsify_metric(self):
        # some random metrics metrics
        n = 100
        metric = np.matrix(np.random.rand(n, n))
        np.fill_diagonal(metric, 1.0)

        # calculate sparse metric with custom and sklearn algorithms
        k = 5
        metric_, distances_, indices_ = estimators.ItemItem.knn_sparsify_metric(sparse.csc_matrix(metric), k)
        metric_sklearn, distances_sklearn, indices_sklearn = estimators.ItemItem.knn_sparsify_metric_sklearn(
            sparse.csc_matrix(metric), k)

        self.assertTrue(np.allclose(metric_.todense(), metric_sklearn.todense()))
        self.assertTrue(np.allclose(np.sort(distances_), np.sort(distances_sklearn)))
        self.assertTrue(np.array_equal(np.sort(indices_), np.sort(indices_sklearn)))


class ItemItemTestCase(JafarTestCase):
    @property
    def storage(self):
        return self.memory_storage

    def test_stability(self):
        est = estimators.ItemItem(n_users=n_users, n_items=n_items, storage=self.storage)
        est.fit(X)
        est = estimators.ItemItem(n_users=n_users, n_items=n_items, storage=self.storage)
        predicted = est.predict(X=X_test, basket=basket)
        self.assertEqual(X_test['item'].to_list(), predicted['item'].to_list(), 'Items must be in the same order')
        self.assertEqual(X_test['user'].to_list(), predicted['user'].to_list(), 'Users must be in the same order')

    def test_top_n(self):
        est = estimators.ItemItem(n_users=n_users, n_items=n_items, storage=self.storage)
        est.fit(X)
        est = estimators.ItemItem(n_users=n_users, n_items=n_items, storage=self.storage)
        predicted = est.predict_top_n(X=DataFrame.from_dict(dict(user=[0])), n=5, basket=basket)
        self.assertEqual(len(predicted), 5, 'Must return N items')
        self.assertEqual(list(predicted[np.argsort(-predicted['value'])]['item']),
                         list(predicted['item']), 'Top N recommendations must be sorted')
        predicted = est.predict_top_n(X=DataFrame.from_dict(dict(user=[test_user])), n=5, basket=None)
        self.assertEqual(len(predicted), 0, 'Predicts nothing with empty basket')


class LocalityTestCase(JafarTestCase):
    @property
    def storage(self):
        return self.memory_storage

    def get_estimator(self, threshold=0.5):
        return estimators.LocalityEstimator(n_users=n_users, n_items=n_items, storage=self.storage,
                                            region_column='city_region', threshold=threshold)
    def test_stability(self):
        est = self.get_estimator()
        est.fit(locality_data)
        est = self.get_estimator()
        predicted = est.predict(X=X_test)
        self.assertEqual(X_test['item'].to_list(), predicted['item'].to_list(), 'Items must be in the same order')
        self.assertEqual(X_test['user'].to_list(), predicted['user'].to_list(), 'Users must be in the same order')

    def test_top_n(self):
        est = self.get_estimator()
        est.fit(locality_data)
        est = self.get_estimator()
        predicted = est.predict_top_n(X=DataFrame.from_dict(dict(user=[0], city_region=[1])), n=5, basket=basket)
        self.assertEqual(len(predicted), 5, "Should predict N=5 values")

    def test_threshold_n(self):
        est = self.get_estimator(0.85)
        est.fit(locality_data)
        est = self.get_estimator(0.85)
        predicted = est.predict_top_n(X=DataFrame.from_dict(dict(user=[0], city_region=[1])), n=5, basket=basket)
        self.assertEqual(len(predicted), 4)
        self.assertGreater(min(predicted['value']), 0.85)

    def test_unknown_city(self):
        est = self.get_estimator()
        est.fit(locality_data)
        est = self.get_estimator(0)
        predicted = est.predict_top_n(X=DataFrame.from_dict(dict(user=[0], city_region=[99])), n=5, basket=basket)
        self.assertEqual(len(predicted), 0)
        X_unknown_city = DataFrame.from_dict(dict(user=[test_user] * len(test_items),
                                        item=test_items,
                                        city_region=[99] * len(test_items)))
        predicted = est.predict(X=X_unknown_city)
        self.assertListEqual(predicted['value'].to_list(), [0] * 5)


class ALSTestCase(JafarTestCase):
    @property
    def storage(self):
        return self.memory_storage

    def test_stability(self):
        est = estimators.ALS(n_users=n_users, n_items=n_items, storage=self.storage)
        est.fit(X)
        est = estimators.ALS(n_users=n_users, n_items=n_items, storage=self.storage)
        predicted = est.predict(X=X_test, basket=basket)
        self.assertEqual(X_test['item'].to_list(), predicted['item'].to_list(), 'Items must be in the same order')
        self.assertEqual(X_test['user'].to_list(), predicted['user'].to_list(), 'Users must be in the same order')

    def test_top_n(self):
        est = estimators.ALS(n_users=n_users, n_items=n_items, storage=self.storage)
        est.fit(X)
        est = estimators.ALS(n_users=n_users, n_items=n_items, storage=self.storage)
        predicted = est.predict_top_n(X=DataFrame.from_dict(dict(user=[0])), n=5, basket=basket)
        self.assertEqual(len(predicted), 5, 'Must return N items')
        self.assertEqual(list(predicted[np.argsort(-predicted['value'])]['item']),
                         list(predicted['item']), 'Top N recommendations must be sorted')
        predicted = est.predict_top_n(X=DataFrame.from_dict(dict(user=[test_user])), n=5, basket=None)
        self.assertEqual(len(predicted), 0, 'Predicts nothing with empty basket')
