import collections
import cv2
import math
import numpy as np
from functools import lru_cache

import tensorflow as tf
from tensorflow.core.framework import graph_pb2

from library.python import resource


class SuperGlueTF(object):
    def __init__(self, gdef):
        self._make_placeholders()
        self._build_inference_graph(gdef)
        self._init_session()

    def _make_placeholders(self):
        FEATURE_DIM = 256

        self.eval_thr_phs = tf.placeholder(tf.float32, shape=[])
        self.is_training = tf.placeholder_with_default(False, shape=[])

        self.points0_ph = tf.placeholder(tf.float32, shape=[None, None, 2])
        self.scores0_ph = tf.placeholder(tf.float32, shape=[None, None])
        self.descs0_ph = tf.placeholder(tf.float32, shape=[None, None, FEATURE_DIM])
        self.imgsz0_ph = tf.placeholder(tf.float32, shape=[None, 2])

        self.points1_ph = tf.placeholder(tf.float32, shape=[None, None, 2])
        self.scores1_ph = tf.placeholder(tf.float32, shape=[None, None])
        self.descs1_ph = tf.placeholder(tf.float32, shape=[None, None, FEATURE_DIM])
        self.imgsz1_ph = tf.placeholder(tf.float32, shape=[None, 2])

    def _build_inference_graph(self, gdef):
        graph_def = graph_pb2.GraphDef()
        graph_def.ParseFromString(gdef)

        for node in graph_def.node:
            if node.op == 'RefSwitch':
                node.op = 'Switch'
                for index in range(len(node.input)):
                    if 'moving_' in node.input[index]:
                        node.input[index] = node.input[index] + '/read'
            elif node.op == 'AssignSub':
                node.op = 'Sub'
                if 'use_locking' in node.attr:
                    del node.attr['use_locking']
            elif node.op == 'AssignAdd':
                node.op = 'Add'
                if 'use_locking' in node.attr:
                    del node.attr['use_locking']

        input_map = {
            'eval_thr': self.eval_thr_phs,
            'is_training': self.is_training,
            'eval_points0': self.points0_ph,
            'eval_scores0': self.scores0_ph,
            'eval_descs0': self.descs0_ph,
            'eval_img_sz0': self.imgsz0_ph,
            'eval_points1': self.points1_ph,
            'eval_scores1': self.scores1_ph,
            'eval_descs1': self.descs1_ph,
            'eval_img_sz1': self.imgsz1_ph,
        }
        tf.import_graph_def(graph_def, name='', input_map=input_map)
        g = tf.get_default_graph()
        self.eval_indices0 = g.get_tensor_by_name('eval_indices0:0')
        self.eval_indices1 = g.get_tensor_by_name('eval_indices1:0')
        self.eval_mscores0 = g.get_tensor_by_name('eval_mscores0:0')
        self.eval_mscores1 = g.get_tensor_by_name('eval_mscores1:0')

    def _init_session(self):
        init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        self.sess = tf.Session()
        self.sess.run(init_op)

    def forward(self, kpts0, kpts1, image_shape0, image_shape1, scores0, scores1, descs0, descs1, conf_thr):
        img_sz0 = np.expand_dims(np.array([image_shape0[1], image_shape0[0]], np.float32), 0)
        img_sz1 = np.expand_dims(np.array([image_shape1[1], image_shape1[0]], np.float32), 0)
        feed_dict = {
            self.eval_thr_phs: conf_thr,
            self.is_training: False,
            self.points0_ph: np.expand_dims(kpts0, 0),
            self.scores0_ph: np.expand_dims(scores0, 0),
            self.descs0_ph: np.expand_dims(descs0, 0),
            self.imgsz0_ph: img_sz0,
            self.points1_ph: np.expand_dims(kpts1, 0),
            self.scores1_ph: np.expand_dims(scores1, 0),
            self.descs1_ph: np.expand_dims(descs1, 0),
            self.imgsz1_ph: img_sz1,
        }

        return self.sess.run([self.eval_indices0, self.eval_indices1, self.eval_mscores0], feed_dict=feed_dict)


@lru_cache(1)
def get_superglue_cnn():
    f = resource.find("/maps/mrc/superglue_matcher/models/superglue.gdef")
    return SuperGlueTF(f)


def _nn_match_two_way_on_score(dist_mat):
    idx = np.argmin(dist_mat, axis=1)
    scores = dist_mat[np.arange(dist_mat.shape[0]), idx]
    idx2 = np.argmin(dist_mat, axis=0)
    keep = np.arange(len(idx)) == idx2[idx]
    idx = idx[keep]
    scores = scores[keep]
    idx1 = np.arange(dist_mat.shape[0])[keep]
    idx2 = idx
    return np.column_stack([idx1, idx2])


Keypoints = collections.namedtuple("Keypoints", ["coords", "image_shape", "scores", "descriptors"])


