import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
import os.path
import models
import trainers

print(tf.__version__)

SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))

INPUT_SEQ_LENGTH  = 10
OUTPUT_SEQ_LENGTH = 10
SEQ_LENGTH        = INPUT_SEQ_LENGTH + OUTPUT_SEQ_LENGTH

LR_BASE         = 0.001
LR_DECAY_EPOCH  = 25
LR_DECAY_RATE   = 0.5
LR_DECAY_STAIRCASE = True
LR_DECAY_MIN_EPOCHS = 10

BATCH_SIZE = 16
#OUT_FOLDER  = os.path.abspath("./model-data/test_ps4_bs16_bn9e-1_wd1e-3_lr1e-3_lrdc25r_lrdcr5e-3_peephole/")
OUT_FOLDER  = os.path.abspath("./model-data/test_do0.5all/")
TFBOARD_FOLDER = os.path.abspath("./tf_board/")

EPOCH_START = 11
EPOCH_MAX   = 200
SAVE_EVERY_EPOCHES = 10

VALID_CHECK_EPOCHS = 10

DROP_OUT = 0.5

# set to zero for no regeneration
EPOCHES_DATASET_REGENERATE = 0

PEEPHOLE_ENABLED = True
BATCHNORM_ENABLED = True
CONDITIONAL_MODEL = False

SMART_TRAINER = True

def loss_function(Y, logits):
    cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(logits = logits, labels = Y)
    loss = tf.reduce_sum(cross_entropy)
    if (0 < len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))):
        print('tf.GraphKeys.REGULARIZATION_LOSSES ', len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)))
        loss += tf.add_n(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
    return loss


if (SMART_TRAINER):
    trainer = trainers.SmartTrainer(INPUT_SEQ_LENGTH, OUTPUT_SEQ_LENGTH,
                                     SCRIPT_DIR + "/data/", OUT_FOLDER, 2,
                                     EPOCHES_DATASET_REGENERATE, BATCH_SIZE,
                                     models.get_model(CONDITIONAL_MODEL, 3, BATCHNORM_ENABLED, PEEPHOLE_ENABLED, DROP_OUT < 1.0),
                                     loss_function,
                                     DROP_OUT,
                                     EPOCH_START, EPOCH_MAX, SAVE_EVERY_EPOCHES,
                                     TFBOARD_FOLDER,
                                     LR_BASE, LR_DECAY_RATE, LR_DECAY_EPOCH, 1e-6, LR_DECAY_MIN_EPOCHS)
else:
    trainer = trainers.SimpleTrainer(INPUT_SEQ_LENGTH, OUTPUT_SEQ_LENGTH,
                                     SCRIPT_DIR + "/data/", OUT_FOLDER, 2,
                                     EPOCHES_DATASET_REGENERATE, BATCH_SIZE,
                                     models.get_model(CONDITIONAL_MODEL, 3, BATCHNORM_ENABLED, PEEPHOLE_ENABLED, DROP_OUT < 1.0),
                                     loss_function,
                                     DROP_OUT,
                                     EPOCH_START, EPOCH_MAX, SAVE_EVERY_EPOCHES,
                                     TFBOARD_FOLDER,
                                     LR_BASE, LR_DECAY_RATE, LR_DECAY_EPOCH, LR_DECAY_STAIRCASE)
trainer.train()

