import unittest

import numpy as np
from mock import patch

from jafar import advisor_mongo, jafar_mongo
from jafar.pipelines import ids
from jafar.pipelines.pipeline import PipelineConfig
from jafar.pipelines.predefined import predefined_pipelines
from jafar.storages.memory import MemoryStorage
from jafar.tests import JafarTestCase
from jafar.tests.fixtures.profile import fake_user_profile, user_id as fake_user_id, apps as fake_user_apps
from jafar.tests.mocks.datasets import mock_get_dataset_processor, FIXTURE_COUNTRIES
from jafar.utils.structarrays import DataFrame


@patch('jafar.pipelines.blocks.data.get_dataset_processor', mock_get_dataset_processor)
class PredefinedPipelineTestCase(JafarTestCase):
    top_n = 200

    def tearDown(self):
        advisor_mongo.db.profile.drop()
        super(PredefinedPipelineTestCase, self).tearDown()

    @property
    def pipeline_creator(self):
        """
        Returns pipeline-creating function.
        """
        return predefined_pipelines[self.pipeline_name]

    def get_training_config(self):
        return PipelineConfig(
            recommendation_mode='score',
            online=False,
        )

    def get_prediction_config(self, online, mode):
        return PipelineConfig(
            recommendation_mode=mode,
            online=online,
        )

    def train_pipeline(self, storage):
        """
        Shortcut to return an instance of trained pipeline.
        """
        config = self.get_training_config()
        pipeline = self.pipeline_creator(config, storage, self.top_n)
        pipeline.train(country=FIXTURE_COUNTRIES[0])
        return pipeline

    @staticmethod
    def get_frames_for_prediction(mode):
        # make a test target frame
        dataset_processor = mock_get_dataset_processor('advisor_mongo')
        df = dataset_processor.get_data(country=FIXTURE_COUNTRIES[0])
        np.random.seed(1234)
        if mode == 'score':
            target_frame = DataFrame.from_dict({
                'user': np.random.choice(df['user'], 100),
                'item': np.random.choice(df['item'], 100),
                'view_count': np.random.randint(0, 10, 100)
            })
            # exclude target user-item interactions from data frame
            data_frame = df[~df[['user', 'item']].is_in(target_frame[['user', 'item']])]
        else:
            target_frame = DataFrame.from_dict({
                'user': np.random.choice(df['user'], 100),
            })
            # no need to exclude anything: data frame will be used as basket
            data_frame = df

        return target_frame, data_frame

    def get_prediction_context(self, pipeline, target_frame, data_frame=None):
        # NOTE: data_frame may be absent in case of online pipeline config
        frames = {
            ids.FRAME_KEY_TARGET: target_frame
        }
        if data_frame is not None:
            frames[self.train_dataframe_key] = data_frame

        return pipeline.create_initial_context(
            country=FIXTURE_COUNTRIES[0],
            frames=frames
        )


@patch('jafar.pipelines.blocks.data.get_dataset_processor', mock_get_dataset_processor)
class ScoringTestMixin(object):
    """
    Includes tests for `score` pipeline mode.
    """

    def test_offline_scoring(self):
        """
        Checks that pipeline can run in score mode.
        In `score` mode target frame should contain user-item
        interactions, and the resulting `value` column will
        contain recommender scores for each.
        """
        storage = MemoryStorage()
        self.train_pipeline(storage)
        config = self.get_prediction_config(online=False, mode='score')
        pipeline = self.pipeline_creator(config, storage, self.top_n)

        # make a prediction
        target_frame, data_frame = self.get_frames_for_prediction(mode='score')
        context = self.get_prediction_context(pipeline, target_frame, data_frame)

        # NOTE: predictions frame should generally be compared with original
        # target frame, but right now it is possible to filter out some (unknown or banned)
        # items from target. This is gonna be fixed someday.
        prediction_context = pipeline.apply_blocks(train=False, initial_context=context)
        result_target_frame, prediction_frame = (
            prediction_context.data[ids.FRAME_KEY_TARGET],
            prediction_context.data[ids.FRAME_KEY_PREDICTIONS]
        )

        # prediction frame must contain the same user and item columns
        self.assertListEqual(result_target_frame['item'].to_list(), prediction_frame['item'].to_list())
        self.assertListEqual(result_target_frame['user'].to_list(), prediction_frame['user'].to_list())
        self.assertIn('value', prediction_frame.dtype.names)
        self.assertEquals(prediction_frame['value'].dtype, np.float32)
        return result_target_frame, prediction_frame

    @patch('jafar.clickhouse._execute', lambda *args: [])
    def test_online_scoring(self):
        """
        Pipelines running in online mode will try to load
        user data from mongo via OnlineReadDataBlock.
        """
        fake_user_profile()

        storage = MemoryStorage()
        self.train_pipeline(storage)
        config = self.get_prediction_config(online=True, mode='score')
        pipeline = self.pipeline_creator(config, storage, self.top_n)

        # make a prediction
        target_frame, _ = self.get_frames_for_prediction(mode='score')
        # for online prediction, we're only interested in users from database
        target_frame['user'] = fake_user_id
        context = self.get_prediction_context(pipeline, target_frame)

        prediction_context = pipeline.apply_blocks(train=False, initial_context=context)
        result_target_frame, prediction_frame = (
            prediction_context.data[ids.FRAME_KEY_TARGET],
            prediction_context.data[ids.FRAME_KEY_PREDICTIONS]
        )

        self.assertIn('value', prediction_frame.dtype.names)
        self.assertEquals(prediction_frame['value'].dtype, np.float32)
        self.assertListEqual(result_target_frame['item'].to_list(), prediction_frame['item'].to_list())
        return result_target_frame, prediction_frame


