import mock

from jafar.pipelines import ids
from jafar.tests.unittests.pipelines.blocks import BlockTestCase
from jafar.pipelines.blocks.estimator import EstimatorBlock


class EstimatorBlockTestCase(BlockTestCase):

    def test_estimator_block_train(self):
        self._test_estimator_block(train=True)

    def test_estimator_block_predict(self):
        self._test_estimator_block(train=False)

    def get_block(self, estimator_class_mock):
        return EstimatorBlock(input_frame=ids.FRAME_KEY_ADVISOR_MONGO_INSTALLS, estimator_class=estimator_class_mock)

    def _test_estimator_block(self, train):
        installs = self.get_installs()
        pipeline = mock.MagicMock()

        estimator_class = mock.MagicMock()
        estimator_class.__name__ = 'test'
        block = self.get_block(estimator_class)
        context = self.get_test_context(pipeline)
        context.data[ids.FRAME_KEY_ADVISOR_MONGO_INSTALLS] = installs
        result_context = block.apply(context, train=train)
        self.assert_same_context(result_context, context)
        self.assertEquals(len(result_context.data), 1)
        estimator = block.create_estimator(context, estimator_class)

        estimator_class.assert_called_with(
            n_users=context.n_users,
            n_items=context.n_items,
            storage=context.storage,
            key_prefix=block.key_for(context),
            **block.estimator_params
        )
        if train:
            estimator.fit.assert_called_with(context.data[ids.FRAME_KEY_ADVISOR_MONGO_INSTALLS])
