import json
import argparse


def _inters_area(bbox1, bbox2):
    """Считает площадь пересечения двух bbox-ов, заданных координатами вершин:
        [[xmin, ymin], [xmax, ymax]]
    """
    xmin1 = min(bbox1[0][0], bbox1[1][0])
    xmin2 = min(bbox2[0][0], bbox2[1][0])
    xmini = max(xmin1, xmin2)

    ymin1 = min(bbox1[0][1], bbox1[1][1])
    ymin2 = min(bbox2[0][1], bbox2[1][1])
    ymini = max(ymin1, ymin2)

    xmax1 = max(bbox1[0][0], bbox1[1][0])
    xmax2 = max(bbox2[0][0], bbox2[1][0])
    xmaxi = min(xmax1, xmax2)

    ymax1 = max(bbox1[0][1], bbox1[1][1])
    ymax2 = max(bbox2[0][1], bbox2[1][1])
    ymaxi = min(ymax1, ymax2)

    w = xmaxi - xmini
    h = ymaxi - ymini
    if (w <= 0 or h <= 0):
        return 0.0
    return float(w * h)


def _area(bbox):
    """Считает площадь bbox-а заданного координатами вершин:
        [[xmin, ymin], [xmax, ymax]]
    """
    w = abs(bbox[1][0] - bbox[0][0])
    h = abs(bbox[1][1] - bbox[0][1])
    return float(w * h)


def _iou(bbox1, bbox2):
    i = _inters_area(bbox1, bbox2)
    u = _area(bbox1) + _area(bbox2) - i
    if (u < 0.001):
        return 0.
    return i / u


def _load_features_objects(json_path, valid_classes_list):
    """Загруает объекты в виде словаря.
        ключ - feature_id
        значение - список объектов на фотографии
    """
    with open(json_path, 'r', encoding='utf-8') as f:
        json_data = json.load(f)
    features_objects = {}
    for item in json_data['features_objects']:
        fid = item['feature_id']
        objects = item['objects']
        if valid_classes_list is not None:
            objects = [obj for obj in objects if obj['type'] in valid_classes_list]
        features_objects[fid] = objects
    return features_objects


def _load_clusters(json_path, uid_by_feature_object_id):
    """Загружаем кластеры.
        В файле кластеров у нас каждый кластер описывается списком объектов,
        каждый объект это (feature_id, objects_id)
            feature_id - id фотографии
            object_id  - id объекта уникальный внутри фотографии.
        Мы сматчили объекты из GT и тестового наборов. И для парных
        объектов задали уникальный uid. В функцию мы передаём путь к
        файлу с кластерами и uid_by_feature_object_id - словарь
        сопоставаляющий (feature_id, object_id) -> uid объекта
        Если некоторой (feature_id, object_id) нет в словаре, значит
        для нее не нашлось пары, и в оценки кластеризации этот объект
        участвовать не будет
        Возвращает два словаря.
            1. cluster_id_by_uid
                ключ - uid объекта на фотографии
                значение - id кластера
            2. cluster_by_cluster_id
                ключ - id кластера
                значение - список uid-ов объектов в кластере
    """
    with open(json_path, 'r', encoding='utf-8') as f:
        json_data = json.load(f)
    cluster_id_by_uid = {}
    cluster_by_cluster_id = {}
    for cluster in json_data['clusters']:
        cluster_id = cluster['cluster_id']
        cluster_by_cluster_id[cluster_id] = []
        for fobj in cluster['objects']:
            uid = uid_by_feature_object_id.get((fobj['feature_id'], fobj['object_id']))
            if (uid is None):
                continue
            assert uid not in cluster_id_by_uid,\
                "One object in more than one cluster: feature_id = " + str(fobj['feature_id']) + ", object_id = " + str(fobj['object_id'])
            cluster_id_by_uid[uid] = cluster_id
            cluster_by_cluster_id[cluster_id] += [uid]
    return cluster_id_by_uid, cluster_by_cluster_id


