import tensorflow as tf
from tensorflow.python.client import device_lib


def loss_func(logits, labels):
    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels))
    if (0 < len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))):
        print("loss")
        loss += tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
    return loss


def create_optimizer(train_dataset_size, train_config):
    global_step = tf.Variable(0, trainable=0, name="global_step")
    lr = tf.train.exponential_decay(
        train_config["lr"]["init"],
        global_step,
        train_config["lr"]["decay_epoches"] * train_dataset_size // train_config["batch_size"],
        train_config["lr"]["decay_rate"],
        staircase=True,
        name="lr"
    )
    optimizer = None
    if "momentum" == train_config["optimizer"]["name"]:
        optimizer = tf.train.MomentumOptimizer(lr, momentum=0.9)
    elif "adam" == train_config["optimizer"]["name"]:
        optimizer = tf.train.AdamOptimizer(lr)
    return global_step, optimizer


def get_available_gpus():
    available_device = device_lib.list_local_devices()
    return [device.name for device in available_device if device.device_type == 'GPU']


def average_gradients(tower_grads):
    average_grads = []
    for grad_and_vars in zip(*tower_grads):
        grads = []
        for g, _ in grad_and_vars:
            expanded_g = tf.expand_dims(g, 0)
            grads.append(expanded_g)

        grad = tf.concat(axis=0, values=grads)
        grad = tf.reduce_mean(grad, 0)

        v = grad_and_vars[0][1]
        grad_and_var = (grad, v)
        average_grads.append(grad_and_var)
    return average_grads


def train_op(net, train_images, train_labels, train_dataset_size, train_config, is_training):
    global_step, optimizer = create_optimizer(train_dataset_size, train_config)
    available_gpus = get_available_gpus()
    gpus_count = len(available_gpus)
    if gpus_count > 1:
        gradients = []
        losses = []
        images_per_gpu = tf.split(train_images, gpus_count)
        labels_per_gpu = tf.split(train_labels, gpus_count)
        for i, gpu_name in enumerate(available_gpus):
            with tf.device(gpu_name):
                train_logits = net(images_per_gpu[i], train_config, is_training)
                loss = loss_func(train_logits, labels_per_gpu[i])
                losses.append(loss)
                gradients.append(optimizer.compute_gradients(loss))
        loss = tf.reduce_mean(losses)
        avg_gradients = average_gradients(gradients)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        updates = tf.group(*update_ops)
        with tf.control_dependencies([updates]):
            op = optimizer.apply_gradients(avg_gradients, global_step=global_step)
    else:
        train_logits = net(train_images, train_config, is_training)
        loss = loss_func(train_logits, train_labels)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        updates = tf.group(*update_ops)
        with tf.control_dependencies([updates]):
            op = optimizer.minimize(loss, global_step=global_step)
    return loss, op
