import numpy as np

from jafar.cross_validation import StratifiedShuffleSplit, FixedClassShuffleSplit
from jafar.tests import JafarTestCase
from jafar.utils.structarrays import DataFrame


class ShuffleSplitMixin(object):
    n_users = 100
    n_items = 100
    n = 1000

    def get_frame(self):
        np.random.seed(1234)
        return DataFrame.from_dict({
            'user': np.random.randint(0, self.n_users, self.n),
            'item': np.random.randint(0, self.n_items, self.n),
        })

    def check_sizes(self, df, train_idx, test_idx, test_item_size):
        self.assertEquals(len(np.unique(df[test_idx]['user'])), 10)
        self.assertEquals(len(np.unique(df[train_idx]['user'])), 100)
        for user, subset in df[test_idx].groupby('user'):
            self.assertEquals(len(subset), test_item_size)


class StratifiedShuffleSplitTestCase(ShuffleSplitMixin, JafarTestCase):

    def test_one_class(self):
        df = self.get_frame()
        df = df.append_column(np.ones(self.n), 'value')
        test_item_size = 2
        train_idx, test_idx = next(
            iter(StratifiedShuffleSplit(df, test_user_size=10, test_item_size=test_item_size))
        )
        self.check_sizes(df, train_idx, test_idx, test_item_size)
        for user, subset in df[test_idx].groupby('user'):
            self.assertTrue(np.all(subset['value'] == 1))

    def test_two_uniform_classes(self):
        df = self.get_frame()
        df = df.append_column(np.random.choice([0, 1], self.n), 'value')
        test_item_size = 2
        train_idx, test_idx = next(
            iter(StratifiedShuffleSplit(df, test_user_size=10, test_item_size=test_item_size))
        )
        self.check_sizes(df, train_idx, test_idx, test_item_size)
        for user, subset in df[test_idx].groupby('user'):
            self.assertEquals(sorted(subset['value']), [0, 1])

    def test_two_classes_skewed_proportion(self):
        df = self.get_frame()
        df = df.append_column(np.random.choice([0, 1], self.n, p=[1 / 3., 2 / 3.]), 'value')
        test_item_size = 3
        train_idx, test_idx = next(
            iter(StratifiedShuffleSplit(df, test_user_size=10, test_item_size=test_item_size))
        )
        self.check_sizes(df, train_idx, test_idx, test_item_size)
        for user, subset in df[test_idx].groupby('user'):
            self.assertEquals(sorted(subset['value']), [0, 1, 1])


class FixedClassShuffleSplitTestCase(ShuffleSplitMixin, JafarTestCase):

    def test_simple(self):
        df = self.get_frame()
        df = df.append_column(np.random.choice([0, 1], self.n), 'value')
        test_item_size = 1
        train_idx, test_idx = next(
            iter(FixedClassShuffleSplit(df, selected_value=1, test_user_size=10, test_item_size=test_item_size))
        )
        self.check_sizes(df, train_idx, test_idx, test_item_size)
        for user, subset in df[test_idx].groupby('user'):
            self.assertTrue(np.all(subset['value'] == 1))
