import os
from tensorflow.python.client import device_lib
if "YTF_NIRVANA" in os.environ:
    import nirvana.job_context as nv

def get_available_gpus():
    available_device = device_lib.list_local_devices()
    return [device.name for device in available_device if device.device_type == 'GPU']

def init_local_params():
    params = {}
    # Image params
    params["width"]  = int(os.getenv('YTF_IMAGE_WIDTH', '64'))
    params["height"] = int(os.getenv('YTF_IMAGE_HEIGHT', '64'))
    params["chns"] = 3
    # Data source
    params["images_filepath"] = os.path.abspath("./data/{}/{}".format(
                                                os.getenv('YTF_DATASET_SUBFOLDER', ''),
                                                os.getenv('YTF_DATASET_IMAGE_FILENAME', 'images.dat')))
    params["edges_filepath"] = os.path.abspath("./data/{}/{}".format(
                                               os.getenv('YTF_DATASET_SUBFOLDER', ''),
                                               os.getenv('YTF_DATASET_EDGES_FILENAME', 'edges.dat')))
    params["verts_filepath"] = os.path.abspath("./data/{}/{}".format(
                                               os.getenv('YTF_DATASET_SUBFOLDER', ''),
                                               os.getenv('YTF_DATASET_VERTS_FILENAME', 'verts.dat')))
    # Net params
    params["bn_decay"]   = 0.9
    params["bn_epsilon"] = 0.001
    params["model_verts_suffix"] = "verts"
    params["model_edges_suffix"] = "edges"
    # Pretrained
    params["vgg_weights"] = os.path.abspath("./data/vgg16_weights.npz")
    params["pretrained_model"] = os.path.abspath("./pretrained_model")
    # Loss function
    params["loss_vertices_exists"] = (0 != int(os.getenv("YTF_LOSS_VERTICES_EXISTS", 1)))
    params["edges_loss_weight"] = float(os.getenv("YTF_EDGES_LOSS_WEIGHT", 1.))
    params["verts_loss_weight"] = float(os.getenv("YTF_VERTS_LOSS_WEIGHT", 10.))
    # Optimizer params
    params["optimizer"]  = os.getenv("YTF_TRAINER_OPTIMIZER", "AdamOptimizer")
    params["lr_base"]    = float(os.getenv('YTF_TRAINER_LR_BASE', 0.001))
    params["lr_decay_epoch"] = int(os.getenv('YTF_TRAINER_LR_DECAY_EPOCH', 25))
    params["lr_decay_rate"]  = float(os.getenv('YTF_TRAINER_LR_DECAY_RATE', 0.3))
    params["lr_decay_staircase"] = (0 != os.getenv('YTF_TRAINER_LR_DECAY_STAIRCASE', 1))
    params["batch_size_per_gpu"] = int(os.getenv('YTF_TRAINER_BATCH_SIZE_PER_GPU', 1))
    params["available_gpus"] = get_available_gpus()
    gpus_count = len(params["available_gpus"])
    params["batch_size"] = params["batch_size_per_gpu"] * (gpus_count if gpus_count > 0 else 1)
    params["weight_decay"] = float(os.getenv('YTF_TRAINER_WEIGHT_DECAY', 0.0001))
    params["epoch_start"] = int(os.getenv('YTF_TRAINER_EPOCH_START', 0))
    params["epoch_max"]   = int(os.getenv('YTF_TRAINER_EPOCH_MAX', 200))
    # Data augmentation
    params["equalize_hist_enable"] = (0 != int(os.getenv("YTF_EQUALIZE_HIST_ENABLE", 1)))
    params["data_augmentation"] = (0 != int(os.getenv("YTF_ENABLE_DATA_AUGMENTATION", 1)))
    params["remove_empty_cell"] = (0 != int(os.getenv("YTF_ENABLE_REMOVE_EMPTY_CELL", 1)))
    params["data_split_seed"]   = int(os.getenv("YTF_DATA_SPLIT_SEED", 42))
    # Inference inputs and outputs
    params["inference_input"] = os.getenv("YTF_INFERENCE_INPUT", "inference_input")
    params["inference_edges_output"] = os.getenv("YTF_INFERENCE_EDGES_OUTPUT", "inference_edges")
    params["inference_verts_output"] = os.getenv("YTF_INFERENCE_VERTS_OUTPUT", "inference_verts")
    # Validate params
    params["validate_batchs"] = int(os.getenv('YTF_VALIDATE_BATCH_CNT', 5))
    # Output params
    params["out_model"] = os.getenv("YTF_MODEL_OUTPUT_FOLDER", "./model-data")
    if not os.path.exists(params["out_model"]):
        os.mkdir(params["out_model"])
    params["out_logs"] = os.getenv("YTF_LOGS_OUTPUT_FOLDER", "./logs")
    if not os.path.exists(params["out_logs"]):
        os.mkdir(params["out_logs"])
    params["compact_progress"] = True
    params["save_every_epoches"] = int(os.getenv("YTF_TRAINER_SAVE_EVERY_EPOCHES", 25))
    params["out_snapshot"] = os.getenv("YTF_SNAPSHOT_OUTPUT_FOLDER", "./snapshot")
    if not os.path.exists(params["out_snapshot"]):
        os.mkdir(params["out_snapshot"])
    for key in params.keys():
        print key, params[key]
    return params

