from collections import OrderedDict
from itertools import chain

import numpy as np

from jafar import estimators
from jafar.pipelines import Pipeline, ids
from jafar.pipelines.blocks import (
    basket, classifier, data, discount, estimator, features, filtering,
    mapping, merge, selection, target, transformer,
    blend)
from jafar.pipelines.predefined import shortcuts


class SuperPipelineFactory(object):
    """
    Base pipeline factory for recommenders with different selection blocks.
    Only base pipeline is trained, subclasses are used at runtime only.
    """
    name = 'super'

    gp_feature_dtype = np.dtype([
        ('category', 'S40'),
        ('publisher', 'U150'),
        ('title', 'U150'),
    ])
    age_features = [
        'crypta_age_0_17',
        'crypta_age_18_24',
        'crypta_age_25_34',
        'crypta_age_35_44',
        'crypta_age_45_99',
    ]
    age_estimator_params = {
        'feature_columns': age_features,
        'threshold': 0.7
    }
    locality_features = [
        'lbs_region_city'
    ]

    gender_features = [
        'crypta_gender_male'
    ]
    gender_estimator_params = {
        'feature_columns': gender_features,
        'threshold': 0.9
    }

    loyalty_features = [
        'crypta_loyalty'
    ]
    loyalty_estimator_params = {
        'feature_columns': loyalty_features,
        'threshold': 0.8
    }
    crypta_user_features = age_features + gender_features + loyalty_features
    collaborative_features = ['popular', 'ii', 'ii_overall', 'popular_overall', 'locality']

    data_frame = ids.FRAME_KEY_ADVISOR_MONGO_INSTALLS
    dataset_frames = [ids.FRAME_KEY_LOCALITY, data_frame, ids.FRAME_KEY_USER_FEATURES, ids.FRAME_KEY_ITEM_FEATURES]

    def __call__(self, pipeline_config, storage, top_n, **kwargs):
        self.config = pipeline_config
        self.top_n = top_n
        blocks = list(self.get_blocks())
        return Pipeline(blocks, name=self.name, storage=storage, **kwargs)

    @property
    def gp_features(self):
        return list(self.gp_feature_dtype.names)

    def get_pre_mapping_blocks(self):
        if self.config.online:
            yield 'read_user_data', data.OnlineReadDataBlock(installs_frame=self.data_frame)
            yield 'read_usage_stats', data.OnlineUsageStatsData(counters=['rec_view_count'])
        else:

            yield 'read_data', data.OfflineReadDataBlock(
                output_data=self.dataset_frames,
                item_features=self.gp_features
            )
            yield 'sample_users', filtering.UserFeaturesFilteringBlock(input_frame=self.data_frame,
                                                                       user_features=self.crypta_user_features)
            yield 'leave_intersection_item_features', data.DatarameFilteringBlock(
                input_data1_key=self.data_frame,
                input_data2_key=ids.FRAME_KEY_ITEM_FEATURES,
                field='item'
            )
            yield 'leave_intersection_locality', data.DatarameFilteringBlock(
                input_data1_key=ids.FRAME_KEY_LOCALITY,
                input_data2_key=ids.FRAME_KEY_ITEM_FEATURES,
                field='item',
                filter_data2=False
            )

    def get_blocks(self):
        for block in self.get_pre_mapping_blocks():
            yield block
        yield 'mapping', mapping.MappingBlock(frames=list(self.get_mapping_frames()),
                                              nested_blocks=list(self.get_mapped_blocks()),
                                              item_columns=('item', 'similar_to'))
        for block in self.get_after_mapping_blocks():
            yield block
        if self.config.online:
            yield 'filter_banned', filtering.BannedItemsFilteringBlock(blacklist_component='recommendations')

    def get_mapping_frames(self):
        return {self.data_frame,
                ids.FRAME_KEY_TARGET,
                ids.FRAME_KEY_USAGE_STATS,
                ids.FRAME_KEY_PREDICTIONS,
                ids.FRAME_KEY_USER_FEATURES,
                ids.FRAME_KEY_ITEM_FEATURES,
                ids.FRAME_KEY_LOCALITY}

    def get_main_blocks(self):
        for block in chain(self.get_estimator_blocks(),
                           self.get_selection_blocks(),
                           self.get_features_blocks(),
                           self.get_classifier_blocks()):
            yield block
        if self.config.online:
            yield 'feature_publisher_title', features.ItemFeaturesBlock(input_frame=ids.FRAME_KEY_PREDICTIONS,
                                                                        feature_names=['publisher', 'title'])

    def get_mapped_blocks(self):
        main_blocks = list(self.get_main_blocks())
        # NOTE: since this pipeline includes binary classifier, negative target values are required.
        # hence, ImplicitNegativeTargetBlock
        yield 'install_target_with_implicit_negatives', target.ImplicitNegativeTargetBlock(input_frame=self.data_frame)
        yield 'store_item_features', features.ItemsStoreBlock(feature_names=self.gp_features)
        yield 'feature_category', features.ItemFeaturesBlock(input_frame=self.data_frame, feature_names=['category'])
        yield 'set_basket', basket.SimpleBasketBlock(input_frame=self.data_frame)
        yield 'fit_category_label_encoder', transformer.FitLabelEncoderBlock(feature='category',
                                                                             ensure_default_label=True,
                                                                             input_frame=ids.FRAME_KEY_BASKET)
        yield 'detect_user_categories', features.DetectUserCategoriesBlock()

        for block in main_blocks:
            yield block

    def get_estimator_blocks(self):
        if not self.config.online:
            yield 'fit_popular', estimator.EstimatorBlock(input_frame=self.data_frame,
                                                          estimator_class=estimators.CategoryAwarePopular)
            yield 'fit_itemitem', estimator.EstimatorBlock(input_frame=self.data_frame,
                                                           estimator_class=estimators.CategoryAwareItemItem)
            yield 'fit_popular_overall', estimator.EstimatorBlock(input_frame=self.data_frame,
                                                                  estimator_class=estimators.Popular)
            yield 'fit_itemitem_overall', estimator.EstimatorBlock(input_frame=self.data_frame,
                                                                   estimator_class=estimators.ItemItem)
            yield 'fit_locality', estimator.EstimatorBlock(input_frame=ids.FRAME_KEY_LOCALITY,
                                                           estimator_class=estimators.LocalityEstimator)
        for block in self.get_crypta_estimator_blocks():
            yield block

    def get_crypta_estimator_blocks(self):
        yield 'impute_gender', estimator.FeatureImputationEstimatorBlock(input_frame=self.data_frame,
                                                                         key_prefix='gender',
                                                                         estimator_class=estimators.FeatureImputationCV,
                                                                         estimator_params=self.gender_estimator_params)
        yield 'impute_age', estimator.FeatureImputationEstimatorBlock(input_frame=self.data_frame,
                                                                      key_prefix='age',
                                                                      estimator_class=estimators.FeatureImputationCV,
                                                                      estimator_params=self.age_estimator_params)
        yield 'impute_loyalty', estimator.FeatureImputationEstimatorBlock(input_frame=self.data_frame,
                                                                          key_prefix='loyalty',
                                                                          estimator_class=estimators.FeatureImputationCV,
                                                                          estimator_params=self.loyalty_estimator_params)

    def get_selection_blocks(self):
        if self.config.recommendation_mode == 'generate':
            for block in self._selection_blocks():
                yield block
            yield 'filter_by_basket', filtering.InstallsFilteringBlock()
        else:
            yield 'skip_selection', selection.DummySelectionBlock()

    def get_features_blocks(self):
        yield 'feature_itemitem', shortcuts.get_blending_itemitem_block(self.data_frame, fit_estimator=False,
                                                                        feature_name='ii',
                                                                        estimator_class=estimators.CategoryAwareItemItem)
        yield 'feature_popular', features.EstimatorFeatureBlock(input_frame=ids.FRAME_KEY_PREDICTIONS,
                                                                estimator_class=estimators.CategoryAwarePopular,
                                                                feature_name='popular')
        yield 'feature_itemitem_overall', shortcuts.get_blending_itemitem_block(self.data_frame,
                                                                                fit_estimator=False,
                                                                                feature_name='ii_overall',
                                                                                estimator_class=estimators.ItemItem)
        yield 'feature_popular_overall', features.EstimatorFeatureBlock(input_frame=ids.FRAME_KEY_PREDICTIONS,
                                                                        estimator_class=estimators.Popular,
                                                                        feature_name='popular_overall')
        for block in chain(self.get_crypta_features_blocks(), self.get_item_features_blocks()):
            yield block

        yield 'feature_locality', features.EstimatorFeatureBlock(input_frame=ids.FRAME_KEY_PREDICTIONS,
                                                                 estimator_class=estimators.LocalityEstimator,
                                                                 feature_name='locality')

    def get_crypta_features_blocks(self):
        yield 'merge_user_features', merge.LeftHashJoinBlock(left_frame=ids.FRAME_KEY_PREDICTIONS,
                                                             right_frame=ids.FRAME_KEY_USER_FEATURES,
                                                             result_frame=ids.FRAME_KEY_PREDICTIONS,
                                                             join_columns=['user'],
                                                             how='left',
                                                             right_columns=self.user_features,
                                                             overwrite=True)
        for block in chain(self.get_importance_feature_blocks(self.age_features, self.age_estimator_params, 'age'),
                           self.get_importance_feature_blocks(self.loyalty_features, self.loyalty_estimator_params,
                                                              'loyalty'),
                           self.get_importance_feature_blocks(self.gender_features, self.gender_estimator_params,
                                                              'gender')):
            yield block

    @property
    def user_features(self):
        return self.age_features + self.loyalty_features + self.gender_features + self.locality_features

    def get_item_features_blocks(self):
        yield 'feature_category', features.ItemFeaturesBlock(input_frame=ids.FRAME_KEY_PREDICTIONS,
                                                             feature_names=['category'])
        yield 'label_encode_category', transformer.TransformAppendLabelEncoderBlock(features=['category'],
                                                                                    output_prefix='category')

    @staticmethod
    def get_importance_feature_blocks(user_features, estimator_params, key_prefix):
        for feature in user_features:
            item_feature = 'item_%s' % feature
            yield 'importance_%s' % feature, features.ImputeItemImportanceBlock(
                item_feature_name=item_feature,
                user_feature_name=feature,
                key_prefix=key_prefix,
                estimator_class=estimators.FeatureImputationCV,
                estimator_params=estimator_params
            )
            yield 'product_%s' % feature, features.FeatureProductBlock(input_features=[feature, item_feature],
                                                                       resulting_feature='product_%s' % feature)

    def get_classifier_blocks(self):
        # classifier kwargs are used only in tests and local training (which is not default)
        yield 'classifier', classifier.CatBoostClassifierBlock(features=self.get_classifier_features(),
                                                               categorical_features=['category_0'],
                                                               classifier_kwargs={'iterations': 10})

    def get_after_mapping_blocks(self):
        if self.config.online:
            yield 'impression discount', discount.ImpressionDiscountBlock()

    def get_classifier_features(self):
        crypta_item_features = ['item_%s' % feature for feature in self.crypta_user_features]
        crypta_item_user_features = ['product_%s' % feature for feature in self.crypta_user_features]
        return (
            self.collaborative_features +
            self.crypta_user_features +
            crypta_item_features +
            crypta_item_user_features +
            ['category_0']
        )


