import sys
from itertools import combinations
from collections import defaultdict
from rules import corsa_xr_metric, corsa_dt_metric

def calc_signals_match(images, signal1='corsax', signal2='dbd_score'):
    wins = 0.0
    total = 0

    qid_counter = defaultdict(int)

    for query in images:
        for img1, img2 in combinations(images[query], 2):
            if signal1 not in img1 or signal1 not in img2 or \
                        signal2 not in img1 or signal2 not in img2:
                continue
            #if img1[signal2] < 0.76 and img2[signal2] < 0.76:
            #    continue
            qid_counter[img1['qid']] += 1
            total += 1
            if img1[signal1] == img2[signal1] and img1[signal2] == img2[signal2]:
                wins += 1.0
            elif img1[signal1] == img2[signal1] or img1[signal2] == img2[signal2]:
                wins += 0.5
            elif img1[signal1] > img2[signal1] and img1[signal2] > img2[signal2]:
                wins += 1.0
            elif img1[signal1] < img2[signal1] and img1[signal2] < img2[signal2]:
                wins += 1.0

    return wins, total, qid_counter


def calc_r2(images, signal_origin, signal_predicted):
    y = []
    y_predicted = []

    for query in images:
        for img in images[query]:
            if signal_origin not in img or signal_predicted not in img:
                continue
            y.append(img[signal_origin])
            y_predicted.append(img[signal_predicted])

    e_y = sum(y) / len(y)
    e_y_predicted = sum(y_predicted) / len(y_predicted)

    print 'EX', e_y, e_y_predicted

    y_predicted = [yp_i + (e_y - e_y_predicted) for yp_i in y_predicted] # for expected values equality

    ss_total = sum([(y_i - e_y) ** 2 for y_i in y])
    ss_reg = sum([(y_i - f_i) ** 2 for y_i, f_i in zip(y, y_predicted)])

    r2 = 1.0 - ss_reg / ss_total
    return r2


def calc_corsa_xl_nodups(data):
    metric_sum = 0.0
    cnt = 0

    for img in data:
        relevance = img.get('relevance', 'IRRELEVANT')
        wideness = img.get('wideness', 'wideness_3')

        boost_weight = 0.3
        if relevance == 'RELEVANT_PLUS' and wideness == 'wideness_1':
            boost_weight = 0.8
        elif relevance == 'RELEVANT_PLUS' and wideness == 'wideness_2':
            boost_weight = 0.5

        metric_value = corsa_xr_metric(img, boost_weight)
        img['corsa-xl'] = metric_value

        metric_sum += metric_value
        cnt += 1

    if cnt == 0:
        return 0.0
    return metric_sum / cnt


def calc_corsa_dt_nodups(data):
    metric_sum = 0.0
    cnt = 0

    for img in data:
        relevance = img.get('relevance', 'IRRELEVANT')
        wideness = img.get('wideness', 'wideness_3')

        boost_weight = 0.3
        if relevance == 'RELEVANT_PLUS' and wideness == 'wideness_1':
            boost_weight = 0.8
        elif relevance == 'RELEVANT_PLUS' and wideness == 'wideness_2':
            boost_weight = 0.5

        metric_value = corsa_dt_metric(img, boost_weight)

        metric_sum += metric_value
        cnt += 1

    if cnt == 0:
        return 0.0
    return metric_sum / cnt
