from collections import OrderedDict

import numpy as np
from flask import current_app as app

from jafar import estimators
from jafar.pipelines import Pipeline, ids
from jafar.pipelines.blocks import (
    data, mapping, estimator, selection, features, classifier, cross_validation, filtering, transformer, permutation
)
from jafar.pipelines.misc import RegexpFeature


class VanishPipelineFactory(object):
    name = 'vanish'
    als_features_prefix = 'als_embedding'
    als_features_prefix_regexp = RegexpFeature('{}.*'.format(als_features_prefix))
    als_value_column = 'als_value'
    user_als_features_prefix_regexp = RegexpFeature('mean_{}.*'.format(als_features_prefix))
    launches = ids.FRAME_KEY_LAUNCHES

    def __call__(self, pipeline_config, storage, top_n=None, **kwargs):
        """
        Creates pipeline
        :param top_n: not used
        """
        self.config = pipeline_config
        blocks = list(self.get_blocks())
        return Pipeline(blocks, name=self.name, storage=storage, **kwargs)

    def get_blocks(self):
        for name, block in self.get_data_blocks():
            yield name, block
        mapped_blocks = self.get_mapped_blocks() if self.config.cross_validation else self.get_cv_blocks()
        yield 'mapping', mapping.MappingBlock(frames=list(self.get_mapping_frames()),
                                              nested_blocks=list(mapped_blocks))
        if self.config.online:
            yield 'filter_banned', filtering.BannedItemsFilteringBlock(blacklist_component='arranger')
            yield 'permutation', permutation.PermutationBlock(input_data=ids.FRAME_KEY_PREDICTIONS,
                                                              config_frame=ids.FRAME_KEY_RANGE,
                                                              features=self.als_features_prefix_regexp)

    def get_data_blocks(self):
        if not self.config.online:
            yield 'read_apps_data', data.OfflineReadDataBlock(
                output_data=[self.launches]
            )

    @property
    def als_params(self):
        return {
            'n_iters': app.config['ARRANGER_ALS_ITERATIONS'],
            'n_features': app.config['ARRANGER_ALS_EMBEDDING_SIZE'],
            'value_column': self.als_value_column,
            'alpha': 1,
            'regularization': 0.01,
            'save_knns': False,
        }

    def get_feature_blocks(self):
        yield 'feature_als', features.ItemALSEmbeddingFeatureBlock(input_frame=ids.FRAME_KEY_PREDICTIONS,
                                                                   estimator_class=estimators.ALS,
                                                                   feature_name=self.als_features_prefix)
        yield 'user_als_embedding', features.MeanByUserFeatureBlock(input_frame=ids.FRAME_KEY_PREDICTIONS,
                                                                    feature_names=[self.als_features_prefix_regexp])

    def get_mapped_blocks(self):
        yield 'cross_validation', cross_validation.CrossValidationBlock(nested_blocks=list(self.get_cv_blocks()),
                                                                        split_data_sources=[self.launches],
                                                                        scorers=self.cv_scorers,
                                                                        target_frame=self.launches,
                                                                        groups_column='user',
                                                                        cv=0.99)

    def get_estimator_blocks(self):
        yield 'als_value_plus_one', features.CustomFeatureBlock(input_frame=self.launches,
                                                                feature_name=self.als_value_column,
                                                                feature_function=lambda x: x['value'] + 1)
        yield 'fit_als_overall', estimator.EstimatorBlock(input_frame=self.launches,
                                                          estimator_class=estimators.ALS,
                                                          estimator_params=self.als_params)

    def get_cv_blocks(self):
        if not self.config.online:
            for block in self.get_estimator_blocks():
                yield block
            yield 'skip_selection', selection.DummySelectionBlock(input_frame=self.launches)
        else:
            yield 'skip_selection', selection.DummySelectionBlock()

        for block in self.get_feature_blocks():
            yield block

        yield 'apply_arranger_model', classifier.NNClassifierBlock(
            features=self.get_classifier_features(),
            classifier_kwargs={'epochs': 5,
                               'lr': 1e-3,
                               'top_n': 3,
                               'disable_cuda': False,
                               'batch_size': 8})

    def get_classifier_features(self):
        return [self.als_features_prefix_regexp,
                self.user_als_features_prefix_regexp]

    def get_mapping_frames(self):
        return {self.launches,
                ids.FRAME_KEY_TARGET,
                ids.FRAME_KEY_PREDICTIONS}

    @property
    def cv_scorers(self):
        scorers = ['NDCGA', 'precision', 'MAP']
        top_n = app.config['PROPER_TOP']
        return ['{}@{}'.format(scorer, top_n) for scorer in scorers]


class DomestosPipelineFactory(VanishPipelineFactory):
    name = 'domestos'
    als_features_prefix = 'als_embedding'
    user_als_features_prefix = 'user_embedding'
    user_als_features_prefix_regexp = RegexpFeature('user_embedding.*'.format())

    def get_feature_blocks(self):
        blocks = OrderedDict(super(DomestosPipelineFactory, self).get_feature_blocks())
        blocks['user_als_embedding'] = features.UserALSEmbeddingFeatureBlock(input_frame=ids.FRAME_KEY_PREDICTIONS,
                                                                             estimator_class=estimators.ALS,
                                                                             feature_name=self.user_als_features_prefix)
        for name, block in blocks.iteritems():
            yield name, block


