#!/usr/bin/python3
# -*- coding: utf-8 -*-

import yt.wrapper as yt
import argparse
import json
import copy
import math
import cv2
import itertools
import numpy as np

from maps.wikimap.mapspro.services.mrc.eye.experiments.signs_map.pylibs.detection import Detection

from .spectral import spectral


def load_feature_object_ids(object_file, min_box_size):
    result = []

    for item in json.load(object_file)["features_objects"]:
        feature_id = item['feature_id']
        for obj in item['objects']:
            detection = Detection.from_dict(obj)
            if detection.box.max_size() >= min_box_size:
                feature_object_id = feature_id, detection.id
                result.append(feature_object_id)

    return result


def load_feature_objects(object_file, min_sz):
    result = {}
    json_data = json.load(object_file)["features_objects"]
    for item in json_data:
        fid = item['feature_id']
        for obj in item['objects']:
            obj_id = obj['object_id']
            bbox = obj['bbox']
            width = abs(bbox[1][0] - bbox[0][0])
            height = abs(bbox[1][1] - bbox[0][1])
            if (min_sz <= width or min_sz <= height):
                result[(fid, obj_id)] = bbox
    return result


def good_pts_div_sampson_func(match):
    MAX_PTS = 2500
    MIN_SAMPSON_DIST = 1e-2
    FUND_MAT_GOOD_PTS_CNT_MIN = 15

    if match['good_cnt'] < FUND_MAT_GOOD_PTS_CNT_MIN:
        return 0.
    h0 = 1. if match["pt_in_hull0"] >= -math.e else (1. + math.log(abs(match["pt_in_hull0"])))
    h1 = 1. if match["pt_in_hull1"] >= -math.e else (1. + math.log(abs(match["pt_in_hull1"])))
    sampson = max(match["sampson_dist"], MIN_SAMPSON_DIST)
    return (match['good_cnt'] / MAX_PTS) * (MIN_SAMPSON_DIST / sampson) * (1.0 / h0 / h1)


def load_matches(match_file, feature_object_ids, conf_fn):
    matches = {}

    for features_pair in json.load(match_file)['features_pairs']:
        feature_id_1 = features_pair['feature_id_1']
        feature_id_2 = features_pair['feature_id_2']

        for match in features_pair['matches']:
            object_id_1 = match['object_id_1']
            object_id_2 = match['object_id_2']
            confidence = conf_fn(match['confidence'])
            if (confidence is None):
                continue

            feature_object_1 = (feature_id_1, object_id_1)
            feature_object_2 = (feature_id_2, object_id_2)

            if (feature_object_1 not in feature_object_ids) or (feature_object_2 not in feature_object_ids):
                continue

            pair = feature_object_1, feature_object_2
            matches[pair] = confidence

    return matches


def dump(dump_file, clusters):
    def as_dict(feature_object_id):
        return {'feature_id': feature_object_id[0], 'object_id': feature_object_id[1]}

    data = {
        'clusters': [
            {'cluster_id': idx, 'objects': list(map(as_dict, cluster))} for idx, cluster in enumerate(clusters, 0)
        ]
    }

    json.dump(data, dump_file, indent=4)


def find_cluster_by_foid(test_clusters, foid):
    for idx, cluster in enumerate(test_clusters):
        if (foid in cluster['foids']):
            return idx
    return -1


def expand_clusters_by_unused(from_clusters, foids, deep_copy=False):
    result = copy.deepcopy(from_clusters) if deep_copy else from_clusters

    for foid in foids:
        cl_idx = find_cluster_by_foid(result, foid)
        if (-1 == cl_idx):
            result.append({'foids': set([foid]), 'fids': set([foid[0]])})
    return result


