from jafar.utils import get_index_pairs, fast, sort_by_users, check_frame
from jafar.estimators import Popular

from sklearn.metrics import roc_auc_score as _roc_auc_sklearn
import numpy as np
import pandas as pd
import logging

logger = logging.getLogger(__name__)


def avg_roc_auc_score(recommender, X, y=None):
    """
    This version of ROC AUC score is averaged per user.
    Similar approach is used in:

     * https://github.com/sryza/aas/blob/master/ch03-recommender/src/main/scala/com/cloudera/datascience/recommender/RunRecommender.scala
     * https://github.com/bbcrd/theano-bpr/blob/master/theano_bpr/bpr.py
    """
    check_frame(X, ('value',))
    assert X['value'].nunique() >= 2, 'ROC AUC is undefined in case of one label. You should probably use `add_implicit_negatives`'
    prediction = recommender.predict(X)
    # sort again
    X_sorted, index_pairs, idx = sort_by_users(X, return_index=True)
    prediction = prediction.iloc[idx]
    return fast.roc_auc_score_avg(np.float32(X_sorted['value']), np.float32(prediction['value']), np.int32(index_pairs))


def overall_roc_auc_score(recommender, X, y=None):
    """
    This version of ROC AUC score treats X[:, 2] as one
    binary vector.
    """
    check_frame(X, ('value',))
    assert X['value'].nunique() >= 2, 'ROC AUC is undefined in case of one label. You should probably use `add_implicit_negatives`'
    prediction = recommender.predict(X)
    return _roc_auc_sklearn(X['value'], prediction['value'])

# scoring functions with parameters


class UnexpectednessScore(object):

    def __name__(self):
        return 'unexpectedness'

    def __init__(self, k):
        self.k = k

    def __call__(self, recommender, X, y=None):
        if isinstance(recommender, Popular):
            return 0.0
        users = pd.DataFrame(dict(user=np.array(list(set(X['user'])))))

        predicted = recommender.predict_top_n(users, self.k)
        expected = Popular(
            n_users=recommender.n_users,
            n_items=recommender.n_items
        ).fit(X).predict_top_n(users, self.k)

        expected_index_pairs = get_index_pairs(expected['user'])
        predicted_index_pairs = get_index_pairs(predicted['user'])

        return fast.unexpectedness_score(
            np.int32(expected['item']),
            np.int32(expected_index_pairs),
            np.int32(predicted['item']),
            np.int32(predicted_index_pairs),
        )


class MAPScore(object):

    def __name__(self):
        return 'map'

    def __init__(self, k, threshold=0):
        self.k = k
        self.threshold = 0

    def __call__(self, recommender, X, y=None):
        X = X[X['value'] > self.threshold]
        X_sorted, x_index_pairs = sort_by_users(X)
        users = pd.DataFrame(dict(user=np.array(list(set(X_sorted['user'])))))

        predicted = recommender.predict_top_n(users, self.k)
        predicted_index_pairs = get_index_pairs(predicted['user'])

        return fast.map_score(
            np.int32(X_sorted['item']),
            np.int32(x_index_pairs),
            np.int32(predicted['item']),
            np.int32(predicted_index_pairs),
            self.k
        )


class NDCGScore(object):

    def __name__(self):
        return 'ndcg'

    def __init__(self, k, threshold=0):
        self.k = k
        self.threshold = 0

    def __call__(self, recommender, X, y=None):
        X = X[X['value'] > self.threshold]
        X_sorted, x_index_pairs = sort_by_users(X)
        users = pd.DataFrame(dict(user=np.array(list(set(X_sorted['user'])))))

        predicted = recommender.predict_top_n(users, self.k)
        predicted_index_pairs = get_index_pairs(predicted['user'])

        return fast.ndcg_score(
            np.int32(X_sorted['item']),
            np.int32(x_index_pairs),
            np.int32(predicted['item']),
            np.int32(predicted_index_pairs),
            self.k
        )


class MRRScore(object):

    def __name__(self):
        return 'mrr'

    def __init__(self, k, threshold=0):
        self.k = k
        self.threshold = 0

    def __call__(self, recommender, X, y=None):
        X = X[X['value'] > self.threshold]
        X_sorted, x_index_pairs = sort_by_users(X)
        users = pd.DataFrame(dict(user=np.array(list(set(X_sorted['user'])))))

        predicted = recommender.predict_top_n(users, self.k)
        predicted_index_pairs = get_index_pairs(predicted['user'])

        return fast.mrr_score(
            np.int32(X_sorted['item']),
            np.int32(x_index_pairs),
            np.int32(predicted['item']),
            np.int32(predicted_index_pairs),
            self.k
        )


unexpectedness_score = UnexpectednessScore(k=10)
map_score = MAPScore(k=10, threshold=0)
ndcg_score = NDCGScore(k=10, threshold=0)
mrr_score = MRRScore(k=10, threshold=0)
