import os.path
os.environ['TF_CPP_MIN_LOG_LEVEL']='3'
import sys
import tensorflow as tf
import numpy as np
from tensorflow.core.framework import graph_pb2
from object_detection.utils import dataset_util
from object_detection.utils import label_map_util
import progress
import cv2
import time

feature = {
    'image/format' :    tf.FixedLenFeature((), tf.string, default_value='jpeg'),
    'image/filename':   tf.FixedLenFeature((), tf.string, default_value=''),
    'image/key/sha256': tf.FixedLenFeature((), tf.string, default_value=''),
    'image/source_id':  tf.FixedLenFeature((), tf.string, default_value=''),
    'image/height':     tf.FixedLenFeature((), tf.int64, default_value=1),
    'image/width':      tf.FixedLenFeature((), tf.int64, default_value=1),
    'image/encoded':    tf.FixedLenFeature((), tf.string, default_value=''),
    'image/object/bbox/xmin':   tf.VarLenFeature(tf.float32),
    'image/object/bbox/xmax':   tf.VarLenFeature(tf.float32),
    'image/object/bbox/ymin':   tf.VarLenFeature(tf.float32),
    'image/object/bbox/ymax':   tf.VarLenFeature(tf.float32),
    'image/object/class/label': tf.VarLenFeature(tf.int64),
    'image/object/class/text':  tf.VarLenFeature(tf.string),
    'image/object/area':        tf.VarLenFeature(tf.float32),
    'image/object/is_crowd':    tf.VarLenFeature(tf.int64),
    'image/object/difficult':   tf.VarLenFeature(tf.int64),
    'image/object/group_of':    tf.VarLenFeature(tf.int64),
    'image/object/weight':      tf.VarLenFeature(tf.float32),
}

SCORE_THRESHOLD = 0.7
CMP_THRESHOLD   = 0.6
MIN_SIZE        = 30   # если GT объект по обеим сторонам меньше MIN_SIZE пикселей, то не считаем ошибкой, если детектор его не нашел
                       # вообще говоря мы не добавляем в датасет объекты размера меньше чем порог, на этапе слияния json
                       # так что это скорее для проверки по autosave набору. В обычном случае, оно просто в тестовый датасет не попадёт

def build_inference_graph(image_tensor, graph_path):
    """Loads the inference graph and connects it to the input image.
    Args:
    image_tensor: The input image. uint8 tensor, shape=[1, None, None, 3]
    inference_graph_path: Path to the inference graph with embedded weights
    Returns:
    detected_boxes_tensor: Detected boxes. Float tensor,
        shape=[num_detections, 4]
    detected_scores_tensor: Detected scores. Float tensor,
        shape=[num_detections]
    detected_labels_tensor: Detected labels. Int64 tensor,
        shape=[num_detections]
    """
    graph_def = graph_pb2.GraphDef()
    with open(graph_path, "rb") as f:
        graph_def.ParseFromString(f.read())

    tf.import_graph_def(graph_def, name='', input_map={'image_tensor': image_tensor})

    g = tf.get_default_graph()

    num_detections_tensor = tf.squeeze(g.get_tensor_by_name('num_detections:0'), 0)
    num_detections_tensor = tf.cast(num_detections_tensor, tf.int32)

    detected_boxes_tensor = tf.squeeze(g.get_tensor_by_name('detection_boxes:0'), 0)
    detected_boxes_tensor = detected_boxes_tensor[:num_detections_tensor]

    detected_scores_tensor = tf.squeeze(g.get_tensor_by_name('detection_scores:0'), 0)
    detected_scores_tensor = detected_scores_tensor[:num_detections_tensor]

    detected_labels_tensor = tf.squeeze(g.get_tensor_by_name('detection_classes:0'), 0)
    detected_labels_tensor = tf.cast(detected_labels_tensor, tf.int64)
    detected_labels_tensor = detected_labels_tensor[:num_detections_tensor]

    return detected_boxes_tensor, detected_scores_tensor, detected_labels_tensor

def tfrecords_count(filepath):
    cnt = 0
    for record in tf.python_io.tf_record_iterator(filepath):
        cnt += 1
    return cnt

