import os.path
os.environ['TF_CPP_MIN_LOG_LEVEL']='3'

import sys
import json
import numpy as np
import tensorflow as tf
from object_detection.utils import dataset_util
import cv2
import random

TRAIN_PART = 0.9

features = {
    'image/format' :    tf.FixedLenFeature((), tf.string, default_value='jpeg'),
    'image/filename':   tf.FixedLenFeature((), tf.string, default_value=''),
    'image/key/sha256': tf.FixedLenFeature((), tf.string, default_value=''),
    'image/source_id':  tf.FixedLenFeature((), tf.string, default_value=''),
    'image/height':     tf.FixedLenFeature((), tf.int64, default_value=1),
    'image/width':      tf.FixedLenFeature((), tf.int64, default_value=1),
    'image/encoded':    tf.FixedLenFeature((), tf.string, default_value=''),
    'image/object/bbox/xmin':   tf.VarLenFeature(tf.float32),
    'image/object/bbox/xmax':   tf.VarLenFeature(tf.float32),
    'image/object/bbox/ymin':   tf.VarLenFeature(tf.float32),
    'image/object/bbox/ymax':   tf.VarLenFeature(tf.float32),
    'image/object/class/label': tf.VarLenFeature(tf.int64),
    'image/object/class/text':  tf.VarLenFeature(tf.string),
    'image/object/area':        tf.VarLenFeature(tf.float32),
    'image/object/is_crowd':    tf.VarLenFeature(tf.int64),
    'image/object/difficult':   tf.VarLenFeature(tf.int64),
    'image/object/group_of':    tf.VarLenFeature(tf.int64),
    'image/object/weight':      tf.VarLenFeature(tf.float32),
    'image/object/mask': tf.VarLenFeature(tf.string),
    'image/labels':    tf.FixedLenFeature((), tf.string, default_value=''),
}

CLASSES_NAMES = [
    '1.18.1_Russian_road_marking',
    '1.18.2_Russian_road_marking',
    '1.18.3_Russian_road_marking',
    '1.18.4_Russian_road_marking',
    '1.18.5_Russian_road_marking'
]

def printProgressBar(iteration, total, prefix = '', suffix = '', decimals = 1, length = 100, fill = '#'):
    """
    Call in a loop to create terminal progress bar
    @params:
        iteration   - Required  : current iteration (Int)
        total       - Required  : total iterations (Int)
        prefix      - Optional  : prefix string (Str)
        suffix      - Optional  : suffix string (Str)
        decimals    - Optional  : positive number of decimals in percent complete (Int)
        length      - Optional  : character length of bar (Int)
        fill        - Optional  : bar fill character (Str)
    """
    percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
    filledLength = int(length * iteration // total)
    bar = fill * filledLength + '-' * (length - filledLength)
    sys.stdout.write('\r%s |%s| %s%s %s' % (prefix, bar, percent, '%', suffix)),
    if iteration == total:
        sys.stdout.write('\n')
    sys.stdout.flush()

def show(json_path, image_folder):
    with open(json_path) as f:
        data = json.load(f)

    results = data['classification_task_results']

    cnt = len(results)
    for idx, result in enumerate(results):
        feature_id = result['feature_id']
        homography = result['homography']
        objects    = result['solutions']
        url        = result['source']
        out_size   = result['var_out_size']

        image = cv2.imread("{}/{}.jpg".format(image_folder, feature_id))
        H = np.array(homography).reshape((3, 3))
        img_bv = cv2.warpPerspective(image, H, (out_size, out_size))

        for object in objects:
            bbox = object['bbox']
            cl_name = object['sign_type']
            cv2.rectangle(img_bv, (bbox[0][0], bbox[0][1]), (bbox[1][0], bbox[1][1]), (0, 0, 255), 1)
            cv2.putText(img_bv, cl_name, (bbox[0][0], bbox[0][1] - 1), 0, 0.5, (0, 0, 255), 2)
        cv2.imshow("img_bv", img_bv)
        cv2.waitKey()
    print(cl_names)

def print_classes(json_path):
    with open(json_path) as f:
        data = json.load(f)

    results = data['classification_task_results']

    cnt = len(results)
    cl_names = {}
    for idx, result in enumerate(results):
        feature_id = result['feature_id']
        objects    = result['solutions']
        printProgressBar(idx+1, cnt, prefix = "", suffix = feature_id)
        for object in objects:
            bbox = object['bbox']
            cl_name = object['sign_type']
            if cl_name in cl_names:
                cl_names[cl_name] += 1
            else:
                cl_names[cl_name] = 1
    print(cl_names)

def create_tf_example(feature_id, image, bboxes, cl_names):
    height   = image.shape[0] # Image height
    width    = image.shape[1] # Image width

    _, encoded_image_data = cv2.imencode(".jpg", image)
    encoded_image_data = encoded_image_data.tobytes()

    image_format = 'jpeg'.encode('utf8')

    xmins = [] # List of normalized left x coordinates in bounding box (1 per box)
    xmaxs = [] # List of normalized right x coordinates in bounding box (1 per box)
    ymins = [] # List of normalized top y coordinates in bounding box (1 per box)
    ymaxs = [] # List of normalized bottom y coordinates in bounding box (1 per box)
    classes_text = [] # List of string class name of bounding box (1 per box)
    classes = [] # List of integer class id of bounding box (1 per box)

    for cl, bbox in zip(cl_names, bboxes):
        try:
            cl_idx = CLASSES_NAMES.index(cl)
        except ValueError:
            continue
        xmin = min(float(bbox[0][0]), float(bbox[1][0])) / float(width)
        xmax = max(float(bbox[0][0]), float(bbox[1][0])) / float(width)
        if (xmin < 0.001):
            xmin = 0.001
        if (xmax > 0.999):
            xmax = 0.999
        if (xmax - xmin < 0.05):
            continue
        ymin = min(float(bbox[0][1]), float(bbox[1][1])) / float(height)
        ymax = max(float(bbox[0][1]), float(bbox[1][1])) / float(height)
        if (ymin < 0.001):
            ymin = 0.001
        if (ymax > 0.999):
            ymax = 0.999
        if (ymax - ymin < 0.05):
            continue

        xmins = xmins + [xmin]
        xmaxs = xmaxs + [xmax]
        ymins = ymins + [ymin]
        ymaxs = ymaxs + [ymax]
        classes_text = classes_text + [cl.encode('utf8')]
        classes = classes + [cl_idx + 1]

    if (len(classes) == 0):
        return None

    tf_example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': dataset_util.int64_feature(height),
        'image/width': dataset_util.int64_feature(width),
        'image/filename': dataset_util.bytes_feature("{}".format(feature_id).encode('utf8')),
        'image/source_id': dataset_util.bytes_feature("{}".format(feature_id).encode('utf8')),
        'image/encoded': dataset_util.bytes_feature(encoded_image_data),
        'image/format': dataset_util.bytes_feature(image_format),
        'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
        'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
        'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
        'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
        'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
        'image/object/class/label': dataset_util.int64_list_feature(classes),
    }))
    return tf_example