def make_keypoints_from_match_context(context):
    tmp = context.feature.size()
    image_shape = (tmp.height, tmp.width)
    assert (context.points.shape[1] == 2) and (context.points.shape[0] == context.confidence.shape[0])
    assert (context.descriptors.shape[1] == 256) and (context.points.shape[0] == context.descriptors.shape[0])
    idx_sort = context.confidence.argsort()[::-1]
    return Keypoints(context.points[idx_sort], image_shape, context.confidence[idx_sort], context.descriptors[idx_sort])


def superglue_match_points(pair_conf_thr, first_keypoints, second_keypoints):
    sglueCNN = get_superglue_cnn()
    indices0, indices1, _ = sglueCNN.forward(
        first_keypoints.coords, second_keypoints.coords,
        first_keypoints.image_shape, second_keypoints.image_shape,
        first_keypoints.scores, second_keypoints.scores,
        first_keypoints.descriptors, second_keypoints.descriptors,
        pair_conf_thr
    )
    indices0 = indices0[0]
    indices1 = indices1[0]
    pairs_idx = []
    for idx0 in range(indices0.shape[0]):
        idx1 = indices0[idx0]
        if idx1 == -1:
            continue
        pairs_idx += [[idx0, idx1]]
        assert idx0 == indices1[idx1], "Invalid matches0[{}] = {} but matches1[{}] = {}".format(
            idx0, indices0[idx0], idx1, indices1[idx1]
        )
    pairs_idx = np.array(pairs_idx)
    valid_pts0 = first_keypoints.coords[pairs_idx[:, 0]]
    valid_pts1 = second_keypoints.coords[pairs_idx[:, 1]]
    return valid_pts0, valid_pts1


