import logging

import numpy as np
import re
from sklearn.metrics import accuracy_score as _accuracy_sklearn
from sklearn.metrics import log_loss as _log_loss_sklearn
from sklearn.metrics import roc_auc_score as _roc_auc_sklearn
from sklearn.preprocessing import LabelEncoder

from jafar.utils import get_index_pairs_for_strings, fast, get_index_pairs_with_keys

logger = logging.getLogger(__name__)

"""
This is a partial duplicate of jafar.scoring, adapted for pipelines
"""


def more_is_better(func):
    func.more_is_better = True
    func.less_is_better = False
    return func


def less_is_better(func):
    func.more_is_better = False
    func.less_is_better = True
    return func


@more_is_better
def avg_roc_auc_score(target, predictions):
    """
    This version of ROC AUC score is averaged per user.
    Similar approach is used in:

     * https://github.com/sryza/aas/blob/master/ch03-recommender/src/main/scala/com/cloudera/datascience/recommender/RunRecommender.scala
     * https://github.com/bbcrd/theano-bpr/blob/master/theano_bpr/bpr.py
    """
    assert np.array_equal(target['user'], predictions['user']), "User columns don't match each other"
    assert np.array_equal(target['item'], predictions['item']), "Item columns don't match each other"
    assert len(np.unique(target['value'])) >= 2, 'ROC AUC is undefined in case of one label. You should probably use `add_implicit_negatives`'
    idx = np.argsort(target['user'])
    index_pairs = get_index_pairs_for_strings(target['user'][idx])
    return fast.roc_auc_score_avg(
        np.float32(target['value'][idx]),
        np.float32(predictions['value'][idx]),
        np.int32(index_pairs)
    )


@more_is_better
def overall_roc_auc_score(target, predictions):
    """
    This version of ROC AUC score treats X[:, 2] as one
    binary vector.
    """
    assert np.array_equal(target['user'], predictions['user']), "User columns don't match each other"
    assert np.array_equal(target['item'], predictions['item']), "Item columns don't match each other"
    assert len(np.unique(target['value'])) >= 2, \
        'ROC AUC is undefined in case of one label. You should probably use `add_implicit_negatives`'
    return _roc_auc_sklearn(np.float32(target['value']), np.float32(predictions['value']))


@less_is_better
def log_loss_score(target, predictions):
    assert np.array_equal(target['user'], predictions['user']), "User columns don't match each other"
    assert np.array_equal(target['item'], predictions['item']), "Item columns don't match each other"
    assert len(np.unique(target['value'])) == 2, \
        'Logistic loss is undefined in case of one label. You should probably use `add_implicit_negatives`'
    return _log_loss_sklearn(np.float32(target['value']), np.float32(predictions['value']))


# top-n scoring functions

def _top_n_score(target, predictions, func, top_n):
    target = target.copy()
    # for top-n scores we only need positive items
    n_target_rows = len(target)
    target = target[target['value'] > 0]
    logger.debug(
        "Computing top-n score, leaving only positive items in target frame "
        "(before: %s rows, after: %s rows)", n_target_rows, len(target)
    )
    assert len(target) > 0, "Score is undefined in case of no positive items (with `value`=1)"
    # NOTE: cannot use ascending/descending directions in np.sort, hence np.lexsort
    target = target[np.lexsort((-target['value'], target['user']))]
    predictions = predictions[np.lexsort((-predictions['value'], predictions['user']))]
    item_encoder = LabelEncoder().fit(np.concatenate([
        target['item'], predictions['item']
    ]))
    return func(
        np.int32(item_encoder.transform(target['item'])),
        np.int32(get_index_pairs_for_strings(target['user'])),
        np.int32(item_encoder.transform(predictions['item'])),
        np.int32(get_index_pairs_for_strings(predictions['user'])),
        top_n
    )


class BaseScore(object):
    score = None

    def __init__(self, top_n):
        self.top_n = top_n

    def __call__(self, target, predictions):
        return _top_n_score(target, predictions, self.score, self.top_n)


@more_is_better
class MAPScore(BaseScore):
    score = fast.map_score


@more_is_better
class NDCGScore(BaseScore):
    score = fast.ndcg_score


@more_is_better
class NDCGScoreArranger(BaseScore):
    """
    NDCGScore that actually uses relevance values from target
    """

    def __call__(self, target, predictions):
        logger.debug("Computing top-n score")
        assert len(target) > 0, "Target should be non-zero"
        # NOTE: cannot use ascending/descending directions in np.sort, hence np.lexsort
        target = target[np.lexsort((-target['value'], target['user']))]
        predictions = predictions[np.lexsort((-predictions['value'], predictions['user']))]
        predict_pairs = dict(get_index_pairs_with_keys(predictions['user']))
        result = 0
        for l, (key, (target_start, target_end)) in enumerate(get_index_pairs_with_keys(target['user'])):
            user_target = target[target_start: target_end]
            counts = dict(zip(user_target['item'], user_target['value']))
            if key not in predict_pairs:  # no prediction for this user
                continue
            predictions_start, predictions_end = predict_pairs.get(key)
            user_predictions = predictions[predictions_start: predictions_end]
            dcg = np.sum([counts[item] / np.log(i + 2) for i, item in enumerate(user_predictions['item'][:self.top_n])])
            idcg = np.sum([value / np.log(i + 2) for i, value in enumerate(user_target['value'][:self.top_n])])
            assert idcg != 0, "User has all zero counters"
            result += dcg / idcg
        return result / l


@more_is_better
class NDCGScoreArrangerPositive(NDCGScoreArranger):

    def __call__(self, target, predictions):
        predictions = predictions[predictions['value'] > 0]
        return super(NDCGScoreArrangerPositive, self).__call__(target, predictions)


@more_is_better
class MRRScore(BaseScore):
    score = fast.mrr_score


@more_is_better
class PrecisionScore(BaseScore):
    score = fast.precision_score


@more_is_better
class RecallScore(BaseScore):
    score = fast.recall_score


@more_is_better
def accuracy_score(target, predictions):
    scores = []
    thresholds = np.linspace(predictions['value'].min(), predictions['value'].max())
    for threshold in thresholds:
        scores.append(_accuracy_sklearn(target['value'], (predictions['value'] > threshold).astype(np.int32)))
    best_score_idx = np.argmax(scores)
    logger.debug("Best accuracy is reached at threshold %s", thresholds[best_score_idx])
    return scores[best_score_idx]


def get_scorer(name):
    """
    For MAP, NDCG, MRR, precision, recall scorers `name` has to be formatted
    like "MAP@k", where `k` denotes top-k results taken into
    account.

    NOTE: this should be more obvious
    """
    scorers = {
        'avg_roc_auc_score': avg_roc_auc_score,
        'overall_roc_auc_score': overall_roc_auc_score,
        'log_loss': log_loss_score,
        'MAP': MAPScore,
        'NDCG': NDCGScore,
        'NDCGA': NDCGScoreArranger,
        'NDCGA_pos': NDCGScoreArrangerPositive,
        'MRR': MRRScore,
        'accuracy_score': accuracy_score,
        'precision': PrecisionScore,
        'recall': RecallScore
    }
    match = re.findall(r'(MAP|NDCG|NDCGA|NDCGA_pos|MRR|precision|recall)@(\d+)', name)
    if match:
        scorer_name, k = match[0]
        return scorers[scorer_name](int(k))
    else:
        try:
            return scorers[name]
        except KeyError:
            raise ValueError(
                "Unknown scorer '{}'. (Note that scorers like MAP/NDCG/MRR need "
                "'@k' postfix, like 'MAP@10')".format(name)
            )
