import sys
import os.path

import tfconfig_utils
import saver_with_events
import ckpt_snapshot

import functools
import tensorflow as tf
from tensorflow.core.protobuf import saver_pb2

from object_detection import trainer
from object_detection.builders import dataset_builder
from object_detection.builders import graph_rewriter_builder
from object_detection.builders import model_builder
from object_detection.utils import config_util

def create_saver(train_dir, num_steps,
                 var_list=None,
                 reshape=False,
                 sharded=False,
                 max_to_keep=5,
                 keep_checkpoint_every_n_hours=10000.0,
                 name=None,
                 restore_sequentially=False,
                 saver_def=None,
                 builder=None,
                 defer_build=False,
                 allow_empty=False,
                 write_version=saver_pb2.SaverDef.V2,
                 pad_step_number=False,
                 save_relative_paths=False,
                 filename=None):
    snapshot = ckpt_snapshot.CkptSnapshot(train_dir, num_steps)
    return saver_with_events.SaverWithEvents(
                after_save_event = snapshot,
                var_list = var_list,
                reshape = reshape,
                sharded = sharded,
                max_to_keep = max_to_keep,
                keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours,
                name = name,
                restore_sequentially = restore_sequentially,
                saver_def = saver_def,
                builder = builder,
                defer_build = defer_build,
                allow_empty = allow_empty,
                write_version = write_version,
                pad_step_number = pad_step_number,
                save_relative_paths = save_relative_paths,
                filename = filename)

def train(train_dir, pipeline_config_path, num_clones = 1):
    configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
    tf.gfile.Copy(pipeline_config_path, os.path.join(train_dir, 'pipeline.config'), overwrite=True)

    model_config = configs['model']
    train_config = configs['train_config']
    input_config = configs['train_input_config']

    model_fn = functools.partial(
        model_builder.build,
        model_config=model_config,
        is_training=True)

    def get_next(config):
        return dataset_builder.make_initializable_iterator(
            dataset_builder.build(config)).get_next()

    create_input_dict_fn = functools.partial(get_next, input_config)

    graph_rewriter_fn = None
    if 'graph_rewriter_config' in configs:
        graph_rewriter_fn = graph_rewriter_builder.build(
                configs['graph_rewriter_config'], is_training=True)

    create_saver_fn = functools.partial(create_saver, train_dir, train_config.num_steps)

    trainer.train(
        create_input_dict_fn,
        model_fn,
        train_config,
        '', 0, num_clones, 1, False, 0, 'lonely_worker', True, # parameters on this line for server-many workers type of training
        train_dir,
        graph_hook_fn=graph_rewriter_fn,
        create_saver_fn=create_saver_fn)

    tf.reset_default_graph()

