import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.framework import tensor_shape
from tensorflow.python.training import moving_averages
import numpy as np

def _batch_norm_rnn(inputs, name, rnn_step, use_beta, decay, epsilon, is_training):
    axis = list(range(len(inputs.get_shape()) - 1))
    out_chns = inputs.get_shape()[-1]

    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
        if (use_beta):
            beta = tf.get_variable('beta', [out_chns], initializer=tf.zeros_initializer)
        else:
            beta = tf.constant(0.0, shape = [out_chns], name='beta')

        gamma = tf.get_variable('gamma', [out_chns], initializer=tf.constant_initializer(0.1))
        moving_mean = tf.get_variable('moving_mean_step{}'.format(rnn_step), [out_chns],
                                      initializer=tf.zeros_initializer,
                                      trainable=False)
        moving_variance = tf.get_variable('moving_variance_step{}'.format(rnn_step), [out_chns],
                                          initializer=tf.ones_initializer(),
                                          trainable=False)
        def batch_norm_train():
            mean, variance = tf.nn.moments(inputs, axis)
            update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, decay, False)
            update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, decay, False)
            #update_moving_mean = tf.assign(moving_mean, moving_mean * decay + mean * (1 - decay))
            #update_moving_variance = tf.assign(moving_variance, moving_variance * decay + variance * (1 - decay))
            with tf.control_dependencies([update_moving_mean, update_moving_variance]):
                return tf.nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon)

        def batch_norm_none_train():
            return tf.nn.batch_normalization(inputs, moving_mean, moving_variance, beta, gamma, epsilon)

        return tf.cond(is_training, batch_norm_train, batch_norm_none_train)

