def get_yuid(rec):
    return rec["yuid"]


def get_all_crypta_slices_of_assessors(assessor_yuids, crypta_graph, yuid2cid):
    result = {}
    for yuid in assessor_yuids:
        if yuid in yuid2cid:
            crypta_id = yuid2cid[yuid]
            result[crypta_id] = map(get_yuid, crypta_graph[crypta_id])
    return result


def get_prec_rec_for_single_rlogin(assessors_crypta_slices, assessor_yuids_set):
    rec_a = 0.0
    for crypta_id_set in assessors_crypta_slices.values():
        intersect_len = len(assessor_yuids_set.intersection(crypta_id_set)) - 1
        if intersect_len < 0:
            intersect_len = 0
        # correctly predicted device linkages for every crypta id related to single rlogin
        rec_a += intersect_len
    # real device linkages. Used rlogin devices as golden source
    rec_b = len(assessor_yuids_set) - 1 if len(assessor_yuids_set) > 0 else 0

    prec_b = 0.0
    for crypta_id_set in assessors_crypta_slices.values():
        # all predicted device linkages (all in related crypta ids)
        prec_b += len(crypta_id_set) - 1
    prec_a = rec_a

    cover_a = 1.0 if len(assessors_crypta_slices.keys()) > 0 else 0.0
    # cover_b = 1.0 if len(assessor_yuids_set) > 0 else 0.0

    return prec_a, prec_b, rec_a, rec_b, cover_a


class LoginMetrics:
    def __init__(self, login, prec_a, prec_b, rec_a, rec_b, cover):
        self.login = login
        self.prec_a = prec_a  # how many of our yuid predictions for login are correct
        self.prec_b = prec_b  # number of our yuid predictions for login
        self.rec_a = rec_a  # how many of real login predictions are
        self.rec_b = rec_b
        self.cover = cover
        self.crypta_ids = set()
        self.login_yuids = set()
        self.crypta_yuids = set()


def get_metrics_per_login(rlogin_graph, crypta_graph):
    # from cid->yuid mapping, getting reverse mapping of for quick lookup
    yuid2cid = {}
    for crypta_id, yuid_rows in crypta_graph.iteritems():
        for yuid_row in yuid_rows:
            yuid2cid[yuid_row["yuid"]] = crypta_id

    login_to_metrics = dict()

    # Count metrics per login to calculate average metrics later
    for login, assessor_rows in rlogin_graph.iteritems():
        assessor_yuids = set(map(get_yuid, assessor_rows))
        assessors_crypta_slices = get_all_crypta_slices_of_assessors(assessor_yuids, crypta_graph, yuid2cid)
        all_assessor_crypta_yuids = set()
        for crypta_id in assessors_crypta_slices:
            all_assessor_crypta_yuids.update(assessors_crypta_slices[crypta_id])

        # We are not interested in splices of single device (1 rlogin yuid <-> 1 crypta uid)
        if len(assessor_rows) < 2 and len(all_assessor_crypta_yuids) < 2:
            continue

        precision_a, precision_b, recall_a, recall_b, cover_a = get_prec_rec_for_single_rlogin(
            assessors_crypta_slices, assessor_yuids
        )

        login_metrics = LoginMetrics(login, precision_a, precision_b, recall_a, recall_b, cover_a)
        login_metrics.login_yuids = assessor_yuids
        login_metrics.crypta_ids = assessors_crypta_slices.keys()
        login_metrics.crypta_yuids = all_assessor_crypta_yuids
        login_to_metrics[login] = login_metrics

    return login_to_metrics


def median(lst):
    lst = sorted(lst)
    if len(lst) < 1:
        return None
    if len(lst) % 2 == 1:
        return lst[((len(lst) + 1) / 2) - 1]
    else:
        return float(sum(lst[(len(lst) / 2) - 1 : (len(lst) / 2) + 1])) / 2.0


def mean(lst):
    if lst:
        return sum(lst) / float(len(lst))
    else:
        return None


def average_metrics(precisions, recalls, yuids_threshold=None):

    # Precision
    all_correctly_detected = map(lambda prec: prec[0], precisions)
    all_detected = map(lambda prec: prec[1], precisions)
    precision_cumulative = sum(all_correctly_detected) / sum(all_detected) if sum(all_detected) != 0 else 0

    # count every missing answer as 0
    prec_list = [(prec[0] / float(prec[1]) if prec[1] != 0 else 0) for prec in precisions]
    precision_mean_pes = mean(prec_list)
    precision_median_pes = median(prec_list)

    # count every missing answer as missing answer
    # if we didn't predict splice at all, it doesn't mean we were incorrect.
    # it just means we didn't answer at all and thus it should be represented as recall problem.
    prec_list_opt = [(prec[0] / float(prec[1]) if prec[1] != 0 else 0) for prec in precisions if prec[0] != 0]
    precision_mean_opt = mean(prec_list_opt)
    precision_median_opt = median(prec_list_opt)

    # Recall
    all_correctly_detected = map(lambda rec: rec[0], recalls)
    all_correct = map(lambda rec: rec[1], recalls)
    recall_cumulative = sum(all_correctly_detected) / sum(all_correct) if sum(all_correct) != 0 else 0

    recall_list = [(rec[0] / float(rec[1]) if rec[1] != 0 else 0) for rec in recalls]
    recall_mean_pes = mean(recall_list)
    recall_median_pes = median(recall_list)

    # If there was no real splice, but we detected one, don't count it as incorrect answer.
    # It should be not recall problem, but precision one.
    recall_list_opt = [(rec[0] / float(rec[1]) if rec[1] != 0 else 0) for rec in recalls if rec[1] != 0]
    recall_mean_opt = mean(recall_list_opt)
    recall_median_opt = median(recall_list_opt)

    t = "max%d." % yuids_threshold if yuids_threshold else ""

    precision = {
        t + "cumulative": precision_cumulative,
        t + "mean.pes": precision_mean_pes,
        t + "median.pes": precision_median_pes,
        t + "mean.opt": precision_mean_opt,
        t + "median.opt": precision_median_opt,
    }

    recall = {
        t + "cumulative": recall_cumulative,
        t + "mean.pes": recall_mean_pes,
        t + "median.pes": recall_median_pes,
        t + "mean.opt": recall_mean_opt,
        t + "median.opt": recall_median_opt,
    }

    if precision_mean_opt and recall_mean_opt and (precision_mean_opt + recall_mean_opt) != 0:
        f1_score = 2 * precision_mean_opt * recall_mean_opt / (precision_mean_opt + recall_mean_opt)
    else:
        f1_score = 0.0

    return precision, recall, f1_score


def similarity(rlogin_graph, crypta_graph):
    rlogin_yuids = set()
    crypta_yuids = set()
    for rlogin_splice_yuids in rlogin_graph.values():
        rlogin_yuids.update(map(get_yuid, rlogin_splice_yuids))
    for crypta_splice_yuids in crypta_graph.values():
        crypta_yuids.update(map(get_yuid, crypta_splice_yuids))

    union = rlogin_yuids.union(crypta_yuids)
    intersect = rlogin_yuids.intersection(crypta_yuids)
    return len(intersect) / float(len(union))