class SuperKanoPipelineFactory(SuperPipelineFactory):
    def _selection_blocks(self):
        yield 'select_candidates_ii', selection.CategoryAwareEstimatorSelectionBlock(
            top_n=self.top_n, estimator_class=estimators.CategoryAwareItemItem
        )
        yield 'select_candidates_pop', selection.CategoryAwareEstimatorSelectionBlock(
            top_n=self.top_n, estimator_class=estimators.CategoryAwarePopular
        )


class SuperSonyaPipelineFactory(SuperPipelineFactory):
    def _selection_blocks(self):
        yield 'select_basket_ii', selection.BasketNeighborSelectionBlock(
            estimator_class=estimators.CategoryAwareItemItem
        )

    def get_features_blocks(self):
        yield 'feature_item', features.EstimatorFeatureBlock(
            input_frame=ids.FRAME_KEY_PREDICTIONS, estimator_class=estimators.CategoryAwareItemItem, feature_name='ii'
        )

    def get_classifier_blocks(self):
        yield 'dummy_classifier', classifier.DummyClassifierBlock(feature='ii')


class SuperLocalPipelineFactory(SuperPipelineFactory):
    def _selection_blocks(self):
        yield 'join_user_region', merge.LeftHashJoinBlock(left_frame=ids.FRAME_KEY_TARGET,
                                                          right_frame=ids.FRAME_KEY_USER_FEATURES,
                                                          result_frame=ids.FRAME_KEY_TARGET,
                                                          join_columns=['user'],
                                                          right_columns=['lbs_region_city'],
                                                          how='left')
        yield 'select_locality', selection.EstimatorSelectionBlock(
            top_n=self.top_n, estimator_class=estimators.LocalityEstimator
        )



