import os
#os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
import numpy as np
import conv_lstm

import ckpt_snapshot
import nirvana.job_context as nv

def _tfrecords_count(filepath):
    cnt = 0
    for record in tf.python_io.tf_record_iterator(filepath):
        cnt += 1
    return cnt

def _image_resize(image, dst_height):
    image_shape = tf.shape(image)
    orig_height = tf.to_float(image_shape[0])
    orig_width = tf.to_float(image_shape[1])

    dst_width  = tf.to_int32(tf.round(dst_height * orig_width / orig_height))

    dst_shape = tf.stack([dst_height, dst_width])
    return tf.squeeze(tf.image.resize_area(tf.expand_dims(image, axis = 0), dst_shape), axis = [0])

def _image_to_sequence(image, seq_length, out_size):
    image_height = tf.shape(image)[0]
    image_width = tf.shape(image)[1]

    step = tf.to_float((image_width - out_size) / (seq_length - 1))
    def while_loop_body(i, samples):
        x_start = tf.to_int64(tf.multiply(step, tf.to_float(i)))
        sample = image[:, x_start : x_start + out_size, :]
        sample = tf.image.resize_area(tf.expand_dims(sample, axis = 0), [out_size, out_size])
        # this is fake resize for make fixed output size of samples in sequence

        result = tf.cond( tf.equal(i, 0),
                          lambda: sample,
                          lambda: tf.concat([samples, sample], axis=0))
        return [i + 1, result]
    cond = lambda i, samples: i < seq_length

    i = tf.constant(0)
    sample0 = tf.zeros([0, out_size, out_size, 3], dtype = tf.float32)
    images_sequence = tf.while_loop(cond, while_loop_body,
                                    loop_vars=[i, sample0],
                                    shape_invariants=[i.get_shape(), tf.TensorShape([None, out_size, out_size, 3])])[1]

    # for fix dimensions of Tensor or as result we will get [None, out_size, out_size, 3]
    # but we know None is seq_length
    images_sequence = tf.reshape(images_sequence, [seq_length, out_size, out_size, 3])
    return images_sequence

def _number_label_to_digits_sequence(number, seq_length, seq_invert, pad_begin):
    PAD_VALUE = 10

    def while_loop_body(i, value, seq):
        digit = tf.stack([tf.mod(value, 10)])
        if (seq_invert):
            result = tf.cond( tf.equal(i, 0),
                     lambda: digit,
                     lambda: tf.concat([seq, digit], axis=0))
        else:
            result = tf.cond( tf.equal(i, 0),
                     lambda: digit,
                     lambda: tf.concat([digit, seq], axis=0))
        return [i + 1, tf.div(value, 10), result]
    cond = lambda i, value, seq: tf.logical_and((i < seq_length), (0 < value))

    i = tf.constant(0)
    value = number
    seq = tf.zeros([0], dtype = tf.int32)
    sequence = tf.while_loop(cond, while_loop_body,
                             loop_vars=[i, value, seq],
                             shape_invariants=[i.get_shape(), value.get_shape(), tf.TensorShape(None)])[2]

    pad = tf.ones([seq_length - tf.shape(sequence)[0]], tf.int32) * PAD_VALUE
    # i think fixed shape of output will be better
    if (pad_begin):
        return tf.reshape(tf.concat([pad, sequence], axis = 0), [seq_length])
    return tf.reshape(tf.concat([sequence, pad], axis = 0), [seq_length])

