import mock
import numpy as np

from jafar.pipelines import ids
from jafar.pipelines.blocks.features import EstimatorFeatureBlock
from jafar.storages.memory import MemoryStorage
from jafar.tests.unittests.pipelines.blocks import BlockTestCase


class EstimatorFeatureBlockTestCase(BlockTestCase):
    feature_name = 'feature'

    def test_estimator_feature_block_train(self):
        self._test_estimator_feature_block(train=True)

    def test_estimator_feature_block_predict(self):
        self._test_estimator_feature_block(train=False)

    def get_block(self, estimator_class_mock):
        return EstimatorFeatureBlock(
            input_frame=ids.FRAME_KEY_PREDICTIONS, estimator_class=estimator_class_mock, feature_name=self.feature_name
        )

    def _test_estimator_feature_block(self, train):
        installs = self.get_installs()
        predictions = self.get_predictions()
        pipeline = mock.MagicMock()
        pipeline.storage = MemoryStorage()

        estimator_class = mock.MagicMock()
        estimator_class.__name__ = 'test'
        block = self.get_block(estimator_class)
        context = self.get_test_context(pipeline)
        context.data[ids.FRAME_KEY_ADVISOR_MONGO_INSTALLS] = installs
        context.data[ids.FRAME_KEY_BASKET] = installs
        context.data[ids.FRAME_KEY_PREDICTIONS] = predictions
        pipeline.storage.store(block.key_for(context, 'n_items'), np.array([self.n_items]))
        pipeline.storage.store(block.key_for(context, 'n_users'), np.array([self.n_users]))

        # trying to apply without estimator should raise exception
        with self.assertRaises(AssertionError):
            block.apply(context.copy(), train=train)

        # putting estimator inside the context
        estimated = predictions.copy()
        estimated = estimated.append_column(np.random.rand(predictions.shape[0]).astype('float32'), 'value')
        estimator_class.return_value.predict.return_value = estimated

        result_context = block.apply(context, train=train)
        self.assert_same_context(result_context, context)
        self.assertEquals(len(result_context.data), 3)

        # again, can't use `assert_called_with` because of numpy array comparison
        estimator_class.return_value.predict.assert_called()
        for field, dtype in [('user', np.int32), ('item', np.int32), (self.feature_name, np.float32)]:
            self.assertIn(field, result_context.data[ids.FRAME_KEY_PREDICTIONS].dtype.names)
            self.assertEquals(result_context.data[ids.FRAME_KEY_PREDICTIONS][field].dtype, dtype)
        self.assertTrue(np.allclose(estimated['value'], result_context.data[ids.FRAME_KEY_PREDICTIONS][self.feature_name]))
