import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
import numpy as np
import os.path
import mnist
import moving_mnist as mm
import utils

class TrainerBase(object):
    def __init__(self,
                 input_seq_length,
                 output_seq_length,
                 data_folder,
                 out_folder,
                 obj_per_image,
                 data_epoches_regen,
                 batch_size,
                 model,
                 loss_function,
                 train_drop_out,
                 epoch_start,
                 epoch_stop,
                 save_every_epoch,
                 tf_board_dir):
        self.data_epoches_regen_ = data_epoches_regen
        self.input_seq_length_ = input_seq_length
        self.output_seq_length_ = output_seq_length
        self.batch_size_ = batch_size
        self.model_ = model
        self.loss_function_ = loss_function
        self.train_drop_out_ = train_drop_out
        self.epoch_start_ = epoch_start
        self.epoch_stop_ = epoch_stop
        self.obj_per_image_ = obj_per_image
        self.data_folder_ = data_folder
        self.out_folder_ = out_folder
        self.save_every_epoch_ = save_every_epoch
        self.tf_board_dir_ = tf_board_dir

    def prepare_data(self):
        train_data, valid_data, _ = mm.load_data(self.data_folder_)
        if (0 != self.data_epoches_regen_):
            self.mnist = mnist.read_mnist(self.data_folder_)

        self.valid_data_X, self.valid_data_Y = np.split(mm.encode_data(valid_data), [self.input_seq_length_, ], 1)
        self.train_data_X, self.train_data_Y = np.split(mm.encode_data(train_data), [self.input_seq_length_, ], 1)
        self.train_batch_per_epoch = train_data.shape[0] // self.batch_size_
        self.valid_batch_per_epoch = valid_data.shape[0] // self.batch_size_
        print("batch_per_epoch for train: ", self.train_batch_per_epoch)
        print("batch_per_epoch for validation: ", self.valid_batch_per_epoch)

    def prepare_model(self):
        inp_shape = self.train_data_X.shape
        out_shape = self.train_data_Y.shape
        self.X = tf.placeholder(tf.float32, [None, inp_shape[1], inp_shape[2], inp_shape[3], inp_shape[4]] , name = "X")
        self.Y = tf.placeholder(tf.float32, [None, out_shape[1], out_shape[2], out_shape[3], out_shape[4]] , name = "Y")
        self.keep_prob = tf.placeholder(tf.float32, name = "keep_prob")
        self.is_training = tf.placeholder(tf.bool, name = "is_training")
        logits = self.model_(self.X, self.keep_prob, self.is_training, self.output_seq_length_)
        self.loss = self.loss_function_(self.Y, logits)
        self.output = tf.sigmoid(logits)
        tf.summary.scalar("loss", self.loss)


    def prepare_optimizer(self):
        raise NotImplementedError

    def init_tfboard(self, sess):
        if (self.tf_board_dir_ is not None):
            self.summary = tf.summary.merge_all()
            self.summary_train_writer = tf.summary.FileWriter(self.tf_board_dir_ + '/train', sess.graph)
            self.summary_validate_writer = tf.summary.FileWriter(self.tf_board_dir_ + '/validate')

    def start_train(self, sess):
        pass

    def data_regenerate(self):
        print("regenerate train data")
        shape = self.train_data_X.shape
        seq_length = shape[-1] + self.train_data_Y.shape[-1]
        train_data = mm.generate_data(mnist, shape=(shape[2], shape[3]), seq_len=seq_length, seq_cnt=shape[0], obj_per_image=self.obj_per_image_)
        self.train_data_X, self.train_data_Y = np.split(mm.encode_data(train_data), [self.input_seq_length_, ], 1)

    def start_epoch(self, sess, epoch):
        return True

    def train_epoch(self, sess, epoch):
        #temp = tf.get_collection(utils.BN_MOVING_MEAN_COLLECTION)[0]
        #if (0 == epoch):
        #    print(temp)

        self.train_data_X, self.train_data_Y = utils.unison_shuffle(self.train_data_X, self.train_data_Y)
        train_loss_sum = 0
        for i in range(self.train_batch_per_epoch):
            data_X = self.train_data_X[i * self.batch_size_ : (i+1) * self.batch_size_, :, :, :, :]
            data_Y = self.train_data_Y[i * self.batch_size_ : (i+1) * self.batch_size_, :, :, :, :]
            if (self.tf_board_dir_ is None):
                loss, _ = sess.run([self.loss, self.train_step], feed_dict={self.X: data_X, self.Y: data_Y, self.keep_prob: self.train_drop_out_, self.is_training: True})
            else:
                summary, loss, _ = sess.run([self.summary, self.loss, self.train_step], feed_dict={self.X: data_X, self.Y: data_Y, self.keep_prob: self.train_drop_out_, self.is_training: True})
                self.summary_train_writer.add_summary(summary, epoch * self.train_batch_per_epoch + i)
            if (0 == (i % 25)):
                print("{}. loss = {}".format(sess.run(self.global_step), loss))
            train_loss_sum = train_loss_sum + loss

        print("epoch average loss = {}, lr = {}".format(train_loss_sum / self.train_batch_per_epoch, sess.run(self.lr)))
        return train_loss_sum

    def validate_epoch(self, sess, epoch):
        valid_loss_sum = 0
        for i in range(self.valid_batch_per_epoch):
            data_X = self.valid_data_X[i * self.batch_size_ : (i+1) * self.batch_size_, :, :, :, :]
            data_Y = self.valid_data_Y[i * self.batch_size_ : (i+1) * self.batch_size_, :, :, :, :]
            if (self.tf_board_dir_ is None):
                loss = sess.run(self.loss, feed_dict={self.X: data_X, self.Y: data_Y, self.keep_prob: 1.0, self.is_training: False})
            else:
                summary, loss = sess.run([self.summary, self.loss], feed_dict={self.X: data_X, self.Y: data_Y, self.keep_prob: 1.0, self.is_training: False})
                self.summary_validate_writer.add_summary(summary, epoch * self.valid_batch_per_epoch + i)
            valid_loss_sum = valid_loss_sum + loss

        print("epoch validation loss = {}".format(valid_loss_sum / self.valid_batch_per_epoch))
        return valid_loss_sum

    def end_epoch(self, sess, epoch):
        model_path = os.path.join(self.out_folder_, 'model')
        self.saver.save(sess, model_path, global_step=epoch)
        if (0 == (epoch % self.save_every_epoch_) and (0 != epoch)):
            data = self.valid_data_X[0 : self.batch_size_, :, :, :, :]
            output_ = sess.run(self.output, feed_dict={self.X: data, self.keep_prob: 1.0, self.is_training: False})
            dataset = mm.decode_data(output_)
            folder_path = "{}l3_epoch_{}".format(self.out_folder_, epoch)
            try:
                os.mkdir(folder_path)
            except:
                pass
            mm.save_data(dataset, folder_path)
        return True

    def end_train(self, sess):
        pass

    def train(self):
        self.prepare_data()
        self.prepare_model()
        utils.print_def_graph_vars()
        self.prepare_optimizer()
        self.saver = tf.train.Saver(max_to_keep = 20, allow_empty=True)
        with tf.Session() as sess:
            self.init_tfboard(sess)
            sess.run(tf.global_variables_initializer())
            if (0 < self.epoch_start_):
                model_path = os.path.join(self.out_folder_, 'model-%d' % (self.epoch_start_ - 1))
                self.saver.restore(sess, model_path)
            sess.graph.finalize()
            self.start_train(sess)
            for epoch in range(self.epoch_start_, self.epoch_stop_ + 1):
                if (0 != self.data_epoches_regen_):
                    if (epoch > self.epoch_start_) and ((epoch - self.epoch_start_) % self.data_epoches_regen_ == 0):
                        self.data_regenerate()
                print("Epoch: ", epoch)
                if (not self.start_epoch(sess, epoch)):
                    break
                self.train_epoch(sess, epoch)
                self.validate_epoch(sess, epoch)
                if (not self.end_epoch(sess, epoch)):
                    break
            self.end_train(sess)

