import tensorflow as tf
import numpy as np
from tensorflow.core.framework import graph_pb2
import cv2
import sys

feature = {
    'image/format' :    tf.FixedLenFeature((), tf.string, default_value='jpeg'),
    'image/filename':   tf.FixedLenFeature((), tf.string, default_value=''),
    'image/key/sha256': tf.FixedLenFeature((), tf.string, default_value=''),
    'image/source_id':  tf.FixedLenFeature((), tf.string, default_value=''),
    'image/height':     tf.FixedLenFeature((), tf.int64, default_value=1),
    'image/width':      tf.FixedLenFeature((), tf.int64, default_value=1),
    'image/encoded':    tf.FixedLenFeature((), tf.string, default_value=''),
    'image/object/bbox/xmin':   tf.VarLenFeature(tf.float32),
    'image/object/bbox/xmax':   tf.VarLenFeature(tf.float32),
    'image/object/bbox/ymin':   tf.VarLenFeature(tf.float32),
    'image/object/bbox/ymax':   tf.VarLenFeature(tf.float32),
    'image/object/class/label': tf.VarLenFeature(tf.int64),
    'image/object/class/text':  tf.VarLenFeature(tf.string),
    'image/object/area':        tf.VarLenFeature(tf.float32),
    'image/object/is_crowd':    tf.VarLenFeature(tf.int64),
    'image/object/difficult':   tf.VarLenFeature(tf.int64),
    'image/object/group_of':    tf.VarLenFeature(tf.int64),
    'image/object/weight':      tf.VarLenFeature(tf.float32),
}

CLASSES_NAMES = [
    '1.18.1_Russian_road_marking',
    '1.18.2_Russian_road_marking',
    '1.18.3_Russian_road_marking',
    '1.18.4_Russian_road_marking',
    '1.18.5_Russian_road_marking'
]

SCORE_THRESHOLD = 0.7
IOU_THRESHOLD   = 0.6