def superglue_match_tf(pair_conf_thr, all_valid_pairs, first_context, second_context):
    def _extract_context(context):
        tmp = context.feature.size()
        image_shape = (tmp.height, tmp.width)
        assert (context.points.shape[1] == 2) and (context.points.shape[0] == context.confidence.shape[0])
        assert (context.descriptors.shape[1] == 256) and (context.points.shape[0] == context.descriptors.shape[0])

        idx_sort = context.confidence.argsort()[::-1]
        return context.points[idx_sort], image_shape, context.confidence[idx_sort], context.descriptors[idx_sort]

    def _extract_objects(objects0, objects1):
        types0 = set([x.type for x in objects0])
        types1 = set([x.type for x in objects1])
        common_types = types0.intersection(types1)
        objects_by_type = {}
        for t in common_types:
            objects_by_type[t] = [[], []]
        for idx, objs in enumerate([objects0, objects1]):
            for obj in objs:
                if obj.type in common_types:
                    objects_by_type[obj.type][idx] += [obj]
        return common_types, objects_by_type

    def area(pts):
        lines = np.hstack([pts, np.roll(pts, -1, axis=0)])
        area = 0.5*abs(sum(x1*y2-x2*y1 for x1, y1, x2, y2 in lines))
        return area

    FIND_FUND_MAT_PTS_CNT_MIN = 7
    RANSAC_REPROJ_THRESHOLD = 3.0
    RANSAC_CONFIDENCE = 0.99

    sglueCNN = get_superglue_cnn()

    common_types, objects_by_type = _extract_objects(first_context.detections, second_context.detections)

    if 0 == len(common_types):
        return []

    kpts0, image_shape0, scores0, descs0 = _extract_context(first_context)
    kpts1, image_shape1, scores1, descs1 = _extract_context(second_context)

    indices0, indices1, pairs_scores0 = sglueCNN.forward(
        kpts0, kpts1, image_shape0, image_shape1, scores0, scores1, descs0, descs1, pair_conf_thr if pair_conf_thr > 0. else 0.0
    )

    indices0 = indices0[0]
    indices1 = indices1[0]
    pairs_scores0 = pairs_scores0[0]

    pairs_idx = []
    pairs_scores = []
    for idx0 in range(indices0.shape[0]):
        idx1 = indices0[idx0]
        if idx1 == -1:
            continue
        pairs_idx += [[idx0, idx1]]
        pairs_scores += [pairs_scores0[idx0]]
        assert idx0 == indices1[idx1], "Invalid matches0[{}] = {} but matches1[{}] = {}".format(
            idx0, indices0[idx0], idx1, indices1[idx1]
        )
        assert pairs_scores0[idx0] >= pair_conf_thr, "super glue return pairs with score less than threshold"

    if len(pairs_idx) <= FIND_FUND_MAT_PTS_CNT_MIN:
        return []

    pairs_idx = np.array(pairs_idx)
    pairs_scores = np.array(pairs_scores)

    if (0. <= pair_conf_thr):
        valid_pts0 = kpts0[pairs_idx[:, 0]]
        valid_pts1 = kpts1[pairs_idx[:, 1]]

        assert valid_pts0.shape[0] == valid_pts1.shape[0], "Something really going wrong"
        F, good_pts = cv2.findFundamentalMat(valid_pts0, valid_pts1, cv2.FM_RANSAC, RANSAC_REPROJ_THRESHOLD, RANSAC_CONFIDENCE)
        good_pts= good_pts.ravel()

        good_pts_scores = np.sum(pairs_scores[good_pts==1])
        bad_pts_scores = np.sum(pairs_scores[good_pts!=1])
        good_cnt = np.sum(good_pts)
        if good_cnt < 15:
            return []

        good_pts0 = valid_pts0[good_pts!=0]
        convex_hull0 = np.squeeze(cv2.convexHull(good_pts0), axis=1)
        good_pts1 = valid_pts1[good_pts!=0]
        convex_hull1 = np.squeeze(cv2.convexHull(good_pts1), axis=1)
    else:
        idx_sort = pairs_scores.argsort()[::-1]
        valid_pts0 = kpts0[pairs_idx[idx_sort, 0]]
        valid_pts1 = kpts1[pairs_idx[idx_sort, 1]]
        pairs_scores = pairs_scores[idx_sort]

        assert valid_pts0.shape[0] == valid_pts1.shape[0], "Something really going wrong"
        if valid_pts0.shape[0] > int(-pair_conf_thr):
            valid_pts0 = valid_pts0[:int(-pair_conf_thr)]
            valid_pts1 = valid_pts1[:int(-pair_conf_thr)]
            pairs_scores = pairs_scores[:int(-pair_conf_thr)]

        F, good_pts = cv2.findFundamentalMat(valid_pts0, valid_pts1, cv2.FM_RANSAC, RANSAC_REPROJ_THRESHOLD, RANSAC_CONFIDENCE)
        good_pts= good_pts.ravel()

        good_pts_scores = np.sum(pairs_scores[good_pts==1])
        bad_pts_scores = np.sum(pairs_scores[good_pts!=1])
        good_cnt = np.sum(good_pts)

        good_pts0 = valid_pts0[good_pts!=0]
        convex_hull0 = np.squeeze(cv2.convexHull(good_pts0), axis=1)
        good_pts1 = valid_pts1[good_pts!=0]
        convex_hull1 = np.squeeze(cv2.convexHull(good_pts1), axis=1)

    feature_pairs_data = {
        'good_cnt': int(good_cnt),
        'good_pts_scores': float(good_pts_scores),
        'bad_pts_scores': float(bad_pts_scores),
        'kpts0_cnt': int(kpts0.shape[0]),
        'kpts1_cnt': int(kpts1.shape[0]),
        'convex_hull_area0': float(area(convex_hull0)),
        'convex_hull_area1': float(area(convex_hull1)),
        'image_shape0': image_shape0,
        'image_shape1': image_shape1,
        'F': F
    }

    matches = []
    for t in common_types:
        if all_valid_pairs:
            for idx0, obj0 in enumerate(objects_by_type[t][0]):
                oc0 = ((obj0.box.min_x + obj0.box.max_x) / 2, (obj0.box.min_y + obj0.box.max_y) / 2)
                for idx1, obj1 in enumerate(objects_by_type[t][1]):
                    oc1 = ((obj1.box.min_x + obj1.box.max_x) / 2, (obj1.box.min_y + obj1.box.max_y) / 2)

                    confidence = {'sampson_dist': math.sqrt(cv2.sampsonDistance((oc0[0], oc0[1], 1.), (oc1[0], oc1[1], 1.), F)),
                                  'pt_in_hull0': cv2.pointPolygonTest(convex_hull0, oc0, True),
                                  'pt_in_hull1': cv2.pointPolygonTest(convex_hull1, oc1, True)}
                    confidence.update(feature_pairs_data)
                    matches += [(obj0.id, obj1.id, confidence)]
        else:
            errs = np.ones((len(objects_by_type[t][0]), len(objects_by_type[t][1])))
            for idx0, obj0 in enumerate(objects_by_type[t][0]):
                obj_center0 = ((obj0.box.min_x + obj0.box.max_x) / 2, (obj0.box.min_y + obj0.box.max_y) / 2, 1.)
                for idx1, obj1 in enumerate(objects_by_type[t][1]):
                    obj_center1 = ((obj1.box.min_x + obj1.box.max_x) / 2, (obj1.box.min_y + obj1.box.max_y) / 2, 1.)
                    errs[idx0, idx1] = math.sqrt(cv2.sampsonDistance(obj_center0, obj_center1, F))
            pairs_idx = _nn_match_two_way_on_score(errs)
            for pair in pairs_idx:
                idx0 = pair[0]
                idx1 = pair[1]
                obj0 = objects_by_type[t][0][idx0]
                obj1 = objects_by_type[t][1][idx1]
                obj_center0 = ((obj0.box.min_x + obj0.box.max_x) / 2, (obj0.box.min_y + obj0.box.max_y) / 2)
                obj_center1 = ((obj1.box.min_x + obj1.box.max_x) / 2, (obj1.box.min_y + obj1.box.max_y) / 2)

                confidence = {'sampson_dist': errs[idx0, idx1],
                              'pt_in_hull0': cv2.pointPolygonTest(convex_hull0, obj_center0, True),
                              'pt_in_hull1': cv2.pointPolygonTest(convex_hull1, obj_center1, True)}
                confidence.update(feature_pairs_data)
                matches += [(objects_by_type[t][0][idx0].id, objects_by_type[t][1][idx1].id, confidence)]

    return matches
