from unittest import TestCase

import numpy as np

from jafar.pipelines import ids
from jafar.pipelines.blocks import permutation
from jafar.pipelines.misc import RegexpFeature
from jafar.pipelines.pipeline import Pipeline
from jafar_yt.utils.structarrays import DataFrame


class PermutationTestCase(TestCase):
    def setUp(self):
        super(PermutationTestCase, self).setUp()
        self.pipeline = Pipeline(blocks=list(self.get_blocks()))

    @staticmethod
    def get_blocks():
        als_features_prefix = 'als_embedding'
        als_features_prefix_regexp = RegexpFeature('{}.*'.format(als_features_prefix))
        yield 'permutation', permutation.PermutationBlock(input_data=ids.FRAME_KEY_PREDICTIONS,
                                                          config_frame=ids.FRAME_KEY_RANGE,
                                                          features=als_features_prefix_regexp)

    def get_context(self):
        X = np.array([[1, 0], [1, 0], [0, 1], [0, 1]])
        n_users = 4
        packages = ['item_' + str(i) for i in range(n_users)]
        context = self.pipeline.create_initial_context(
            frames={ids.FRAME_KEY_PREDICTIONS: DataFrame.from_dict({'item': packages,
                                                                    'user': ['user'] * len(packages),
                                                                    'als_embedding_0': X[:, 0],
                                                                    'als_embedding_1': X[:, 1],
                                                                    'value': np.ones(n_users)}),
                    ids.FRAME_KEY_RANGE: DataFrame.from_dict({'left': [0],
                                                              'right': [n_users]})
                    },
            country='RU'
        )
        return context

    def test_permutation(self):
        context = self.get_context()
        final_context = self.pipeline.apply_blocks(False, context)
        route = final_context.data[ids.FRAME_KEY_PREDICTIONS]['value']
        diff = np.abs(np.diff(route))  # check that pairs (0, 1) and (2, 3) end up together
        self.assertEqual((diff[0], diff[2]), (1, 1))


class PermutationSelectiveTestCase(PermutationTestCase):
    @staticmethod
    def prepare_recommendations(context):
        predictions = context.data[ids.FRAME_KEY_PREDICTIONS]
        predictions.sort(order=('value',))
        return predictions[::-1]['item']

    def get_context(self):
        n_users = 100
        packages = ['item_' + str(i) for i in range(n_users)]
        context = self.pipeline.create_initial_context(
            frames={ids.FRAME_KEY_PREDICTIONS: DataFrame.from_dict({'item': packages,
                                                                    'user': ['user'] * n_users,
                                                                    'als_embedding_0': np.arange(n_users),
                                                                    'value': np.arange(n_users)}),
                    ids.FRAME_KEY_RANGE: DataFrame.from_dict({'left': [10],
                                                              'right': [20]})
                    },
            country='RU'
        )
        return context

    def test_permutation(self):
        context = self.get_context()
        before_permutation = self.prepare_recommendations(context)[10:20]
        final_context = self.pipeline.apply_blocks(False, context)
        after_permutation = self.prepare_recommendations(final_context)[10:20]
        self.assertEqual(np.any(before_permutation != after_permutation), True)  # check that permutation changed order
        self.assertEqual(set(before_permutation) == set(after_permutation), True)  # check that objects stayed the same
