import logging

import numpy as np
from numpy.lib.recfunctions import unstructured_to_structured
from scipy.sparse import issparse
from sklearn.preprocessing import StandardScaler, FunctionTransformer, OneHotEncoder, LabelEncoder
from sklearn.utils.validation import column_or_1d

from jafar.pipelines import ids
from jafar.pipelines.blocks import SingleContextBlock
from jafar.pipelines.misc import extract_feature_names
from jafar.storages import make_key
from jafar.utils.structarrays import get_null_index, get_null_value

logger = logging.getLogger(__name__)


class BaseSklearnTransformerBlock(SingleContextBlock):
    """
    This block takes Transformer class from sklearn and runs its
    fit/transform operations at train/predict steps.
    """

    def __init__(self, features, transformer_class, transformer_attributes, transformer_kwargs=None,
                 input_frame=ids.FRAME_KEY_PREDICTIONS, *args, **kwargs):
        """
        :param features: list of features to apply the transformer to
        :param transformer_class: sklearn's Transformer class
        :param transformer_attributes: attributes of a fitted transformer which will be
            saved to pipeline storage.
        :param transformer_kwargs: dictionary of transformer parameters
        """
        super(BaseSklearnTransformerBlock, self).__init__(
            input_data=[input_frame], output_data=None, destroyed_data=None, *args, **kwargs
        )
        assert features, 'features must be non-empty'
        self.input_frame = input_frame
        self.features = features
        self.transformer_class = transformer_class
        self.transformer_attributes = transformer_attributes
        self.transformer_kwargs = transformer_kwargs or {}

    @property
    def transformer_name(self):
        return self.transformer_class.__name__

    def key_for(self, context, name=None):
        return super(BaseSklearnTransformerBlock, self).key_for(
            context, make_key(self.transformer_name, name)
        )

    def fit(self, context):
        frame = context.data[self.input_frame]
        extracted_features = extract_feature_names(frame, self.features)
        frame = frame[extracted_features].to_2d_array()

        transformer = self.transformer_class(**self.transformer_kwargs).fit(frame)
        for attribute in self.transformer_attributes:
            context.storage.store(self.key_for(context, attribute), getattr(transformer, attribute))
        return context

    def transform(self, context):
        frame = context.data[self.input_frame]
        if len(frame) == 0:
            return context
        extracted_features = extract_feature_names(frame, self.features)
        frame = frame[extracted_features].to_2d_array()

        transformer = self.transformer_class(**self.transformer_kwargs)
        for attribute in self.transformer_attributes:
            attribute_value = context.storage.get_object(self.key_for(context, attribute))
            setattr(transformer, attribute, attribute_value)

        transformed = transformer.transform(frame)
        context.data[self.input_frame] = self.attach_transformed_columns(
            context.data[self.input_frame], transformed, extracted_features
        )
        return context

    def attach_transformed_columns(self, frame, transformed, columns):
        raise NotImplementedError

    def apply(self, context, train):
        raise NotImplementedError


class FitBlock(BaseSklearnTransformerBlock):
    """
    Fits transformer, doesn't actually transform anything.
    Can be useful when transformer needs to be placed before
    cross-validation.
    """

    def apply(self, context, train):
        if train:
            return self.fit(context)
        else:
            return context


class TransformBlock(BaseSklearnTransformerBlock):
    """
    Takes previously fitted transformer, applies transformation.

    NOTE: this is kinda "encoder block"-"decoder block" situation
    we tried to avoid since MappingBlock has been made composite,
    but making a composite TransformerBlock is hardly better.
    """

    def apply(self, context, train):
        return self.transform(context)


class FitAndTransformBlock(BaseSklearnTransformerBlock):
    """
    Combines FitBlock with TransformBlock.
    """

    def apply(self, context, train):
        if train:
            context = self.fit(context)
        return self.transform(context)


class DimensionPreservingMixin(object):
    """
    This version of transformer block doesn't change dimensionality
    of its input: it transforms specified `columns` into the same
    number of columns. Therefore transformed columns can have the same
    names as before.
    """

    def attach_transformed_columns(self, frame, transformed, columns):
        assert transformed.shape[1] == len(columns), \
            ("Expected {} dimensions, got {} instead. Perhaps an instance of "
             "DimensionChangingMixin should be used instead?".format(
                len(columns), transformed.shape[1]))

        # since we're preserving original names, just replace the existing columns
        # NOTE: this may fail if dtype has been changed, but we'll think about it later
        for i, feature in enumerate(columns):
            frame = frame.replace_column(transformed[:, i], feature)

        return frame