def make_ds(json_path, image_folder, tfrecord_folder, train_part = TRAIN_PART):
    with open(json_path) as f:
        data = json.load(f)

    results = data['classification_task_results']

    train_writer = tf.python_io.TFRecordWriter(tfrecord_folder + "/train.tfrecord")
    test_writer  = tf.python_io.TFRecordWriter(tfrecord_folder + "/test.tfrecord")
    cnt = len(results)
    for idx, result in enumerate(results):
        feature_id = result['feature_id']
        homography = result['homography']
        objects    = result['solutions']
        uel        = result['source']
        out_size   = result['var_out_size']

        printProgressBar(idx+1, cnt, prefix = "", suffix = feature_id)

        image_path = "{}/{}.jpg".format(image_folder, feature_id)
        assert os.path.isfile(image_path), "Unable to find image file: ".format(image_path)
        image = cv2.imread(image_path)
        H = np.array(homography).reshape((3, 3))
        img_bv = cv2.warpPerspective(image, H, (out_size, out_size))

        bboxes = []
        cl_names = []
        for object in objects:
            bboxes = bboxes + [object['bbox']]
            cl_names = cl_names + [object['sign_type']]
        tf_example = create_tf_example(feature_id, img_bv, bboxes, cl_names)
        if (tf_example is None):
            continue
        if (random.random() < train_part):
            train_writer.write(tf_example.SerializeToString())
        else:
            test_writer.write(tf_example.SerializeToString())

    train_writer.close()
    test_writer.close()

def ds_stat(tfrecord_path):
    filename_queue = tf.train.string_input_producer([tfrecord_path], num_epochs=1)

    reader = tf.TFRecordReader()
    _, example = reader.read(filename_queue)

    example_features = tf.parse_single_example(example, features=features)
    class_names = tf.sparse_tensor_to_dense(example_features['image/object/class/text'], default_value='')
    class_lbls = tf.sparse_tensor_to_dense(example_features['image/object/class/label'])

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    cl_names = {}
    cl_lbls = {}
    with tf.Session()  as sess:
        sess.run(init_op)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        try:
            while (True):
                class_names_, class_lbls_ = sess.run([class_names, class_lbls])
                for cl_name in class_names_:
                    if cl_name in cl_names:
                        cl_names[cl_name] += 1
                    else:
                        cl_names[cl_name] = 1
                for cl_lbl in class_lbls_:
                    if cl_lbl in cl_lbls:
                        cl_lbls[cl_lbl] += 1
                    else:
                        cl_lbls[cl_lbl] = 1
        except tf.errors.OutOfRangeError:
            print("Epoch ended")
        finally:
            coord.request_stop()
        coord.request_stop()
        coord.join(threads)

    cl_all = 0
    for cl in cl_names.items():
        cl_all = cl_all + cl[1]
    for cl in cl_names.items():
        print("{}: {} ({:.2f}%)".format(cl[0], cl[1], float(cl[1] * 100)/float(cl_all)))
    print()
    for cl in cl_lbls.items():
        print("{}: {}".format(cl[0], cl[1]))

if __name__ == '__main__':
    make_ds(sys.argv[1], sys.argv[2], sys.argv[3])

    print("Train part: ")
    ds_stat(sys.argv[3] + "/train.tfrecord")

    print("Test part: ")
    ds_stat(sys.argv[3] + "/test.tfrecord")
