import mock
import numpy as np

from jafar.pipelines import ids
from jafar.pipelines.blocks.merge import JoinBlock, LeftHashJoinBlock
from jafar.tests.unittests.pipelines.blocks import BlockTestCase


class MergeBlockTestCase(BlockTestCase):
    block_class = None

    def _get_merged_frame(self, block):
        pipeline = mock.MagicMock()

        context = self.get_test_context(pipeline)
        context.data[ids.FRAME_KEY_TARGET] = self.get_target()
        context.data[ids.FRAME_KEY_USER_FEATURES] = self.get_user_features()

        # merge blocks have the same behavior in train and predict
        result_context = block.apply(context, train=False)
        self.assert_same_context(result_context, context)
        self.assertEquals(len(result_context.data), 3)

        return result_context.data[ids.FRAME_KEY_PREDICTIONS]

    def compare_frames(self, array_one, array_two):
        for row_one, row_two in zip(array_one, array_two):
            for value_one, value_two in zip(row_one, row_two):
                if not (np.isnan(value_one) and np.isnan(value_two)):
                    self.assertEqual(value_one, value_two)


class JoinBlockTestCase(MergeBlockTestCase):
    block_class = JoinBlock

    def test_left_join(self):
        merged_frame = self._get_merged_frame(self.block_class(
            left_frame=ids.FRAME_KEY_TARGET,
            right_frame=ids.FRAME_KEY_USER_FEATURES,
            result_frame=ids.FRAME_KEY_PREDICTIONS,
            join_columns=['user'],
            how='left',
        ))
        self.assertItemsEqual(merged_frame.dtype.names, ('user', 'item', 'value', 'age', 'gender'))
        self.assertEqual(len(merged_frame), 10)
        result = np.array([(0, 0, 1.0, 10.0, 0.0),
                           (0, 4, 0.0, 10.0, 0.0),
                           (1, 1, 1.0, np.nan, np.nan),
                           (1, 3, 0.0, np.nan, np.nan),
                           (2, 2, 1.0, 24.0, 0.0),
                           (2, 0, 0.0, 24.0, 0.0),
                           (3, 3, 1.0, 35.0, 1.0),
                           (3, 1, 0.0, 35.0, 1.0),
                           (4, 4, 1.0, 68.0, 0.0),
                           (4, 2, 0.0, 68.0, 0.0)],
                          dtype=[('user', '<i4'), ('item', '<i4'), ('value', '<f4'),
                                 ('age', '<f4'), ('gender', '<f4')])
        self.compare_frames(merged_frame, result)

    def test_inner_join(self):
        merged_frame = self._get_merged_frame(self.block_class(
            left_frame=ids.FRAME_KEY_TARGET,
            right_frame=ids.FRAME_KEY_USER_FEATURES,
            result_frame=ids.FRAME_KEY_PREDICTIONS,
            join_columns=['user'],
            how='inner',
        ))
        self.assertItemsEqual(merged_frame.dtype.names, ('user', 'item', 'value', 'age', 'gender'))
        self.assertEqual(len(merged_frame), 8)
        result = np.array([(0, 0, 1.0, 10.0, 0.0),
                           (0, 4, 0.0, 10.0, 0.0),
                           (2, 2, 1.0, 24.0, 0.0),
                           (2, 0, 0.0, 24.0, 0.0),
                           (3, 3, 1.0, 35.0, 1.0),
                           (3, 1, 0.0, 35.0, 1.0),
                           (4, 4, 1.0, 68.0, 0.0),
                           (4, 2, 0.0, 68.0, 0.0)],
                          dtype=[('user', '<i4'), ('item', '<i4'), ('value', '<f4'),
                                 ('age', '<f4'), ('gender', '<f4')])
        self.compare_frames(merged_frame, result)

    def test_right_join(self):
        merged_frame = self._get_merged_frame(self.block_class(
            left_frame=ids.FRAME_KEY_TARGET,
            right_frame=ids.FRAME_KEY_USER_FEATURES,
            result_frame=ids.FRAME_KEY_PREDICTIONS,
            join_columns=['user'],
            how='right',
        ))
        self.assertItemsEqual(merged_frame.dtype.names, ('user', 'item', 'value', 'age', 'gender'))
        self.assertEqual(len(merged_frame), 9)
        result = np.array([(0, 0, 1.0, 10.0, 0.0),
                           (0, 4, 0.0, 10.0, 0.0),
                           (2, 2, 1.0, 24.0, 0.0),
                           (2, 0, 0.0, 24.0, 0.0),
                           (3, 3, 1.0, 35.0, 1.0),
                           (3, 1, 0.0, 35.0, 1.0),
                           (4, 4, 1.0, 68.0, 0.0),
                           (4, 2, 0.0, 68.0, 0.0),
                           (5, np.nan, np.nan, 29, 1)],
                          dtype=[('user', '<i4'), ('item', '<f4'), ('value', '<f4'),
                                 ('age', '<f4'), ('gender', '<f4')])
        self.compare_frames(merged_frame, result)

    def test_inner_join_filtered_right_columns(self):
        merged_frame = self._get_merged_frame(self.block_class(
            left_frame=ids.FRAME_KEY_TARGET,
            right_frame=ids.FRAME_KEY_USER_FEATURES,
            result_frame=ids.FRAME_KEY_PREDICTIONS,
            join_columns=['user'],
            how='inner',
            right_columns=['age'],
        ))
        self.assertItemsEqual(merged_frame.dtype.names, ('user', 'item', 'value', 'age'))
        self.assertEqual(len(merged_frame), 8)
        result = np.array([(0, 0, 1.0, 10.0),
                           (0, 4, 0.0, 10.0),
                           (2, 2, 1.0, 24.0),
                           (2, 0, 0.0, 24.0),
                           (3, 3, 1.0, 35.0),
                           (3, 1, 0.0, 35.0),
                           (4, 4, 1.0, 68.0),
                           (4, 2, 0.0, 68.0)],
                          dtype=[('user', '<i4'), ('item', '<i4'), ('value', '<f4'),
                                 ('age', '<f4')])
        self.compare_frames(merged_frame, result)

    def test_inner_join_filtered_left_columns(self):
        merged_frame = self._get_merged_frame(self.block_class(
            left_frame=ids.FRAME_KEY_TARGET,
            right_frame=ids.FRAME_KEY_USER_FEATURES,
            result_frame=ids.FRAME_KEY_PREDICTIONS,
            join_columns=['user'],
            how='inner',
            left_columns=['user', 'value'],
        ))
        self.assertItemsEqual(merged_frame.dtype.names, ('user', 'value', 'age', 'gender'))
        self.assertEqual(len(merged_frame), 8)
        result = np.array([(0, 1.0, 10.0, 0.0),
                           (0, 0.0, 10.0, 0.0),
                           (2, 1.0, 24.0, 0.0),
                           (2, 0.0, 24.0, 0.0),
                           (3, 1.0, 35.0, 1.0),
                           (3, 0.0, 35.0, 1.0),
                           (4, 1.0, 68.0, 0.0),
                           (4, 0.0, 68.0, 0.0)],
                          dtype=[('user', '<i4'), ('value', '<f4'),
                                 ('age', '<f4'), ('gender', '<f4')])
        self.compare_frames(merged_frame, result)


class LeftHashJoinBlockTestCase(JoinBlockTestCase):
    block_class = LeftHashJoinBlock

    def test_right_join(self):
        """ Right join is not supported """
        self.assertRaises(AssertionError, super(LeftHashJoinBlockTestCase, self).test_right_join)