class ConvLSTMCell(rnn_cell_impl.RNNCell):
    def __init__(self,
                 name,
                 input_shape,
                 kernel_sz,
                 out_chns,
                 strides = 1,
                 dilations = 1,
                 bn_params = None,
                 weight_initializer = tf.contrib.layers.xavier_initializer_conv2d(),
                 bias_initializer = tf.zeros_initializer,
                 weight_regularizer = None,
                 bias_regularizer = None,
                 has_peephole = True):
        super(ConvLSTMCell, self).__init__(name=name)
        self.name_                  = name
        self.out_chns_              = out_chns
        self.kernel_sz_             = kernel_sz
        self.strides_               = strides
        self.dilations_             = dilations
        self.bn_params_             = bn_params
        self.weight_initializer_    = weight_initializer
        self.bias_initializer_      = bias_initializer
        self.weight_regularizer_    = weight_regularizer
        self.bias_regularizer_      = bias_regularizer
        self.has_peephole_          = has_peephole

        self.output_size_ = tensor_shape.TensorShape([dim // self.strides_ for dim in input_shape[:-1]] + [out_chns])

        self.state_size_ = rnn_cell_impl.LSTMStateTuple(self.output_size_, self.output_size_)

        self.call_idx_ = 0

    @property
    def state_size(self):
        return self.state_size_

    @property
    def output_size(self):
        return self.output_size_

    def __call__(self, inputs, state, scope=None):
        cell, hidden = state
        cell_shape = cell.get_shape().as_list()
        if (self.has_peephole_):
            peephole_shape = [cell_shape[1], cell_shape[2], cell_shape[3]]

        with tf.variable_scope(self.name_, reuse=tf.AUTO_REUSE):
            hid_chns = hidden.get_shape()[-1]
            Wh = tf.get_variable('Wh',
                                 shape=[self.kernel_sz_, self.kernel_sz_, hid_chns, 4 * self.out_chns_],
                                 initializer = self.weight_initializer_,
                                 regularizer = self.weight_regularizer_)
            if (self.dilations_ != 1):
                conv_h  = tf.nn.convolution(hidden, Wh, 'SAME')
            else:
                conv_h  = tf.nn.conv2d(hidden, Wh, [1, 1, 1, 1], padding='SAME')

            if self.bn_params_ is not None:
                conv_h = _batch_norm_rnn(conv_h, "WhBN", self.call_idx_, False, **self.bn_params_)
            conv = conv_h

            if (inputs is not None):
                inp_chns = inputs.get_shape()[-1]
                Wx = tf.get_variable('Wx',
                                     shape=[self.kernel_sz_, self.kernel_sz_, inp_chns, 4 * self.out_chns_],
                                     initializer = self.weight_initializer_,
                                     regularizer = self.weight_regularizer_)
                if (self.dilations_ != 1):
                    conv_x  = tf.nn.convolution(inputs, Wx, 'SAME', strides=[self.strides_, self.strides_], dilation_rate=[self.dilations_, self.dilations_])
                else:
                    conv_x  = tf.nn.conv2d(inputs, Wx, [1, self.strides_, self.strides_, 1], padding='SAME')
                if self.bn_params_ is not None:
                    conv_x = _batch_norm_rnn(conv_x, "WxBN", self.call_idx_, False, **self.bn_params_)
                conv = conv_h + conv_x

            biases = tf.get_variable('biases', [4 * self.out_chns_],
                                     initializer = self.bias_initializer_,
                                     regularizer = self.bias_regularizer_)
            conv  = tf.nn.bias_add(conv, biases)
            gates = tf.split(conv, 4, -1)
            input_gate, new_input, forget_gate, output_gate = gates

            if (self.has_peephole_):
                Wci = tf.get_variable("wci", peephole_shape,
                                      initializer = tf.zeros_initializer,
                                      regularizer = self.weight_regularizer_)
                input_gate = input_gate + cell * Wci
                Wcf = tf.get_variable("wcf", peephole_shape,
                                      initializer = tf.zeros_initializer,
                                      regularizer = self.weight_regularizer_)
                forget_gate = forget_gate + cell * Wcf

            new_cell = tf.sigmoid(forget_gate) * cell + tf.sigmoid(input_gate) * tf.tanh(new_input)
            if (self.has_peephole_):
                Wco = tf.get_variable("wco", peephole_shape,
                                      initializer = tf.zeros_initializer,
                                      regularizer = self.weight_regularizer_)
                output_gate = output_gate + new_cell * Wco
            new_cell_bn = new_cell
            if self.bn_params_ is not None:
                new_cell_bn = _batch_norm_rnn(new_cell, "CellBN", self.call_idx_, True, **self.bn_params_)
            output = tf.tanh(new_cell_bn) * tf.sigmoid(output_gate)

        self.call_idx_ = self.call_idx_ + 1
        new_state = (new_cell, output)
        return output, new_state

class MultiConvLSTMCell(rnn_cell_impl.RNNCell):
    def __init__(self,
                 name,
                 input_shape,
                 kernel_sz,
                 out_chns,
                 strides = 1,
                 dilations = 1,
                 bn_params = None,
                 input_keep_prob = None,
                 state_keep_prob = None,
                 output_keep_prob = None,
                 weight_initializer = tf.contrib.layers.xavier_initializer_conv2d(),
                 bias_initializer = tf.zeros_initializer,
                 weight_regularizer = None,
                 bias_regularizer = None,
                 has_peephole = True,
                 concat_all_cells_output = True):
        assert len(kernel_sz) == len(out_chns)
        assert isinstance(strides, int) or (len(kernel_sz) == len(strides))
        if (isinstance(strides, int)):
            strides = len(kernel_sz) * [strides]
        assert isinstance(dilations, int) or (len(kernel_sz) == len(dilations))
        if (isinstance(dilations, int)):
            dilations = len(kernel_sz) * [dilations]

        super(MultiConvLSTMCell, self).__init__(name=name)
        self.name_  = name
        self.concat_all_cells_output_ = concat_all_cells_output
        self.cells_ = []
        with tf.variable_scope(self.name_, reuse=tf.AUTO_REUSE):
            for i, (kernel, out_ch, stride, dilation) in enumerate(zip(kernel_sz, out_chns, strides, dilations)):
                cell_name = "cell{}".format(i)
                cell = ConvLSTMCell(cell_name, input_shape, kernel, out_ch, stride, dilation,
                                    bn_params,
                                    weight_initializer, bias_initializer,
                                    weight_regularizer, bias_regularizer,
                                    has_peephole)
                input_shape = cell.output_size

                if (input_keep_prob is not None) or (state_keep_prob is not None) or (output_keep_prob is not None):
                    cell = tf.contrib.rnn.DropoutWrapper(cell,
                                                         input_keep_prob = input_keep_prob if (input_keep_prob is not None) else 1.0,
                                                         state_keep_prob = state_keep_prob if (state_keep_prob is not None) else 1.0,
                                                         output_keep_prob = output_keep_prob if (output_keep_prob is not None) else 1.0,
                                                         variational_recurrent = True, dtype = tf.float32)
                self.cells_.append(cell)

    @property
    def state_size(self):
        return [cell.state_size for cell in self.cells_]

    @property
    def output_size(self):
        return self.cells_[-1].output_size

    def __call__(self, inputs, state, scope=None):
        new_states = []
        outputs = []
        cur_inp = inputs
        with tf.variable_scope(self.name_, reuse=tf.AUTO_REUSE):
            for i, cell in enumerate(self.cells_):
                cur_inp, cell_state = cell(cur_inp, state[i])
                if (self.concat_all_cells_output_):
                    outputs.append(cur_inp)
                new_states.append(cell_state)
        if (self.concat_all_cells_output_):
            return tf.concat(outputs, -1), new_states
        return cur_inp, new_states