def _make_dataset(data_paths, image_seq_length, digit_seq_length, out_size, batch_size, prefetch, digit_seq_invert = False, crop_bbox = 0, digit_seq_pad_begin = True, shuffle = True):
    def _parse_function(example_proto):
        features = { 'image/filename':   tf.FixedLenFeature((), tf.string, default_value=''),
                     'image/encoded':    tf.FixedLenFeature((), tf.string, default_value=''),
                     'image/object/bbox/xmin':  tf.FixedLenFeature((), tf.float32, default_value=0.),
                     'image/object/bbox/xmax':  tf.FixedLenFeature((), tf.float32, default_value=1.),
                     'image/object/bbox/ymin':  tf.FixedLenFeature((), tf.float32, default_value=0.),
                     'image/object/bbox/ymax':  tf.FixedLenFeature((), tf.float32, default_value=1.),
                     'image/object/class/label':      tf.FixedLenFeature((), tf.int64, default_value=0)}
        parsed_features = tf.parse_single_example(example_proto, features)
        filename = parsed_features['image/filename']
        label = tf.to_int32(parsed_features['image/object/class/label'])
        image = tf.image.decode_jpeg(parsed_features['image/encoded'], channels = 3)
        image = tf.div(tf.to_float(image), 255.0)
        image = tf.subtract(image, 0.5)
        image = tf.multiply(image, 2.0)

        print("crop_bbox type: {}".format(crop_bbox))
        if (crop_bbox == 1):
            xmin = tf.random_uniform([], 0., parsed_features['image/object/bbox/xmin'])
            xmin = tf.to_int32(xmin * tf.to_float(tf.shape(image)[1]))
            xmax = tf.random_uniform([], parsed_features['image/object/bbox/xmax'], 1.)
            xmax = tf.to_int32(xmax * tf.to_float(tf.shape(image)[1]))
            ymin = tf.random_uniform([], 0., parsed_features['image/object/bbox/ymin'])
            ymin = tf.to_int32(ymin * tf.to_float(tf.shape(image)[0]))
            ymax = tf.random_uniform([], parsed_features['image/object/bbox/ymax'], 1.)
            ymax = tf.to_int32(ymax * tf.to_float(tf.shape(image)[0]))
            image = tf.image.crop_to_bounding_box(image, ymin, xmin, ymax - ymin - 1, xmax - xmin - 1)
        elif (crop_bbox == 2):
            xmin = parsed_features['image/object/bbox/xmin']
            xmin = tf.to_int32(xmin * tf.to_float(tf.shape(image)[1]))
            xmax = parsed_features['image/object/bbox/xmax']
            xmax = tf.to_int32(xmax * tf.to_float(tf.shape(image)[1]))
            ymin = parsed_features['image/object/bbox/ymin']
            ymin = tf.to_int32(ymin * tf.to_float(tf.shape(image)[0]))
            ymax = parsed_features['image/object/bbox/ymax']
            ymax = tf.to_int32(ymax * tf.to_float(tf.shape(image)[0]))
            image = tf.image.crop_to_bounding_box(image, ymin, xmin, ymax - ymin - 1, xmax - xmin - 1)
        elif (crop_bbox == 3):
            xmin = parsed_features['image/object/bbox/xmin']
            xmax = parsed_features['image/object/bbox/xmax']
            ymin = parsed_features['image/object/bbox/ymin']
            ymax = parsed_features['image/object/bbox/ymax']

            width  = xmax - xmin
            height = ymax - ymin

            GAP_PERCENT = 0.1

            xmin = tf.random_uniform([], tf.maximum(0., xmin - GAP_PERCENT * width), parsed_features['image/object/bbox/xmin'])
            xmax = tf.random_uniform([], parsed_features['image/object/bbox/xmax'], tf.minimum(1., xmax + GAP_PERCENT * width))
            ymin = tf.random_uniform([], tf.maximum(0., ymin - GAP_PERCENT * height), parsed_features['image/object/bbox/ymin'])
            ymax = tf.random_uniform([], parsed_features['image/object/bbox/ymax'], tf.minimum(1., ymax + GAP_PERCENT * height))

            xmin = tf.to_int32(xmin * tf.to_float(tf.shape(image)[1]))
            xmax = tf.to_int32(xmax * tf.to_float(tf.shape(image)[1]))
            ymin = tf.to_int32(ymin * tf.to_float(tf.shape(image)[0]))
            ymax = tf.to_int32(ymax * tf.to_float(tf.shape(image)[0]))
            image = tf.image.crop_to_bounding_box(image, ymin, xmin, ymax - ymin - 1, xmax - xmin - 1)

        image = _image_resize(image, out_size)
        # [seq_len, out_height, out_width, 3]
        image_seq = _image_to_sequence(image, image_seq_length, out_size)
        digit_seq = _number_label_to_digits_sequence(label, digit_seq_length, digit_seq_invert, digit_seq_pad_begin)

        return image_seq, digit_seq

    dataset = (tf.data.TFRecordDataset(data_paths)
               .map(_parse_function, num_parallel_calls=batch_size)
               .prefetch(prefetch))
    if (shuffle):
        dataset = dataset.shuffle(prefetch)
    dataset = dataset.batch(batch_size)
    return dataset