def _map_objects(gt_objects, test_objects, iou_thr):
    """Возвращает карту соответствий Id объектов в GT и в тестовом наборе.
        objects_id_gt_to_test - словарь, ключ id объекта в GT, значение id объекта в тестовом наборе
    """
    paired = []
    for idx_gt, obj_gt in enumerate(gt_objects):
        for idx_tst, obj_tst in enumerate(test_objects):
            if (obj_gt['type'] != obj_tst['type']):
                continue
            iou_val = _iou(obj_gt['bbox'], obj_tst['bbox'])
            if (iou_thr < iou_val):
                paired += [(iou_val, idx_gt, idx_tst)]

    paired.sort(key=lambda x: x[0], reverse=True)

    objects_id_gt_to_test = {}
    objects_id_test_used = set()
    for pair in paired:
        id_gt = gt_objects[pair[1]]['object_id']
        id_tst = test_objects[pair[2]]['object_id']
        if (id_gt in objects_id_gt_to_test):
            continue
        if (id_tst in objects_id_test_used):
            continue
        objects_id_gt_to_test[id_gt] = id_tst
        objects_id_test_used.add(id_tst)

    return objects_id_gt_to_test


def _get_filtered_feature_objects_id(features_objects, min_sz):
    feature_object_valid_ids = set()
    for fid, objects in features_objects.items():
        for obj in objects:
            bbox = obj['bbox']
            width = abs(bbox[1][0] - bbox[0][0])
            height = abs(bbox[1][1] - bbox[0][1])
            if (min_sz > width and min_sz > height):
                continue
            feature_object_valid_ids.add((fid, obj['object_id']))
    return feature_object_valid_ids


def _calculate_detector_quality(objects_map_by_feature_ids, gt_objects_by_fids, test_objects_by_fids, min_sz):
    """Вычисляет качество детектора
            objects_map_by_feature_ids - словарь с картами соответствия object_id в GT и в тестовом наборе
                ключ - feature_id
                значение - словарь
                    ключ - object_id в GT наборе,
                    значение - object_id в тестовом наборе,
            gt_objects_by_fids - словарь объектов из GT датасете
                ключ - feature_id
                значение - объекты нв feature_id
            test_objects_by_fids - словарь объектов из тестового датасете
                ключ - feature_id
                значение - объекты нв feature_id
    """
    true_positives = 0
    false_positives = 0
    false_negatives = 0
    for fid, object_map in objects_map_by_feature_ids.items():
        current_tp = len(object_map)
        true_positives += current_tp
        false_positives += len(test_objects_by_fids.get(fid, [])) - current_tp

        gt_objects = gt_objects_by_fids.get(fid, [])
        current_fn = 0
        if (min_sz is None):
            current_fn = len(gt_objects) - current_tp
        else:
            # мы не планируем штрафовать за объекты малого размера
            for obj_gt in gt_objects:
                bbox = obj_gt['bbox']
                width = abs(bbox[1][0] - bbox[0][0])
                height = abs(bbox[1][1] - bbox[0][1])
                if (min_sz > width and min_sz > height):
                    continue
                if (obj_gt['object_id'] in object_map):
                    continue
                current_fn += 1
        false_negatives += current_fn

    if (0 == true_positives + false_positives):
        print("Set of test objects is empty")
        return 0., 0.

    assert 0 < true_positives + false_negatives, "Set of relevant objects is empty"
    precision = true_positives / (true_positives + false_positives)
    recall = true_positives / (true_positives + false_negatives)
    return precision, recall


def _calculate_cluster_quality(cluster_id_by_uid_1, cluster_by_cluster_id_1, cluster_id_by_uid_2):
    quality = 0
    for uidi, cluster_id in cluster_id_by_uid_1.items():
        cluster = cluster_by_cluster_id_1[cluster_id]
        if (1 == len(cluster)):
            quality += 1
            continue
        cluster_quality = 0
        for uidj in cluster:
            if (uidi == uidj):
                continue
            if (cluster_id_by_uid_2[uidi] == cluster_id_by_uid_2[uidj]):
                cluster_quality += 1
        quality += cluster_quality / (len(cluster) - 1)
    return quality / len(cluster_id_by_uid_1)


def _calculate_clusterization_quality(clusters_gt_path, clusters_test_path, objects_map_by_feature_ids, gt_feature_object_valid_ids, test_feature_object_valid_ids):
    unique_id = 0
    unique_id_by_gt_feature_objects = {}
    unique_id_by_test_feature_objects = {}
    for fid, objects_map in objects_map_by_feature_ids.items():
        for objects_id_gt, objects_id_test in objects_map.items():
            if (((fid, objects_id_gt) in gt_feature_object_valid_ids) and
               ((fid, objects_id_test) in test_feature_object_valid_ids)):
                unique_id_by_gt_feature_objects[(fid, objects_id_gt)] = unique_id
                unique_id_by_test_feature_objects[(fid, objects_id_test)] = unique_id
                unique_id += 1

    objects_cnt = unique_id
    assert 0 < objects_cnt, "There are no matched objects in GT and test data"

    gt_cluster_id_by_uid, gt_cluster_by_cluster_id = _load_clusters(clusters_gt_path, unique_id_by_gt_feature_objects)
    if (len(gt_cluster_id_by_uid) != objects_cnt):
        print("Some objects don't belong to any GT clusters.")

    tst_cluster_id_by_uid, tst_cluster_by_cluster_id = _load_clusters(clusters_test_path, unique_id_by_test_feature_objects)
    if (len(tst_cluster_id_by_uid) != objects_cnt):
        print("Some objects don't belong to any test clusters.")

    precision = _calculate_cluster_quality(tst_cluster_id_by_uid, tst_cluster_by_cluster_id, gt_cluster_id_by_uid)
    recall = _calculate_cluster_quality(gt_cluster_id_by_uid, gt_cluster_by_cluster_id, tst_cluster_id_by_uid)
    return precision, recall