def printProgressBar(iteration, total, prefix = '', suffix = '', decimals = 1, length = 100, fill = '#'):
    """
    Call in a loop to create terminal progress bar
    @params:
        iteration   - Required  : current iteration (Int)
        total       - Required  : total iterations (Int)
        prefix      - Optional  : prefix string (Str)
        suffix      - Optional  : suffix string (Str)
        decimals    - Optional  : positive number of decimals in percent complete (Int)
        length      - Optional  : character length of bar (Int)
        fill        - Optional  : bar fill character (Str)
    """
    percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
    filledLength = int(length * iteration // total)
    bar = fill * filledLength + '-' * (length - filledLength)
    sys.stdout.write('\r%s |%s| %s%s %s' % (prefix, bar, percent, '%', suffix)),
    if iteration == total:
        sys.stdout.write('\n')
    sys.stdout.flush()

def build_inference_graph(image_tensor, graph_path):
    """Loads the inference graph and connects it to the input image.
    Args:
    image_tensor: The input image. uint8 tensor, shape=[1, None, None, 3]
    inference_graph_path: Path to the inference graph with embedded weights
    Returns:
    detected_boxes_tensor: Detected boxes. Float tensor,
        shape=[num_detections, 4]
    detected_scores_tensor: Detected scores. Float tensor,
        shape=[num_detections]
    detected_labels_tensor: Detected labels. Int64 tensor,
        shape=[num_detections]
    """
    graph_def = graph_pb2.GraphDef()
    with open(graph_path, "rb") as f:
        graph_def.ParseFromString(f.read())

    tf.import_graph_def(graph_def, name='', input_map={'image_tensor': image_tensor})

    g = tf.get_default_graph()

    num_detections_tensor = tf.squeeze(g.get_tensor_by_name('num_detections:0'), 0)
    num_detections_tensor = tf.cast(num_detections_tensor, tf.int32)

    detected_boxes_tensor = tf.squeeze(g.get_tensor_by_name('detection_boxes:0'), 0)
    detected_boxes_tensor = detected_boxes_tensor[:num_detections_tensor]

    detected_scores_tensor = tf.squeeze(g.get_tensor_by_name('detection_scores:0'), 0)
    detected_scores_tensor = detected_scores_tensor[:num_detections_tensor]

    detected_labels_tensor = tf.squeeze(g.get_tensor_by_name('detection_classes:0'), 0)
    detected_labels_tensor = tf.cast(detected_labels_tensor, tf.int64)
    detected_labels_tensor = detected_labels_tensor[:num_detections_tensor]

    return detected_boxes_tensor, detected_scores_tensor, detected_labels_tensor

def tfrecords_count(filepath):
    cnt = 0
    for record in tf.python_io.tf_record_iterator(filepath):
        cnt += 1
    return cnt

def inters_area(bbox1, bbox2):
    """Calculate area of intersection of two bbox.
    Args:
    bbox1: first rectangle, [ymin, xmin, ymax, xmax]
    bbox1: second rectangle, [ymin, xmin, ymax, xmax]
    """
    xmin1 = min(bbox1[1], bbox1[3])
    xmin2 = min(bbox2[1], bbox2[3])
    xmini = max(xmin1, xmin2)

    ymin1 = min(bbox1[0], bbox1[2])
    ymin2 = min(bbox2[0], bbox2[2])
    ymini = max(ymin1, ymin2)

    xmax1 = max(bbox1[1], bbox1[3])
    xmax2 = max(bbox2[1], bbox2[3])
    xmaxi = min(xmax1, xmax2)

    ymax1 = max(bbox1[0], bbox1[2])
    ymax2 = max(bbox2[0], bbox2[2])
    ymaxi = min(ymax1, ymax2)

    w = xmaxi - xmini;
    h = ymaxi - ymini;
    if (w <= 0 or h <= 0):
        return 0.0
    return float(w * h)

def area(bbox):
    """Calculate area of bbox.
    Args:
    bbox: bbox, [ymin, xmin, ymax, xmax]
    """
    w = abs(bbox[3] - bbox[1]);
    h = abs(bbox[2] - bbox[0])
    return float(w * h)

def iou(bbox1, bbox2):
    i = inters_area(bbox1, bbox2)
    u = area(bbox1) + area(bbox2) - i
    if (u < 0.001):
        return 0.
    return i / u

def compare_with_classes(bboxes_gt, cl_gt, bboxes_test, cl_test, iou_thr):
    #     | CL1 | CL2 | .... | CLN | Neg |   <-- GT
    # CL1 |     |     | .... |     |     |
    # CL2 |     |     | .... |     |     |
    # ... | ............................ |
    # CLN |     |     | .... |     |     |
    # Neg |     |     | .... |     |     |
    #  |
    # TEST

    cl_cnt = len(CLASSES_NAMES)
    stat = np.zeros((cl_cnt + 1, cl_cnt + 1), np.int)

    if (len(bboxes_gt) == 0 and len(bboxes_test) == 0):
        return stat

    iou_arr=[]
    for idx_gt, bbox_gt in enumerate(bboxes_gt):
        for idx_test, bbox_test in enumerate(bboxes_test):
            iou_val = iou(bbox_gt, bbox_test)
            if (iou_val > iou_thr):
                iou_arr = iou_arr + [(iou_val, idx_gt, idx_test)]

    iou_arr.sort(key=lambda x: x[0], reverse=True)

    fnd_test = []
    for i in range(len(bboxes_test)):
        fnd_test.append(False)
    fnd_gt = []
    for i in range(len(bboxes_gt)):
        fnd_gt.append(False)

    for iou_item in iou_arr:
        idx_gt = iou_item[1]
        if (fnd_gt[idx_gt]) :
            continue
        idx_test = iou_item[2]
        if (fnd_test[idx_test]) :
            continue
        fnd_gt[idx_gt]     = True
        fnd_test[idx_test] = True
        stat[cl_test[idx_test] - 1, cl_gt[idx_gt] - 1] += 1

    for idx_test in range(len(bboxes_test)):
        if (fnd_test[idx_test]):
            continue
        stat[cl_test[idx_test] - 1, cl_cnt] += 1
    for idx_gt in range(len(bboxes_gt)):
        if (fnd_gt[idx_gt]):
            continue
        stat[cl_cnt, cl_gt[idx_gt] - 1] += 1
    return stat

def compare(bboxes_gt, bboxes_test, iou_thr):
    if (len(bboxes_gt) == 0 or len(bboxes_test) == 0):
        return 0
    iou_arr=[]
    for idx_gt, bbox_gt in enumerate(bboxes_gt):
        for idx_test, bbox_test in enumerate(bboxes_test):
            iou_val = iou(bbox_gt, bbox_test)
            if (iou_val > iou_thr):
                iou_arr = iou_arr + [(iou_val, idx_gt, idx_test)]

    iou_arr.sort(key=lambda x: x[0], reverse=True)

    fnd_test = []
    for i in range(len(bboxes_test)):
        fnd_test.append(False)
    fnd_gt = []
    for i in range(len(bboxes_gt)):
        fnd_gt.append(False)

    TP = 0
    for iou_item in iou_arr:
        idx_gt = iou_item[1]
        if (fnd_gt[idx_gt]) :
            continue
        idx_test = iou_item[2]
        if (fnd_test[idx_test]) :
            continue
        TP = TP + 1
    return TP

def main(graph_def_path, val_rec_path, scr_thr = SCORE_THRESHOLD, iou_thr = IOU_THRESHOLD):
    rec_cnt = tfrecords_count(val_rec_path)

    filename_queue = tf.train.string_input_producer([val_rec_path], num_epochs=1)

    reader = tf.TFRecordReader()
    _, example = reader.read(filename_queue)

    example_features = tf.parse_single_example(example, features=feature)

    lbls_gt = tf.sparse_tensor_to_dense(example_features['image/object/class/label'])

    x_max = tf.expand_dims(tf.sparse_tensor_to_dense(example_features['image/object/bbox/xmax'] * tf.to_float(example_features['image/width'])), 1)
    x_min = tf.expand_dims(tf.sparse_tensor_to_dense(example_features['image/object/bbox/xmin'] * tf.to_float(example_features['image/width'])), 1)
    y_min = tf.expand_dims(tf.sparse_tensor_to_dense(example_features['image/object/bbox/ymin'] * tf.to_float(example_features['image/height'])), 1)
    y_max = tf.expand_dims(tf.sparse_tensor_to_dense(example_features['image/object/bbox/ymax'] * tf.to_float(example_features['image/height'])), 1)
    boxes_gt = tf.concat([y_min, x_min, y_max, x_max], axis = 1)

    filename = example_features['image/filename']
    img_enc = example_features['image/encoded']
    image = tf.expand_dims(tf.image.decode_jpeg(img_enc), axis=0)

    boxes, scores, lbls = build_inference_graph(image, graph_def_path)

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    cl_cnt = len(CLASSES_NAMES)
    stat = np.zeros((cl_cnt + 1, cl_cnt + 1), np.int)
    with tf.Session()  as sess:
        sess.run(init_op)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            idx  = 0
            while (True):
                boxes_, scores_, lbls_, image_, filename_, boxes_gt_, lbls_gt_ = sess.run([boxes, scores, lbls, image, filename, boxes_gt, lbls_gt])
                boxes_ = boxes_[scores_ > scr_thr]
                boxes_[:,0::2] = boxes_[:,0::2] * image_.shape[1] # y_min -- y_max
                boxes_[:,1::2] = boxes_[:,1::2] * image_.shape[2] # x_min -- x_max

                printProgressBar(idx + 1, rec_cnt, prefix = "", suffix = "")
                idx += 1
                stat = stat + compare_with_classes(boxes_gt_, lbls_gt_, boxes_, lbls_, iou_thr)
        except tf.errors.OutOfRangeError:
            printProgressBar(rec_cnt, rec_cnt, prefix = "", suffix = "Epoch ended")
        finally:
            coord.request_stop()
        coord.request_stop()
        coord.join(threads)

    print(stat)
    test_cnt = np.sum(stat, axis = 1)
    gt_cnt   = np.sum(stat, axis = 0)
    diag     = np.diag(stat)

    precision = diag.astype(np.float) / test_cnt.astype(np.float) * 100.
    print("Precision: {:.2f}%".format(np.sum(diag.astype(np.float)) / np.sum(test_cnt[0:-1].astype(np.float))  * 100.))
    for idx, cl in enumerate(CLASSES_NAMES):
        print("  {}: {:.2f}%".format(cl, precision[idx]))
    print()
    recall = diag.astype(np.float) / gt_cnt.astype(np.float) * 100.
    print("Recall:    {:.2f}%".format(np.sum(diag.astype(np.float)) / np.sum(gt_cnt[0:-1].astype(np.float))  * 100.))
    for idx, cl in enumerate(CLASSES_NAMES):
        print("  {}: {:.2f}%".format(cl, recall[idx]))

if __name__ == '__main__':
    main(sys.argv[1], sys.argv[2])