def get_nv_param(name):
    return nv.context().get_parameters().get(name)

def get_nv_input(name):
    return nv.context().get_inputs().get(name)

def get_nv_output(name):
    return nv.context().get_outputs().get(name)

def init_nirvana_params():
    params = {}
    # Image params
    params["width"]  = get_nv_param("image_width")
    params["height"] = get_nv_param("image_height")
    params["chns"] = 3
    # Data source
    params["images_filepath"] = os.path.abspath("./data/{}".format(get_nv_param("img_dat")))
    params["edges_filepath"]  = os.path.abspath("./data/{}".format(get_nv_param("edges_dat")))
    params["verts_filepath"]  = os.path.abspath("./data/{}".format(get_nv_param("verts_dat")))
    # Net params
    params["bn_decay"]   = 0.9
    params["bn_epsilon"] = 0.001
    params["model_verts_suffix"] = "verts"
    params["model_edges_suffix"] = "edges"
    # Pretrained
    params["vgg_weights"] = os.path.abspath("./data/vgg16_weights.npz")
    params["pretrained_model"] = os.path.abspath("./pretrained_model")
    # Loss function
    params["loss_vertices_exists"] = get_nv_param("loss_vertices_exists")
    params["edges_loss_weight"] = get_nv_param("edges_loss_weight")
    params["verts_loss_weight"] = get_nv_param("verts_loss_weight")
    # Optimizer params
    params["optimizer"]  = get_nv_param("optimizer")
    params["lr_base"]    = get_nv_param("lr_base")
    params["lr_decay_epoch"] = get_nv_param("lr_decay_epoch")
    params["lr_decay_rate"]  = get_nv_param("lr_decay_rate")
    params["lr_decay_staircase"] =  True
    params["batch_size_per_gpu"] = get_nv_param("batch_size_per_gpu")
    params["available_gpus"] = get_available_gpus()
    gpus_count = len(params["available_gpus"])
    params["batch_size"] = params["batch_size_per_gpu"] * (gpus_count if gpus_count > 0 else 1)
    params["weight_decay"] = get_nv_param("weight_decay")
    params["epoch_start"] = get_nv_param("epoch_start")
    params["epoch_max"]   = get_nv_param("epoch_end")
    # Data augmentation
    params["equalize_hist_enable"] = get_nv_param("equalize_hist")
    params["data_augmentation"] = get_nv_param("data_augmentation")
    params["remove_empty_cell"] = get_nv_param("remove_empty_cell")
    params["data_split_seed"]   = get_nv_param("data_split_seed")
    # Inference inputs and outputs
    params["inference_input"]        = get_nv_param("inference_input")
    params["inference_edges_output"] = get_nv_param("inference_edges_output")
    params["inference_verts_output"] = get_nv_param("inference_verts_output")
    # Validate params
    params["validate_batchs"] = get_nv_param("validate_batchs")
    # Output params
    params["out_model"] = os.path.abspath("./model-data/")
    if not os.path.exists(params["out_model"]):
        os.mkdir(params["out_model"])
    params["out_logs"] = os.path.abspath("./logs")
    if not os.path.exists(params["out_logs"]):
        os.mkdir(params["out_logs"])
    params["compact_progress"] = False
    params["save_every_epoches"] = get_nv_param("save_every_epoch")
    params["out_snapshot"] = os.path.abspath("./snapshot")
    if not os.path.exists(params["out_snapshot"]):
        os.mkdir(params["out_snapshot"])
    for key in params.keys():
        print key, params[key]
    return params

def init_params():
    if "YTF_NIRVANA" in os.environ:
        # init params in nirvana
        params = init_nirvana_params()
        # run snapshot
        import snapshot
        params["epoch_start"] = snapshot.init("model", params)
        snapshot.init("logs", params)
        return params
    else:
        # init params on local machine
        return init_local_params()