def greedy(matches, feature_objects, conf_thr):
    feature_object_ids = list(feature_objects.keys())

    clusters = []
    for foids, conf in sorted(matches.items(), key=lambda item: item[1], reverse=True):
        if (conf < conf_thr):
            break
        foid0 = foids[0]
        foid1 = foids[1]
        fid0, oid0 = foid0
        fid1, oid1 = foid1
        assert fid0 != fid1
        cl_idx0 = find_cluster_by_foid(clusters, foid0)
        cl_idx1 = find_cluster_by_foid(clusters, foid1)
        if (-1 == cl_idx0 and -1 == cl_idx1):
            clusters.append({'foids': set([foid0, foid1]), 'fids': set([fid0, fid1])})
        elif (-1 == cl_idx0 and -1 != cl_idx1):
            cluster = clusters[cl_idx1]
            if (fid0 in cluster['fids']):
                continue
            cluster['foids'].add(foid0)
            cluster['fids'].add(fid0)
        elif (-1 != cl_idx0 and -1 == cl_idx1):
            cluster = clusters[cl_idx0]
            if (fid1 in cluster['fids']):
                continue
            cluster['foids'].add(foid1)
            cluster['fids'].add(fid1)
        else:
            if cl_idx0 == cl_idx1:
                continue
            cluster0 = clusters[cl_idx0]
            cluster1 = clusters[cl_idx1]
            if (0 != len(cluster0['fids'].intersection(cluster1['fids']))):
                continue
            cluster0['fids'] = cluster0['fids'].union(cluster1['fids'])
            cluster0['foids'] = cluster0['foids'].union(cluster1['foids'])
            clusters.remove(cluster1)

    clusters = expand_clusters_by_unused(clusters, feature_object_ids)
    return [cluster['foids'] for cluster in clusters]


