# -*- coding: utf-8 -*-

import json
import argparse
from collections import namedtuple


class Rect:
    def __init__(self, x1, y1, x2, y2):
        self.min_x = min(x1, x2)
        self.min_y = min(y1, y2)
        self.max_x = max(x1, x2)
        self.max_y = max(y1, y2)

    def width(self):
        return self.max_x - self.min_x

    def height(self):
        return self.max_y - self.min_y

    def area(self):
        return self.width() * self.height()

    def intersects(self, rect):
        if self.min_x > rect.max_x or self.max_x < rect.min_x:
            return Rect(0, 0, 0, 0)
        if self.min_y > rect.max_y or self.max_y < rect.min_y:
            return Rect(0, 0, 0, 0)

        min_x = max(self.min_x, rect.min_x)
        min_y = max(self.min_y, rect.min_y)
        max_x = min(self.max_x, rect.max_x)
        max_y = min(self.max_y, rect.max_y)

        return Rect(min_x, min_y, max_x, max_y)


def iou(lhs_rect, rhs_rect):
    intersection_area = lhs_rect.intersects(rhs_rect).area()
    union_area = lhs_rect.area() + rhs_rect.area() - intersection_area
    if (union_area < 0.001):
        return 0.
    return intersection_area / union_area


Object = namedtuple('Object', 'type, bbox')


def load_objects(path):
    data = json.load(open(path))
    objects = {}
    for feature_objects in data['features_objects']:
        feature_id = feature_objects['feature_id']
        if feature_id not in objects:
            objects[feature_id] = {}
        for item in feature_objects['objects']:
            object_id = item['object_id']
            type = item['type']
            p1 = item['bbox'][0]
            p2 = item['bbox'][1]
            bbox = Rect(p1[0], p1[1], p2[0], p2[1])
            objects[feature_id][object_id] = Object(type, bbox)
    return objects


def load_valid_classes(path):
    return set(open(path).read().splitlines())


def filter_valid_classes(objects, valid_classes):
    filtered_objects = {}
    for feature_id, object_by_id in objects.items():
        for object_id, object in object_by_id.items():
            if object.type in valid_classes:
                if feature_id not in filtered_objects:
                    filtered_objects[feature_id] = {}
                filtered_objects[feature_id][object_id] = object
    return filtered_objects


def filter_small_feature_object_ids(objects, min_bbox_size):
    feature_object_ids = set()
    for feature_id, object_by_id in objects.items():
        for object_id, object in object_by_id.items():
            if object.bbox.width() >= min_bbox_size or object.bbox.height() >= min_bbox_size:
                feature_object_ids.add((feature_id, object_id))
    return feature_object_ids


def load_matches(path, conf_thr=0.):
    data = json.load(open(path))
    matches = set()
    for features_pair in data['features_pairs']:
        feature_id_1 = features_pair['feature_id_1']
        feature_id_2 = features_pair['feature_id_2']
        for match in features_pair['matches']:
            object_id_1 = match['object_id_1']
            object_id_2 = match['object_id_2']
            confidence = match.get('confidence', float('inf'))
            if (confidence < conf_thr):
                continue

            feature_object_1 = (feature_id_1, object_id_1)
            feature_object_2 = (feature_id_2, object_id_2)

            if feature_id_1 < feature_id_2:
                matches.add((feature_object_1, feature_object_2))
            else:
                matches.add((feature_object_2, feature_object_1))

    return matches


class ObjectIdTransformer:
    def __init__(self, gt_objects, test_objects, iou_thresh):
        self.test_to_gt = {}
        self.gt_to_test = {}

        feature_ids = set(gt_objects.keys()).union(test_objects.keys())

        for feature_id in feature_ids:
            gt_object_by_id = gt_objects.get(feature_id, {})
            test_object_by_id = test_objects.get(feature_id, {})

            pairs = []
            for gt_object_id, gt_object in gt_object_by_id.items():
                for test_object_id, test_object in test_object_by_id.items():
                    if gt_object.type != test_object.type:
                        continue
                    iou_value = iou(gt_object.bbox, test_object.bbox)
                    if iou_value < iou_thresh:
                        continue
                    pairs.append((iou_value, gt_object_id, test_object_id))

            pairs.sort(key=lambda x: x[0], reverse=True)

            used_gt_object_id = set()
            used_test_object_id = set()
            for iou_value, gt_object_id, test_object_id in pairs:
                if gt_object_id in used_gt_object_id:
                    continue
                if test_object_id in used_test_object_id:
                    continue

                gt_feature_object_id = (feature_id, gt_object_id)
                test_feature_object_id = (feature_id, test_object_id)
                self.test_to_gt[test_feature_object_id] = gt_feature_object_id
                self.gt_to_test[gt_feature_object_id] = test_feature_object_id

                used_gt_object_id.add(gt_object_id)
                used_test_object_id.add(test_object_id)


