import sys
import tensorflow as tf
from google.protobuf.json_format import MessageToJson
import cv2

feature = {
    'image/format' :    tf.FixedLenFeature((), tf.string, default_value='jpeg'),
    'image/filename':   tf.FixedLenFeature((), tf.string, default_value=''),
    'image/key/sha256': tf.FixedLenFeature((), tf.string, default_value=''),
    'image/source_id':  tf.FixedLenFeature((), tf.string, default_value=''),
    'image/height':     tf.FixedLenFeature((), tf.int64, default_value=1),
    'image/width':      tf.FixedLenFeature((), tf.int64, default_value=1),
    'image/encoded':    tf.FixedLenFeature((), tf.string, default_value=''),
    'image/object/bbox/xmin':   tf.VarLenFeature(tf.float32),
    'image/object/bbox/xmax':   tf.VarLenFeature(tf.float32),
    'image/object/bbox/ymin':   tf.VarLenFeature(tf.float32),
    'image/object/bbox/ymax':   tf.VarLenFeature(tf.float32),
    'image/object/class/label': tf.VarLenFeature(tf.int64),
    'image/object/class/text':  tf.VarLenFeature(tf.string),
    'image/object/area':        tf.VarLenFeature(tf.float32),
    'image/object/is_crowd':    tf.VarLenFeature(tf.int64),
    'image/object/difficult':   tf.VarLenFeature(tf.int64),
    'image/object/group_of':    tf.VarLenFeature(tf.int64),
    'image/object/weight':      tf.VarLenFeature(tf.float32),
}

def show_data(data_path):
    filename_queue = tf.train.string_input_producer([data_path], num_epochs=1)

    reader = tf.TFRecordReader()
    _, example = reader.read(filename_queue)

    example_features = tf.parse_single_example(example, features=feature)

    x_min = tf.sparse_tensor_to_dense(example_features['image/object/bbox/xmin'] * tf.to_float(example_features['image/width']))
    x_max = tf.sparse_tensor_to_dense(example_features['image/object/bbox/xmax'] * tf.to_float(example_features['image/width']))
    y_min = tf.sparse_tensor_to_dense(example_features['image/object/bbox/ymin'] * tf.to_float(example_features['image/height']))
    y_max = tf.sparse_tensor_to_dense(example_features['image/object/bbox/ymax'] * tf.to_float(example_features['image/height']))
    image = tf.image.decode_jpeg(example_features['image/encoded'])

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    with tf.Session()  as sess:
        sess.run(init_op)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        while (1):
            show_image, bbx1, bbx2, bby1, bby2 = sess.run([image, x_min, x_max, y_min, y_max])
            #print(bbx1)
            #print(bbx2)
            #print(bby1)
            #print(bby2)
            for bbox in zip(bbx1, bbx2, bby1, bby2):
                print(bbox)
                cv2.line(show_image, (bbox[0], bbox[2]), (bbox[1], bbox[2]), (255, 0, 255), 2)
                cv2.line(show_image, (bbox[1], bbox[2]), (bbox[1], bbox[3]), (255, 0, 255), 2)
                cv2.line(show_image, (bbox[1], bbox[3]), (bbox[0], bbox[3]), (255, 0, 255), 2)
                cv2.line(show_image, (bbox[0], bbox[3]), (bbox[0], bbox[2]), (255, 0, 255), 2)
            cv2.imshow("show", show_image)
            if (27 == cv2.waitKey()):
                break

        coord.request_stop()
        coord.join(threads)

def stat_data(data_path, check_image=True):
    filename_queue = tf.train.string_input_producer([data_path], num_epochs=1)

    reader = tf.TFRecordReader()
    _, example = reader.read(filename_queue)

    example_features = tf.parse_single_example(example, features=feature)

    class_name = tf.sparse_tensor_to_dense(example_features['image/object/class/text'], default_value='')
    x_min = tf.sparse_tensor_to_dense(example_features['image/object/bbox/xmin'] * tf.to_float(example_features['image/width']))
    x_max = tf.sparse_tensor_to_dense(example_features['image/object/bbox/xmax'] * tf.to_float(example_features['image/width']))
    y_min = tf.sparse_tensor_to_dense(example_features['image/object/bbox/ymin'] * tf.to_float(example_features['image/height']))
    y_max = tf.sparse_tensor_to_dense(example_features['image/object/bbox/ymax'] * tf.to_float(example_features['image/height']))
    height = tf.to_float(example_features['image/height'])
    width  = tf.to_float(example_features['image/width'])
    filename = example_features['image/filename']
    img_enc = example_features['image/encoded']
    if (check_image):
        image = tf.image.decode_jpeg(img_enc)
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    img_cnt = 0
    data_stat = {}
    with tf.Session()  as sess:
        sess.run(init_op)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        try:
            while (True):
                if (check_image):
                    try:
                        class_name_list, fn, decoded_image, img_height, img_width, bbx1, bbx2, bby1, bby2 = sess.run([class_name, filename, image, height, width, x_min, x_max, y_min, y_max])
                    except tf.errors.InvalidArgumentError:
                        print("Unable decode jpeg???")
                        continue
                    print(fn)
                else:
                    class_name_list, fn, encoded_image, img_height, img_width, bbx1, bbx2, bby1, bby2 = sess.run([class_name, filename, img_enc, height, width, x_min, x_max, y_min, y_max])
                    #decoded_image = cv2.imdecode(np.frombuffer(encoded_image, dtype=np.uint8), 0)
                    #try:
                    #    decoded_image = sess.run(tf.image.decode_jpeg(encoded_image))
                    #except tf.errors.InvalidArgumentError:
                    #    print("Problem in jpeg: ", fn, len(encoded_image))
                    #    continue
                    #print(fn, len(encoded_image), decoded_image.shape)
                    print(fn)
                for object in class_name_list:
                    if object in data_stat.keys():
                        data_stat[object] = data_stat[object] + 1
                    else:
                        data_stat[object] = 1
                if (check_image):
                    assert img_height == decoded_image.shape[0] and img_width == decoded_image.shape[1], "Image sizes don't equal size in tfrecord: {}".format(filename)
                    for bbox in zip(bbx1, bbx2, bby1, bby2):
                        assert bbox[0] < bbox[1] and 0 <= bbox[0] and bbox[1] < decoded_image.shape[1], "Invalid bbox: {} {}".format(fn, bbox)
                        assert bbox[2] < bbox[3] and 0 <= bbox[2] and bbox[3] < decoded_image.shape[0], "Invalid bbox: {} {}".format(fn, bbox)
                else:
                    for bbox in zip(bbx1, bbx2, bby1, bby2):
                        assert bbox[0] < bbox[1] and 0. <= bbox[0] and bbox[1] < img_width, "Invalid bbox: {} {}".format(fl, bbox)
                        assert bbox[2] < bbox[3] and 0. <= bbox[2] and bbox[3] < img_height, "Invalid bbox: {} {}".format(fl, bbox)
                img_cnt = img_cnt + 1
        except tf.errors.OutOfRangeError:
            print("Epoch ended")
        finally:
            coord.request_stop()
        coord.request_stop()
        coord.join(threads)

    print("Images count: {}".format(img_cnt))
    print()
    for item in data_stat:
        print(item, data_stat[item])

if __name__ == '__main__':
    stat_data(sys.argv[1], False)
    #show_data()