def inters_area(bbox1, bbox2):
    """Calculate area of intersection of two bbox.
    Args:
    bbox1: first rectangle, [ymin, xmin, ymax, xmax]
    bbox1: second rectangle, [ymin, xmin, ymax, xmax]
    """
    xmin1 = min(bbox1[1], bbox1[3])
    xmin2 = min(bbox2[1], bbox2[3])
    xmini = max(xmin1, xmin2)

    ymin1 = min(bbox1[0], bbox1[2])
    ymin2 = min(bbox2[0], bbox2[2])
    ymini = max(ymin1, ymin2)

    xmax1 = max(bbox1[1], bbox1[3])
    xmax2 = max(bbox2[1], bbox2[3])
    xmaxi = min(xmax1, xmax2)

    ymax1 = max(bbox1[0], bbox1[2])
    ymax2 = max(bbox2[0], bbox2[2])
    ymaxi = min(ymax1, ymax2)

    w = xmaxi - xmini;
    h = ymaxi - ymini;
    if (w <= 0 or h <= 0):
        return 0.0
    return float(w * h)

def area(bbox):
    """Calculate area of bbox.
    Args:
    bbox: bbox, [ymin, xmin, ymax, xmax]
    """
    w = abs(bbox[3] - bbox[1]);
    h = abs(bbox[2] - bbox[0])
    return float(w * h)

def iou(bbox1, bbox2):
    i = inters_area(bbox1, bbox2)
    u = area(bbox1) + area(bbox2) - i
    if (u < 0.001):
        return 0.
    return i / u

def iom(bbox1, bbox2):
    i = inters_area(bbox1, bbox2)
    m = min(area(bbox1), area(bbox2))
    if (m < 0.001):
        return 0.
    return i / m

def compare_with_classes(bboxes_gt, cl_gt, bboxes_test, cl_test, cmp_thr, classes_names, cmp_func = iou):
    #     | CL1 | CL2 | .... | CLN | Neg |   <-- GT
    # CL1 |     |     | .... |     |     |
    # CL2 |     |     | .... |     |     |
    # ... | ............................ |
    # CLN |     |     | .... |     |     |
    # Neg |     |     | .... |     |     |
    #  |
    # TEST

    cl_cnt = len(classes_names)
    stat = np.zeros((cl_cnt + 1, cl_cnt + 1), np.int)

    if (len(bboxes_gt) == 0 and len(bboxes_test) == 0):
        return stat

    cmp_arr=[]
    for idx_gt, bbox_gt in enumerate(bboxes_gt):
        for idx_test, bbox_test in enumerate(bboxes_test):
            cmp_val = cmp_func(bbox_gt, bbox_test)
            if (cmp_val > cmp_thr):
                cmp_arr = cmp_arr + [(cmp_val, idx_gt, idx_test)]

    cmp_arr.sort(key=lambda x: x[0], reverse=True)

    fnd_test = []
    for i in range(len(bboxes_test)):
        fnd_test.append(False)
    fnd_gt = []
    for i in range(len(bboxes_gt)):
        fnd_gt.append(False)

    for cmp_item in cmp_arr:
        idx_gt = cmp_item[1]
        if (fnd_gt[idx_gt]) :
            continue
        idx_test = cmp_item[2]
        if (fnd_test[idx_test]) :
            continue
        fnd_gt[idx_gt]     = True
        fnd_test[idx_test] = True
        stat[cl_test[idx_test] - 1, cl_gt[idx_gt] - 1] += 1

    for idx_test in range(len(bboxes_test)):
        if (fnd_test[idx_test]):
            continue
        stat[cl_test[idx_test] - 1, cl_cnt] += 1
    for idx_gt in range(len(bboxes_gt)):
        if (fnd_gt[idx_gt]):
            continue
        bbox = bboxes_gt[idx_gt]
        w = abs(bbox[3] - bbox[1]);
        h = abs(bbox[2] - bbox[0])
        if (w < MIN_SIZE and h < MIN_SIZE):
            continue
        stat[cl_cnt, cl_gt[idx_gt] - 1] += 1
    return stat

