import numpy as np

from jafar.tests import JafarTestCase
from jafar.pipelines.context import PipelineContext
from jafar.utils.structarrays import DataFrame

class BlockTestCase(JafarTestCase):
    n_users = 5
    n_items = 5

    def get_test_context(self, pipeline):
        # RU and SOCIAL are the only country/category in fixture dataset
        return PipelineContext(pipeline=pipeline, country='RU', requested_categories=['SOCIAL'])

    def assert_context_is_empty(self, context):
        self.assertEquals(len(context.data), 0)

    def assert_same_context(self, context1, context2):
        self.assertEquals(context1.country, context2.country)
        self.assertEquals(context1.requested_categories, context2.requested_categories)
        # this should be the same as just comparing two contexts
        self.assertEquals(context1, context2)

    def check_target_frame(self, target_frame, context):
        for field in ('user', 'item', 'value',):
            self.assertIn(field, target_frame.dtype.names)
        self.assertEquals(target_frame['user'].dtype, np.int32)
        self.assertEquals(target_frame['item'].dtype, np.int32)
        self.assertEquals(target_frame['value'].dtype, np.float32)

    def get_installs(self):
        # sample install dataset
        return DataFrame.from_structarray(np.array([
            (0, 0, 1),
            (0, 1, 1),
            (1, 1, 1),
            (1, 2, 1),
            (2, 2, 1),
            (2, 3, 1),
            (3, 3, 1),
            (3, 4, 1),
            (4, 4, 1),
            (4, 0, 1)
        ], dtype=[
            ('user', np.int32),
            ('item', np.int32),
            ('value', np.float32)
        ]))

    def get_target(self):
        # sample target frame
        # configured to overlap with basket, so that
        # filtering out basket items could be tested
        return DataFrame.from_structarray(np.array([
            (0, 0, 1),
            (0, 4, 0),
            (1, 1, 1),
            (1, 3, 0),
            (2, 2, 1),
            (2, 0, 0),
            (3, 3, 1),
            (3, 1, 0),
            (4, 4, 1),
            (4, 2, 0)
        ], dtype=[
            ('user', np.int32),
            ('item', np.int32),
            ('value', np.float32)
        ]))

    def get_predictions(self):
        # sample `predictions` frame
        return DataFrame.from_structarray(np.array([
            (0, 4),
            (1, 3),
            (2, 2),
            (3, 1),
            (4, 0),
        ], dtype=[
            ('user', np.int32),
            ('item', np.int32),
        ]))

    def get_user_features(self):
        # sample `user_features` frame
        return DataFrame.from_structarray(np.array([
            (0, 10, 0),
            (2, 24, 0),
            (3, 35, 1),
            (4, 68, 0),
            (5, 29, 1),
        ], dtype=[
            ('user', np.int32),
            ('age', np.float32),
            ('gender', np.float32),
        ]))
