import os
import sys
import tensorflow as tf

from google.protobuf import text_format

from object_detection.protos import eval_pb2
from object_detection.protos import image_resizer_pb2
from object_detection.protos import input_reader_pb2
from object_detection.protos import model_pb2
from object_detection.protos import pipeline_pb2
from object_detection.protos import train_pb2
from object_detection.protos import optimizer_pb2
from object_detection.utils import config_util
from object_detection.utils import label_map_util

def _calc_classes_count(label_map_path):
    lm = label_map_util.load_labelmap(label_map_path)
    return len(lm.item)

def _tfrecords_count(filepath):
    cnt = 0
    for record in tf.python_io.tf_record_iterator(filepath):
        cnt += 1
    return cnt

def read_config(pipeline_config_path):
    pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
    with tf.gfile.GFile(pipeline_config_path, "r") as f:
        proto_str = f.read()
        text_format.Merge(proto_str, pipeline_config)
    return pipeline_config

def _write_config(config, config_path):
    config_text = text_format.MessageToString(config)
    with tf.gfile.Open(config_path, "wb") as f:
        f.write(config_text)

def _update_num_classes(pipeline_config, num_classes):
    meta_architecture = pipeline_config.model.WhichOneof("model")
    if meta_architecture == "faster_rcnn":
        pipeline_config.model.faster_rcnn.num_classes = num_classes
    if meta_architecture == "ssd":
        pipeline_config.model.ssd.num_classes = num_classes

def _update_train_tfrecord(pipeline_config, train_tfrecord_path_list):
    input_reader_type = pipeline_config.train_input_reader.WhichOneof("input_reader")
    if input_reader_type == "tf_record_input_reader":
        pipeline_config.train_input_reader.tf_record_input_reader.ClearField("input_path")
        for path in train_tfrecord_path_list:
            pipeline_config.train_input_reader.tf_record_input_reader.input_path.append(path)
    else:
        raise TypeError("Input reader type must be `tf_record_input_reader`.")

def _update_label_map(pipeline_config, label_map_path):
    pipeline_config.train_input_reader.label_map_path = label_map_path

def _update_detection_checkpoint(pipeline_config, detection_checkpoint):
    pipeline_config.train_config.fine_tune_checkpoint_type = 'detection'
    pipeline_config.train_config.fine_tune_checkpoint = detection_checkpoint

def _update_num_steps(pipeline_config, num_steps):
    pipeline_config.train_config.num_steps = num_steps

def _update_batch_size(pipeline_config, batch_size):
    pipeline_config.train_config.batch_size = batch_size

def _update_lr_scheme(pipeline_config, lr_init, lr_decay_steps, lr_decay_factor, weight_decay, optimizer_type = "adam_optimizer"):
    if (optimizer_type == "momentum_optimizer"):
        pipeline_config.train_config.optimizer.momentum_optimizer.learning_rate.exponential_decay_learning_rate.initial_learning_rate = lr_init
        pipeline_config.train_config.optimizer.momentum_optimizer.learning_rate.exponential_decay_learning_rate.decay_steps = lr_decay_steps
        pipeline_config.train_config.optimizer.momentum_optimizer.learning_rate.exponential_decay_learning_rate.decay_factor = lr_decay_factor
    elif (optimizer_type == "adamw_optimizer"):
        pipeline_config.train_config.optimizer.adamw_optimizer.weight_decay = weight_decay
        pipeline_config.train_config.optimizer.adamw_optimizer.learning_rate.exponential_decay_learning_rate.initial_learning_rate = lr_init
        pipeline_config.train_config.optimizer.adamw_optimizer.learning_rate.exponential_decay_learning_rate.decay_steps = lr_decay_steps
        pipeline_config.train_config.optimizer.adamw_optimizer.learning_rate.exponential_decay_learning_rate.decay_factor = lr_decay_factor
    else:## (optimizer_type == "adam_optimizer"):
        pipeline_config.train_config.optimizer.adam_optimizer.learning_rate.exponential_decay_learning_rate.initial_learning_rate = lr_init
        pipeline_config.train_config.optimizer.adam_optimizer.learning_rate.exponential_decay_learning_rate.decay_steps = lr_decay_steps
        pipeline_config.train_config.optimizer.adam_optimizer.learning_rate.exponential_decay_learning_rate.decay_factor = lr_decay_factor

def update_template(template_path,
                    num_epoches,
                    optimizer_type,
                    lr_init, lr_decay_epoches, lr_decay_factor,
                    weight_decay,
                    train_tfrecord_path_list,
                    label_map_path,
                    detection_checkpoint,
                    batch_size,
                    config_out_path):
    pipeline_config = read_config(template_path)
    num_classes = _calc_classes_count(label_map_path)
    _update_num_classes(pipeline_config, num_classes)
    _update_train_tfrecord(pipeline_config, train_tfrecord_path_list)
    _update_label_map(pipeline_config, label_map_path)
    if (detection_checkpoint is not None) and (detection_checkpoint != ''):
        _update_detection_checkpoint(pipeline_config, detection_checkpoint)

    tfrecords_cnt = 0
    for path in train_tfrecord_path_list:
        tfrecords_cnt += _tfrecords_count(path)
    print("tfrecords_cnt: {}".format(tfrecords_cnt))
    print("tfrecords_cnt / batch_size: {}".format(tfrecords_cnt // batch_size))

    _update_num_steps(pipeline_config, num_epoches * tfrecords_cnt // batch_size)
    _update_batch_size(pipeline_config, batch_size)
    _update_lr_scheme(pipeline_config, lr_init, int(lr_decay_epoches * tfrecords_cnt // batch_size), lr_decay_factor, weight_decay, optimizer_type)
    _write_config(pipeline_config, config_out_path)

