import tensorflow as tf
from tensorflow.python.training import moving_averages

def _two_element_tuple(int_or_tuple):
    """Converts `int_or_tuple` to height, width.
    Several of the functions that follow accept arguments as either
    a tuple of 2 integers or a single integer.  A single integer
    indicates that the 2 values of the tuple are the same.
    This functions normalizes the input value by always returning a tuple.
    Args:
        int_or_tuple: A list of 2 ints, a single int or a tf.TensorShape.
    Returns:
        A tuple with 2 values.
    Raises:
        ValueError: If `int_or_tuple` it not well formed.
    """
    if isinstance(int_or_tuple, (list, tuple)):
        if len(int_or_tuple) != 2:
            raise ValueError('Must be a list with 2 elements: %s' % int_or_tuple)
        return int(int_or_tuple[0]), int(int_or_tuple[1])
    if isinstance(int_or_tuple, int):
        return int(int_or_tuple), int(int_or_tuple)
    if isinstance(int_or_tuple, tf.TensorShape):
        if len(int_or_tuple) == 2:
            return int_or_tuple[0], int_or_tuple[1]
    raise ValueError('Must be an int, a list with 2 elements or a TensorShape of length 2')

def _add_loss_regulizer(var, coeff):
    if (0 != coeff):
        print("_add_loss_regulizer")
        loss = tf.multiply(tf.nn.l2_loss(var), coeff)
        tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, loss)

def concat(values, axis, name="concat"):
    if (1 <= int(tf.__version__.split('.')[0])) :
        concated = tf.concat(values, axis)
    else :
        concated = tf.concat_v2(values, axis)
    return concated

def prelu(inputs, name, shared_axes = None, reuse = True) :
    param_shape = list(inputs.get_shape()[1:])
    if shared_axes is not None:
        for i in shared_axes:
            param_shape[i - 1] = 1
    with tf.variable_scope(name, reuse=reuse):
        _alpha = tf.get_variable("alpha",
                                 shape = param_shape,
                                 initializer=tf.zeros_initializer)
    return tf.maximum(_alpha*inputs, inputs)

def batch_norm( inputs,
                decay = 0.999,
                epsilon = 0.001,
                train = False):

    axis = list(range(len(inputs.get_shape()) - 1))
    out_chns = inputs.get_shape()[-1]

    beta = tf.get_variable('beta', [out_chns], initializer=tf.zeros_initializer)
    gamma = tf.get_variable('gamma', [out_chns], initializer=tf.ones_initializer())

    moving_mean = tf.get_variable('moving_mean', [out_chns], initializer=tf.zeros_initializer, trainable=False)
    moving_variance = tf.get_variable('moving_variance', [out_chns], initializer=tf.ones_initializer(), trainable=False)
    control_inputs = []
    if train:
        mean, variance = tf.nn.moments(inputs, axis)
        update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, decay)
        update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, decay)
        control_inputs = [update_moving_mean, update_moving_variance]
    else:
        mean = moving_mean
        variance = moving_variance
    with tf.control_dependencies(control_inputs):
        output = tf.nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon)
    return output