def clusterize_by_triangles(matches, feature_objects, conf_thr):
    TRIANGLES_GOOD_CNT_THR = 40
    MATCHES_GOOD_CNT_THR = 40
    SAMPSON_DIST_THR = 4

    def bbox_center(bbox):
        return (bbox[0][0] + bbox[1][0]) / 2., (bbox[0][1] + bbox[1][1]) / 2.

    def lines_intersection(line0, line1):
        det = line0[0] * line1[1] - line0[1] * line1[0]
        if abs(det) < 1e-3:
            return None, None
        x = (line0[1] * line1[2] - line1[1] * line0[2]) / det
        y = (line1[0] * line0[2] - line0[0] * line1[2]) / det
        return (x, y), det

    def pt_in_bbox(pt, bbox):
        for i in range(2):
            if pt[i] < min(bbox[0][i], bbox[1][i]) or pt[i] > max(bbox[0][i], bbox[1][i]):
                return False
        return True

    def get_match(foid0, foid1):
        if (foid0, foid1) in matches:
            return matches[(foid0, foid1)]
        if (foid1, foid0) in matches:
            return matches[(foid1, foid0)]
        return None

    def split_components(matches, feature_object_ids):
        components = []
        for foids, conf in matches.items():
            foid0 = foids[0]
            foid1 = foids[1]
            cl_idx0 = find_cluster_by_foid(components, foid0)
            cl_idx1 = find_cluster_by_foid(components, foid1)
            if (-1 == cl_idx0 and -1 == cl_idx1):
                components.append({'foids': set([foid0, foid1]), 'matches': {foids: conf}})
            elif (-1 == cl_idx0 and -1 != cl_idx1):
                component = components[cl_idx1]
                component['foids'].add(foid0)
                component['matches'][foids] = conf
            elif (-1 != cl_idx0 and -1 == cl_idx1):
                component = components[cl_idx0]
                component['foids'].add(foid1)
                component['matches'][foids] = conf
            else:
                if (cl_idx0 == cl_idx1):
                    components[cl_idx0]['matches'][foids] = conf
                    continue
                component0 = components[cl_idx0]
                component1 = components[cl_idx1]
                component0['foids'] = component0['foids'].union(component1['foids'])
                component0['matches'].update(component1['matches'])
                component0['matches'][foids] = conf
                components.remove(component1)

        foids_rest = set(feature_object_ids)
        for comp in components:
            foids_rest = foids_rest.difference(comp['foids'])
            comp['foids'] = list(comp['foids'])

        for foid in foids_rest:
            components.append({'foids': [foid], 'matches': {}})

        return components

    def extract_triangles(matches, feature_objects):
        def get_match(foid0, foid1):
            F01, F10, good_cnt = None, None, None
            if (foid0, foid1) in matches:
                match = matches[(foid0, foid1)]
                F01 = np.array(match['F'])
                F10 = np.transpose(F01)
                good_cnt = match['good_cnt']
            if (foid1, foid0) in matches:
                match = matches[(foid1, foid0)]
                F10 = np.array(match['F'])
                F01 = np.transpose(F10)
                good_cnt = match['good_cnt']
            return F01, F10, good_cnt

        triangles = {}
        for foid0, foid1, foid2 in itertools.combinations(feature_objects, 3):
            F01, F10, good_cnt01 = get_match(foid0, foid1)
            if (F01 is None):
                continue
            F02, F20, good_cnt02 = get_match(foid0, foid2)
            if (F02 is None):
                continue
            F12, F21, good_cnt12 = get_match(foid1, foid2)
            if (F12 is None):
                continue
            triangles[(foid0, foid1, foid2)] = {'F01': F01, 'F10': F10, 'F02': F02, 'F20': F20, 'F12': F12, 'F21': F21,
                                                'good_cnt01': good_cnt01, 'good_cnt02': good_cnt02, 'good_cnt12': good_cnt12,
                                                'min_good_cnt': min(good_cnt01, good_cnt02, good_cnt12)}
        return triangles

    def by_triangles(triangles):
        clusters = []
        for foids, data in sorted(triangles.items(), key=lambda item: item[1]['min_good_cnt'], reverse=True):
            if (data['min_good_cnt'] < TRIANGLES_GOOD_CNT_THR):
                break
            assert foids[0][0] != foids[1][0] and foids[0][0] != foids[2][0] and foids[1][0] != foids[2][0]

            bboxes = []
            centers = []
            for i in range(3):
                bbox = feature_objects[foids[i]]
                cx, cy = bbox_center(bbox)
                centers.append(np.array([[[cx, cy]]]))
                xmin = min([bbox[0][0], bbox[1][0]])
                xmax = max([bbox[0][0], bbox[1][0]])
                ymin = min([bbox[0][1], bbox[1][1]])
                ymax = max([bbox[0][1], bbox[1][1]])
                bbox = [[xmin, ymin], [xmax, ymax]]
                bboxes.append(bbox)

            F01 = np.array(data['F01'])
            F02 = np.array(data['F02'])
            F12 = np.array(data['F12'])

            # linesIJ - I - line on image, J - point from image
            lines10 = cv2.computeCorrespondEpilines(centers[0], 1, F01)
            lines20 = cv2.computeCorrespondEpilines(centers[0], 1, F02)
            lines01 = cv2.computeCorrespondEpilines(centers[1], 2, F01)
            lines21 = cv2.computeCorrespondEpilines(centers[1], 1, F12)
            lines02 = cv2.computeCorrespondEpilines(centers[2], 2, F02)
            lines12 = cv2.computeCorrespondEpilines(centers[2], 2, F12)

            inters = []
            inters.append(lines_intersection(lines01.ravel(), lines02.ravel()))
            inters.append(lines_intersection(lines10.ravel(), lines12.ravel()))
            inters.append(lines_intersection(lines20.ravel(), lines21.ravel()))

            cnt = 0
            for i in range(3):
                if inters[i][0] is None:
                    continue
                elif pt_in_bbox(inters[i][0], bboxes[i]):
                    cnt += 1

            if cnt < 3:
                continue

            cl_idx = []
            for i in range(3):
                cl_idx.append(find_cluster_by_foid(clusters, foids[i]))

            if (-1 == cl_idx[0]) and (-1 == cl_idx[1]) and (-1 == cl_idx[2]):
                clusters.append({'foids': set([foids[0], foids[1], foids[2]]), 'fids': set([foids[0][0], foids[1][0], foids[2][0]])})
            elif (-1 != cl_idx[0]) and (-1 == cl_idx[1]) and (-1 == cl_idx[2]):
                cluster = clusters[cl_idx[0]]
                if (foids[1][0] in cluster['fids']) or (foids[2][0] in cluster['fids']):
                    continue
                cluster['foids'].add(foids[1])
                cluster['fids'].add(foids[1][0])
                cluster['foids'].add(foids[2])
                cluster['fids'].add(foids[2][0])
            elif (-1 == cl_idx[0]) and (-1 != cl_idx[1]) and (-1 == cl_idx[2]):
                cluster = clusters[cl_idx[1]]
                if (foids[0][0] in cluster['fids']) or (foids[2][0] in cluster['fids']):
                    continue
                cluster['foids'].add(foids[0])
                cluster['fids'].add(foids[0][0])
                cluster['foids'].add(foids[2])
                cluster['fids'].add(foids[2][0])
            elif (-1 == cl_idx[0]) and (-1 == cl_idx[1]) and (-1 != cl_idx[2]):
                cluster = clusters[cl_idx[2]]
                if (foids[0][0] in cluster['fids']) or (foids[1][0] in cluster['fids']):
                    continue
                cluster['foids'].add(foids[0])
                cluster['fids'].add(foids[0][0])
                cluster['foids'].add(foids[1])
                cluster['fids'].add(foids[1][0])
            elif (-1 == cl_idx[0]) and (-1 != cl_idx[1]) and (-1 != cl_idx[2]):
                if cl_idx[1] != cl_idx[2]:
                    cluster1 = clusters[cl_idx[1]]
                    cluster2 = clusters[cl_idx[2]]
                    if (0 != len(cluster1['fids'].intersection(cluster2['fids']))) or (foids[0][0] in cluster1['fids']) or (foids[0][0] in cluster2['fids']):
                        continue
                    cluster1['foids'] = cluster1['foids'].union(cluster2['foids'])
                    cluster1['fids'] = cluster1['fids'].union(cluster2['fids'])
                    cluster1['foids'].add(foids[0])
                    cluster1['fids'].add(foids[0][0])
                    clusters.remove(cluster2)
                else:
                    cluster = clusters[cl_idx[1]]
                    if (foids[0][0] in cluster['fids']) or (foids[0][0] in cluster['fids']):
                        continue
                    cluster['foids'].add(foids[0])
                    cluster['fids'].add(foids[0][0])
            elif (-1 != cl_idx[0]) and (-1 == cl_idx[1]) and (-1 != cl_idx[2]):
                if cl_idx[0] != cl_idx[2]:
                    cluster0 = clusters[cl_idx[0]]
                    cluster2 = clusters[cl_idx[2]]
                    if (0 != len(cluster0['fids'].intersection(cluster2['fids']))) or (foids[1][0] in cluster0['fids']) or (foids[1][0] in cluster2['fids']):
                        continue
                    cluster0['foids'] = cluster0['foids'].union(cluster2['foids'])
                    cluster0['fids'] = cluster0['fids'].union(cluster2['fids'])
                    cluster0['foids'].add(foids[1])
                    cluster0['fids'].add(foids[1][0])
                    clusters.remove(cluster2)
                else:
                    cluster = clusters[cl_idx[0]]
                    if (foids[1][0] in cluster['fids']) or (foids[1][0] in cluster['fids']):
                        continue
                    cluster['foids'].add(foids[1])
                    cluster['fids'].add(foids[1][0])
            elif (-1 != cl_idx[0]) and (-1 != cl_idx[1]) and (-1 == cl_idx[2]):
                if cl_idx[0] != cl_idx[1]:
                    cluster0 = clusters[cl_idx[0]]
                    cluster1 = clusters[cl_idx[1]]
                    if (0 != len(cluster0['fids'].intersection(cluster1['fids']))) or (foids[2][0] in cluster0['fids']) or (foids[2][0] in cluster1['fids']):
                        continue
                    cluster0['foids'] = cluster0['foids'].union(cluster1['foids'])
                    cluster0['fids'] = cluster0['fids'].union(cluster1['fids'])
                    cluster0['foids'].add(foids[2])
                    cluster0['fids'].add(foids[2][0])
                    clusters.remove(cluster1)
                else:
                    cluster = clusters[cl_idx[0]]
                    if (foids[2][0] in cluster['fids']) or (foids[2][0] in cluster['fids']):
                        continue
                    cluster['foids'].add(foids[2])
                    cluster['fids'].add(foids[2][0])
            elif (-1 != cl_idx[0]) and (-1 != cl_idx[1]) and (-1 != cl_idx[2]):
                if cl_idx[0] == cl_idx[1] and cl_idx[0] == cl_idx[2]:
                    continue
                cluster0 = clusters[cl_idx[0]]
                cluster1 = clusters[cl_idx[1]]
                cluster2 = clusters[cl_idx[2]]
                if (cl_idx[0] != cl_idx[1]) and (0 != len(cluster0['fids'].intersection(cluster1['fids']))):
                    continue
                if (cl_idx[0] != cl_idx[2]) and (0 != len(cluster0['fids'].intersection(cluster2['fids']))):
                    continue
                if (cl_idx[1] != cl_idx[2]) and (0 != len(cluster1['fids'].intersection(cluster2['fids']))):
                    continue

                if cl_idx[0] != cl_idx[1]:
                    cluster0['foids'] = cluster0['foids'].union(cluster1['foids'])
                    cluster0['fids'] = cluster0['fids'].union(cluster1['fids'])
                    clusters.remove(cluster1)
                if (cl_idx[0] != cl_idx[2]) and (cl_idx[1] != cl_idx[2]):
                    cluster0['foids'] = cluster0['foids'].union(cluster2['foids'])
                    cluster0['fids'] = cluster0['fids'].union(cluster2['fids'])
                    clusters.remove(cluster2)
        return clusters

    def by_matches(matches, clusters):
        for foids, data in sorted(matches.items(), key=lambda item: item[1]['good_cnt'], reverse=True):
            assert foids[0][0] != foids[1][0]
            if data['good_cnt'] < MATCHES_GOOD_CNT_THR:
                break
            if data['sampson_dist'] > SAMPSON_DIST_THR:
                continue

            cl_idx0 = find_cluster_by_foid(clusters, foids[0])
            cl_idx1 = find_cluster_by_foid(clusters, foids[1])
            if (-1 == cl_idx0 and -1 == cl_idx1):
                clusters.append({'foids': set([foids[0], foids[1]]), 'fids': set([foids[0][0], foids[1][0]])})
            elif (-1 == cl_idx0 and -1 != cl_idx1):
                cluster = clusters[cl_idx1]
                if (foids[0][0] in cluster['fids']):
                    continue
                cluster['foids'].add(foids[0])
                cluster['fids'].add(foids[0][0])
            elif (-1 != cl_idx0 and -1 == cl_idx1):
                cluster = clusters[cl_idx0]
                if (foids[1][0] in cluster['fids']):
                    continue
                cluster['foids'].add(foids[1])
                cluster['fids'].add(foids[1][0])
            else:
                if cl_idx0 == cl_idx1:
                    continue
                cluster0 = clusters[cl_idx0]
                cluster1 = clusters[cl_idx1]
                if (0 != len(cluster0['fids'].intersection(cluster1['fids']))):
                    continue
                cluster0['foids'] = cluster0['foids'].union(cluster1['foids'])
                cluster0['fids'] = cluster0['fids'].union(cluster1['fids'])
                clusters.remove(cluster1)
        return clusters

    def clusterize_component(component):
        comp_feature_objects = {foid: feature_objects[foid] for foid in component['foids']}
        clusters = []
        clusters = by_triangles(extract_triangles(component['matches'], comp_feature_objects))
        clusters = by_matches(component['matches'], clusters)
        return clusters

    clusters = []
    components = split_components(matches, feature_objects)
    for component in components:
        clusters += clusterize_component(component)
    feature_object_ids = [foid for foid in feature_objects.keys()]
    clusters = expand_clusters_by_unused(clusters, feature_object_ids)
    return [cluster['foids'] for cluster in clusters]