class DimensionChangingMixin(object):
    """
    This version of transformer block changes the dimensionality
    of its input: it transforms k `columns` into m columns (m can be
    less than or greater than k). Column name preservation is impossible
    in this case, and DimensionChangingMixin will name the
    output columns {output_prefix}_1, {output_prefix}_2 and so on.
    """

    def __init__(self, features, output_prefix, *args, **kwargs):
        self.output_prefix = output_prefix
        super(DimensionChangingMixin, self).__init__(features, *args, **kwargs)

    def attach_transformed_columns(self, frame, transformed, columns):
        if issparse(transformed):
            transformed = transformed.todense()
        transformed = unstructured_to_structured(transformed,
                                                 names=['{}_{}'.format(self.output_prefix, i)
                                                        for i in xrange(transformed.shape[1])])
        # NOTE: `append_columns` replaces original columns. perhaps we shouldn't do that, but since it's
        # called _transformer_ block, this should help avoiding confusion
        frame = frame.append_columns(transformed)
        return frame


class StandardScalerBlock(DimensionPreservingMixin, FitAndTransformBlock):
    def __init__(self, features, *args, **kwargs):
        super(StandardScalerBlock, self).__init__(
            features,
            transformer_class=StandardScaler,
            transformer_attributes=['mean_', 'scale_'],
            *args, **kwargs
        )


class LogarithmicBlock(DimensionPreservingMixin, FitAndTransformBlock):
    def __init__(self, features, *args, **kwargs):
        super(LogarithmicBlock, self).__init__(
            features,
            transformer_class=FunctionTransformer,
            transformer_attributes=[],
            transformer_kwargs={'func': lambda x: np.log(x + 1)},
            *args, **kwargs
        )


class LabelEncoderWithDefaultClass(LabelEncoder):
    """
    Similar to sklearn label encoder, but ensures that we can apply transform to default values
    """

    def fit(self, y):
        y = column_or_1d(y, warn=False)
        super(LabelEncoderWithDefaultClass, self).fit(y)
        null_value = get_null_value(y.dtype)
        if null_value not in self.classes_:
            logger.info("Adding default class '{}' to '{}'.".format(null_value, self.classes_))
            self.classes_ = np.append(self.classes_, null_value)
        return self


class FitLabelEncoderBlock(DimensionPreservingMixin, FitBlock):
    def __init__(self, feature, ensure_default_label=False, *args, **kwargs):
        super(FitLabelEncoderBlock, self).__init__(
            features=[feature],
            transformer_class=LabelEncoderWithDefaultClass if ensure_default_label else LabelEncoder,
            transformer_attributes=['classes_'],
            *args, **kwargs
        )

    @property
    def transformer_name(self):
        return LabelEncoder.__name__


class TransformLabelEncoderBlock(DimensionPreservingMixin, TransformBlock):
    def __init__(self, features, *args, **kwargs):
        super(TransformLabelEncoderBlock, self).__init__(
            features,
            transformer_class=LabelEncoder,
            transformer_attributes=['classes_'],
            *args, **kwargs
        )

    @property
    def transformer_name(self):
        return LabelEncoder.__name__


class TransformAppendLabelEncoderBlock(DimensionChangingMixin, TransformLabelEncoderBlock):
    pass


class LabelEncoderBlock(DimensionPreservingMixin, FitAndTransformBlock):
    def __init__(self, features, *args, **kwargs):
        super(LabelEncoderBlock, self).__init__(
            features,
            transformer_class=LabelEncoder,
            transformer_attributes=['classes_'],
            *args, **kwargs
        )

    @property
    def transformer_name(self):
        return LabelEncoder.__name__


class OneHotEncoderBlock(DimensionChangingMixin, FitAndTransformBlock):
    def __init__(self, features, output_prefix, *args, **kwargs):
        super(OneHotEncoderBlock, self).__init__(
            features, output_prefix,
            transformer_class=OneHotEncoder,
            transformer_attributes=['feature_indices_', 'n_values_', 'active_features_'],
            *args, **kwargs
        )


class FillMissingBlock(SingleContextBlock):

    def __init__(self, fill_values, input_frame=ids.FRAME_KEY_PREDICTIONS):
        """
        :param fill_values: a dict-like object containing {column: fill value} mapping
        :param input_frame: frame to apply transformation to
        """
        super(FillMissingBlock, self).__init__(
            input_data=[input_frame], output_data=None, destroyed_data=None
        )
        self.input_frame = input_frame
        self.fill_values = fill_values

    def apply(self, context, train):
        frame = context.data[self.input_frame]
        for column, (dtype, _) in frame.dtype.fields.iteritems():
            if column in self.fill_values:
                idx = get_null_index(frame[column])
                frame[column][idx] = self.fill_values[column]
        context.data[self.input_frame] = frame
        return context