def create_tf_example(feature_id, image, bboxes_gt, cl_gt, bboxes_test, cl_test):
    height   = image.shape[0] # Image height
    width    = image.shape[1] # Image width

    _, encoded_image_data = cv2.imencode(".jpg", image)
    encoded_image_data = encoded_image_data.tobytes()

    image_format = 'jpeg'.encode('utf8')
    gt_xmins = bboxes_gt[:,1] / width
    gt_xmaxs = bboxes_gt[:,3] / width
    gt_ymins = bboxes_gt[:,0] / height
    gt_ymaxs = bboxes_gt[:,2] / height

    test_xmins = bboxes_test[:,1] / width
    test_xmaxs = bboxes_test[:,3] / width
    test_ymins = bboxes_test[:,0] / height
    test_ymaxs = bboxes_test[:,2] / height

    tf_example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': dataset_util.int64_feature(height),
        'image/width': dataset_util.int64_feature(width),
        'image/filename': dataset_util.bytes_feature("{}".format(feature_id).encode('utf8')),
        'image/source_id': dataset_util.bytes_feature("{}".format(feature_id).encode('utf8')),
        'image/encoded': dataset_util.bytes_feature(encoded_image_data),
        'image/format': dataset_util.bytes_feature(image_format),

        'image/gt/bbox/xmin': dataset_util.float_list_feature(gt_xmins),
        'image/gt/bbox/xmax': dataset_util.float_list_feature(gt_xmaxs),
        'image/gt/bbox/ymin': dataset_util.float_list_feature(gt_ymins),
        'image/gt/bbox/ymax': dataset_util.float_list_feature(gt_ymaxs),
        'image/gt/class/label': dataset_util.int64_list_feature(cl_gt),

        'image/test/bbox/xmin': dataset_util.float_list_feature(test_xmins),
        'image/test/bbox/xmax': dataset_util.float_list_feature(test_xmaxs),
        'image/test/bbox/ymin': dataset_util.float_list_feature(test_ymins),
        'image/test/bbox/ymax': dataset_util.float_list_feature(test_ymaxs),
        'image/test/class/label': dataset_util.int64_list_feature(cl_test),
    }))
    return tf_example

def compare(bboxes_gt, bboxes_test, cmp_thr, cmp_func = iou):
    if (len(bboxes_gt) == 0 or len(bboxes_test) == 0):
        return 0
    cmp_arr=[]
    for idx_gt, bbox_gt in enumerate(bboxes_gt):
        for idx_test, bbox_test in enumerate(bboxes_test):
            cmp_val = cmp_func(bbox_gt, bbox_test)
            if (cmp_val > cmp_thr):
                cmp_arr = cmp_arr + [(cmp_val, idx_gt, idx_test)]

    cmp_arr.sort(key=lambda x: x[0], reverse=True)

    fnd_test = []
    for i in range(len(bboxes_test)):
        fnd_test.append(False)
    fnd_gt = []
    for i in range(len(bboxes_gt)):
        fnd_gt.append(False)

    TP = 0
    for cmp_item in cmp_arr:
        idx_gt = cmp_item[1]
        if (fnd_gt[idx_gt]) :
            continue
        idx_test = cmp_item[2]
        if (fnd_test[idx_test]) :
            continue
        TP = TP + 1
    return TP

def load_label_map(label_map_path):
    label_map = label_map_util.get_label_map_dict(label_map_path)
    classes_names = [None] * len(label_map)
    for label in label_map.items():
        classes_names[label[1] - 1] = label[0]
    return classes_names

