import logging

from jafar.exceptions import RecommenderNotReady
from jafar.pipelines import ids
from jafar.pipelines.blocks.composite import CompositeBlock
from jafar.pipelines.blocks.cross_validation import CrossValidationBlock, CrossValidationEarlyExit
from jafar.pipelines.context import PipelineContext
from jafar.storages.exceptions import StorageKeyError
from jafar.storages.memory import MemoryStorage
from jafar.utils import memory_usage

logger = logging.getLogger(__name__)


class PipelineConfig(object):
    """
    Enum-like object that contains flags denoting
    pipeline's behaviour. These flags are applied
    pipeline-wise and are set at pipeline's initialization.
    """
    recommendation_mode_choices = ('score', 'generate')

    def __init__(self, recommendation_mode, online, cross_validation=False):
        """
        There are currently 3 flags that can affect pipeline behaviour:

        :param recommendation_mode: can be either 'score' or 'generate':

          * 'score': pipeline will take user-item rows from target frame and
            compute a ranking score for each. Target frame must contain 'item'
            column in this mode, otherwise exception will be thrown.
          * 'generate': pipeline will take users from target frame and
            generate recommendations for each user.

        :param online: True/False. This mostly has to do with data loading,
        which is different in production and train environments.

        :param cross-validation: True/False. If False, it doesn't affect the pipeline.
        Otherwise, two things happen:

         * a special CrossValidationBlock gets inserted just after the `cv_split_at`
           block. It functions as a filter, selecting the relevant fold
           of the dataset for cross-validation.
         * "read data" blocks are replaced with caching versions of themselves,
           allowing to avoid repeated data reading operations.

        :param cv_block_params: parameter dictionary for CrossValidationBlock.
        """
        assert recommendation_mode in self.recommendation_mode_choices, \
            "Unknown recommendation mode, choose one of {}".format(self.recommendation_mode_choices)
        self.recommendation_mode = recommendation_mode
        self.online = bool(online)
        self.cross_validation = cross_validation


# go through all blocks (including nested blocks of composite blocks)
def iterate_all_blocks(blocks):
    for name, block in blocks:
        yield name, block
        if isinstance(block, (CompositeBlock, CrossValidationBlock)):
            for x in iterate_all_blocks(block.get_blocks()):
                yield x


def _apply_blocks(train, context, blocks):
    for name, block in blocks:
        try:
            if train:
                logger.debug('Memory usage: peak %(peak)d MB, RSS %(rss)d MB', memory_usage())
            logger.debug('Applying block "%s"', name)
            if isinstance(block, CompositeBlock):
                # split step of composite block
                result_contexts = []
                for child_context, nested_blocks in block.apply_init(context, train):
                    result_contexts.append(_apply_blocks(train, child_context, nested_blocks))
                context = block.apply_complete(result_contexts, train)
            else:
                context = block.apply(context, train)

            # if we've just applied cross-validation block, pipeline stops there
            if isinstance(block, CrossValidationBlock):
                logger.info("Early stopping due to cross-validation")
                raise CrossValidationEarlyExit(context)
        except StorageKeyError as e:
            raise RecommenderNotReady('Error applying blocks for context (key error): ' + repr(e),
                                      context=context, error=e)
    return context


class Pipeline(object):
    """
    Pipeline representation of app recommendation algorithm
    """

    def __init__(self, blocks, name='', storage=None):
        """
        :param cv_split_at: one of pipeline's blocks' name. During
        cross-validation, the pipeline is splitted in two parts:
        the first (aka data-collecting part) is executed exactly once,
        and the second (aka cross-validated part) is repeatedly executed
        on different cross-validation folds. `split_at` value, therefore,
        defines the splitting block.
        """

        self.blocks = blocks
        # for easy block access
        self.block_dict = {name: block for name, block in iterate_all_blocks(self.blocks)}
        self.name = name
        self.storage = storage or MemoryStorage()
        logger.debug('Pipeline "%s" initialized:\n%s', name, str(self))

    def create_initial_context(self, country=None, requested_categories=None, default_categories=None, frames=None):
        context = PipelineContext(self, country, requested_categories, default_categories)
        for key, value in (frames or {}).iteritems():
            context.data[key] = value
        return context

    def apply_blocks(self, train, initial_context):
        try:
            return _apply_blocks(train, initial_context, self.blocks)
        except CrossValidationEarlyExit as e:
            return e.context

    def train(self, country, context=None):
        """
        train the pipeline on data for the specified country
        """
        initial_context = context or self.create_initial_context(country)
        context = self.apply_blocks(train=True, initial_context=initial_context)
        logger.info('Completed training pipeline')
        return context

    def predict(self, context):
        """
        predict scores for user-items pairs in target frame
        """
        assert ids.FRAME_KEY_TARGET in context.data, \
            "Have to put a target frame in initial context"
        context.data[ids.FRAME_KEY_TARGET].assert_has_columns(('user', 'item'))
        context = self.apply_blocks(False, context)
        return context.data[ids.FRAME_KEY_PREDICTIONS].copy()

    def predict_top_n(self, context):
        """
        select items for the users in target frame and calculates scores for them
        """
        # NOTE: now this is almost the same as `predict` method, do we need two of them?
        # probably after feature/ADVISOR-1087 gets merged we won't
        assert ids.FRAME_KEY_TARGET in context.data, \
            "Have to put a target frame in initial context"
        context.data[ids.FRAME_KEY_TARGET].assert_has_columns(('user',))
        context = self.apply_blocks(False, context)
        return context.data[ids.FRAME_KEY_PREDICTIONS].copy()

    def set_params(self, params):
        logger.debug("Setting block params: %s", params)
        for block_name, block_params in params.iteritems():
            self.block_dict[block_name].set_params(block_params)

    def __repr__(self):

        def iterate_blocks_with_levels(blocks, level):
            for name, block in blocks:
                yield level, name, block
                if isinstance(block, (CompositeBlock, CrossValidationBlock)):
                    for element in iterate_blocks_with_levels(block.get_blocks(), level + 1):
                        yield element

        text = ''
        for level, name, block in iterate_blocks_with_levels(self.blocks, level=0):
            row = '{indent}("{name}", {block})'.format(
                indent='    ' * level, name=name,
                block=block.__class__.__name__
            )
            text += row + '\n'

        return text.strip()
