import os
import argparse
import pandas as pd
import tensorflow as tf


def int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def create_tf_example(item, dataset_path, label_map):
    image = open(os.path.join(dataset_path, "image/{}.jpg".format(item.id)), "rb").read()
    mask = open(os.path.join(dataset_path, "mask/{}.png".format(item.id)), "rb").read()
    tf_example = tf.train.Example(features=tf.train.Features(feature={
        "image/filename": bytes_feature(str(item.id).encode('utf8')),
        "image/encoded":  bytes_feature(image),
        "image/label":    int64_feature(label_map[item.height]),
        "mask/encoded":   bytes_feature(mask),
        "elev":           float_feature(item.elev),
        "azim":           float_feature(item.azim),
        "pixel_width":    float_feature(item.pixel_width),
        "pixel_height":   float_feature(item.pixel_height)
    }))
    return tf_example


def load_label_map(path):
    f = open(path)
    label_map = {}
    for line in f:
        name, label = map(int, line.strip().split(" "))
        label_map[name] = label
    return label_map


def main():
    parser = argparse.ArgumentParser("Tool for generate tfrecord from dataset")
    parser.add_argument("--dataset", required=True, help="Path to folder with dataset")
    parser.add_argument("--tfrecord", required=True, help="Path to output tfrecord")
    parser.add_argument("--label_map", required=True, help="Map from class name to label")
    args = parser.parse_args()

    print("Loading label map: {}".format(args.label_map))
    label_map = load_label_map(args.label_map)

    print("Creating tfrecord: {}".format(args.tfrecord))
    data = pd.read_csv(os.path.join(args.dataset, "data.csv"))
    writer = tf.python_io.TFRecordWriter(args.tfrecord)
    for item in data.itertuples():
        tf_example = create_tf_example(item, args.dataset, label_map)
        if (tf_example is None):
            continue
        writer.write(tf_example.SerializeToString())
        if ((item.Index + 1) % 1000 == 0):
            print("Processed {} items".format(item.Index + 1))
    writer.close()


if __name__ == "__main__":
    main()