class SimpleTrainer(TrainerBase):
    def __init__(self,
                 input_seq_length,
                 output_seq_length,
                 data_folder,
                 out_folder,
                 obj_per_image,
                 data_epoches_regen,
                 batch_size,
                 model,
                 loss_function,
                 train_drop_out,
                 epoch_start,
                 epoch_stop,
                 save_every_epoch,
                 tf_board_dir,
                 lr_base,
                 lr_decay_rate,
                 lr_decay_epoches,
                 lr_staircase):
        super(SimpleTrainer, self).__init__(input_seq_length, output_seq_length, data_folder, out_folder, obj_per_image, data_epoches_regen, batch_size, model, loss_function, train_drop_out, epoch_start, epoch_stop, save_every_epoch, tf_board_dir)
        self.lr_base_ = lr_base
        self.lr_decay_rate_ = lr_decay_rate
        self.lr_decay_epoches_ = lr_decay_epoches
        self.lr_staircase_ = lr_staircase

    def prepare_optimizer(self):
        self.global_step    = tf.Variable(0, trainable=False)
        self.lr = tf.train.exponential_decay(self.lr_base_, self.global_step, self.lr_decay_epoches_ * self.train_batch_per_epoch, self.lr_decay_rate_, staircase=self.lr_staircase_)
        self.train_step = tf.train.AdamOptimizer(self.lr).minimize(self.loss, global_step=self.global_step)