class ToiletDuckPipelineFactory(VanishPipelineFactory):
    name = 'duck'
    item_features_frame = ids.FRAME_KEY_ARRANGER_ITEM_FEATURES
    item_features = ['category']

    def get_data_blocks(self):
        if not self.config.online:
            yield 'read_apps_data', data.OfflineReadDataBlock(
                output_data=[self.launches, self.item_features_frame],
                item_features=self.item_features
            )
            yield 'leave_intersection', data.DatarameFilteringBlock(
                input_data1_key=self.launches,
                input_data2_key=self.item_features_frame,
                field='item'
            )
            yield 'label_encode_category', transformer.LabelEncoderBlock(input_frame=self.item_features_frame,
                                                                         features=self.item_features)
            yield 'store_item_features', features.ItemsStoreBlock(input_frame=self.item_features_frame,
                                                                  feature_names=['category'])

    def get_feature_blocks(self):
        def category_aggregation(frame):
            result = np.zeros(frame.shape[0])
            for _, idx in frame.arggroupby(('user', 'category')):
                result[idx] = len(idx)

            return result

        yield 'feature_category', features.ItemFeaturesBlock(input_frame=ids.FRAME_KEY_PREDICTIONS,
                                                             feature_names=self.item_features)
        yield 'category_size', features.AggregatedFeatureBlock(input_frame=ids.FRAME_KEY_PREDICTIONS,
                                                               input_features=['user'] + self.item_features,
                                                               feature_name='category_size',
                                                               aggregation_function=category_aggregation)
        yield 'normalize_category_size', transformer.StandardScalerBlock(features=['category_size'])
        yield 'onehot_category', transformer.OneHotEncoderBlock(features=self.item_features,
                                                                output_prefix='category')
        for name, block in super(ToiletDuckPipelineFactory, self).get_feature_blocks():
            yield name, block

    def get_mapping_frames(self):
        return {self.launches,
                ids.FRAME_KEY_TARGET,
                ids.FRAME_KEY_PREDICTIONS,
                self.item_features_frame}

    def get_classifier_features(self):
        return [self.als_features_prefix_regexp,
                self.user_als_features_prefix_regexp,
                RegexpFeature('category_.*')]


class MrMusclePipelineFactory(ToiletDuckPipelineFactory, DomestosPipelineFactory):
    name = 'mr_muscle'


class GeneralPipelineFactory(VanishPipelineFactory):
    name = 'general'

    def get_cv_blocks(self):
        if not self.config.online:
            yield 'fit_popularity', estimator.EstimatorBlock(input_frame=self.launches,
                                                             estimator_class=estimators.SummarisingPopular)
            yield 'skip_selection', selection.DummySelectionBlock(input_frame=self.launches)
        else:
            yield 'skip_selection', selection.DummySelectionBlock()

        yield 'feature_popularity', features.EstimatorFeatureBlock(input_frame=ids.FRAME_KEY_PREDICTIONS,
                                                                   estimator_class=estimators.SummarisingPopular,
                                                                   feature_name='value')


class ConstantPipelineFactory(VanishPipelineFactory):
    name = 'constant'

    @property
    def cv_scorers(self):
        scorers = ['NDCGA_pos']
        top_n = app.config['PROPER_TOP']
        return ['{}@{}'.format(scorer, top_n) for scorer in scorers]

    def get_blocks(self):
        for name, block in self.get_data_blocks():
            yield name, block
        mapped_blocks = self.get_mapped_blocks()
        yield 'mapping', mapping.MappingBlock(frames=list(self.get_mapping_frames()),
                                              nested_blocks=list(mapped_blocks))

    def get_mapped_blocks(self):
        yield 'cross_validation', cross_validation.CrossValidationBlock(nested_blocks=list(self.get_cv_blocks()),
                                                                        split_data_sources=[self.launches],
                                                                        scorers=self.cv_scorers,
                                                                        target_frame=self.launches,
                                                                        cv=lambda data: [([], [i for i in range(len(data))])])

    def get_cv_blocks(self):
        if not self.config.online:
            yield 'fit_popularity', estimator.EstimatorBlock(input_frame=self.launches,
                                                             estimator_class=estimators.ConstantPopular)
            yield 'skip_selection', selection.DummySelectionBlock(input_frame=self.launches)
        else:
            yield 'skip_selection', selection.DummySelectionBlock()

        yield 'feature_popularity', features.EstimatorFeatureBlock(input_frame=ids.FRAME_KEY_PREDICTIONS,
                                                                   estimator_class=estimators.ConstantPopular,
                                                                   feature_name='value')


create_vanish_pipeline = VanishPipelineFactory()
create_domestos_pipeline = DomestosPipelineFactory()
create_toilet_duck_pipeline = ToiletDuckPipelineFactory()
create_mr_muscle_pipeline = MrMusclePipelineFactory()
create_general_pipeline = GeneralPipelineFactory()
create_constant_pipeline = ConstantPipelineFactory()
