import logging
from collections import defaultdict
from functools import partial

import numpy as np
from sklearn import cross_validation

from jafar.cross_validation import FixedClassShuffleSplit
from jafar.pipelines import ids
from jafar.pipelines.blocks import SingleContextBlock
from jafar.pipelines.scoring import get_scorer
from jafar.storages.memory import MemoryStorage
from jafar.utils.structarrays import DataFrame

logger = logging.getLogger(__name__)


class CrossValidationEarlyExit(Exception):
    def __init__(self, context):
        self.context = context


class CrossValidationBlock(SingleContextBlock):

    def __init__(self, nested_blocks, split_data_sources, scorers,
                 cv=None, target_frame=ids.FRAME_KEY_TARGET, predictions_frame=ids.FRAME_KEY_PREDICTIONS,
                 groups_column=None, statistics=(np.mean, np.std), seed=42):

        super(CrossValidationBlock, self).__init__()
        unknown_data_sources = set(split_data_sources).difference(set(ids.all_frame_keys))
        if unknown_data_sources:
            raise ValueError("Unrecognized data sources: {}, the following are available: {}".format(
                unknown_data_sources, ids.all_frame_keys)
            )

        self.nested_blocks = nested_blocks
        self.cv = self.parse_cv(cv, groups_column) if cv is not None else partial(FixedClassShuffleSplit,
                                                                                  selected_value=1,
                                                                                  test_user_size=0.1,
                                                                                  test_item_size=5,
                                                                                  n_splits=1)
        self.split_data_sources = split_data_sources
        self.scorers = {name: get_scorer(name) for name in scorers}
        self.target_frame = target_frame
        self.predictions_frame = predictions_frame
        self.statistics = statistics
        self.seed = seed

        self.indices = None

    def get_blocks(self):
        return self.nested_blocks

    @staticmethod
    def parse_cv(cv, groups_column=None):
        if isinstance(cv, float) and 0 <= cv <= 1:
            if groups_column:
                return lambda data: cross_validation.LabelShuffleSplit(data[groups_column], test_size=1 - cv, n_iter=1)
            else:
                return lambda data: cross_validation.ShuffleSplit(len(data), test_size=1 - cv, n_iter=1)
        elif isinstance(cv, (float, int)):
            if groups_column:
                return lambda data: cross_validation.LabelKFold(data[groups_column], n_folds=int(cv))
            else:
                return lambda data: cross_validation.KFold(len(data), n_folds=int(cv), shuffle=True)
        elif callable(cv):
            return cv
        else:
            raise TypeError("Unrecognized cv value: {}".format(cv))

    def get_data_subset(self, context):
        """
        Returns a subset of `context.data` dictionary
        containing only cross-validated data frames.
        """
        return {key: context.data[key] for key in self.split_data_sources}

    def get_interactions(self, context):
        data_subset = self.get_data_subset(context)
        # for interaction-base split, we need just one cv object
        # but first we have to make an interaction dataframe
        interactions = reduce(DataFrame.union1d, [
            frame[['user', 'item', 'value']].copy() for frame in data_subset.values()
        ])
        return interactions

    def filter_data_by_interactions(self, context, interactions):
        indices = {}
        data_subset = self.get_data_subset(context)
        for frame_key, array in data_subset.iteritems():
            indices[frame_key] = array[['user', 'item', 'value']].is_in(interactions)
            logger.debug(
                "Splitted frame %s for context %s: %s rows (interactions strategy)",
                frame_key, context, len(data_subset[frame_key])
            )
        return indices

    def generate_indices(self, context):
        np.random.seed(self.seed)
        indices = []
        # first get interactions frame
        interactions = self.get_interactions(context)
        # run cv object on it to split interactions into folds
        for train_idx, test_idx in self.cv(interactions):
            # convert interaction index into a bunch of indices for multiple frames
            # notice the resulting index will be a "frame: array of indices" dictionary
            train_idx = self.filter_data_by_interactions(context, interactions[train_idx])
            test_idx = self.filter_data_by_interactions(context, interactions[test_idx])
            indices.append((train_idx, test_idx))
        return indices

    def get_data_fold(self, data, idx):
        data_fold = data.copy()
        for frame_key in self.split_data_sources:
            frame_idx = idx[frame_key]
            array = data[frame_key]
            data_fold[frame_key] = array[frame_idx]
        return data_fold

    def apply_single_cv(self, context):
        cv_results = self.cross_validate(
            params=None,
            country=context.country,
            requested_categories=context.requested_categories,
            data=context.data,
            storage_content=context.storage.to_dict(),
            pipeline_name=context.pipeline.name
        )
        context.data[ids.FRAME_KEY_CV_RESULTS] = DataFrame.from_dict({
            key: [value] for key, value in cv_results.iteritems()
        })
        return context

    def apply(self, context, train):
        if not self.indices:
            self.indices = self.generate_indices(context)

        # isolated part of main pipeline is runned multiple times
        return self.apply_single_cv(context)

    # NOTE: someday in the brave new world contexts will become picklable, and we won't need to pass
    # their components (like country/requested_categories) separately in this function
    def cross_validate(self, params, country, requested_categories, data, storage_content, pipeline_name):
        """
        Performs a single round of cross-validation, returns a dictionary
        containing scores and loss value.
        """
        from jafar.pipelines.pipeline import Pipeline

        cv_pipeline = Pipeline(self.nested_blocks, name=pipeline_name,
                               storage=MemoryStorage.from_dict(storage_content))
        if params is not None:
            logger.info('Trial with params: %s', str(params))
            cv_pipeline.set_params(params)

        scores_dict = defaultdict(list)
        for train_idx, test_idx in self.indices:

            # train step
            train_data = self.get_data_fold(data, train_idx)
            train_context = cv_pipeline.create_initial_context(
                country=country,
                requested_categories=requested_categories,
                frames=train_data
            )
            # NOTE: pipeline's `train` method should accept context analogous to `predict`
            logger.info("Cross-validation train pass")
            cv_pipeline.apply_blocks(train=True, initial_context=train_context)

            # predict step
            test_data = self.get_data_fold(data, test_idx)
            # set basket equal to train basket
            if ids.FRAME_KEY_BASKET in train_data:
                test_data[ids.FRAME_KEY_BASKET] = train_data[ids.FRAME_KEY_BASKET]
            test_context = cv_pipeline.create_initial_context(
                country=country,
                requested_categories=requested_categories,
                frames=test_data
            )
            logger.info("Cross-validation test pass")
            test_context = cv_pipeline.apply_blocks(train=False, initial_context=test_context)

            # take two frames (target and predictions) to compare
            target = test_context.data[self.target_frame]
            predictions = test_context.data[self.predictions_frame]

            for name, scorer in self.scorers.iteritems():
                score = scorer(target, predictions)
                scores_dict[name].append(score)
                logger.info("Scorer %s is %f", name, score)

        result = {}
        for name in self.scorers:
            for stat in self.statistics:
                result[name + '__' + stat.__name__] = stat(scores_dict[name])

        return result
