import sys
import os
import tfconfig_utils
import tensorflow as tf
from tensorflow.core.framework import graph_pb2

import tarfile
import tempfile

from object_detection.utils import label_map_util
from object_detection import exporter

def _load_label_map(label_map_path):
    label_map = label_map_util.get_label_map_dict(label_map_path)
    classes_names = [None] * len(label_map)
    for label in label_map.items():
        classes_names[label[1] - 1] = label[0]
    return classes_names

def _add_classes_to_graph(inp_path, label_map_path, out_path):
    classes_names = _load_label_map(label_map_path)

    graph_def = graph_pb2.GraphDef()
    with open(inp_path, "rb") as f:
        graph_def.ParseFromString(f.read())

    tf.import_graph_def(graph_def, name='')
    g = tf.get_default_graph()
    tf.constant(classes_names, tf.string, name="class_names")

    graph_def = g.as_graph_def()
    with tf.gfile.GFile(out_path, "wb") as f:
        f.write(graph_def.SerializeToString())


def export(label_map_path, all_ckpt_tar_path, gdef_tgz_path, train_config):
    pipeline_config = tfconfig_utils.read_config(train_config)

    ckpt_temp_path = os.path.join(tempfile.gettempdir(), 'ckpt_temp')
    export_temp_path = os.path.join(tempfile.gettempdir(), 'export_temp')

    src_tar = tarfile.open(all_ckpt_tar_path, 'r')
    gdef_tgz = tarfile.open(gdef_tgz_path, 'w:gz')
    names = src_tar.getnames()
    exported = []
    for fn in names:
        if (fn.startswith('model.ckpt-')):
            ckpt_prefix, _ = os.path.splitext(fn)
            if (ckpt_prefix in exported):
                continue
            src_tar.extract('{}.index'.format(ckpt_prefix), ckpt_temp_path)
            src_tar.extract('{}.meta'.format(ckpt_prefix), ckpt_temp_path)
            src_tar.extract('{}.data-00000-of-00001'.format(ckpt_prefix), ckpt_temp_path)

            checkpt_idx = ckpt_prefix[ckpt_prefix.index('-') + 1:]
            print('Export checkpoint {}'.format(checkpt_idx))

            ckpt_file_prefix = os.path.join(ckpt_temp_path, ckpt_prefix)
            exporter.export_inference_graph('image_tensor', pipeline_config,
                                            ckpt_file_prefix,
                                            export_temp_path, None, False)

            graph_path = '{}/frozen_inference_graph.pb'.format(export_temp_path)
            tf.reset_default_graph()

            graph_with_names_path = '{}/{}.gdef'.format(export_temp_path, checkpt_idx)
            _add_classes_to_graph(graph_path, label_map_path, graph_with_names_path)
            gdef_tgz.add(graph_with_names_path, '{}.gdef'.format(checkpt_idx))

            #gdef_tgz.add(graph_path, '{}.gdef'.format(checkpt_idx))

            exported = exported + [ckpt_prefix]
            tf.gfile.DeleteRecursively(export_temp_path)
            tf.gfile.DeleteRecursively(ckpt_temp_path)
            tf.reset_default_graph()

    src_tar.close()
    gdef_tgz.close()