import tensorflow as tf
from tensorflow.python.framework import graph_util
from tensorflow.python.platform import gfile
import random
from scipy import misc
import numpy as np
import os.path
import params
import importlib

def loss_func(logits, labels):
    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels))
    if (0 < len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))):
        print('tf.GraphKeys.REGULARIZATION_LOSSES ', len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)))
        loss += tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
    return loss

def pick_train_batch(train_data, batch_idx):
    images_data = np.empty([params.BATCH_SIZE, params.IMAGE_SIZE, params.IMAGE_SIZE, params.IMAGE_CHNS], np.float32)
    labels_data = np.empty([params.BATCH_SIZE], np.int32)

    batch_start = batch_idx * params.CL_PER_BATCH
    batch_end = batch_start + params.CL_PER_BATCH

    for label in range(params.POS_CLASSES_CNT):
        start = label * params.CL_PER_BATCH
        end   = start + params.CL_PER_BATCH
        images_data[start : end, ...] = train_data[label][batch_start : batch_end, ...]
        labels_data[start : end] = label

    if (params.USE_INVALID):
        batch_start = batch_idx * params.INVALID_PER_BATCH
        batch_end = batch_start + params.INVALID_PER_BATCH
        start_invalid = params.POS_CLASSES_CNT * params.CL_PER_BATCH
        for label in range(params.INVALID_CLASSES_CNT):
            start = start_invalid + label * params.INVALID_PER_BATCH
            end   = start + params.INVALID_PER_BATCH
            images_data[start : end, ...] = train_data[params.POS_CLASSES_CNT + label][batch_start : batch_end, ...]
            labels_data[start : end] = params.POS_CLASSES_CNT + label


    if (params.USE_NEGATIVES):
        batch_start = batch_idx * params.NEG_PER_BATCH
        batch_end = batch_start + params.NEG_PER_BATCH
        if (batch_end > train_data[params.POS_CLASSES_CNT].shape[0]):
            batch_start = batch_start % train_data[params.POS_CLASSES_CNT].shape[0]
            batch_end = batch_start + params.NEG_PER_BATCH

        start = params.POS_CLASSES_CNT * params.CL_PER_BATCH + (params.INVALID_CLASSES_CNT * params.INVALID_PER_BATCH if params.USE_INVALID else 0)
        end   = start + params.NEG_PER_BATCH
        images_data[start : end, ...] = train_data[params.NEG_CLASS_IDX][batch_start : batch_end, ...]
        labels_data[start : end] = params.NEG_CLASS_IDX

    return images_data, labels_data

def fill_negatives(img_cnt):
    print("start filling negatives...")
    data = np.empty([img_cnt, params.IMAGE_SIZE, params.IMAGE_SIZE, params.IMAGE_CHNS], np.float32)

    cnt = 0
    negs_list = open(params.NEGS_LIST_FILEPATH, "r").read().splitlines()
    random.shuffle(negs_list)
    negs_dir = os.path.dirname(params.NEGS_LIST_FILEPATH)
    for line in negs_list:
        neg_img_path = line
        if (not os.path.isabs(neg_img_path)):
            neg_img_path = os.path.normpath(os.path.join(negs_dir, line))

        print(cnt, "  ", neg_img_path)
        try:
            image = misc.imread(neg_img_path)
        except:
            continue
        if (0 == image.size):
            continue
        if (3 != len(image.shape)):
            continue
        if ((500 > image.shape[0]) or (500 > image.shape[1]) or (3 != image.shape[2])):
            continue

        #RGB to BGR
        image = image[:,:,::-1]
        for i in range(params.TRY_NEG_PER_IMAGE):
            left = random.randint(0, image.shape[1] - params.IMAGE_SIZE)
            top  = random.randint(0, image.shape[0] - params.IMAGE_SIZE)
            size   = random.randint(params.IMAGE_SIZE, min(image.shape[1] - left, image.shape[0] - top))
            right = left + size
            bottom = top + size

            object = image[top:bottom, left:right, :]
            object = misc.imresize(object, [params.IMAGE_SIZE, params.IMAGE_SIZE]).astype(np.float32) / 255.

            data[cnt,:,:,:] = object[:,:,:]
            cnt = cnt + 1
            if (cnt >= img_cnt) :
                print("end filling negatives.")
                return data

    print("end filling negatives.")
    print("There are not enought negatives file for add negatives")
    return data

