import sys
import os
import os.path
import urllib.request
import tarfile

import tensorflow as tf
import ytensorflow as ytf

import nirvana_dl
import nirvana_dl.snapshot

import export_graphs
import train
import tfconfig_utils

tf.logging.set_verbosity(tf.logging.INFO)

logger = logging.getLogger(__name__)

def run_train():
    # ctx = nv.context()
    # inputs = ctx.get_inputs()
    # outputs = ctx.get_outputs()
    # params = ctx.get_parameters()
    # train_dir = os.path.abspath('./faces_resnet50/')
    # tf.gfile.MakeDirs(train_dir)

    source_path = os.getenv('SOURCE_CODE_PATH', "./")

    params = nirvana_dl.params()
    checkpoint_path = tf.train.latest_checkpoint(nirvana_dl.snapshot.get_snapshot_path())
    output_folder = nirvana_dl.output_data_path()
    output_path = os.path.join(output_folder, "graph.pb")

    template_config_path = os.path.join(source_path, 'template.config')


    # detection_ckpt = None
    # detection_ckpt_url = params.get('detection_ckpt_url')
    # if (detection_ckpt_url):
    #     detection_ckpt_local = os.path.abspath('./detection_ckpt.tgz')
    #     if (not tf.gfile.Exists(detection_ckpt_local)):
    #         with urllib.request.urlopen(detection_ckpt_url) as url:
    #             with open(detection_ckpt_local,'wb') as output:
    #                 output.write(url.read())
    #     ckpt_dir = train_dir + '/detection_ckpt'
    #     detection_ckpt_tgz = tarfile.open(detection_ckpt_local, 'r')
    #     for fl in detection_ckpt_tgz:
    #         detection_ckpt_tgz.extract(fl, ckpt_dir)
    #     detection_ckpt_tgz.close()
    #     detection_ckpt = ckpt_dir + '/model.ckpt'
    #     tf.gfile.Remove(detection_ckpt_local)


    # if (tf.gfile.Exists(train_dir + '/detection_ckpt/checkpoint')):
    #     # for local launch use prepared detetion ckpt
    #     detection_ckpt = train_dir + '/detection_ckpt/model.ckpt'

    # try:
    #     tf.gfile.Copy(inputs.get("all_ckpt_state"), outputs.get("all_ckpt_state"), False)
    # except:
    #     pass
    # try:
    #     tf.gfile.Copy(inputs.get("last_ckpt_state"), outputs.get("last_ckpt_state"), False)
    # except:
    #     pass

    train_config = 'train.config'
    tfconfig_utils.update_template(
                template_config_path,
                params.get("num_epoches"),
                params.get("optimizer"),
                params.get("lr_init"),
                params.get("lr_decay_epoches"),
                params.get("lr_decay_factor"),
                params.get("weight_decay"),
                inputs.get_list("train_data"),
                inputs.get("label_map"),
                detection_ckpt,
                params.get('gpu-count'),
                train_config)

    train.train(train_dir, train_config, params.get('gpu-count'))
    export_graphs.export(
                inputs.get("label_map"),
                outputs.get("all_ckpt_state"),
                outputs.get("gdef_tgz"),
                train_config)



def main():
    run_train()
    logger.info("Done")


if __name__ == '__main__':
    ytf.app.flags.DEFINE('checkpoint_dir', ytf.app.defaults.checkpoint_dir(), 'Checkpoint directory.')
    ytf.app.flags.DEFINE('logs_dir', ytf.app.defaults.logs_dir(), 'Logs directory.')
    ytf.app.flags.DEFINE('data_dir', ytf.app.defaults.data_dir(), 'Data directory.')
    ytf.app.flags.DEFINE('output_dir', ytf.app.defaults.output_dir(), 'Output directory.')
    ytf.app.run(level=logging.DEBUG)