def _get_data(dataset):
    iterator = dataset.make_initializable_iterator()
    example = iterator.get_next()
    return example[0], example[1], iterator

def _seq_to_list(image_seq, axis = 1):
    seq_len = image_seq.get_shape().as_list()[axis]
    image_list = tf.split(image_seq, seq_len, axis)
    image_list = [tf.squeeze(image, [axis]) for image in image_list]
    return image_list

def _make_model(image_seq, out_seq_length, kernel_szs, out_chns, strides, weight_decay, bn_params, peephole, output_keep_prob, fc_keep_prob, is_training, fc_additional = 0):
    DIGIT_CLASSES_COUNT = 10

    image_list = _seq_to_list(image_seq)
    input_shape = image_list[0].get_shape().as_list()[1:] # without batch size

    regularizer = tf.contrib.layers.l2_regularizer(weight_decay) if weight_decay > 0. else None
    encoder = conv_lstm.MultiConvLSTMCell("encoder", input_shape, kernel_szs, out_chns, strides,
                                          bn_params = bn_params, has_peephole = peephole,
                                          output_keep_prob = tf.cond(is_training, lambda: output_keep_prob, lambda: 1.0),
                                          weight_regularizer = regularizer, bias_regularizer = regularizer,
                                          concat_all_cells_output = False)

    decoder = conv_lstm.MultiConvLSTMCell("decoder", input_shape, kernel_szs, out_chns, strides,
                                          bn_params = bn_params, has_peephole=peephole,
                                          output_keep_prob = tf.cond(is_training, lambda: output_keep_prob, lambda: 1.0),
                                          weight_regularizer = regularizer, bias_regularizer = regularizer,
                                          concat_all_cells_output = False)

    _, states = tf.nn.static_rnn(encoder, image_list, dtype = tf.float32)
    logits_list = []
    for i in range(out_seq_length):
        output, states = decoder(None, states)
        output = slim.flatten(output)
        output = slim.dropout(output, keep_prob = fc_keep_prob, is_training=is_training)
        if (fc_additional > 0):
            output = slim.fully_connected(output, fc_additional,
                                          normalizer_params = bn_params,
                                          weights_regularizer=regularizer,
                                          weights_initializer=slim.variance_scaling_initializer(),
                                          activation_fn = tf.nn.relu, reuse = tf.AUTO_REUSE, scope = "fc_internal")
            output = slim.dropout(output, keep_prob = fc_keep_prob, is_training=is_training)

        output = slim.fully_connected(output, DIGIT_CLASSES_COUNT + 1,
                                      weights_regularizer=regularizer,
                                      weights_initializer=slim.variance_scaling_initializer(),
                                      activation_fn = None, reuse = tf.AUTO_REUSE, scope = "fc")
        logits_list.append(tf.expand_dims(output, 1))
    return tf.concat(logits_list, 1)

def _loss_func(logits, labels):
    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels))
    if (0 < len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))):
        loss += tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
    return loss

def _create_optimizer(optimizer_name, lr, weight_decay):
    if ('momentum' == optimizer_name):
        return tf.train.MomentumOptimizer(lr)
    elif ('adam' == optimizer_name) :
        return tf.train.AdamOptimizer(lr)
    elif ('adamw' == optimizer_name):
        return tf.contrib.opt.AdamWOptimizer(weight_decay * lr, lr)
    return tf.train.AdamOptimizer(lr)

def _save_gdef(sess, out_graph_path):
    print("Save gdef: {}".format(out_graph_path))

    graph_def = tf.get_default_graph().as_graph_def()
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['inference_softmax'])

    for node in output_graph_def.node:
        if node.op == 'RefSwitch':
            node.op = 'Switch'
            for index in range(len(node.input)):
                if 'moving_' in node.input[index]:
                    node.input[index] = node.input[index] + '/read'
        elif node.op == 'AssignSub':
            node.op = 'Sub'
            if 'use_locking' in node.attr: del node.attr['use_locking']
        elif node.op == 'AssignAdd':
            node.op = 'Add'
            if 'use_locking' in node.attr: del node.attr['use_locking']

    with gfile.GFile(out_graph_path, "wb") as f:
        f.write(output_graph_def.SerializeToString())

