import os
import tarfile
import tensorflow as tf
import nirvana.job_context as nv
from tensorflow.python.platform import gfile
from tensorflow.python.framework import graph_util


def pack_tarfile(path, filename_list):
    with tarfile.open(path, "w:gz") as tar:
        for filename in filename_list:
            tar.add(filename, arcname=os.path.basename(filename))


def unpack_tarfile(path, output_path='.'):
    with tarfile.open(path, "r:gz") as tar:
        tar.extractall(path=output_path)


def has_input_snapshot(name):
    return nv.context().get_inputs().has(name)


def has_output_snapshot(name):
    return os.path.exists(nv.context().get_outputs().get(name))


def has_snapshot(name):
    return has_input_snapshot(name) or has_output_snapshot(name)


def get_input_snapshot_path(name):
    return nv.context().get_inputs().get(name)


def get_output_snapshot_path(name):
    return nv.context().get_outputs().get(name)


def dump_snapshot(name, path):
    if (not nv.context().get_outputs().has(name)):
        raise Exception("Output '{}' not exists".format(name))
    tmp_name = "tmp_snapshot_{}.tar.gz".format(name)
    filename_list = [os.path.join(path, filename) for filename in os.listdir(path)]
    pack_tarfile(tmp_name, filename_list)
    os.rename(tmp_name, get_output_snapshot_path(name))


def load_snapshot(name, output_path):
    if (has_output_snapshot(name)):
        snapshot_path = get_output_snapshot_path(name)
    elif (has_input_snapshot(name)):
        snapshot_path = get_input_snapshot_path(name)
    else:
        raise Exception('no snapshot available')
    unpack_tarfile(snapshot_path, output_path)


def init(train_config, environment):
    if not os.path.isdir(train_config["model_path"]):
        os.mkdir(train_config["model_path"])
    if not os.path.isdir(train_config["gdef_path"]):
        os.mkdir(train_config["gdef_path"])
    if "nirvana" == environment:
        os.environ["YTF_NIRVANA"] = "1"
        if (has_snapshot("state")):
            load_snapshot("state", train_config["model_path"])
        if (has_snapshot("gdef")):
            load_snapshot("gdef", train_config["gdef_path"])


def next_epoch(train_config):
    if not os.path.isdir(train_config["model_path"]):
        raise RuntimeError("Folder with models does not exist: {}".format(train_config["model_path"]))
    model_names = [name for name in os.listdir(train_config["model_path"]) if name.endswith("index")]
    if len(model_names) != 0:
        return 1 + max([int(name.split(".")[0].split("-")[-1]) for name in model_names])
    else:
        return 0


def save_gdef(sess, epoch, path):
    out_graph_path = os.path.join(path, "model-{}.gdef".format(epoch))
    graph_def = tf.get_default_graph().as_graph_def()
    output_graph_def = graph_util.convert_variables_to_constants(
        sess,
        graph_def,
        ["inference_softmax", "class_names"]
    )
    with gfile.GFile(out_graph_path, "wb") as f:
        f.write(output_graph_def.SerializeToString())
    if "YTF_NIRVANA" in os.environ:
        if "1" == os.environ["YTF_NIRVANA"]:
            dump_snapshot("gdef", path)


def save_model(sess, saver, epoch, path):
    for name in os.listdir(path):
        os.remove(os.path.join(path, name))
    model_path = os.path.join(path, 'model')
    saver.save(sess, model_path, global_step=epoch)
    if "YTF_NIRVANA" in os.environ:
        if "1" == os.environ["YTF_NIRVANA"]:
            dump_snapshot("model", path)


def restore(sess, saver, epoch, path):
    model_path = os.path.join(path, "model-{}".format(epoch))
    saver.restore(sess, model_path)