def read_train_data():
    train_data = []
    train_img_cnt_min = 1000000
    for file in params.POS_CLASSES_FILES:
        data = np.fromfile(params.TRAIN_FOLDER + '/' + file, dtype=np.uint8).astype(np.float32) / 255.
        num_images = data.size // params.IMAGE_SIZE // params.IMAGE_SIZE // params.IMAGE_CHNS
        data = data.reshape(num_images, params.IMAGE_SIZE, params.IMAGE_SIZE, params.IMAGE_CHNS)
        if (train_img_cnt_min > data.shape[0]):
            train_img_cnt_min = data.shape[0]
        train_data.append(data)

    if (params.USE_INVALID):
        for file in params.INVALID_CLASSES_FILES:
            data = np.fromfile(params.TRAIN_FOLDER + '/' + file, dtype=np.uint8).astype(np.float32) / 255.
            num_images = data.size // params.IMAGE_SIZE // params.IMAGE_SIZE // params.IMAGE_CHNS
            data = data.reshape(num_images, params.IMAGE_SIZE, params.IMAGE_SIZE, params.IMAGE_CHNS)
            if (train_img_cnt_min > data.shape[0]):
                train_img_cnt_min = data.shape[0]
            train_data.append(data)

    if (params.USE_NEGATIVES):
        if (os.path.exists(params.TRAIN_FOLDER + '/' + params.NEGATIVE_FILE)):
            data = np.fromfile(params.TRAIN_FOLDER + '/' + params.NEGATIVE_FILE, dtype=np.uint8).astype(np.float32) / 255.
            num_images = data.size // params.IMAGE_SIZE // params.IMAGE_SIZE // params.IMAGE_CHNS
            data = data.reshape(num_images, params.IMAGE_SIZE, params.IMAGE_SIZE, params.IMAGE_CHNS)
            train_data.append(data)
        else:
            data = fill_negatives(train_img_cnt)
            train_data.append(data)
    return train_data, train_img_cnt_min

def read_eval_data():
    eval_data = []
    eval_img_cnt_min = 1000000
    for file in params.POS_CLASSES_FILES:
        data = np.fromfile(params.TEST_FOLDER + '/' + file, dtype=np.uint8).astype(np.float32) / 255.
        num_images = data.size // params.IMAGE_SIZE // params.IMAGE_SIZE // params.IMAGE_CHNS
        data = data.reshape(num_images, params.IMAGE_SIZE, params.IMAGE_SIZE, params.IMAGE_CHNS)
        if (eval_img_cnt_min > data.shape[0]):
            eval_img_cnt_min = data.shape[0]
        eval_data.append(data)

    #if (params.USE_NEGATIVES):
    #    if (os.path.exists(params.TEST_FOLDER + '/' + params.NEGATIVE_FILE)):
    #        data = np.fromfile(params.TEST_FOLDER + '/' + params.NEGATIVE_FILE, dtype=np.uint8).astype(np.float32) / 255.
    #        num_images = data.size // params.IMAGE_SIZE // params.IMAGE_SIZE // params.IMAGE_CHNS
    #        data = data.reshape(num_images, params.IMAGE_SIZE, params.IMAGE_SIZE, params.IMAGE_CHNS)
    #        eval_data.append(data)
    #    else:
    #        data = fill_negatives(eval_img_cnt_min)
    #        eval_data.append(data)
    return eval_data, eval_img_cnt_min

