import sys
import os.path
import urllib.request
import tensorflow as tf

import tarfile

import ckpt_snapshot
import svhn_conv_lstm as train

tf.logging.set_verbosity(tf.logging.INFO)

import nirvana.job_context as nv

def main():
    ctx = nv.context()
    inputs = ctx.get_inputs()
    outputs = ctx.get_outputs()
    params = ctx.get_parameters()

    train_dir = os.path.abspath('./train_dir/')
    tf.gfile.MakeDirs(train_dir)
    print(train_dir)

    try:
        tf.gfile.Copy(inputs.get("last_ckpt_state"), outputs.get("last_ckpt_state"), False)
    except:
        pass

    epoch_last = ctx.load_state("last_ckpt_state", ckpt_snapshot.load_last_ckpt, train_dir, default_value=-1)

    train.train(
        inputs.get_list("train_data"),
        train_dir,
        params.get('image_seq_length'),
        params.get('image_size'),
        params.get('batch_size'),
        params.get('crop_bbox_type'),
        params.get('digit_seq_length'),
        params.get('digit_seq_invert'),
        params.get('digit_seq_pad_begin'),
        [int(item) for item in params.get('kernel_szs').split()],
        [int(item) for item in params.get('out_chns').split()],
        [int(item) for item in params.get('strides').split()],
        params.get('lr_init'),
        params.get('lr_decay_factor'),
        params.get('lr_decay_epoches'),
        epoch_last + 1, # epoch_start
        params.get('num_epoches'),
        params.get('optimizer'),
        params.get('peephole'),
        params.get('output_keep_prob'),
        params.get('fc_keep_prob'),
        params.get('fc_additional'),
        params.get('weight_decay'),
        params.get('bn_decay'),
        params.get('bn_epsilon'),
        params.get('save_after_epoch'),
        params.get('save_every_epoch'))

if __name__ == '__main__':
    main()