def main(graph_def_path, val_rec_path, label_map_path, out_stat_path, scr_thr = SCORE_THRESHOLD, cmp_thr = CMP_THRESHOLD, out_rec_path = None):
    classes_names = load_label_map(label_map_path)

    rec_cnt = tfrecords_count(val_rec_path)

    filename_queue = tf.train.string_input_producer([val_rec_path], num_epochs=1)

    reader = tf.TFRecordReader()
    _, example = reader.read(filename_queue)

    example_features = tf.parse_single_example(example, features=feature)

    lbls_gt = tf.sparse_tensor_to_dense(example_features['image/object/class/label'])

    x_max = tf.expand_dims(tf.sparse_tensor_to_dense(example_features['image/object/bbox/xmax'] * tf.to_float(example_features['image/width'])), 1)
    x_min = tf.expand_dims(tf.sparse_tensor_to_dense(example_features['image/object/bbox/xmin'] * tf.to_float(example_features['image/width'])), 1)
    y_min = tf.expand_dims(tf.sparse_tensor_to_dense(example_features['image/object/bbox/ymin'] * tf.to_float(example_features['image/height'])), 1)
    y_max = tf.expand_dims(tf.sparse_tensor_to_dense(example_features['image/object/bbox/ymax'] * tf.to_float(example_features['image/height'])), 1)
    boxes_gt = tf.concat([y_min, x_min, y_max, x_max], axis = 1)

    filename = example_features['image/filename']
    img_enc = example_features['image/encoded']
    image = tf.expand_dims(tf.image.decode_jpeg(img_enc), axis=0)

    boxes, scores, lbls = build_inference_graph(image, graph_def_path)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    cl_cnt = len(classes_names)
    stat = np.zeros((cl_cnt + 1, cl_cnt + 1), np.int)

    if (not out_rec_path is None):
        out_rec_writer = tf.python_io.TFRecordWriter(out_rec_path)
    else:
        out_rec_writer = None
    with tf.Session()  as sess:
        sess.run(init_op)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            idx  = 0
            while (True):
                boxes_, scores_, lbls_, image_, filename_, boxes_gt_, lbls_gt_ = sess.run([boxes, scores, lbls, image, filename, boxes_gt, lbls_gt])
                boxes_ = boxes_[scores_ > scr_thr]
                boxes_[:,0::2] = boxes_[:,0::2] * image_.shape[1] # y_min -- y_max
                boxes_[:,1::2] = boxes_[:,1::2] * image_.shape[2] # x_min -- x_max
                lbls_  = lbls_[scores_ > scr_thr]

                progress.printProgressBar(idx, rec_cnt, prefix = "", suffix = filename_)
                idx += 1
                stat = stat + compare_with_classes(boxes_gt_, lbls_gt_, boxes_, lbls_, cmp_thr, classes_names)

                if (not out_rec_writer is None):
                    tf_example = create_tf_example(filename_, image_[0][:,:,::-1], boxes_gt_, lbls_gt_, boxes_, lbls_)
                    out_rec_writer.write(tf_example.SerializeToString())
        except tf.errors.OutOfRangeError:
            progress.printProgressBar(rec_cnt, rec_cnt, prefix = "", suffix = "Eval dataset ended")
        finally:
            coord.request_stop()
        coord.request_stop()
        coord.join(threads)
    if (not out_rec_path is None):
        out_rec_writer.close()

    test_cnt = np.sum(stat, axis = 1)
    gt_cnt   = np.sum(stat, axis = 0)
    diag     = np.diag(stat)
    precision = np.divide(diag.astype(np.float) * 100., test_cnt.astype(np.float), out=np.zeros_like(diag.astype(np.float)), where=test_cnt!=0)
    recall = np.divide(diag.astype(np.float) * 100., gt_cnt.astype(np.float), out=np.zeros_like(diag.astype(np.float)), where=gt_cnt!=0)

    with open(out_stat_path, 'a') as of:
        of.write("{}\n".format(graph_def_path))
        for row in range(stat.shape[0]):
            line = classes_names[row] if row < len(classes_names) else 'negatives'
            for col in range(stat.shape[1]):
                line = "{};{}".format(line, stat[row, col])
            print(line)
            of.write('{}\n'.format(line))
        of.write("Precision: {:.2f}%\n".format(np.sum(diag.astype(np.float)) / np.sum(test_cnt[0:-1].astype(np.float))  * 100.))
        of.write("Recall:    {:.2f}%\n".format(np.sum(diag.astype(np.float)) / np.sum(gt_cnt[0:-1].astype(np.float))  * 100.))
        of.write('\n')

    for idx, cl in enumerate(classes_names):
        print("{:<25}: {:.2f}% / {:.2f}%".format(cl, precision[idx], recall[idx]))
    print("Precision {:.2f}% / Recall {:.2f}%".format(np.sum(diag.astype(np.float)) / np.sum(test_cnt[0:-1].astype(np.float))  * 100.,
                                                       np.sum(diag.astype(np.float)) / np.sum(gt_cnt[0:-1].astype(np.float))  * 100.))
    print()

if __name__ == '__main__':
    if (len(sys.argv) == 6):
        main(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], out_rec_path = sys.argv[5])
    else:
        main(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4])