def search_hard_negatives(sess, train_data, data_mean):
    print("start filling hard negatives...")
    try_ph      = tf.placeholder(tf.float32, shape=[params.TRY_NEG_PER_IMAGE, params.IMAGE_SIZE, params.IMAGE_SIZE, params.IMAGE_CHNS])
    try_logits  = modelClass.model(try_ph , False, True)
    try_softmax = tf.nn.softmax(try_logits)
    try_argmax  = tf.argmax(try_logits, 1)

    try_data    = np.empty([params.TRY_NEG_PER_IMAGE, params.IMAGE_SIZE, params.IMAGE_SIZE, params.IMAGE_CHNS], np.float32)

    neg_img_need = train_data.shape[0] // 2

    negs_list = open(params.NEGS_LIST_FILEPATH, "r").read().splitlines()
    random.shuffle(negs_list)
    negs_dir = os.path.dirname(params.NEGS_LIST_FILEPATH)
    cnt = 0
    for line in negs_list:
        neg_img_path = line
        if (not os.path.isabs(neg_img_path)):
            neg_img_path = os.path.normpath(os.path.join(negs_dir, line))

        print(cnt, "  ", neg_img_path)
        try:
            image = misc.imread(neg_img_path)
        except:
            continue
        if (0 == image.size):
            continue
        if (3 != len(image.shape)):
            continue
        if ((500 > image.shape[0]) or (500 > image.shape[1]) or (3 <> image.shape[2])):
            continue

        #RGB to BGR
        image = image[:,:,::-1]
        for i in range(params.TRY_NEG_PER_IMAGE):
            left = random.randint(0, image.shape[1] - params.IMAGE_SIZE)
            top  = random.randint(0, image.shape[0] - params.IMAGE_SIZE)
            size   = random.randint(params.IMAGE_SIZE, min(image.shape[1] - left, image.shape[0] - top))
            right = left + size
            bottom = top + size

            object = image[top:bottom, left:right, :]
            object = misc.imresize(object, [params.IMAGE_SIZE, params.IMAGE_SIZE]).astype(np.float32) / 255.

            object[:, :] -= data_mean
            try_data[i,:,:,:] = object[:,:,:]

        softmax, argmax = sess.run([try_softmax, try_argmax], feed_dict= {try_ph: try_data})
        for i in range(params.TRY_NEG_PER_IMAGE):
            if (params.NEG_CLASS_IDX != argmax[i] or params.HARD_NEG_THRESHOLD > softmax[i][params.NEG_CLASS_IDX]):
                train_data[cnt, :, :, :] = try_data[i, :, :, :]
                cnt = cnt + 1
                if (cnt >= neg_img_need):
                    print("end filling hard negatives...")
                    return

    print("end filling hard negatives...")

def calc_mean(data):
    sum = np.zeros(3, dtype=np.float64)
    cnt = 0
    for cldata in data:
        cnt += np.prod(cldata[:, :, :, 0].shape)
        sum += cldata.sum(axis = (0, 1, 2), dtype=np.float64)
    return sum / cnt

def make_classes_names_list():
    class_names_list = [item[:-4] for item in params.POS_CLASSES_FILES]
    if (params.USE_INVALID):
        class_names_list += [item[:-4] for item in params.INVALID_CLASSES_FILES]
    if (params.USE_NEGATIVES):
        class_names_list.append('negative')
    return class_names_list

