import mock
import numpy as np

from jafar.pipelines import ids
from jafar.pipelines.blocks import transformer
from jafar.storages.exceptions import StorageKeyError
from jafar.storages.memory import MemoryStorage
from jafar.tests.unittests.pipelines.blocks import BlockTestCase


class FitAndTransformBlockTestCaseMixin(object):
    block_class = None
    transformer_attributes = None

    def test_sklearn_transformer_block_train(self):
        self._test_sklearn_transformer_block(train=True)

    def test_sklearn_transformer_block_predict(self):
        self._test_sklearn_transformer_block(train=False)

    def _test_sklearn_transformer_block(self, train):
        predictions = self.get_predictions()
        predictions = predictions.append_column(np.ones(len(predictions)), 'test_feature')
        pipeline = mock.MagicMock()
        pipeline.name = 'test_pipeline'
        pipeline.storage = MemoryStorage()
        context = self.get_test_context(pipeline)
        context.data[ids.FRAME_KEY_PREDICTIONS] = predictions

        block = self.block_class(
            features=['test_feature']
        )
        if not train:
            # should raise StorageKeyError
            with self.assertRaises(StorageKeyError):
                result_context = block.apply(context, train=train)
            # put transformer params in the storage
            for key, value in self.transformer_attributes.iteritems():
                pipeline.storage.store(block.key_for(context, key), value)

        result_context = block.apply(context, train=train)
        self.assert_same_context(result_context, context)
        self.assertEquals(len(result_context.data), 1)
        # if train mode, check stored params
        if train:
            for key, value in self.transformer_attributes.iteritems():
                try:
                    pipeline.storage.get_object(block.key_for(context, key))
                except StorageKeyError:
                    self.fail("Attribute {} missing from storage".format(key))


class StandardScalerBlockTestCase(FitAndTransformBlockTestCaseMixin, BlockTestCase):
    block_class = transformer.StandardScalerBlock
    transformer_attributes = {'mean_': np.array([1]), 'scale_': np.array([2])}
