import sys
import os.path
import tarfile
import tensorflow as tf

import nirvana.job_context as nv

def _tar_add_file(tar, filepath):
    _, filename = os.path.split(filepath)
    tar.add(filepath, filename)

def _tar_add_ckpt(tar, ckpt_prefix):
    _tar_add_file(tar, '{}.index'.format(ckpt_prefix))
    _tar_add_file(tar, '{}.meta'.format(ckpt_prefix))
    _tar_add_file(tar, '{}.data-00000-of-00001'.format(ckpt_prefix))

def _dump_last_ckpt(state, f):
    last_ckpt = tarfile.open('', 'w', f)
    _tar_add_ckpt(last_ckpt, state)
    last_ckpt.close()

def _load_last_ckpt(f, train_dir):
    last_ckpt = tarfile.open('', 'r', f)
    for fl in last_ckpt:
        last_ckpt.extract(fl, train_dir)
    last_ckpt.close()
    ckpt_filter = os.path.join(train_dir, 'model.ckpt-*.index')
    ckpt_paths = tf.gfile.Glob(ckpt_filter)
    if (len(ckpt_paths) == 0):
        return;
    ckpt_main = os.path.join(train_dir, 'checkpoint')
    ckpt_idx_max = -1
    ckpt_idx_max_prefix = ''
    expr = 'model.ckpt-'
    for ckpt in ckpt_paths:
        prefix, _ = os.path.splitext(ckpt)
        ckpt_idx = int(prefix[prefix.index(expr) + len(expr):])
        if (ckpt_idx_max < ckpt_idx):
            ckpt_idx_max = ckpt_idx
            ckpt_idx_max_prefix = prefix
    with open(ckpt_main, "wt") as f:
        f.write('model_checkpoint_path: "{}"\n'.format(ckpt_idx_max_prefix))
        f.write('all_model_checkpoint_paths: "{}"\n'.format(ckpt_idx_max_prefix))
    return ckpt_idx_max

def _check_last_all_ckpt_idx(all_ckpt_tar_path, start_idx):
    if (not tf.gfile.Exists(all_ckpt_tar_path)):
        return start_idx
    tar = tarfile.open(all_ckpt_tar_path, 'r')
    ckpt_idx_max = start_idx
    for fl in tar:
        if (fl.name.startswith('model.ckpt-')):
            prefix, _ = os.path.splitext(fl.name)
            ckpt_idx = int(prefix[prefix.index('-') + 1:])
            if (ckpt_idx > ckpt_idx_max):
                ckpt_idx_max = ckpt_idx
    tar.close()
    return ckpt_idx_max

def _update_all_ckpt(all_ckpt_tar_path, ckpt_prefix):
    if (tf.gfile.Exists(all_ckpt_tar_path)):
        tar = tarfile.open(all_ckpt_tar_path, 'a')
    else:
        tar = tarfile.open(all_ckpt_tar_path, 'w')
    _tar_add_ckpt(tar, ckpt_prefix)
    tar.close()

class CkptSnapshot(object):
    def __init__(self, train_dir, num_steps):
        self.num_steps = num_steps
        self._nv_ctx = nv.context()
        outputs = self._nv_ctx.get_outputs()
        params = self._nv_ctx.get_parameters()
        self._save_after_steps = params.get('save_after_steps')
        self._last_all_ckpt_idx = self._save_after_steps
        self._save_every_steps = params.get('save_every_steps')
        self._all_ckpt_tar_path = outputs.get('all_ckpt_state')

        self._last_ckpt_idx = self._nv_ctx.load_state("last_ckpt_state", _load_last_ckpt, train_dir, default_value=0)
        self._last_all_ckpt_idx = _check_last_all_ckpt_idx(
                                            self._all_ckpt_tar_path,
                                            self._save_after_steps - self._save_every_steps)
        self._last_all_ckpt_idx += self._save_every_steps

    def __call__(self, ckpt_prefix):
        ckpt_idx = int(ckpt_prefix.split('-')[-1])
        if (ckpt_idx > self._last_ckpt_idx + 10):
             self._nv_ctx.dump_state("last_ckpt_state", ckpt_prefix, _dump_last_ckpt)
        if (ckpt_idx > self._last_all_ckpt_idx or ckpt_idx >= self.num_steps - 1):
            self._last_all_ckpt_idx = ckpt_idx + self._save_every_steps
            _update_all_ckpt(self._all_ckpt_tar_path, ckpt_prefix)
