from __future__ import division
import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
import train_params
from data_utils import Dataset
import os
import print_progress
from models import hed
import tf_utils

def loss_func(logits, labels):
    labels_area  = tf.shape(labels)[1]*tf.shape(labels)[2]
    loss = 0
    betta = tf.div(tf.to_float(tf.reduce_sum(labels, axis = [1, 2, 3])), tf.to_float(labels_area))
    for i in range(len(logits)):
        not_edges = -betta*tf.multiply(1-labels, -tf.nn.softplus(-logits[i]) )
        edges = -(1-betta)*tf.multiply(labels, -tf.nn.softplus( logits[i] ) )
        loss += tf.reduce_mean( not_edges + edges )
    return loss

def save_graphdef(sess, epoch, out_folder, outputs):
    out_graph_path = os.path.join(out_folder, "model-{}.gdef".format(epoch))
    graph_def = tf.get_default_graph().as_graph_def()
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, outputs)
    with gfile.GFile(out_graph_path, "wb") as f:
        f.write(output_graph_def.SerializeToString())
    print("%d ops in the final graph." % len(output_graph_def.node))

def train(params):
    # Load data
    dataset = Dataset(params)
    # Create output model folder
    if (not os.path.isdir(params["out_model"])):
        os.makedirs(params["out_model"])
    # Create output logs folder
    if (not os.path.isdir(params["out_logs"])):
        os.makedirs(params["out_logs"])
    # Batch per epoch
    params["train_batch_per_epoch"] = dataset.train_cnt() // params["batch_size"]
    # Print params
    # Hardware and software params
    print("Tensorflow version: {}".format(tf.__version__))
    print("Available gpus: {}".format(params["available_gpus"]))
    # Data params
    dataset.print_info()
    # Optimizer params
    print("Optimizer: {}".format(params["optimizer"]))
    print("    LR base: {}".format(params["lr_base"]))
    print("    LR decay epoches: {}".format(params["lr_decay_epoch"]))
    print("    LR decay rate: {}".format(params["lr_decay_rate"]))
    print("    Batch size: {}".format(params["batch_size"]))
    print("    Batch per epoch: {}".format(params["train_batch_per_epoch"]))
    # Output folder
    print("Output model folder: {}".format(params["out_model"]))
    print("Output logs folder: {}".format(params["out_logs"]))

    # Prepare training operation
    images_ph = tf.placeholder(tf.uint8, shape = [None, None, None, params["chns"]])
    edges_ph  = tf.placeholder(tf.uint8, shape = [None, None, None, 1])

    bool_false = tf.constant(False, tf.bool)
    is_training = tf.placeholder_with_default(bool_false, shape=[], name = "is_training")

    global_step = tf.Variable(0, trainable=False)
    lr = tf.train.exponential_decay(params["lr_base"],
                                    global_step,
                                    params["lr_decay_epoch"] * params["train_batch_per_epoch"],
                                    params["lr_decay_rate"],
                                    staircase = params["lr_decay_staircase"])

    if (params["optimizer"] == "GradientDescentOptimizer"):
        optimizer = tf.train.GradientDescentOptimizer(lr)
    elif (params["optimizer"] == "MomentumOptimizer"):
        optimizer = tf.train.MomentumOptimizer(lr, params["momentum"])
    elif (params["optimizer"] == "AdagradOptimizer"):
        optimizer = tf.train.AdagradOptimizer(lr)
    else :
        optimizer = tf.train.AdamOptimizer(lr)

    # Multi-GPU training
    gpus_count = len(params["available_gpus"])
    if (1 < gpus_count):
        gradients = []
        all_loss = []
        all_loss_reg = []
        images_ph_per_gpu = tf.split(images_ph, gpus_count)
        edges_ph_per_gpu  = tf.split(edges_ph, gpus_count)
        for i, gpu_name in enumerate(params["available_gpus"]):
            with tf.device(gpu_name):
                # Inputs
                images_phf = 2.0 * ( (tf.to_float(images_ph_per_gpu[i]) / 255.) - 0.5)
                edges_phf = tf.to_float(edges_ph_per_gpu[i])
                # Models
                logits = hed.model_l5(images_phf, params, is_training)
                # Losses
                loss = loss_func(logits, edges_phf)
                loss_reg = loss
                if (0 < len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))):
                    loss_reg += tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
                all_loss.append(loss)
                all_loss_reg.append(loss_reg)
                # Gradients
                gradients.append(optimizer.compute_gradients(loss_reg))
        # Losses
        loss = tf.reduce_mean(all_loss)
        loss_reg = tf.reduce_mean(all_loss_reg)
        # Train step
        avg_gradients = tf_utils.average_gradients(gradients)
        train_step = optimizer.apply_gradients(avg_gradients, global_step = global_step)
    else:
        # Inputs
        images_phf = 2.0 * ( (tf.to_float(images_ph) / 255.) - 0.5)
        edges_phf = tf.to_float(edges_ph)
        # Model
        logits = hed.model_l5(images_phf, params, is_training)
        # Losses
        loss = loss_func(logits, edges_phf)
        loss_reg = loss
        if (0 < len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))):
            loss_reg += tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
        # Training step
        train_step = optimizer.minimize(loss_reg, global_step = global_step)

    # Prepare inference operation
    # Inputs
    inference_ph   = tf.placeholder(tf.uint8, shape=[1, None, None, params["chns"]], name = params["inference_input"])
    inference_phf       = tf.to_float(inference_ph) / 255.
    if (params["equalize_hist_enable"]):
        inference_phf   = tf_utils.equalize_histogram_bgr(inference_phf)
    inference_phf   = 2.0 * (inference_phf - 0.5)
    # Model
    logits = hed.model_l5(inference_phf, params, is_training)
    # Outputs
    inference_edges = tf.sigmoid(logits[-1], name = params["inference_edges_output"])
    # Tensorboard
    tf.summary.scalar('learning rate', lr)
    tf.summary.scalar('loss', loss)
    tf.summary.scalar('loss_reg', loss_reg)
    summary_merged = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(params["out_logs"], tf.get_default_graph())
    # Train
    train_batch_per_epoch = dataset.train_cnt() // params["batch_size"]
    valid_batch_per_epoch = dataset.validate_cnt() // params["batch_size"]
    if (train_batch_per_epoch == 0):
        raise Exception("Too few images in data for this batch size and GPU devices count")
    saver = tf.train.Saver(max_to_keep = params["epoch_max"] + 1, allow_empty = True)
    snapshot_saver = tf.train.Saver(max_to_keep = 1, allow_empty = True)
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        if (0 == params["epoch_start"] and os.path.isdir(params["pretrained_model"])):
            models = [name.split(".")[0] for name in os.listdir(params["pretrained_model"]) if name.endswith("index")]
            if len(models) != 1:
                raise Exception("Incorrect number of available pretrained models: {}".format(len(models)))
            else:
                print "Load pretrained model: {}".format(models[0])
                model_path = os.path.join(params["pretrained_model"], models[0])
                restore_saver = tf.train.Saver(tf.trainable_variables())
                restore_saver.restore(sess, model_path)
                sess.run(tf.assign(global_step, 0))
        if (0 < params["epoch_start"]):
            model_path = os.path.join(params["out_snapshot"], 'model-%d' % (params["epoch_start"] - 1))
            saver.restore(sess, model_path)
        sess.graph.finalize()

        for epoch in range(params["epoch_start"], params["epoch_max"] + 1):
            epoch_string = "Epoch: {}".format(epoch)
            if (params["compact_progress"]):
                print_progress.printProgressBar(0, train_batch_per_epoch, prefix = epoch_string, suffix = '')
            else:
                print(epoch_string)
            epoch_loss = 0
            epoch_loss_reg = 0
            dataset.shuffle_train_data()
            for batch_indx in range(train_batch_per_epoch):
                batch = dataset.get_batch(batch_indx, params["batch_size"], is_training = True)
                feed_dict = {images_ph: batch["image"], edges_ph: batch["edges"], is_training: True}
                _, batch_loss, batch_loss_reg, summary_ = sess.run([train_step, loss, loss_reg, summary_merged], feed_dict=feed_dict)
                if (params["compact_progress"]):
                    print_progress.printProgressBar(batch_indx+1, train_batch_per_epoch, prefix = epoch_string, suffix = "loss = {} ({})".format(batch_loss, batch_loss_reg))
                else:
                    if (0 == (batch_indx % 25)):
                        print("{}. loss = {} ({})".format(sess.run(global_step), batch_loss, batch_loss_reg))
                epoch_loss += batch_loss
                epoch_loss_reg += batch_loss_reg
                summary_writer.add_summary(summary_, sess.run(global_step))

            print("epoch average loss = {} ({}), lr = {}".format(epoch_loss / train_batch_per_epoch, epoch_loss_reg / train_batch_per_epoch, sess.run(lr)))

            epoch_loss = 0
            epoch_loss_reg = 0
            for batch_indx in range(valid_batch_per_epoch):
                batch = dataset.get_batch(batch_indx, params["batch_size"], is_training = False)
                feed_dict = {images_ph: batch["image"], edges_ph: batch["edges"], is_training: False}
                batch_loss, batch_loss_reg = sess.run([loss, loss_reg], feed_dict=feed_dict)
                epoch_loss += batch_loss
                epoch_loss_reg += batch_loss_reg

            print("epoch validation loss = {} ({})".format(epoch_loss / valid_batch_per_epoch, epoch_loss_reg / valid_batch_per_epoch))

            if (0 == (epoch % params["save_every_epoches"]) and (0 != epoch)):
                model_path = os.path.join(params["out_model"], 'model')
                saver.save(sess, model_path, global_step=epoch)
                save_graphdef(sess, epoch, params["out_model"], [params["inference_edges_output"]])
            # Save model for snapshot
            model_path = os.path.join(params["out_snapshot"], 'model')
            snapshot_saver.save(sess, model_path, global_step=epoch)
        sess.close()
        print("Training completed")

def main():
    params = train_params.init_params()
    train(params)

if __name__ == '__main__':
    main()
