import logging
import numpy as np

from jafar.estimators.base import BaseEstimator
from jafar.pipelines import ids
from jafar.pipelines.blocks import SingleContextBlock
from jafar.pipelines.pipeline import PipelineConfig
from jafar.utils.structarrays import DataFrame

logger = logging.getLogger(__name__)


class ConversionRateBlock(SingleContextBlock):
    """
    Use this block to predict conversion rate in CPM pipelines.
    It is similar to blending block, but simpler (at least now) - no split to multiple folds.

    In train: run fit on 'fit blocks', then run predict on 'predict blocks' (to calculate scores for the conversion)
    In production: run 'predict blocks' only
    """

    def __init__(self, fit_blocks, predict_blocks,
                 classifier_blocks,):
        super(ConversionRateBlock, self).__init__()
        self.fit_blocks = fit_blocks
        self.predict_blocks = predict_blocks
        self.classifier_blocks = classifier_blocks

    @staticmethod
    def get_isolated_pipeline(blocks, context):
        from jafar.pipelines.pipeline import Pipeline
        return Pipeline(blocks, name=context.pipeline.name, storage=context.storage)

    def apply(self, context, train):
        if train:
            logger.info('Running train for fit blocks')
            nested_fit_pipeline = self.get_isolated_pipeline(self.fit_blocks, context)
            context = nested_fit_pipeline.apply_blocks(train=True, initial_context=context)
            logger.info('Running predict blocks for conversions')

        # need this prediction part for train and predict cases
        # for train pass it calculates 'score' feature for conversion target
        nested_predict_pipeline = self.get_isolated_pipeline(self.predict_blocks, context)
        context = nested_predict_pipeline.apply_blocks(train=False, initial_context=context)

        nested_classifier_pipeline = self.get_isolated_pipeline(self.classifier_blocks, context)
        context = nested_classifier_pipeline.apply_blocks(train=train, initial_context=context)
        return context


class ConversionsFilteringBlock(SingleContextBlock):
    """
    Filter conversions to avoid the cases where we 'predict' score for
    the conversions that already exist in installs dataset (i.e. information leaking)
    """

    def __init__(self, conversions_frame, installs_frame):
        super(ConversionsFilteringBlock, self).__init__(input_data=[conversions_frame, installs_frame])
        self.conversions_frame = conversions_frame
        self.installs_frame = installs_frame

    def apply(self, context, train):
        if train:
            conversions = context.data[self.conversions_frame]
            installs = context.data[self.installs_frame]

            logger.info('Installs count before filtering: {}'.format(len(installs)))
            installs = installs[~installs[['user', 'item']].is_in(conversions[['user', 'item']])]
            logger.info('Filtered installs count: {}'.format(len(installs)))
            context.data[self.installs_frame] = installs
        return context


class ClickToConversionEstimator(BaseEstimator):
    """
    Learn and predict conversion rate for the items.
    This is offline only estimator - in online use fresh statistics from mongo, not from the snapshot.
    """

    basket_required = False

    def fit(self, X):
        X.assert_has_columns(['item', 'click', 'value'])

        # calculate install/click ratio for each app
        installs = np.array(X[X['value'] == 1]['item'], dtype=np.int32)
        installs_counts = np.bincount(installs, minlength=self.n_items).astype(np.float32)
        clicks = np.array(X[X['click'] == 1]['item'], dtype=np.int32)
        clicks_counts = np.bincount(clicks, minlength=self.n_items).astype(np.float32)
        click_to_install = installs_counts / clicks_counts

        # there is no clicks for the most of the items - replace them with average conversions
        mean_conversion = float(len(installs)) / len(clicks)
        infinite_mask = ~np.isfinite(click_to_install)
        logger.info("Imputing {} missing 'click to install' values with mean: {}".format(
            sum(infinite_mask), mean_conversion))
        click_to_install[infinite_mask] = mean_conversion
        self.storage.store(self.key_for('click_to_install'), click_to_install)
        return self

    def predict(self, X, basket=None):
        X.assert_has_columns(('user', 'item'))
        click_to_install = self.storage.get_proxy(self.key_for('click_to_install'))
        return DataFrame.from_dict(dict(
            user=X['user'],
            item=X['item'],
            value=click_to_install[X['item'].astype(np.int32)].astype(np.float32)
        ))


class PretrainedPipelineScorerBlock(SingleContextBlock):
    """
    Scores items with a pipeline that was already trained
    """

    def __init__(self, pipeline_creator, predictions_column, online, top_n):
        super(PretrainedPipelineScorerBlock, self).__init__()
        self.top_n = top_n
        self.pipeline_creator = pipeline_creator
        self.predictions_column = predictions_column
        self.online = online

    def apply(self, context, train):
        assert self.predictions_column not in context.data[ids.FRAME_KEY_TARGET].columns, \
            "Target already contains '{}'".format(self.predictions_column)

        # ask pre-trained pipeline to score items in 'target' frame
        # note: the pipeline will copy it to 'predictions' frame because this is how pipelines work
        outer = context.pipeline
        config = PipelineConfig(recommendation_mode='score', online=self.online)
        pretrained_pipeline = self.pipeline_creator(config, context.storage, top_n=self.top_n)

        # replace current pipeline with the pre-trained one in the current context
        pretrained_pipeline.name = outer.name
        context.pipeline = pretrained_pipeline

        # copy 'values' calculated by pre-trained pipeline into the 'predictions' frame
        predictions = pretrained_pipeline.predict(context)
        context.pipeline = outer

        predictions = predictions.append_column(predictions['value'], self.predictions_column)
        if train:
            # restore values from 'target' frame since it was overwritten by pre-trained pipeline
            predictions['value'] = context.data[ids.FRAME_KEY_TARGET]['value']

        context.data[ids.FRAME_KEY_PREDICTIONS] = predictions
        return context
