import mock
import numpy as np
from mock import patch
from sklearn.linear_model import LogisticRegression

from jafar.pipelines import ids
from jafar.pipelines.blocks.classifier import LogisticClassifierBlock
from jafar.storages.memory import MemoryStorage
from jafar.tests.unittests.pipelines.blocks import BlockTestCase


class ClassifierBlockTestCase(BlockTestCase):
    def _test_classifier_block(self, train):
        predictions = self.get_predictions()
        mock_classifier = self.get_mock_classifier(predictions.shape[0])

        with patch.object(self.classifier_block_class, 'create_classifier', return_value=mock_classifier):
            with patch.object(self.classifier_block_class, 'read_classifier', return_value=mock_classifier):
                block = self.classifier_block_class(features=['ft1', 'ft2'])
                np.random.seed(42)
                predictions = predictions.append_column(np.random.rand(len(predictions)), 'ft1')
                predictions = predictions.append_column(np.random.rand(len(predictions)), 'ft2')
                predictions = predictions.append_column(np.random.randint(0, 2, size=len(predictions)), 'value')

                pipeline = mock.MagicMock()
                pipeline.name = 'test_pipeline'
                pipeline.storage = MemoryStorage()
                pipeline.features = ['ft1', 'ft2']
                context = self.get_test_context(pipeline)
                context.data[ids.FRAME_KEY_PREDICTIONS] = predictions

                result_context = block.apply(context, train=train)
                self.assert_same_context(result_context, context)
                self.assertEquals(len(result_context.data), 1)

                if train:
                    mock_classifier.fit.assert_called()
                    self.check_stored_params(pipeline)
                else:
                    mock_classifier.predict_proba.assert_called()


class LogisticClassifierBlockTestCase(ClassifierBlockTestCase):
    classifier_block_class = LogisticClassifierBlock
    coef = np.array([0.1, 0.2])
    intercept = np.array(0.1)

    def test_logistic_classifier_block_train(self):
        self._test_classifier_block(train=True)

    def test_logistic_classifier_block_predict(self):
        self._test_classifier_block(train=False)

    def get_mock_classifier(self, n_predictions):
        mocked = mock.MagicMock(spec=LogisticRegression)
        mocked.intercept_ = self.intercept
        mocked.coef_ = self.coef
        mocked.predict_proba.return_value = np.random.rand(n_predictions, 2)
        return mocked

    def check_stored_params(self, pipeline):
        coef = pipeline.storage.get_object('test_pipeline.ru.coef')
        self.assertTrue(np.allclose(coef, self.coef))
        intercept = pipeline.storage.get_object('test_pipeline.ru.intercept')
        self.assertTrue(np.allclose(intercept, self.intercept))
