import logging
import operator
import os

import numpy as np
from catboost import CatBoostClassifier, Pool
from flask import current_app as app
from sklearn.linear_model import LogisticRegression

from jafar import fast_cache
from jafar.arranger.model import ArrangerModel
from jafar.pipelines import ids
from jafar.pipelines.blocks import SingleContextBlock
from jafar.pipelines.misc import extract_feature_names

logger = logging.getLogger(__name__)


def get_xboost_features_importance(classifier, features):
    """
    converts to readable format e.g.:
    {'f0': 22, 'f1': 33, 'f2': 11} -> [('popular', 33), ('rating', 22), ('ii', 11)]
    """

    importance = classifier.booster().get_fscore()
    sorted_importance = reversed(sorted(importance.items(), key=operator.itemgetter(1)))
    result = []
    for name, score in sorted_importance:
        feature_idx = int(name[1:])
        feature_name = features[feature_idx]
        result.append((feature_name, score))
    return result


class BaseClassifierBlock(SingleContextBlock):
    def __init__(self, features, predictions_column='value', input_frame=ids.FRAME_KEY_PREDICTIONS, *args, **kwargs):
        """
        Classifier blocks operate on "predictions" frame only.
        TODO: should that be a subject to change?
        """
        super(BaseClassifierBlock, self).__init__(
            input_data=[input_frame], destroyed_data=None, *args, **kwargs
        )
        assert features, 'Columns must be non-empty'
        self.features = features
        self.predictions_column = predictions_column
        self.input_frame = input_frame

    def apply(self, context, train):
        if train:
            return self.fit(context)
        else:
            return self.predict(context)

    def fit(self, context):
        raise NotImplementedError

    def predict(self, context):
        raise NotImplementedError


class SklearnClassifierBlock(BaseClassifierBlock):
    """
    fits classifier or predicts with classifier
    """

    def __init__(self, features, predictions_column='value', input_frame=ids.FRAME_KEY_PREDICTIONS,
                 classifier_kwargs=None, sample_weight=None, fit_params=None, *args, **kwargs):
        super(SklearnClassifierBlock, self).__init__(features, predictions_column, input_frame, *args, **kwargs)
        self.input_frame = input_frame
        self.classifier_kwargs = classifier_kwargs or {}
        self.fit_params = fit_params or {}
        self.sample_weight = sample_weight

    def fit(self, context):
        self._prepare_fit_params(context)

        logger.debug('Training classifier with next params: %s', str(self.fit_params.keys()))

        classifier = self.create_classifier()
        classifier.fit(**self.fit_params)

        self.store_classifier(classifier, context)
        return context

    def _prepare_fit_params(self, context):
        predictions = context.data[self.input_frame]
        extracted_features = extract_feature_names(predictions, self.features)

        target = predictions[self.predictions_column].to_array()
        features = predictions[extracted_features].copy().to_2d_array().astype(np.float32)
        if self.sample_weight is not None:
            sample_weight = predictions[self.sample_weight].copy()
        else:
            sample_weight = None

        self.fit_params.update(X=features, y=target, sample_weight=sample_weight)

    def predict(self, context):
        predictions = context.data[self.input_frame]
        extracted_features = extract_feature_names(predictions, self.features)

        if len(predictions) == 0:
            score = np.empty(0, dtype=np.float32)
        else:
            classifier = self.read_classifier(context)
            score = classifier.predict_proba(
                predictions[extracted_features].to_pandas()
            )[:, 1].astype(np.float32)

        context.data[self.input_frame] = predictions.append_column(score, self.predictions_column)
        return context

    def create_classifier(self):
        raise NotImplementedError

    def store_classifier(self, classifier, context):
        """
        save trained classifier to the storage
        """
        raise NotImplementedError

    def read_classifier(self, context):
        """
        read fitted classifier from storage
        """
        raise NotImplementedError


class LogisticClassifierBlock(SklearnClassifierBlock):
    def create_classifier(self):
        return LogisticRegression(solver='sag', **self.classifier_kwargs)

    def read_classifier(self, context):
        """
        read fitted logistic classifier from storage
        """
        classifier = LogisticRegression(**self.classifier_kwargs)
        classifier.intercept_ = context.storage.get_object(self.key_for(context, 'intercept')).ravel()
        coef = context.storage.get_object(self.key_for(context, 'coef')).ravel()
        classifier.coef_ = coef.reshape((1, len(coef)))
        return classifier

    def store_classifier(self, classifier, context):
        """
        save trained classifier to the storage
        """
        logger.info('Store logistic regression data: intercept={}; coefficients={}'.format(classifier.intercept_,
                                                                                           classifier.coef_))
        context.storage.store(self.key_for(context, 'coef'), classifier.coef_)
        context.storage.store(self.key_for(context, 'intercept'), classifier.intercept_)


class DummyClassifierBlock(BaseClassifierBlock):
    """
    returns feature value as prediction
    """

    def __init__(self, feature, predictions_column='value', input_frame=ids.FRAME_KEY_PREDICTIONS, ):
        super(DummyClassifierBlock, self).__init__([feature], predictions_column, input_frame)
        self.feature = feature

    def fit(self, context):
        return context

    def predict(self, context):
        predictions = context.data[self.input_frame]
        assert self.feature in predictions.dtype.names, \
            "Feature '{}' not found in predictions frame".format(self.feature)

        context.data[self.input_frame] = predictions.append_column(predictions[self.feature], self.predictions_column)
        return context


def get_absolute_path(filename):
    return os.path.join(app.config['DATASET_PATH'], filename)


