import argparse
import tensorflow as tf
from maps.wikimap.mapspro.services.autocart.tools.auto_toloker.train.utils import config
from maps.wikimap.mapspro.services.autocart.tools.auto_toloker.train.utils import dataset
from maps.wikimap.mapspro.services.autocart.tools.auto_toloker.train.utils import snapshot
from maps.wikimap.mapspro.services.autocart.tools.auto_toloker.train.nets import train_utils
from maps.wikimap.mapspro.services.autocart.tools.auto_toloker.train.nets import nets_factory


def average_loss(losses):
    return sum(losses) / float(len(losses))


def train(train_config):
    print("Creating train and valid datasets")

    train_dataset = dataset.create_train_dataset(train_config)
    train_images, train_labels, train_init_op = dataset.get_data(train_dataset)
    train_dataset_size = dataset.get_elements_number(train_config["dataset"]["train"])
    print("Train dataset size: {}".format(train_dataset_size))

    valid_dataset = dataset.create_valid_dataset(train_config)
    valid_images, valid_labels, valid_init_op = dataset.get_data(valid_dataset)
    valid_dataset_size = dataset.get_elements_number(train_config["dataset"]["valid"])
    print("Valid dataset size: {}".format(valid_dataset_size))

    print("Creating neural network: {}".format(train_config["net"]))
    net = nets_factory.get_network_fn(train_config)

    print("Prepare training operations")
    is_training = tf.placeholder_with_default(False, shape=[], name='is_training')
    loss, train_op = train_utils.train_op(
        net,
        train_images, train_labels,
        train_dataset_size,
        train_config,
        is_training
    )
    init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())

    print("Prepare validation operations")
    valid_logits = net(valid_images, train_config, is_training)
    valid_argmax = tf.argmax(valid_logits, axis=-1)
    valid_is_correct_op = tf.equal(tf.to_int32(valid_argmax), valid_labels)

    print("Prepate inference operations")
    image_ph = tf.placeholder(tf.uint8, shape=[None, None, None, 3], name="inference_image")
    mask_ph = tf.placeholder(tf.uint8, shape=[None, None, None, 1], name="inference_mask")
    image_ph = tf.reverse(image_ph, [-1])  # BGR->RGB
    inputs_ph = tf.concat([image_ph, mask_ph], axis=-1)
    inputs_phf = tf.to_float(inputs_ph)
    inputs_phf = tf.image.resize_images(
        inputs_phf,
        [train_config["image"]["height"], train_config["image"]["width"]]
    )
    inputs_phf = dataset.preprocess(inputs_phf)
    inference_logits = net(inputs_phf, train_config, False)
    tf.argmax(inference_logits, axis=-1, name="inference_argmax")
    tf.nn.softmax(inference_logits, name="inference_softmax")
    tf.constant(dataset.get_class_names(), tf.string, name="class_names")

    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(init_op)
        if train_config["epoch_start"] > 0:
            snapshot.restore(sess, saver, train_config["epoch_start"] - 1, train_config["model_path"])
        sess.graph.finalize()

        for epoch in range(train_config["epoch_start"], train_config["epoch_end"] + 1):
            print("Epoch: {}".format(epoch))
            sess.run(train_init_op)
            train_losses = []
            step = 0
            while True:
                try:
                    _, batch_loss = sess.run([train_op, loss], feed_dict={is_training: True})
                    train_losses.append(batch_loss)
                    step += 1
                    print("Step: {}, loss: {}".format(step, batch_loss))
                except tf.errors.OutOfRangeError:
                    break
            print("Epoch {} ended, average loss: {}".format(epoch, average_loss(train_losses)))
            if epoch % train_config["save_every_epoches"] == 0:
                snapshot.save_gdef(sess, epoch, train_config["gdef_path"])
            snapshot.save_model(sess, saver, epoch, train_config["model_path"])
            if epoch % train_config["valid_every_epoches"] == 0:
                sess.run(valid_init_op)
                valid_correct_cnt = 0
                while True:
                    try:
                        is_correct = sess.run([valid_is_correct_op], feed_dict={is_training: False})
                        valid_correct_cnt += sum(is_correct)
                    except tf.errors.OutOfRangeError:
                        break
                print("Evaluation ended, accuracy: {}".format(valid_correct_cnt / float(valid_dataset_size)))


def main():
    parser = argparse.ArgumentParser("Tool for train neural network")
    parser.add_argument("--train_config", required=True, help="Path to train config")
    parser.add_argument("--environment", required=True, choices=["local", "nirvana"])
    args = parser.parse_args()
    print("Loading train configuration from json file: {}".format(args.train_config))
    train_config = config.load_train_config(args.train_config)
    print("Initializing snapshot")
    snapshot.init(train_config, args.environment)
    train_config["epoch_start"] = snapshot.next_epoch(train_config)
    config.print_train_config(train_config)
    train(train_config)


if __name__ == "__main__":
    main()
