#!/usr/bin/python3
# -*- coding: utf-8 -*-

import collections
import functools

import cv2 as cv
import numpy as np

from .superglue_matcher_tf import superglue_match_tf
from .nn_mutual_matcher import nn_mutual


def center(box):
    x = (box.min_x + box.max_x) / 2
    y = (box.min_y + box.max_y) / 2

    return np.float32([x, y, 1.0]).reshape(-1, 1)


def match(first_descriptors, second_descriptors, r=0.85):
    FLANN_INDEX_KDTREE = 1

    flann = cv.FlannBasedMatcher({'algorithm': FLANN_INDEX_KDTREE, 'trees': 5}, {'checks': 50})

    matches_forward = set()
    matches_backward = set()

    for first, second in flann.knnMatch(first_descriptors, second_descriptors, k=2):
        if first.distance < r * second.distance:
            matches_forward.add((first.queryIdx, first.trainIdx))

    for first, second in flann.knnMatch(second_descriptors, first_descriptors, k=2):
        if first.distance < r * second.distance:
            matches_backward.add((first.trainIdx, first.queryIdx))

    return matches_forward & matches_backward


MatchContext = collections.namedtuple('MatchContext', ['feature', 'points', 'descriptors', 'confidence', 'detections'])


def knn_match(first_context, second_context):
    mathes = match(first_context.descriptors, second_context.descriptors)

    if len(mathes) < 15:
        return []

    src_pts = np.float32([first_context.points[i] for i, j in mathes])
    dst_pts = np.float32([second_context.points[j] for i, j in mathes])

    F, mask = cv.findFundamentalMat(src_pts, dst_pts, cv.FM_RANSAC, 0.5, 0.999999)

    if F is None:
        return []

    pairs = []
    for first in first_context.detections:
        p = center(first.box)
        l = F @ p

        index = None
        minDistance = 10

        for i, second in enumerate(second_context.detections):
            if second.type != first.type:
                continue

            q = center(second.box)
            distance = np.abs(l.T @ q) / np.sqrt(l[0] ** 2 + l[1] ** 2)
            if distance < minDistance:
                minDistance = distance
                index = i

        if index is None:
            continue

        pair = first.id, second_context.detections[index].id, 1.0
        del second_context.detections[index]

        pairs.append(pair)

    return pairs


MUTUAL_PAIRS_DISTANCE_THRESHOLD = 0.7
ALL_VALID_PAIRS = False


options = {
    'knn_match': knn_match,
    'nn_mutual': functools.partial(nn_mutual, MUTUAL_PAIRS_DISTANCE_THRESHOLD, ALL_VALID_PAIRS),
    'nn_mutual_0.30_all': functools.partial(nn_mutual, 0.3, True),
    'nn_mutual_0.35_all': functools.partial(nn_mutual, 0.35, True),
    'nn_mutual_0.40_all': functools.partial(nn_mutual, 0.4, True),
    'nn_mutual_0.45_all': functools.partial(nn_mutual, 0.45, True),
    'nn_mutual_0.50_all': functools.partial(nn_mutual, 0.5, True),
    'nn_mutual_0.55_all': functools.partial(nn_mutual, 0.55, True),
    'nn_mutual_0.60_all': functools.partial(nn_mutual, 0.6, True),
    'nn_mutual_0.65_all': functools.partial(nn_mutual, 0.65, True),
    'nn_mutual_0.70_all': functools.partial(nn_mutual, 0.7, True),
    'nn_mutual_0.75_all': functools.partial(nn_mutual, 0.75, True),
    'nn_mutual_0.80_all': functools.partial(nn_mutual, 0.8, True),
    'nn_mutual_0.85_all': functools.partial(nn_mutual, 0.85, True),
    'nn_mutual_0.90_all': functools.partial(nn_mutual, 0.9, True),
    'nn_mutual_0.30': functools.partial(nn_mutual, 0.3, False),
    'nn_mutual_0.35': functools.partial(nn_mutual, 0.35, False),
    'nn_mutual_0.40': functools.partial(nn_mutual, 0.4, False),
    'nn_mutual_0.45': functools.partial(nn_mutual, 0.45, False),
    'nn_mutual_0.50': functools.partial(nn_mutual, 0.5, False),
    'nn_mutual_0.55': functools.partial(nn_mutual, 0.55, False),
    'nn_mutual_0.60': functools.partial(nn_mutual, 0.6, False),
    'nn_mutual_0.65': functools.partial(nn_mutual, 0.65, False),
    'nn_mutual_0.70': functools.partial(nn_mutual, 0.7, False),
    'nn_mutual_0.75': functools.partial(nn_mutual, 0.75, False),
    'nn_mutual_0.80': functools.partial(nn_mutual, 0.8, False),
    'nn_mutual_0.85': functools.partial(nn_mutual, 0.85, False),
    'nn_mutual_0.90': functools.partial(nn_mutual, 0.9, False),
    'superglue_tf': functools.partial(superglue_match_tf, 0.45, True),
    'superglue_tf_0.30_all': functools.partial(superglue_match_tf, 0.3, True),
    'superglue_tf_0.35_all': functools.partial(superglue_match_tf, 0.35, True),
    'superglue_tf_0.40_all': functools.partial(superglue_match_tf, 0.4, True),
    'superglue_tf_0.45_all': functools.partial(superglue_match_tf, 0.45, True),
    'superglue_tf_0.50_all': functools.partial(superglue_match_tf, 0.5, True),
    'superglue_tf_0.55_all': functools.partial(superglue_match_tf, 0.55, True),
    'superglue_tf_0.60_all': functools.partial(superglue_match_tf, 0.6, True),
    'superglue_tf_0.65_all': functools.partial(superglue_match_tf, 0.65, True),
    'superglue_tf_0.70_all': functools.partial(superglue_match_tf, 0.7, True),
    'superglue_tf_0.75_all': functools.partial(superglue_match_tf, 0.75, True),
    'superglue_tf_0.80_all': functools.partial(superglue_match_tf, 0.8, True),
    'superglue_tf_0.85_all': functools.partial(superglue_match_tf, 0.85, True),
    'superglue_tf_0.90_all': functools.partial(superglue_match_tf, 0.9, True),
    'superglue_tf_0.30': functools.partial(superglue_match_tf, 0.3, False),
    'superglue_tf_0.35': functools.partial(superglue_match_tf, 0.35, False),
    'superglue_tf_0.40': functools.partial(superglue_match_tf, 0.4, False),
    'superglue_tf_0.45': functools.partial(superglue_match_tf, 0.45, False),
    'superglue_tf_0.50': functools.partial(superglue_match_tf, 0.5, False),
    'superglue_tf_0.55': functools.partial(superglue_match_tf, 0.55, False),
    'superglue_tf_0.60': functools.partial(superglue_match_tf, 0.6, False),
    'superglue_tf_0.65': functools.partial(superglue_match_tf, 0.65, False),
    'superglue_tf_0.70': functools.partial(superglue_match_tf, 0.7, False),
    'superglue_tf_0.75': functools.partial(superglue_match_tf, 0.75, False),
    'superglue_tf_0.80': functools.partial(superglue_match_tf, 0.8, False),
    'superglue_tf_0.85': functools.partial(superglue_match_tf, 0.85, False),
    'superglue_tf_0.90': functools.partial(superglue_match_tf, 0.9, False),

    'superglue_tf_m1000_all': functools.partial(superglue_match_tf, -1000, True),
}
