import sys
import os.path
import tarfile
import tensorflow as tf

import nirvana.job_context as nv

def _tar_add_file(tar, filepath):
    _, filename = os.path.split(filepath)
    tar.add(filepath, filename)

def _tar_add_ckpt(tar, ckpt_prefix):
    _tar_add_file(tar, '{}.index'.format(ckpt_prefix))
    _tar_add_file(tar, '{}.meta'.format(ckpt_prefix))
    _tar_add_file(tar, '{}.data-00000-of-00001'.format(ckpt_prefix))

def dump_last_ckpt(state, f):
    last_ckpt = tarfile.open('', 'w', f)
    _tar_add_ckpt(last_ckpt, state)
    last_ckpt.close()

def load_last_ckpt(f, train_dir):
    last_ckpt = tarfile.open('', 'r', f)
    for fl in last_ckpt:
        last_ckpt.extract(fl, train_dir)
    last_ckpt.close()
    ckpt_filter = os.path.join(train_dir, 'model.ckpt-*.index')
    ckpt_paths = tf.gfile.Glob(ckpt_filter)
    if (len(ckpt_paths) == 0):
        return;
    ckpt_main = os.path.join(train_dir, 'checkpoint')
    ckpt_idx_max = -1
    ckpt_idx_max_prefix = ''
    expr = 'model.ckpt-'
    for ckpt in ckpt_paths:
        prefix, _ = os.path.splitext(ckpt)
        ckpt_idx = int(prefix[prefix.index(expr) + len(expr):])
        if (ckpt_idx_max < ckpt_idx):
            ckpt_idx_max = ckpt_idx
            ckpt_idx_max_prefix = prefix
    with open(ckpt_main, "wt") as f:
        f.write('model_checkpoint_path: "{}"\n'.format(ckpt_idx_max_prefix))
        f.write('all_model_checkpoint_paths: "{}"\n'.format(ckpt_idx_max_prefix))
    return ckpt_idx_max

def update_all_gdef(all_gdef_tar_path, gdef_file_path):
    if (tf.gfile.Exists(all_gdef_tar_path)):
        tar = tarfile.open(all_gdef_tar_path, 'a')
    else:
        tar = tarfile.open(all_gdef_tar_path, 'w')
    _tar_add_file(tar, gdef_file_path)
    tar.close()