def filter_gt_matches(gt_matches, transformer):
    matches = set()

    for gt_feature_object_1, gt_feature_object_2 in gt_matches:
        if gt_feature_object_1 not in transformer.gt_to_test:
            continue
        if gt_feature_object_2 not in transformer.gt_to_test:
            continue
        matches.add((gt_feature_object_1, gt_feature_object_2))

    return matches


def transform_test_matches(test_matches, transformer):
    matches = set()

    for test_feature_object_1, test_feature_object_2 in test_matches:
        if test_feature_object_1 not in transformer.test_to_gt:
            continue
        if test_feature_object_2 not in transformer.test_to_gt:
            continue
        gt_feature_object_1 = transformer.test_to_gt[test_feature_object_1]
        gt_feature_object_2 = transformer.test_to_gt[test_feature_object_2]
        matches.add((gt_feature_object_1, gt_feature_object_2))

    return matches


def calculate_quality(gt_matches, test_matches, feature_object_ids):
    true_positive = 0
    false_negative = 0

    for gt_match in gt_matches:
        if gt_match in test_matches:
            true_positive += 1
        elif gt_match[0] in feature_object_ids and gt_match[1] in feature_object_ids:
            false_negative += 1

    precision = true_positive / len(test_matches)
    recall = true_positive / (true_positive + false_negative)
    f1_score = 2 * precision * recall / (precision + recall)

    return precision, recall, f1_score


def main():
    parser = argparse.ArgumentParser(description="Calculate quality values")

    parser.add_argument('--gt-objects', required=True,
                        help='Path to input json file with ground truth objects')
    parser.add_argument('--test-objects', required=True,
                        help='Path to input json file with test data objects')
    parser.add_argument('--iou-thresh', default=0.5, type=float,
                        help='Threshold value for IoU of GT and test objects (default: 0.5)')
    parser.add_argument('--min-bbox-size', default=0, type=int,
                        help='Minimal size of GT object (default: 0 px)')
    parser.add_argument('--gt-matches', required=True,
                        help='Path to input json file with ground truth')
    parser.add_argument('--test-matches', required=True,
                        help='Path to input json file with test data')
    parser.add_argument('--valid-classes-list', default=None,
                        help='Path to text file with valid traffic sign classes, one name by line')
    parser.add_argument('--confidence-threshold', default=0.0, type=float,
                        help='Confidence threshold for tested matches')
    args = parser.parse_args()

    gt_objects = load_objects(args.gt_objects)
    test_objects = load_objects(args.test_objects)

    # Фильтруем по списку классов эталонные и тестовые детекции
    if args.valid_classes_list is not None:
        valid_classes = load_valid_classes(args.valid_classes_list)
        gt_objects = filter_valid_classes(gt_objects, valid_classes)
        test_objects = filter_valid_classes(test_objects, valid_classes)

    gt_matches = load_matches(args.gt_matches)
    test_matches = load_matches(args.test_matches, args.confidence_threshold)
    # Находим соответствия между эталонными и тестовыми объектами
    object_id_transformer = ObjectIdTransformer(gt_objects, test_objects, args.iou_thresh)
    # Фильтруем матчи в которых учатсвуют эталонные объекты,
    # для которых не было найдено отображений в тестовые объекты.
    gt_matches = filter_gt_matches(gt_matches, object_id_transformer)
    # Трансформирует матчи между тестовыми объектами в матчи между эталонными объектами
    # Если какого-то отображения нет, то такой матч пропадает из множества тестовых матчей
    test_matches = transform_test_matches(test_matches, object_id_transformer)
    # Собираем айдишники объектов, которые имеют подходящие размеры хотя бы по одной из осей
    feature_object_ids = filter_small_feature_object_ids(gt_objects, args.min_bbox_size)

    precision, recall, f1_score = calculate_quality(gt_matches, test_matches, feature_object_ids)

    print("Matching:")
    print("  precision: = {:.2f}%".format(100. * precision))
    print("  recall:    = {:.2f}%".format(100. * recall))
    print("  F1 score:  = {:.2f}%".format(100. * f1_score))


if __name__ == '__main__':
    main()
