import mock
import numpy as np

from jafar.pipelines import ids
from jafar.pipelines.blocks.selection import EstimatorSelectionBlock
from jafar.storages.memory import MemoryStorage
from jafar.tests.unittests.pipelines.blocks import BlockTestCase
from jafar.utils.structarrays import DataFrame


class SingleEstimatorSelectionBlockTestCase(BlockTestCase):
    top_n = 200

    def test_estimator_selection_block_train(self):
        self._test_estimator_selection_block(train=True)

    def test_estimator_selection_block_predict(self):
        self._test_estimator_selection_block(train=False)

    def get_block(self, estimator_class):
        return EstimatorSelectionBlock(estimator_class=estimator_class, top_n=self.top_n)

    def get_estimator_class(self, name='test', prediction_size=10):
        estimator_class = mock.MagicMock()
        estimator_class.__name__ = name
        users = np.random.randint(0, self.n_users, size=prediction_size)
        items = np.random.randint(0, self.n_items, size=prediction_size)
        values = np.random.rand(prediction_size)
        estimated = DataFrame.from_structarray(np.array(
            zip(users, items, values),
            dtype=[
                ('user', np.int32),
                ('item', np.int32),
                ('value', np.float32)
            ]
        ))
        return estimator_class, estimated

    def _test_estimator_selection_block(self, train):
        basket = self.get_installs()
        target_frame = self.get_target()
        pipeline = mock.MagicMock()
        pipeline.storage = MemoryStorage()

        estimator_class, estimated_values = self.get_estimator_class(name='test')
        estimator_class.return_value.predict_top_n.return_value = estimated_values

        block = self.get_block(estimator_class)
        context = self.get_test_context(pipeline)
        context.data[ids.FRAME_KEY_BASKET] = basket
        context.data[ids.FRAME_KEY_TARGET] = target_frame
        pipeline.storage.store(block.key_for(context, 'n_items'), np.array([self.n_items]))
        pipeline.storage.store(block.key_for(context, 'n_users'), np.array([self.n_users]))

        should_be_predictions = estimated_values

        result_context = block.apply(context, train=train)
        self.assert_same_context(result_context, context)
        self.assertEquals(len(result_context.data), 3)
        self.assertIsNotNone(result_context.data[ids.FRAME_KEY_PREDICTIONS])
        if train:
            # should set predictions to target_frame and not do anything else
            self.assertTrue((result_context.data[ids.FRAME_KEY_PREDICTIONS] == target_frame).all())
            return

        # again, can't use `assert_called_with` because of numpy array comparison
        estimator_class.return_value.predict_top_n.assert_called()

        # check predictions
        predictions = result_context.data[ids.FRAME_KEY_PREDICTIONS]
        for field, dtype in [('user', np.int32), ('item', np.int32)]:
            self.assertIn(field, predictions.dtype.names)
            self.assertEquals(predictions[field].dtype, dtype)

        # result predictions should be `estimated` filtered by
        # user-item pairs already in installs
        should_be_predictions.sort(order=['user', 'item'])
        predictions.sort(order=['user', 'item'])
        self.assertTrue((should_be_predictions[['user', 'item']] == predictions[['user', 'item']]).all())


mock_bank_apps = [
    ('alfabank', 'ru.alfabank.mobile.android'),
    ('vtb', 'ru.vtb24.mobilebanking.android'),
    ('raiff', 'ru.raiffeisennews'),
    ('otkritie', 'com.openbank'),
    ('vtb', 'com.bssys.VTBClient'),
]

mock_operator_apps = {
    (250, 1): ['ru.mts.mymts'],
    (250, 2): ['ru.megafon.mlk'],
    (250, 20): ['ru.tele2.mytele2'],
    (250, 99): ['ru.beeline.services', 'ru.beeline.card']
}