def train():
    try:
        os.mkdir(params.OUT_FOLDER)
    except:
        print('Output dir %s already exists' % params.OUT_FOLDER)

    print('optimizer: ', params.OPTIMIZER)
    print('  LR base: ', params.LR_BASE)
    print('  LR decay epoches: ', params.LR_DECAY_EPOCH)
    print('  LR decay rate: ', params.LR_DECAY_RATE)

    train_data, train_img_cnt = read_train_data()
    train_data_mean = calc_mean(train_data)
    for data in train_data:
        data[:, :, :] -= train_data_mean

    eval_data, eval_img_cnt = read_eval_data()
    for data in eval_data:
        data[:, :, :] -= train_data_mean

    data_mean = tf.get_variable("data_mean", shape=[params.IMAGE_CHNS],
                                initializer = tf.constant_initializer(train_data_mean))

    batch_per_epoch = train_img_cnt // params.CL_PER_BATCH
    if (params.USE_NEGATIVES):
        neg_batch = train_data[-1].shape[0] // params.NEG_PER_BATCH
        if (batch_per_epoch > neg_batch):
            batch_per_epoch = neg_batch

    global_step = tf.Variable(0, trainable=False)

    #############################################################
    # prepare training operation
    images_ph = tf.placeholder(tf.float32, shape=[params.BATCH_SIZE, params.IMAGE_SIZE, params.IMAGE_SIZE, params.IMAGE_CHNS])
    labels_ph = tf.placeholder(tf.int32,   shape=[params.BATCH_SIZE])

    logits = modelClass.model(images_ph, True, False)
    train_loss = loss_func(logits, labels_ph)

    lr = tf.train.exponential_decay(params.LR_BASE, global_step, params.LR_DECAY_EPOCH * batch_per_epoch, params.LR_DECAY_RATE, staircase=True)
    if ('GradientDescentOptimizer' == params.OPTIMIZER) :
        optimizer = tf.train.GradientDescentOptimizer(lr)
    elif ('AdagradOptimizer' == params.OPTIMIZER) :
         optimizer = tf.train.AdagradOptimizer(lr)
    else :
         optimizer = tf.train.AdamOptimizer(lr)
    train = optimizer.minimize(train_loss, global_step=global_step)
    saver = tf.train.Saver(max_to_keep=params.EPOCH_MAX + 1)
    #############################################################


    #############################################################
    # prepare valuated operation
    eval_ph         = tf.placeholder(tf.float32, shape=[eval_img_cnt, params.IMAGE_SIZE, params.IMAGE_SIZE, params.IMAGE_CHNS])
    eval_logits     = modelClass.model(eval_ph , False, True)
    eval_softmax    = tf.nn.softmax(eval_logits)
    eval_argmax     = tf.argmax(eval_logits, 1)
    #############################################################


    #############################################################
    # prepare inference graph operation
    classes_names_list = make_classes_names_list()

    tf.constant(classes_names_list, tf.string, name="class_names")

    inference_inp       = tf.placeholder(tf.uint8, shape=[1, params.IMAGE_SIZE, params.IMAGE_SIZE, params.IMAGE_CHNS], name="inference_input")
    inference_inpf      = tf.to_float(inference_inp)
    inference_inpf      = (inference_inpf / 255.) - data_mean

    inference_logits    = modelClass.model(inference_inpf, False, True)

    tf.nn.softmax(inference_logits, name="inference_softmax")
    #############################################################

    print(classes_names_list)

    trained = False
    while (not trained):
        trained = True
        loss_file = open(params.LOSS_FILE_DUMP, "w")
        with tf.Session() as sess:
            if (0 < params.EPOCH_START):
                model_path = os.path.join(params.OUT_FOLDER, 'model-%d' % (params.EPOCH_START - 1))
                saver.restore(sess, model_path)
            else:
                sess.run(tf.global_variables_initializer())
            print("data_mean", sess.run(data_mean))
            for epoch in range(params.EPOCH_START, params.EPOCH_MAX + 1):
                print("epoch: ", epoch, " global_step: ", sess.run(global_step), " lr: ", sess.run(lr));

                eval_loss = 0
                all_invalid_in_class = False
                for a in range(params.POS_CLASSES_CNT):
                    softmax, argmax= sess.run([eval_softmax, eval_argmax], feed_dict= {eval_ph: eval_data[a]})
                    valid_cnt = np.count_nonzero(argmax == a)
                    if (0 == valid_cnt):
                        all_invalid_in_class = True
                    class_loss = -np.log(softmax[:, a] + 1e-5).sum();
                    print("  classes id = ", a, " valid: ", valid_cnt, " from ", eval_data[a].shape[0], ". loss = ", class_loss / eval_data[a].shape[0], classes_names_list[a])
                    eval_loss += class_loss

                # in eval_data we have only positives and negative samplea, invalid signs don't loaded to eval data
                # don't check negatives - it's useless. May be some negatives go to "invalid" classes, anyway we collected negatives for validation
                # not by detectors errors - then them not interesting.
                #if (params.USE_NEGATIVES):
                #    softmax, argmax= sess.run([eval_softmax, eval_argmax], feed_dict= {eval_ph: eval_data[params.POS_CLASSES_CNT]})
                #    valid_cnt = np.count_nonzero(argmax == params.NEG_CLASS_IDX)
                #    class_loss = -np.log(softmax[:, params.NEG_CLASS_IDX]).sum();
                #    print("  negatives valid: ", valid_cnt, " from ", eval_data[params.POS_CLASSES_CNT].shape[0], ". loss = ", class_loss / eval_data[params.POS_CLASSES_CNT].shape[0])
                #    eval_loss += class_loss

                eval_loss = (eval_loss / eval_img_cnt / (params.POS_CLASSES_CNT + (1 if params.USE_NEGATIVES else 0)))
                print("  eval loss = ", eval_loss)

                # we walked to the local minimum, with bad generalization for the class
                if ((50 == epoch) and all_invalid_in_class):
                    trained = False
                    break

                if (params.HARD_NEG_SEARCH and (0 == (epoch % 100)) and (100 <= epoch)) :
                    search_hard_negatives(sess, train_data[-1], train_data_mean)
                    search_hard_negatives(sess, eval_data[-1], train_data_mean)

                for data in train_data:
                    np.random.shuffle(data)
                epoch_loss = 0
                for batch in range(batch_per_epoch):
                    images_data, labels_data = pick_train_batch(train_data, batch);
                    feed_dict = {images_ph: images_data, labels_ph: labels_data}
                    _, batch_loss = sess.run([train, train_loss], feed_dict=feed_dict)
                    print("  batch: ", batch, " loss = ", batch_loss);
                    epoch_loss += batch_loss
                    loss_file.write("%d %f %f\n" % (sess.run(global_step), batch_loss, eval_loss))
                    loss_file.flush()

                print("  average loss = ", epoch_loss / batch_per_epoch);

                if (0 == (epoch % params.SAVE_EVERY_EPOCHES)):
                    model_path = os.path.join(params.OUT_FOLDER, 'model')
                    saver.save(sess, model_path, global_step=epoch)
                    #################################################
                    # save graph in protobuf
                    out_graph_path = os.path.join(params.OUT_FOLDER, "{}.gdef".format(epoch))
                    graph_def = tf.get_default_graph().as_graph_def()
                    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['inference_softmax', 'class_names'])
                    with gfile.GFile(out_graph_path, "wb") as f:
                        f.write(output_graph_def.SerializeToString())
                    print("%d ops in the final graph." % len(output_graph_def.node))
            sess.close()
        loss_file.close()