def conv2d( inputs,
            name,
            kernel_sz,
            out_chns,
            strides = 1,
            dilation_rate = 1,
            bn_params = None,
            weight_initializer = tf.contrib.layers.xavier_initializer_conv2d(),
            bias_initializer = tf.zeros_initializer,
            loss_reg_weight = 0.0,
            loss_reg_biases = 0.0,
            padding = 'SAME',
            non_linear_func = tf.nn.relu,
            use_bias = True,
            train = False,
            reuse = True):

    inp_chns = inputs.get_shape()[-1]
    kernel_h, kernel_w = _two_element_tuple(kernel_sz)
    stride_h, stride_w = _two_element_tuple(strides)
    dilation_h, dilation_w = _two_element_tuple(dilation_rate)

    with tf.variable_scope(name, reuse=reuse):
        kernel = tf.get_variable('weights',
                                 shape=[kernel_h, kernel_w, inp_chns, out_chns],
                                 initializer = weight_initializer)
        if train:
            _add_loss_regulizer(kernel, loss_reg_weight);

        if (1 == dilation_rate) :
            conv  = tf.nn.conv2d(inputs, kernel, [1, stride_h, stride_w, 1], padding=padding)
        else :
            conv  = tf.nn.convolution(inputs, kernel, padding, strides=[stride_h, stride_w], dilation_rate=[dilation_h, dilation_w])

        if use_bias :
            biases = tf.get_variable('biases', [out_chns], initializer=bias_initializer)
            if train:
                _add_loss_regulizer(biases, loss_reg_biases);
            conv  = tf.nn.bias_add(conv, biases)

        if bn_params is not None:
            bn_params['train'] = train
            conv = batch_norm(conv, **bn_params)

        if non_linear_func is not None:
            conv = non_linear_func(conv, name=name)
    return conv

def const_conv2d( inputs,
                  name,
                  strides = 1,
                  padding = 'SAME',
                  non_linear_func = tf.nn.relu,
                  use_bias = True):

    stride_h, stride_w = _two_element_tuple(strides)

    kernel = tf.get_default_graph().get_tensor_by_name(name + '/weights:0')
    biases = tf.get_default_graph().get_tensor_by_name(name + '/biases:0')
    conv   = tf.nn.bias_add(tf.nn.conv2d(inputs, kernel, [1, stride_h, stride_w, 1], padding = padding), biases)
    if non_linear_func is not None:
        conv = non_linear_func(conv, name = name)
    return conv

def max_pool( inputs,
              name,
              kernel_sz = 2,
              strides = 2,
              padding='SAME') :
    kernel_h, kernel_w = _two_element_tuple(kernel_sz)
    stride_h, stride_w = _two_element_tuple(strides)
    return tf.nn.max_pool(inputs,
                          ksize = [1, kernel_h, kernel_w, 1],
                          strides = [1, stride_h, stride_w, 1],
                          padding = padding,
                          name = name)

def avg_pool( inputs,
              name,
              kernel_sz = 2,
              strides = 2,
              padding='SAME') :
    kernel_h, kernel_w = _two_element_tuple(kernel_sz)
    stride_h, stride_w = _two_element_tuple(strides)
    return tf.nn.avg_pool(inputs,
                          ksize = [1, kernel_h, kernel_w, 1],
                          strides = [1, stride_h, stride_w, 1],
                          padding = padding,
                          name = name)

def flatten(inputs):
    if len(inputs.get_shape()) < 2:
        raise ValueError('Inputs must be have a least 2 dimensions')
    dims = inputs.get_shape()[1:]
    k = dims.num_elements()
    return tf.reshape(inputs, [-1, k])

def fc( inputs,
        name,
        out_chns,
        bn_params = None,
        weight_initializer = tf.contrib.layers.xavier_initializer(),
        bias_initializer = tf.zeros_initializer,
        loss_reg_weight = 0.0,
        loss_reg_biases = 0.0,
        padding = 'SAME',
        non_linear_func = tf.nn.relu,
        dropout_keep_prob = 0.5,
        train = False,
        reuse = True):
    with tf.variable_scope(name, reuse=reuse):
        weights = tf.get_variable('weights', shape=[inputs.get_shape()[1], out_chns],
                                  initializer = weight_initializer)
        biases  = tf.get_variable('biases', shape=[out_chns],
                                  initializer = bias_initializer)

        if train:
            _add_loss_regulizer(weights, loss_reg_weight);
            _add_loss_regulizer(biases, loss_reg_biases);

        fc = tf.matmul(inputs, weights) + biases

        if bn_params is not None:
            bn_params['train'] = train
            fc = batch_norm(fc, **bn_params)

        if non_linear_func is not None:
            fc = non_linear_func(fc, name=name)

    if train and 0 < dropout_keep_prob:
        fc = tf.nn.dropout(fc, dropout_keep_prob)
    return fc