options = {
    'greedy': greedy,
    'spectral': spectral,
    'by_triangles': clusterize_by_triangles,
}

conf_fns = {
    'none': lambda item: item,
    'good_pts_div_sampson': good_pts_div_sampson_func,
}


def get_args():
    parser = argparse.ArgumentParser(description='Make detection clustering')

    parser.add_argument('--yt-proxy', dest='yt_proxy', default='hahn', help='name of yt proxy')

    parser.add_argument('--match-file', type=str, required=True, help='input json file with matches')

    parser.add_argument(
        '--object-file', dest='object_file', type=str, required=True, help='input yt json file with objects'
    )

    parser.add_argument(
        '--cluster-file', dest='cluster_file', type=str, required=True, help='output json file with clusters'
    )

    parser.add_argument(
        '--min-confidence', dest='min_confidence', type=float, default=0, help='min matching confidence'
    )

    parser.add_argument(
        '--min-box-size', dest='min_box_size', type=int, default=0, help='min matching box size in pixels'
    )

    parser.add_argument('--type', type=str, choices=list(options), help='cluster type', required=True)

    parser.add_argument('--conf-fn', dest='conf_fn', type=str, choices=list(conf_fns), help='confidence function type', default='none')

    return parser.parse_args()


def main():
    args = get_args()

    clusterize = options[args.type]

    yt_client = yt.YtClient(proxy=args.yt_proxy)
    feature_objects = load_feature_objects(yt_client.read_file(args.object_file), args.min_box_size)

    with open(args.match_file) as lines:
        matches = load_matches(lines, feature_objects, conf_fns[args.conf_fn])

    clusters = clusterize(matches, feature_objects, args.min_confidence)

    with open(args.cluster_file, 'w') as out:
        dump(out, clusters)


if __name__ == "__main__":
    main()
