import argparse
import os
import tensorflow as tf
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2


def main() :
    parser = argparse.ArgumentParser(description='Convert TF2 Saved Model format to GDef')
    parser.add_argument('input_dir', type=str, help='Input SavedModel directory')
    parser.add_argument('output_path', type=str, help='Output gdef path')
    parser.add_argument('signature', type=str, help='Custom signature with TF concrete function', default='serving_default')
    args = parser.parse_args()

    model = tf.saved_model.load(args.input_dir)

    frozen_func = convert_variables_to_constants_v2(model.signatures[args.signature])
    frozen_func.graph.as_graph_def()

    print("-" * 60)
    print("Model inputs: ")
    print(frozen_func.inputs)
    print("Model outputs: ")
    print(frozen_func.outputs)

    output_dir, filename = os.path.split(args.output_path)
    tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                      logdir=output_dir,
                      name=filename,
                      as_text=False)


if __name__ == '__main__':
    main()
