import os
import sys
import tensorflow as tf

from google.protobuf import text_format

from object_detection.protos import eval_pb2
from object_detection.protos import image_resizer_pb2
from object_detection.protos import input_reader_pb2
from object_detection.protos import model_pb2
from object_detection.protos import pipeline_pb2
from object_detection.protos import train_pb2
from object_detection.utils import config_util

def _read_config(pipeline_config_path):
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    with tf.gfile.GFile(pipeline_config_path, "r") as f:
        proto_str = f.read()
        text_format.Merge(proto_str, pipeline_config)
    return pipeline_config

def _write_config(config, config_path):
    config_text = text_format.MessageToString(config)
    with tf.gfile.Open(config_path, "wb") as f:
        f.write(config_text)

def _update_num_classes(pipeline_config, num_classes):
    meta_architecture = pipeline_config.model.WhichOneof("model")
    if meta_architecture == "faster_rcnn":
        pipeline_config.model.faster_rcnn.num_classes = num_classes
    if meta_architecture == "ssd":
        pipeline_config.model.ssd.num_classes = num_classes

def _update_train_tfrecord(pipeline_config, train_tfrecord_path):
    input_reader_type = pipeline_config.train_input_reader.WhichOneof("input_reader")
    if input_reader_type == "tf_record_input_reader":
        pipeline_config.train_input_reader.tf_record_input_reader.ClearField("input_path")
        pipeline_config.train_input_reader.tf_record_input_reader.input_path.append(train_tfrecord_path)
    else:
        raise TypeError("Input reader type must be `tf_record_input_reader`.")

def _update_label_map(pipeline_config, label_map_path):
    pipeline_config.train_input_reader.label_map_path = label_map_path

def update_template(template_path, num_classes, train_tfrecord_path, label_map_path, config_out_path):
    pipeline_config = _read_config(template_path)
    _update_num_classes(pipeline_config, num_classes)
    _update_train_tfrecord(pipeline_config, train_tfrecord_path)
    _update_label_map(pipeline_config, label_map_path)
    _write_config(pipeline_config, config_out_path)

