import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
import numpy as np
import utils
import conv_lstm

BN_DECAY        = 0.999
#BN_DECAY        = 0.9
BN_EPSILON      = 0.001

WEIGHT_DECAY = 0.001

def _seq2list(X, axis = 1):
    seq_len = X.get_shape().as_list()[axis]
    x_list = tf.split(X, seq_len, axis)
    x_list = [tf.squeeze(x_, [axis]) for x_ in x_list]
    return x_list

def convLSTM_l3_cond(X, keep_prob, is_training, out_seq_len, batch_norm, peephole, dropout):
    if (batch_norm):
        bn_params = {'decay': BN_DECAY, 'epsilon': BN_EPSILON, 'is_training': is_training}
    else:
        bn_params = None

    x_list = _seq2list(X)

    input_shape = x_list[0].get_shape().as_list()[1:] # without batch size
    input_channels = input_shape[-1]

    encoder = conv_lstm.MultiConvLSTMCell("encoder", input_shape, [5, 5, 5], [128, 64, 64],
                                          bn_params = bn_params, has_peephole=peephole,
                                          weight_regularizer = tf.contrib.layers.l2_regularizer(WEIGHT_DECAY),
                                          bias_regularizer = tf.contrib.layers.l2_regularizer(WEIGHT_DECAY))

    decoder = conv_lstm.MultiConvLSTMCell("decoder", input_shape, [5, 5, 5], [128, 64, 64],
                                          bn_params = bn_params, has_peephole=peephole,
                                          weight_regularizer = tf.contrib.layers.l2_regularizer(WEIGHT_DECAY),
                                          bias_regularizer = tf.contrib.layers.l2_regularizer(WEIGHT_DECAY))

    _, states = tf.nn.static_rnn(encoder, x_list, dtype = tf.float32)

    temp = x_list[-1]
    logits_list = []
    for i in range(out_seq_len):
        output, states = decoder(temp, states)
        if (dropout):
            output = tf.nn.dropout(output, keep_prob)
        logits = utils.conv2d(output, "fc_out", 1, input_channels,
                              non_linear_func=None,
                              weight_regularizer = tf.contrib.layers.l2_regularizer(WEIGHT_DECAY),
                              bias_regularizer = tf.contrib.layers.l2_regularizer(WEIGHT_DECAY))
        logits_list.append(tf.expand_dims(logits, 1))
        temp = tf.sigmoid(logits)

    logits = tf.concat(logits_list, 1)
    return logits

def convLSTM_l3(X, keep_prob, is_training, out_seq_len, batch_norm, peephole, dropout):
    if (batch_norm):
        bn_params = {'decay': BN_DECAY, 'epsilon': BN_EPSILON, 'is_training': is_training}
    else:
        bn_params = None

    x_list = _seq2list(X)

    input_shape = x_list[0].get_shape().as_list()[1:] # without batch size
    input_channels = input_shape[-1]

    encoder = conv_lstm.MultiConvLSTMCell("encoder", input_shape, [5, 5, 5], [128, 64, 64],
                                          bn_params = bn_params, has_peephole=peephole,
                                          keep_prob = keep_prob if dropout else None,
                                          weight_regularizer = tf.contrib.layers.l2_regularizer(WEIGHT_DECAY),
                                          bias_regularizer = tf.contrib.layers.l2_regularizer(WEIGHT_DECAY))

    decoder = conv_lstm.MultiConvLSTMCell("decoder", input_shape, [5, 5, 5], [128, 64, 64],
                                          bn_params = bn_params, has_peephole=peephole,
                                          keep_prob = keep_prob if dropout else None,
                                          weight_regularizer = tf.contrib.layers.l2_regularizer(WEIGHT_DECAY),
                                          bias_regularizer = tf.contrib.layers.l2_regularizer(WEIGHT_DECAY))

    _, states = tf.nn.static_rnn(encoder, x_list, dtype = tf.float32)
    logits_list = []
    for i in range(out_seq_len):
        output, states = decoder(None, states)
        if (dropout):
            output = tf.nn.dropout(output, keep_prob)
        logits = utils.conv2d(output, "fc_out", 1, input_channels,
                              non_linear_func=None,
                              weight_regularizer = tf.contrib.layers.l2_regularizer(WEIGHT_DECAY),
                              bias_regularizer = tf.contrib.layers.l2_regularizer(WEIGHT_DECAY))
        logits_list.append(tf.expand_dims(logits, 1))

    logits = tf.concat(logits_list, 1)
    return logits


def get_model(conditional, levels, batch_norm, peephole, dropout):
    if (conditional):
        if (levels == 3):
            return lambda X, keep_prob, is_training, out_seq_len: convLSTM_l3_cond(X, keep_prob, is_training, out_seq_len, batch_norm, peephole, dropout)
    else:
        if (levels == 3):
            return lambda X, keep_prob, is_training, out_seq_len: convLSTM_l3(X, keep_prob, is_training, out_seq_len, batch_norm, peephole, dropout)
    raise RuntimeError("Unable to make model")
