import mock
import numpy as np

from jafar.pipelines import ids
from jafar.tests.unittests.pipelines.blocks import BlockTestCase
from jafar.pipelines.blocks.target import ImplicitNegativeTargetBlock, ConversionTargetBlock
from jafar.storages.memory import MemoryStorage
from jafar.utils.structarrays import DataFrame

class InstallTargetBlockTestCase(BlockTestCase):

    def test_install_target_block(self):
        self._test_install_target_block()

    def _test_install_target_block(self):
        installs = DataFrame.from_structarray(np.array([
            (0, 0, 1),
            (0, 1, 1),
            (1, 1, 1),
            (1, 2, 1),
            (2, 2, 1),
            (2, 3, 1),
            (3, 3, 1),
            (3, 4, 1),
            (4, 4, 1),
            (4, 0, 1)
        ], dtype=[
            ('user', np.int32),
            ('item', np.int32),
            ('value', np.float32),
        ]))
        pipeline = mock.MagicMock()
        pipeline.storage = MemoryStorage()
        block = ImplicitNegativeTargetBlock(input_frame=ids.FRAME_KEY_ADVISOR_MONGO_INSTALLS)

        # train mode
        context = self.get_test_context(pipeline)
        context.data[ids.FRAME_KEY_ADVISOR_MONGO_INSTALLS] = installs
        pipeline.storage.store(block.key_for(context, 'n_items'), np.array([5]))
        pipeline.storage.store(block.key_for(context, 'n_users'), np.array([5]))
        result_context = block.apply(context, train=True)
        self.assert_same_context(result_context, context)
        self.assertEquals(len(result_context.data), 2)
        self.assertFalse(result_context.data[ids.FRAME_KEY_TARGET] is None)
        self.assertIn(ids.FRAME_KEY_ADVISOR_MONGO_INSTALLS, result_context.data)

        # check fields of target frame
        self.check_target_frame(result_context.data[ids.FRAME_KEY_TARGET], result_context)

        # check implicit negatives
        self.assertEquals(result_context.data[ids.FRAME_KEY_TARGET].shape[0], installs.shape[0] * 2)
        positive_target = result_context.data[ids.FRAME_KEY_TARGET][result_context.data[ids.FRAME_KEY_TARGET]['value'] == 1.0]
        negative_target = result_context.data[ids.FRAME_KEY_TARGET][result_context.data[ids.FRAME_KEY_TARGET]['value'] == 0.0]
        self.assertEquals(positive_target.shape[0], negative_target.shape[0])

        # predict mode
        context = self.get_test_context(pipeline)
        context.data[ids.FRAME_KEY_ADVISOR_MONGO_INSTALLS] = installs
        pipeline.storage.store(block.key_for(context, 'n_items'), np.array([5]))
        pipeline.storage.store(block.key_for(context, 'n_users'), np.array([5]))
        result_context = block.apply(context, train=False)
        self.assert_same_context(result_context, context)
        self.assertIn(ids.FRAME_KEY_ADVISOR_MONGO_INSTALLS, result_context.data)
        self.assertEquals(len(result_context.data), 1)


class ConversionTargetBlockTestCase(BlockTestCase):

    def test_conversion_target_block_deduplicated(self):
        self._test_conversion_target_block(deduplicated=True)

    def test_conversion_target_block(self):
        self._test_conversion_target_block(deduplicated=False)

    def _test_conversion_target_block(self, deduplicated):
        dtype = [
            ('user', np.int32),
            ('item', np.int32),
            ('value', np.float32),
            ('timestamp', np.int32)
        ]
        events = DataFrame.from_structarray(np.array([
            (0, 0, 0, 1489676105),
            (0, 0, 0, 1489676110),
            (1, 1, 0, 1489676106),
            (1, 1, 1, 1489676111),
            (2, 2, 0, 1489676107),
            (2, 2, 0, 1489676112),
            (3, 3, 0, 1489676108),
            (3, 3, 1, 1489676113),
            (4, 4, 0, 1489676109),
            (4, 4, 1, 1489676114)
        ], dtype=dtype))
        pipeline = mock.MagicMock()
        pipeline.storage = MemoryStorage()
        block = ConversionTargetBlock(input_frame=ids.FRAME_KEY_CONVERSIONS_PROMO, deduplicated=deduplicated)

        # train mode
        context = self.get_test_context(pipeline)
        context.data[ids.FRAME_KEY_CONVERSIONS_PROMO] = events
        pipeline.storage.store(block.key_for(context, 'n_items'), np.array([5]))
        pipeline.storage.store(block.key_for(context, 'n_users'), np.array([5]))
        result_context = block.apply(context, train=True)

        self.assert_same_context(result_context, context)
        self.assertEquals(len(result_context.data), 2)
        self.assertFalse(result_context.data[ids.FRAME_KEY_TARGET] is None)
        self.assertIn(ids.FRAME_KEY_CONVERSIONS_PROMO, result_context.data)

        # check fields of X
        self.check_target_frame(result_context.data[ids.FRAME_KEY_TARGET], result_context)

        # check deduplication
        if deduplicated:
            self.assertListEqual(result_context.data[ids.FRAME_KEY_TARGET]['user'].to_list(), [0, 1, 2, 3, 4])
            self.assertListEqual(result_context.data[ids.FRAME_KEY_TARGET]['item'].to_list(), [0, 1, 2, 3, 4])
            self.assertListEqual(result_context.data[ids.FRAME_KEY_TARGET]['value'].to_list(), [0.0, 1.0, 0.0, 1.0, 1.0])

        # predict mode
        context = self.get_test_context(pipeline)
        context.data[ids.FRAME_KEY_CONVERSIONS_PROMO] = events
        pipeline.storage.store(block.key_for(context, 'n_items'), np.array([5]))
        pipeline.storage.store(block.key_for(context, 'n_users'), np.array([5]))
        result_context = block.apply(context, train=False)
        self.assert_same_context(result_context, context)
        self.assertIn(ids.FRAME_KEY_CONVERSIONS_PROMO, result_context.data)
        self.assertEquals(len(result_context.data), 1)