def calculate_quality(objects_gt_path, objects_test_path, clusters_gt_path, clusters_test_path, iou_thr, min_sz, valid_classes_path):
    valid_classes_list = None
    if valid_classes_path is not None:
        valid_classes_list = open(valid_classes_path, 'r').read().splitlines()

    gt_objects_by_fids = _load_features_objects(objects_gt_path, valid_classes_list)
    if (objects_test_path is not None):
        test_objects_by_fids = _load_features_objects(objects_test_path, valid_classes_list)
        fids = set(gt_objects_by_fids.keys()).union(test_objects_by_fids.keys())

        objects_map_by_feature_ids = {}
        for fid in fids:
            gt_objects = gt_objects_by_fids.get(fid, [])
            test_objects = test_objects_by_fids.get(fid, [])
            objects_map_by_feature_ids[fid] = _map_objects(gt_objects, test_objects, iou_thr)

        precision, recall = _calculate_detector_quality(objects_map_by_feature_ids, gt_objects_by_fids, test_objects_by_fids, min_sz)
        print("Detection:")
        print("  precision: = {:.2f}%". format(100. * precision))
        print("  recall:    = {:.2f}%". format(100. * recall))
    else:
        test_objects_by_fids = gt_objects_by_fids
        objects_map_by_feature_ids = {}
        # мы не передали файл с тестовыми объектами, значит не хотим оценивать качество детектирования
        # и значит objectId для GT и для Test кластеризации совпадают
        # строим тривиальный map, чтобы не усложнять код оценки кластеризации двумя веткаим
        for fid, gt_objects in gt_objects_by_fids.items():
            map_obj_id = {}
            for obj in gt_objects:
                obj_id = obj['object_id']
                map_obj_id[obj_id] = obj_id
            objects_map_by_feature_ids[fid] = map_obj_id

    if (clusters_gt_path is None or clusters_test_path is None):
        return
    gt_feature_object_valid_ids = _get_filtered_feature_objects_id(gt_objects_by_fids, min_sz)
    test_feature_object_valid_ids = _get_filtered_feature_objects_id(test_objects_by_fids, min_sz)
    precision, recall = _calculate_clusterization_quality(clusters_gt_path, clusters_test_path, objects_map_by_feature_ids, gt_feature_object_valid_ids, test_feature_object_valid_ids)
    print("Clusterization:")
    print("  precision: = {:.2f}%". format(100. * precision))
    print("  recall:    = {:.2f}%". format(100. * recall))
    print("  F1 score:  = {:.2f}%".format(100. * 2 * precision * recall / (precision + recall)))


def main():
    parser = argparse.ArgumentParser(description="Calculate quality values")

    parser.add_argument('--gt-objects', required=True,
                        help='Path to input json file with ground truth')
    parser.add_argument('--test-objects', default=None,
                        help='Path to input json file with test data')
    parser.add_argument('--iou-thresh', default=0.5, type=float,
                        help='Threshold value for IoU of GT and test objects (default: 0.5)')
    parser.add_argument('--min-bbox-size', default=30, type=int,
                        help='Minimal size of GT object (default: 30 px)')
    parser.add_argument('--gt-clusters', default=None,
                        help='Path to input json file with ground truth clusters')
    parser.add_argument('--test-clusters', default=None,
                        help='Path to input json file with test data clusters')
    parser.add_argument('--valid-classes-list', default=None,
                        help='Path to text file with valid traffic sign classes, one name by line')

    args = parser.parse_args()
    calculate_quality(args.gt_objects, args.test_objects, args.gt_clusters, args.test_clusters, args.iou_thresh, args.min_bbox_size, args.valid_classes_list)

if __name__ == '__main__':
    main()