class SmartTrainer(TrainerBase):
    def __init__(self,
                 input_seq_length,
                 output_seq_length,
                 data_folder,
                 out_folder,
                 obj_per_image,
                 data_epoches_regen,
                 batch_size,
                 model,
                 loss_function,
                 train_drop_out,
                 epoch_start,
                 epoch_stop,
                 save_every_epoch,
                 tf_board_dir,
                 lr_base,
                 lr_decay_rate,
                 lr_decay_epoches,
                 lr_min,
                 validate_check_epoches):
        super(SmartTrainer, self).__init__(input_seq_length, output_seq_length, data_folder, out_folder, obj_per_image, data_epoches_regen, batch_size, model, loss_function, train_drop_out, epoch_start, epoch_stop, save_every_epoch, tf_board_dir)
        self.lr_base_ = lr_base
        self.lr_decay_rate_ = lr_decay_rate
        self.lr_decay_epoches_ = lr_decay_epoches
        self.lr_min_ = lr_min
        self.validate_check_epoches_ = validate_check_epoches

    def prepare_optimizer(self):
        self.global_step  = tf.Variable(0, trainable=False)
        self.lr           = tf.Variable(self.lr_base_, trainable=False)
        self.lr_dec       = tf.assign(self.lr, self.lr * tf.constant(self.lr_decay_rate_))
        self.train_step = tf.train.AdamOptimizer(self.lr).minimize(self.loss, global_step=self.global_step)

    def start_train(self, sess):
        super(SmartTrainer, self).start_train(sess)
        self.valid_loss_sum_all = []
        self.train_loss_sum_all = []

    def data_regenerate(self):
        super(SmartTrainer, self).data_renerate()
        self.valid_loss_sum_all = []
        self.train_loss_sum_all = []

    def start_epoch(self, sess, epoch):
        if (not super(SmartTrainer, self).start_epoch(sess, epoch)):
            return False
        if (len(self.train_loss_sum_all) <= self.lr_decay_epoches_):
            return True
        if (sum(self.train_loss_sum_all[-self.lr_decay_epoches_ - 1 : -1]) >= 1.02 * self.lr_decay_epoches_ * self.train_loss_sum_all[-1]):
            return True
        print("learning rate decay")
        sess.run(self.lr_dec)
        self.train_loss_sum_all = []
        if (sess.run(self.lr) < self.lr_min_):
            print("stop because learning rate to small")
            return False
        return True

    def train_epoch(self, sess, epoch):
        train_loss_sum = super(SmartTrainer, self).train_epoch(sess, epoch)
        self.train_loss_sum_all.append(train_loss_sum)
        return train_loss_sum

    def validate_epoch(self, sess, epoch):
        valid_loss_sum = super(SmartTrainer, self).validate_epoch(sess, epoch)
        self.valid_loss_sum_all.append(valid_loss_sum)
        return valid_loss_sum

    def end_epoch(self, sess, epoch):
        if (not super(SmartTrainer, self).end_epoch(sess, epoch)):
            return False
        if (self.valid_loss_sum_all[-1] * self.train_batch_per_epoch <=
            self.train_loss_sum_all[-1] * self.valid_batch_per_epoch * 1.1):
            return True
        if (len(self.valid_loss_sum_all) <= self.validate_check_epoches_):
            return True
        if (sum(self.valid_loss_sum_all[-self.validate_check_epoches_ - 1 : -1]) < 0.95 * self.validate_check_epoches_ * self.valid_loss_sum_all[-1]):
            print('average of previous loss: {} smaller than last {}'.format(sum(self.valid_loss_sum_all[-self.validate_check_epoches_ - 1 : -1]) / self.validate_check_epoches_, self.valid_loss_sum_all[-1]))
            return False # valid loss grow up on 10% relate to average of previous self.validate_check_epoches_ losses
        return True