class CatBoostClassifierBlock(SklearnClassifierBlock):
    def __init__(self, categorical_features=None, auxiliary_features=('user', 'item'), local=True, *args, **kwargs):
        """
        Classifier block with CatBoost
        :param categorical_features: list of categorical features names
        :param local: whether to fit inside block or in external Nirvana block
        :param args:
        :param kwargs:
        """
        super(CatBoostClassifierBlock, self).__init__(*args, **kwargs)
        self.categorical_features = categorical_features
        self.auxiliary_features = list(auxiliary_features)
        self.local = local

    def fit(self, context):
        logger.debug("Fitting CatBoost with next features: %s", ' '.join(self.features))
        if self.local:
            logger.debug('Fitting Catboost locally')
            return super(CatBoostClassifierBlock, self).fit(context)

        # Catboost is fitted out of this context
        # fitted model will be stored in storage
        assert app.config['DEPLOYMENT_TYPE'] == app.config['NIRVANA'], "Catboost classifier block can be fitted " \
                                                                       "only in Nirvana"
        logger.debug('Dumping data for Catboost')
        self._dump_data(context)

        return context

    def create_classifier(self):
        return CatBoostClassifier(random_seed=42, **self.classifier_kwargs)

    def store_classifier(self, classifier, context):
        context.storage.store(self.key_for(context, 'catboost.bin'), classifier)

    def read_classifier(self, context):
        """
        read fitted catboost classifier from filesystem
        """
        @fast_cache.memoize()
        def get_cached_classifier(storage_key):
            return context.storage.get_object(storage_key)

        return get_cached_classifier(self.key_for(context, 'catboost.bin'))

    def _prepare_fit_params(self, context):
        predictions = context.data[self.input_frame]
        extracted_features = extract_feature_names(predictions, self.features)

        target = predictions[self.predictions_column].to_array()
        features = predictions[extracted_features].copy()
        if self.sample_weight is not None:
            sample_weight = predictions[self.sample_weight].copy()
        else:
            sample_weight = None

        X = Pool(data=features.to_pandas(),
                 label=target,
                 cat_features=extract_feature_names(features, self.categorical_features))

        self.fit_params.update(X=X, sample_weight=sample_weight)

    def _dump_data(self, context):
        column_description, data = self._get_catboost_data(context)
        data.to_tsv(get_absolute_path('catboost_data.tsv'))

        with open(get_absolute_path('column_description.tsv'), 'w') as cd:
            cd.write('\n'.join(['\t'.join(map(str, row)) for row in column_description]))

        with open(get_absolute_path('cat_key'), 'w') as ck:
            ck.write(self.key_for(context, 'catboost.bin'))  # key of catboost model to store it in storage later

    def _get_catboost_data(self, context):
        predictions = context.data[self.input_frame]
        columns = (extract_feature_names(predictions, self.features) + [self.predictions_column] +
                   self.auxiliary_features)
        if self.sample_weight is not None:
            columns += [self.sample_weight]

        logger.debug("CatBoost columns %s", ' '.join(columns))
        data = predictions[columns]

        column_description = [(data.columns.index(self.predictions_column), 'Target')]
        if self.sample_weight:
            column_description.extend([(data.columns.index(self.sample_weight), 'Weight')])

        if self.categorical_features:
            column_description.extend([(data.columns.index(column), 'Categ')
                                       for column in extract_feature_names(data, self.categorical_features)])

        if self.auxiliary_features:
            column_description.extend([(data.columns.index(column), 'Auxiliary')
                                       for column in self.auxiliary_features])

        return column_description, data


class NNClassifierBlock(SklearnClassifierBlock):
    nn_weights_storage_path = 'arranger_nn_model.bin'
    nn_model_key = 'nn_key'

    def _prepare_fit_params(self, context):
        predictions = context.data[self.input_frame]
        extracted_features = extract_feature_names(predictions, self.features)
        target = predictions[self.predictions_column].to_array()
        self.fit_params.update(X=predictions[['user'] + extracted_features], y=target)

    def predict(self, context):
        predictions = context.data[self.input_frame]
        extracted_features = extract_feature_names(predictions, self.features)

        if len(predictions) == 0:
            score = np.empty(0, dtype=np.float32)
        else:
            classifier = self.read_classifier(context)
            score = classifier.predict_proba(
                predictions[['user'] + extracted_features]
            )
        context.data[self.input_frame] = predictions.append_column(score, self.predictions_column)
        return context

    def create_classifier(self):
        return ArrangerModel(len(self.fit_params['X'].columns) - 1, **self.classifier_kwargs)

    def store_classifier(self, classifier, context):
        model_weights = classifier.state_dict()
        context.storage.store(self.key_for(context, self.nn_weights_storage_path), model_weights)

    def read_classifier(self, context):
        """
        read fitted neural network from filesystem
        """
        @fast_cache.memoize()
        def get_cached_weights(storage_key):
            model_weights = context.storage.get_object(storage_key)
            embedding_size = self._get_embedding_size_from_weights(model_weights)
            return model_weights, embedding_size

        model_weights, embedding_size = get_cached_weights(self.key_for(context, self.nn_weights_storage_path))
        classifier = ArrangerModel(embedding_size)
        classifier.load_state_dict(model_weights)
        return classifier

    def _dump_data(self, context):
        predictions = context.data[self.input_frame]
        columns = extract_feature_names(predictions, self.features) + [self.predictions_column, 'user']
        data = predictions[columns]
        data.to_tsv(get_absolute_path('nn_data.tsv'))

        with open(self.nn_model_key, 'w') as f:
            # key for nn model weights to store it in storage later
            f.write(self.key_for(context, self.nn_weights_storage_path))

    def _get_embedding_size_from_weights(self, weights):
        # First weights matrix in 'weights' is of size N x input_size
        return weights.values()[0].shape[1]
