import mock
import numpy as np

from jafar.pipelines.blocks.mapping import MappingBlock
from jafar.storages.memory import MemoryStorage
from jafar.tests.unittests.pipelines.blocks import BlockTestCase
from jafar.utils.structarrays import DataFrame

class SklearnTransformerBlockTestCase(BlockTestCase):

    def test_mapping_block(self):
        self._test_mapping_block(missing=False, extra_item_column=False)

    def test_mapping_block_with_missing(self):
        self._test_mapping_block(missing=True, extra_item_column=False)

    def test_mapping_block_multicolumn(self):
        self._test_mapping_block(missing=False, extra_item_column=True)

    def test_mapping_block_with_missing_multicolumn(self):
        self._test_mapping_block(missing=True, extra_item_column=True)

    def _test_mapping_block(self, missing, extra_item_column):
        data = self.get_data_frame()
        data_unknown = self.get_data_frame_unknown_items()

        if extra_item_column:
            # case when there is more than one item columns to map
            item_columns = ('item', 'item_extra')
            data = self.add_extra_item_column(data, 'item_extra')
            data_unknown = self.add_extra_item_column(data_unknown, 'item_extra')
        else:
            item_columns = ('item',)

        pipeline = mock.MagicMock()
        pipeline.name = 'test_mapping'
        pipeline.storage = MemoryStorage()
        context = self.get_test_context(pipeline)
        context.data['data'] = data
        block = MappingBlock(nested_blocks=[], frames=['data'], item_columns=item_columns,
                             blocks_for_missing_items=[] if missing else None)

        # run train
        contexts = self.contexts_from_apply_init(block.apply_init(context, train=True))
        self.assertEqual(len(contexts), 1)
        result_context = contexts[0]
        self.assertEqual(len(result_context.data['data']), len(data))
        for column in item_columns + ('user',):
            self.assertEqual(result_context.data['data'][column].dtype, np.int32)
        block.apply_complete(contexts, train=True)

        # run test (with some unknown items)
        context.data['data'] = DataFrame.concatenate([data, data_unknown])
        contexts = self.contexts_from_apply_init(block.apply_init(context, train=False))
        self.assertEqual(len(contexts), 2 if missing else 1)
        result_context = contexts[0]
        self.assertEqual(len(result_context.data['data']), len(data))
        for column in item_columns + ('user',):
            self.assertEqual(result_context.data['data'][column].dtype, np.int32)
        if missing:
            missing_context = contexts[1]
            self.assertTrue((missing_context.data['data'] == data_unknown).all())
        final_context = block.apply_complete(contexts, train=False)
        expected_data = DataFrame.concatenate([data, data_unknown]) if missing else data
        self.assertTrue((final_context.data['data'] == expected_data).all())

    @staticmethod
    def contexts_from_apply_init(result):
        return list(map(lambda x: x[0], result))

    @staticmethod
    def get_data_frame():
        return DataFrame.from_structarray(np.array([
            ('user1', 'item3', 0),
            ('user2', 'item2', 0),
            ('user3', 'item1', 0),
            ('user1', 'item1', 1),
            ('user2', 'item2', 1),
            ('user3', 'item3', 1),
        ], dtype=[
            ('user', np.object),
            ('item', np.object),
            ('value', np.int32),
        ]))

    @staticmethod
    def get_data_frame_unknown_items():
        return DataFrame.from_structarray(np.array([
            ('user1', 'item4', 0),
            ('user2', 'item5', 0),
            ('user3', 'item6', 1),
        ], dtype=[
            ('user', np.object),
            ('item', np.object),
            ('value', np.int32),
        ]))

    @staticmethod
    def add_extra_item_column(array, column):
        return array.append_column(array['item'][::-1], column)