@patch('jafar.pipelines.blocks.data.get_dataset_processor', mock_get_dataset_processor)
class GenerationTestMixin(object):
    """
    Includes tests for `generate` pipeline mode.
    """

    def test_offline_generation(self):
        """
        Checks that pipeline can run in `generate` mode.
        In this mode, target frame will only contain `user` column
        and the resulting `item` column will contain candidate recommendations.
        """
        storage = MemoryStorage()
        self.train_pipeline(storage)
        config = self.get_prediction_config(online=False, mode='generate')
        pipeline = self.pipeline_creator(config, storage, self.top_n)

        # make a prediction
        target_frame, data_frame = self.get_frames_for_prediction(mode='generate')
        context = self.get_prediction_context(pipeline, target_frame, data_frame)
        predictions = pipeline.predict_top_n(context)

        self.assertNotEquals(len(predictions), 0)
        # check that predictions frame has `user`, `item` and `value` columns
        self.assertIn('value', predictions.dtype.names)
        self.assertEquals(predictions['value'].dtype, np.float32)
        self.assertIn('item', predictions.dtype.names)
        self.assertEquals(predictions['item'].dtype, np.object)

        # also check that predictions don't intersect with data frame
        self.assertEquals(len(
            data_frame[['user', 'item']].intersect1d(predictions[['user', 'item']])
        ), 0)
        return predictions

    @patch('jafar.clickhouse._execute', lambda *args: [])
    def test_online_generation(self):
        fake_user_profile()

        storage = MemoryStorage()
        self.train_pipeline(storage)
        config = self.get_prediction_config(online=True, mode='generate')
        pipeline = self.pipeline_creator(config, storage, self.top_n)

        # make a prediction
        target_frame = DataFrame.from_dict({'user': [fake_user_id]})
        context = self.get_prediction_context(pipeline, target_frame)
        predictions = pipeline.predict_top_n(context)

        self.assertNotEquals(len(predictions), 0)
        self.assertIn('value', predictions.dtype.names)
        self.assertIn('item', predictions.dtype.names)
        self.assertEquals(predictions['value'].dtype, np.float32)
        self.assertEquals(predictions['item'].dtype, np.object)
        self.assertEquals(len(set(fake_user_apps).intersection(set(predictions['item']))), 0)
        return predictions


class UserFeaturesMixin(object):
    def get_prediction_context(self, pipeline, target_frame, data_frame=None):
        """
        Local pipeline needs an extra 'user features' frame.
        """
        context = super(UserFeaturesMixin, self).get_prediction_context(pipeline, target_frame, data_frame)
        dataset_processor = mock_get_dataset_processor('advisor_mongo')
        context.data[ids.FRAME_KEY_USER_FEATURES] = dataset_processor.get_user_features(
            country=FIXTURE_COUNTRIES[0]
        )
        return context


class SonyaNeighborsTestCase(GenerationTestMixin, PredefinedPipelineTestCase):
    pipeline_name = 'sonya'
    train_dataframe_key = ids.FRAME_KEY_ADVISOR_MONGO_INSTALLS

    def get_training_config(self):
        return PipelineConfig(
            recommendation_mode='generate',
            online=False,
        )


class KanoTestCase(UserFeaturesMixin, ScoringTestMixin,
                   GenerationTestMixin, PredefinedPipelineTestCase):
    pipeline_name = 'kano'
    train_dataframe_key = ids.FRAME_KEY_ADVISOR_MONGO_INSTALLS


class LocalTestCase(UserFeaturesMixin, ScoringTestMixin,
                    GenerationTestMixin, PredefinedPipelineTestCase):
    pipeline_name = 'local'
    train_dataframe_key = ids.FRAME_KEY_ADVISOR_MONGO_INSTALLS

    def tearDown(self):
        jafar_mongo.db.location_stats.drop()
        super(LocalTestCase, self).tearDown()

    def get_training_config(self):
        return PipelineConfig(
            recommendation_mode='generate',
            online=False,
        )

    @unittest.skip("Needs geobase mock to work")
    def test_online_generation(self):
        pass
