import logging

import numpy as np
from sklearn.cross_validation import KFold

logger = logging.getLogger(__name__)


class UserBasedShuffleSplit(object):
    """
    Cross-validation for recommenders splits test and train sets
    across items of the same users: i.e., if item I of user U is
    present in a test set, user U (with some other items) must
    also be present in the corresponding train set.

    This is the base class for two kinds of ShuffleSplits, which
    first pick a subset of users with sufficient amount of items
    and them for each of user, splits some of their items into test
    set (the rest goes to train set).
    """

    def __init__(self, array, test_user_size=0.1, test_item_size=1, n_splits=1):
        assert isinstance(test_user_size, (int, float)), 'test_user_size must be int or float'
        assert isinstance(test_item_size, int), 'test_user_size must be integer'
        array.assert_has_columns(('user', 'item', 'value'))

        self.array = array[['user', 'item', 'value']].copy()
        self.test_user_size = test_user_size
        self.test_item_size = test_item_size
        self.n_splits = n_splits

    def get_splittable_users(self):
        """
        Returns all suitable users (those who have sufficient
        amount of items).
        """
        raise NotImplementedError

    def get_subset_to_split(self, splittable_users):
        """
        Returns a subset of original `array` which contains users
        with sufficient amount of items to split into two sets.
        """
        if isinstance(self.test_user_size, float):
            test_user_size = int(self.test_user_size * len(splittable_users))
        else:
            test_user_size = self.test_user_size
        return self.array[np.in1d(
            self.array['user'],
            np.random.choice(splittable_users, test_user_size, replace=False)
        )]

    def __len__(self):
        return self.n_splits

    def __iter__(self):
        self.array = self.array.append_column(np.arange(len(self.array)), 'index')
        self.splittable_users = self.get_splittable_users()
        for _ in xrange(self.n_splits):
            splittable_subset = self.get_subset_to_split(self.splittable_users)
            test_idx = []
            for user, subset in splittable_subset.groupby('user'):
                test_idx.append(self.split_user(subset))

            test_idx = np.hstack(test_idx)
            train_idx = np.setdiff1d(np.arange(len(self.array)), test_idx)
            yield train_idx, test_idx

    def split_user(self, subset):
        raise NotImplementedError


class StratifiedShuffleSplit(UserBasedShuffleSplit):

    def get_splittable_users(self):
        # first get overall class counts
        unique_classes, class_counts = np.unique(self.array['value'], return_counts=True)
        # number of classes must be less or equal to test_item_size
        if len(unique_classes) > self.test_item_size:
            raise ValueError(
                "Cannot perform a stratified shuffle split: "
                "there are {} classes total, which is more than test_item_size ({})".format(
                    len(unique_classes), self.test_item_size
                )
            )
        # let's determine the minimal test item set
        class_proportions = class_counts / float(class_counts.sum())
        # smallest percentage must correspond to at minimum one item, the others must be scaled accordingly
        self.min_test_set = np.round(class_proportions * (1 / class_proportions.min()))
        # scale min_test_set counts preserving the proportion, if necessary
        if self.min_test_set.sum() < self.test_item_size:
            self.min_test_set *= (float(self.test_item_size) / self.min_test_set.sum())
        # finally, convert min_test_set into integers
        self.min_test_set = np.round(self.min_test_set).astype(np.int32)
        condition_text = ', '.join([
            'class {}: {} items'.format(klass, count)
            for klass, count in zip(unique_classes, self.min_test_set)]
        )
        if self.min_test_set.sum() > self.test_item_size:
            logger.warn(
                "Asked to %s test items per user, but due to stratified conditions %s will be selected: " +
                condition_text, self.test_item_size, self.min_test_set.sum()
            )
        # for each class, collect users who have at least minimal amount of items of that class
        users = []
        for i, (klass, subset) in enumerate(self.array.groupby('value')):
            subset_users, subset_counts = np.unique(subset['user'], return_counts=True)
            users.append(subset_users[subset_counts > self.min_test_set[i]])
        users = reduce(np.intersect1d, users)
        if len(users) == 0:
            raise ValueError(
                "Cannot perform a stratified shuffle split: no users found satisfying "
                "the splitting condition: " + condition_text
            )
        return users

    def split_user(self, subset):
        test_idx = []
        for i, (klass, subset) in enumerate(subset.groupby('value')):
            test_idx.append(np.random.choice(subset['index'], self.min_test_set[i], replace=False))
        return np.hstack(test_idx)


class FixedClassShuffleSplit(UserBasedShuffleSplit):
    """
    Takes only items of a certain class to a test set (this is useful when
    you don't need implicit negatives in your cross-validation).
    """

    def __init__(self, array, selected_value, test_user_size=0.1, test_item_size=1, n_splits=1):
        super(FixedClassShuffleSplit, self).__init__(array, test_user_size, test_item_size, n_splits)
        self.selected_value = selected_value

    def get_splittable_users(self):
        users, counts = np.unique(self.array[self.array['value'] == self.selected_value]['user'], return_counts=True)
        return users[counts > self.test_item_size]

    def split_user(self, subset):
        subset = subset[subset['value'] == self.selected_value]
        return np.random.choice(subset['index'], self.test_item_size, replace=False)


class BlendingKFold(object):
    """
    This cross-validation class performs the following operations:

     1. Groups data by users.
     2. For each user, if he/she has n >= k (number of folds) items, splits
        items into folds by n / k chunks.
     3. If a user has n < k items, they are splitted into n random folds in
        a leave-one-out fashion (one item goes to test set, n - 1 go to train),
        and at other k - n steps all items just go to train set.

    This scheme guarantees that for any item in a test set, train test will
    contain at least one item that belongs to the same user (which is a requirement
    for certain recommenders like ItemItem).
    """

    def __init__(self, array, n_folds):
        array.assert_has_columns(('user', 'item'))

        self.array = array
        self.n_folds = n_folds

    def __len__(self):
        return self.n_folds

    def __iter__(self):
        # enumerate array rows
        array_index = np.arange(len(self.array))
        # we store only test indices, because train indices can be obtained by inversion
        test_indices = [[] for _ in xrange(self.n_folds)]

        for user, group_idx in self.array.arggroupby('user'):
            user_subset = self.array[group_idx]
            n_items = len(user_subset)
            if n_items <= 1:
                # user has only one item. unable to split, item goes to train set
                continue

            if n_items >= self.n_folds:
                folds_with_split = np.arange(self.n_folds)
            else:
                # only `n_items` folds can have both train and test items
                # let's choose them randomly:

                folds_with_split = np.random.choice(self.n_folds, n_items, replace=False)

            for fold, (_, test_idx) in zip(folds_with_split, KFold(len(user_subset), n_folds=len(folds_with_split))):
                test_indices[fold].append(array_index[group_idx[test_idx]])

        for fold_indices in test_indices:
            # convert int indices to boolean
            test_idx = np.zeros(len(self.array), dtype=np.bool)
            if fold_indices:
                test_idx[np.hstack(fold_indices)] = True
            yield ~test_idx, test_idx
