import cv2
import numpy as np


def _nn_match_two_way_on_score(dist_mat, nn_thresh):
    idx = np.argmin(dist_mat, axis=1)
    scores = dist_mat[np.arange(dist_mat.shape[0]), idx]
    keep = scores < nn_thresh
    idx2 = np.argmin(dist_mat, axis=0)
    keep_bi = np.arange(len(idx)) == idx2[idx]
    keep = np.logical_and(keep, keep_bi)
    idx = idx[keep]
    scores = scores[keep]
    idx1 = np.arange(dist_mat.shape[0])[keep]
    idx2 = idx
    return np.column_stack([idx1, idx2]), scores


def _nn_match_two_way(desc1, desc2, nn_thresh):
    """
    Performs two-way nearest neighbor matching of two sets of descriptors, such
    that the NN match from descriptor A->B must equal the NN match from B->A.

    Inputs:
      desc1 - N1xM numpy matrix of N corresponding M-dimensional descriptors.
      desc2 - N2xM numpy matrix of N corresponding M-dimensional descriptors.
      nn_thresh - Optional descriptor distance below which is a good match.

    Returns:
      matches - 3xL numpy array, of L matches, where L <= N and each column i is
                a match of two descriptors, d_i in image 1 and d_j' in image 2:
                [d_i index, d_j' index, match_score]^T
    """
    assert desc1.shape[1] == desc2.shape[1]
    if desc1.shape[0] == 0 or desc2.shape[0] == 0:
        return np.empty(), np.empty()
    if nn_thresh < 0.0:
        raise ValueError('\'nn_thresh\' should be non-negative')
    # Compute L2 distance. Easy since vectors are unit normalized.
    dmat = np.dot(desc1, desc2.T)
    dmat = np.sqrt(2 - 2 * np.clip(dmat, -1, 1))
    return _nn_match_two_way_on_score(dmat, nn_thresh)


def _calculate_fundamental_error(pt1, pt2, F):
    result = np.matmul(np.array([pt2[0], pt2[1], 1]), F)
    result = np.matmul(result, np.array([[pt1[0]], [pt1[1]], [1]]))
    return abs(result[0])


def confidence_func(fund_good_cnt, kpts0_cnt, kpts1_cnt, fund_center_err):
    EPSILON = 1e-10
    a0 = fund_good_cnt / kpts0_cnt
    a1 = fund_good_cnt / kpts1_cnt
    b = 1.0 / (fund_center_err + EPSILON)
    return a0 * a0 * a1 * a1 * b


def nn_mutual(pair_dist_thr, all_valid_pairs, first_context, second_context):
    def _extract_context(context):
        tmp = context.feature.size()
        image_shape = (tmp.height, tmp.width)
        assert (context.points.shape[1] == 2) and (context.points.shape[0] == context.confidence.shape[0])
        assert (context.descriptors.shape[1] == 256) and (context.points.shape[0] == context.descriptors.shape[0])

        return context.points, image_shape, context.confidence, context.descriptors

    def _extract_objects(objects0, objects1):
        types0 = set([x.type for x in objects0])
        types1 = set([x.type for x in objects1])
        common_types = types0.intersection(types1)
        objects_by_type = {}
        for t in common_types:
            objects_by_type[t] = [[], []]
        for idx, objs in enumerate([objects0, objects1]):
            for obj in objs:
                if obj.type in common_types:
                    objects_by_type[obj.type][idx] += [obj]
        return common_types, objects_by_type

    FIND_FUND_MAT_PTS_CNT_MIN = 7

    common_types, objects_by_type = _extract_objects(first_context.detections, second_context.detections)

    if 0 == len(common_types):
        return []

    kpts0, image_shape0, scores0, descs0 = _extract_context(first_context)
    kpts1, image_shape1, scores1, descs1 = _extract_context(second_context)

    pairs_idx, pairs_conf = _nn_match_two_way(descs0, descs1, pair_dist_thr)

    if pairs_idx.shape[0] <= FIND_FUND_MAT_PTS_CNT_MIN:
        return []

    valid_pts0 = kpts0[pairs_idx[:, 0]]
    valid_pts1 = kpts1[pairs_idx[:, 1]]

    assert valid_pts0.shape[0] == valid_pts1.shape[0], "Something really going wrong"

    F, good_pts = cv2.findFundamentalMat(valid_pts0, valid_pts1, cv2.FM_RANSAC, 3, 0.99)
    good_cnt = np.sum(np.reshape(good_pts, (good_pts.shape[0])))

    matches = []
    for t in common_types:
        if all_valid_pairs:
            for idx0, obj0 in enumerate(objects_by_type[t][0]):
                obj_center0 = ((obj0.box.min_x + obj0.box.max_x) / 2, (obj0.box.min_y + obj0.box.max_y) / 2)
                for idx1, obj1 in enumerate(objects_by_type[t][1]):
                    obj_center1 = ((obj1.box.min_x + obj1.box.max_x) / 2, (obj1.box.min_y + obj1.box.max_y) / 2)
                    err = _calculate_fundamental_error(obj_center0, obj_center1, F)
                    confidence = confidence_func(good_cnt, kpts0.shape[0], kpts1.shape[0], err)
                    matches += [(obj0.id, obj1.id, confidence)]
        else:
            errs = np.ones((len(objects_by_type[t][0]), len(objects_by_type[t][1])))
            for idx0, obj0 in enumerate(objects_by_type[t][0]):
                obj_center0 = ((obj0.box.min_x + obj0.box.max_x) / 2, (obj0.box.min_y + obj0.box.max_y) / 2)
                for idx1, obj1 in enumerate(objects_by_type[t][1]):
                    obj_center1 = ((obj1.box.min_x + obj1.box.max_x) / 2, (obj1.box.min_y + obj1.box.max_y) / 2)
                    errs[idx0, idx1] = _calculate_fundamental_error(obj_center0, obj_center1, F)
            pairs_idx, _ = _nn_match_two_way_on_score(errs, float("inf"))
            for pair in pairs_idx:
                idx0 = pair[0]
                idx1 = pair[1]
                confidence = confidence_func(good_cnt, kpts0.shape[0], kpts1.shape[0], errs[idx0, idx1])
                matches += [(objects_by_type[t][0][idx0].id, objects_by_type[t][1][idx1].id, confidence)]

    return matches