class TrendingKanoPipelineFactory(SuperKanoPipelineFactory):
    def __init__(self, trending_by, trend_days):
        self.trending_by = trending_by
        self.trend_days = trend_days

    def get_pre_mapping_blocks(self):
        for block in super(TrendingKanoPipelineFactory, self).get_pre_mapping_blocks():
            yield block
        if self.config.online:
            yield 'select_trending', selection.OnlineTrendingSelectionBlock(self.trending_by, self.trend_days)

    def get_selection_blocks(self):
        if not self.config.recommendation_mode == 'generate':
            yield 'skip_selection', selection.DummySelectionBlock()

    def get_after_mapping_blocks(self):
        for block in super(TrendingKanoPipelineFactory, self).get_after_mapping_blocks():
            yield block
        yield 'filter_by_basket', filtering.InstallsFilteringBlock()

    def get_mapping_frames(self):
        return super(TrendingKanoPipelineFactory, self).get_mapping_frames() | {ids.FRAME_KEY_BASKET}


class SmokePipelineFactory(SuperKanoPipelineFactory):
    collaborative_features = ['popular', 'popular_overall', 'vanilla_als']
    name = 'smoke'

    def get_estimator_blocks(self):
        yield 'fit_als_overall', estimator.EstimatorBlock(input_frame=self.data_frame,
                                                          estimator_class=estimators.ALS)
        for block in super(SmokePipelineFactory, self).get_estimator_blocks():
            yield block

    def get_features_blocks(self):
        yield 'feature_als_overall', blend.BlendingFeatureBlock(
            fit_blocks=[
                ('fit_als_blend',
                 estimator.EstimatorBlock(input_frame=self.data_frame,
                                          estimator_class=estimators.ALS,
                                          estimator_params={'save_user_features': True}))
            ],
            predict_blocks=[
                ('feature_vanilla_als_blend', features.ALSVanillaFeatureBlock(
                    input_frame=ids.FRAME_KEY_PREDICTIONS,
                    estimator_class=estimators.ALS,
                    feature_name='vanilla_als',
                    estimator_params={'save_user_features': not self.config.online}
                ))
            ],
            features_names=('vanilla_als',), fit_frame=self.data_frame,
            predict_frame=ids.FRAME_KEY_PREDICTIONS,
            n_folds=10, fit_with_full_data=False,
        )
        blocks = OrderedDict(super(SmokePipelineFactory, self).get_features_blocks())
        blocks.pop('feature_itemitem')
        blocks.pop('feature_itemitem_overall')

        for item in blocks.iteritems():
            yield item


create_smoke_pipeline = SmokePipelineFactory()
create_kano_pipeline = SuperKanoPipelineFactory()
create_sonya_pipeline = SuperSonyaPipelineFactory()
create_local_pipeline = SuperLocalPipelineFactory()


# TODO (dmitryka): leave just one of the following pipelines
create_trending_installs_2_weeks_kano_pipeline = TrendingKanoPipelineFactory('installs', 14)
create_trending_installs_4_weeks_kano_pipeline = TrendingKanoPipelineFactory('installs', 28)
create_trending_launches_2_weeks_kano_pipeline = TrendingKanoPipelineFactory('launches', 14)
create_trending_launches_4_weeks_kano_pipeline = TrendingKanoPipelineFactory('launches', 28)