BATCH_SIZE_RECALC = False

try:
    val = os.environ['TSC_TRAINER_POS_CLASSES_REMOVE']
    list_rm = val.split(',')
    for el_rm in list_rm:
        params.POS_CLASSES_FILES.remove(el_rm)
    params.POS_CLASSES_CNT  = len(params.POS_CLASSES_FILES)
    params.OUT_CLASSES_CNT  = params.POS_CLASSES_CNT + (params.INVALID_CLASSES_CNT if params.USE_INVALID else 0) + (1 if params.USE_NEGATIVES else 0)
    params.NEG_CLASS_IDX = (params.OUT_CLASSES_CNT - 1) if params.USE_NEGATIVES else -1
    BATCH_SIZE_RECALC = True
except KeyError:
    val = -1


try:
    val = os.environ['TSC_TRAINER_IMAGE_CHNS']
    params.IMAGE_CHNS = int(val)
except KeyError:
    val = -1

try:
    val = os.environ['TSC_TRAINER_IMAGE_SIZE']
    params.IMAGE_SIZE = int(val)
except KeyError:
    val = -1

try:
    val = os.environ['TSC_TRAINER_CL_PER_BATCH']
    params.CL_PER_BATCH = int(val)
    params.NEG_PER_BATCH = 4 * params.CL_PER_BATCH
    params.INVALID_PER_BATCH = params.CL_PER_BATCH
    BATCH_SIZE_RECALC = True
except KeyError:
    val = -1

try:
    val = os.environ['TSC_TRAINER_NEG_PER_BATCH']
    params.NEG_PER_BATCH = int(val)
    BATCH_SIZE_RECALC = True
except KeyError:
    val = -1

try:
    val = os.environ['TSC_TRAINER_INVALID_PER_BATCH']
    params.INVALID_PER_BATCH = int(val)
    BATCH_SIZE_RECALC = True
except KeyError:
    val = -1

if (BATCH_SIZE_RECALC):
    params.BATCH_SIZE   = params.CL_PER_BATCH * params.POS_CLASSES_CNT + (params.INVALID_PER_BATCH * params.INVALID_CLASSES_CNT if params.USE_INVALID else 0) + (params.NEG_PER_BATCH if params.USE_NEGATIVES else 0)

try:
    val = os.environ['TSC_TRAINER_EPOCH_START']
    params.EPOCH_START = int(val)
except KeyError:
    val = -1

try:
    val = os.environ['TSC_TRAINER_EPOCH_MAX']
    params.EPOCH_MAX = int(val)
except KeyError:
    val = -1

try:
    val = os.environ['TSC_TRAINER_LR_BASE']
    params.LR_BASE = float(val)
except KeyError:
    val = -1

try:
    val = os.environ['TSC_TRAINER_LR_DECAY_EPOCH']
    params.LR_DECAY_EPOCH = int(val)
except KeyError:
    val = -1

try:
    val = os.environ['TSC_TRAINER_LR_DECAY_RATE']
    params.LR_DECAY_RATE = float(val)
except KeyError:
    val = -1

try:
    val = os.environ['TSC_TRAINER_OPTIMIZER']
    params.OPTIMIZER = val
except KeyError:
    val = -1

try:
    val = os.environ['TSC_TRAINER_SAVE_EVERY_EPOCHES']
    params.SAVE_EVERY_EPOCHES = int(val)
except KeyError:
    val = -1


model_class = 'models.model_bn_32__64_64__128_reg'
try:
    model_class = 'models.' + os.environ['TSC_MODEL_CLASS']
except KeyError:
    model_class = 'models.model_bn_32__64_64__128_reg'

print('model: ', model_class)
modelClass = importlib.import_module(model_class)

train()