def train(train_data_paths, train_dir,
         image_seq_length, image_size, batch_size, crop_bbox,
         digit_seq_length, digit_seq_invert, digit_seq_pad_begin,
         kernel_szs, out_chns, strides,
         lr_init, lr_decay_factor, lr_decay_epoches,
         epoch_start, epoch_end,
         optimizer_name = "adam",
         peephole = True,
         output_keep_prob = 0.5,
         fc_keep_prob = 0.5,
         fc_additional = 0,
         weight_decay = 0.001,
         bn_decay = 0.999, bn_epsilon = 0.001,
         save_after_epoch = 20, save_every_epoch = 10):

    nv_ctx = nv.context()

    train_rec_cnt = 0
    for path in train_data_paths:
        train_rec_cnt += _tfrecords_count(path)


    train_ds = _make_dataset(train_data_paths, image_seq_length, digit_seq_length, image_size, batch_size, 2 * batch_size, digit_seq_invert, crop_bbox, digit_seq_pad_begin)

    is_training = tf.placeholder_with_default(False, shape=[], name='is_training')

    bn_params = {'decay': bn_decay, 'epsilon': bn_epsilon, 'is_training': is_training}

    ###########################################
    train_image_seq, train_digit_seq, train_iterator = _get_data(train_ds)
    logits = _make_model(train_image_seq, digit_seq_length, kernel_szs, out_chns, strides, weight_decay, bn_params, peephole, output_keep_prob, fc_keep_prob, is_training, fc_additional)

    loss = _loss_func(logits, train_digit_seq)
    global_step = tf.Variable(0, trainable=False)
    lr = tf.train.exponential_decay(lr_init, global_step, lr_decay_epoches * train_rec_cnt // batch_size, lr_decay_factor, staircase=True)

    optimizer = _create_optimizer(optimizer_name, lr, weight_decay)

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    updates = tf.group(*update_ops)
    with tf.control_dependencies([updates]):
        train_op = optimizer.minimize(loss, global_step=global_step)
    ###########################################

    ###########################################
    image_ph = tf.placeholder(tf.uint8, shape=[1, None, None, 3], name = 'inference_input')
    imagef = tf.to_float(tf.reverse(image_ph, [-1])) # BGR -> RGB -> float
    imagef = tf.squeeze(imagef, 0)
    imagef = tf.div(imagef, 255.0)
    imagef = tf.subtract(imagef, 0.5)
    imagef = tf.multiply(imagef, 2.0)
    imagef = _image_resize(imagef, image_size)
    imagef_seq = _image_to_sequence(imagef, image_seq_length, image_size)
    imagef_seq = tf.expand_dims(imagef_seq, axis = 0)

    inference_logits = _make_model(imagef_seq, digit_seq_length, kernel_szs, out_chns, strides, weight_decay, bn_params, peephole, 1.0, 1.0, is_training, fc_additional)
    inference_softmax = tf.nn.softmax(inference_logits, name = 'inference_softmax')
    ###########################################

    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())

    last_saved_epoch = -1
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(init_op)
        if (epoch_start > 0):
            saver.restore(sess, '{}/model.ckpt-{}'.format(train_dir, epoch_start - 1))
            print('restored')
            last_saved_epoch = epoch_start - 1

        lbls_cnt = np.zeros([4], np.int32)
        for epoch in range(epoch_start, epoch_end):
            sess.run(train_iterator.initializer)
            avg_train_loss = 0
            local_step = 0
            while True:
                try:
                    gs_, loss_, _ = sess.run([global_step, loss, train_op], feed_dict = {is_training: True})
                    avg_train_loss += loss_
                    print("Step: {}, loss: {}".format(gs_, loss_))
                    local_step += 1
                except tf.errors.OutOfRangeError:
                    break
            ckpt_prefix = saver.save(sess, train_dir + '/model.ckpt', global_step=epoch)
            nv_ctx.dump_state('last_ckpt_state', ckpt_prefix, ckpt_snapshot.dump_last_ckpt)
            if ((epoch >= save_after_epoch) and (epoch >= last_saved_epoch + save_every_epoch)) or (epoch == epoch_end - 1):
                out_graph_path = os.path.join(train_dir, "{}.gdef".format(epoch))
                _save_gdef(sess, out_graph_path)
                ckpt_snapshot.update_all_gdef(nv_ctx.get_outputs().get('gdef_tgz'), out_graph_path)
                tf.gfile.Remove(out_graph_path)
                last_saved_epoch = epoch
            print('Average train loss: {}'.format(avg_train_loss /  train_rec_cnt * batch_size))

