import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
from tensorflow.python.training import moving_averages
import numpy as np

BN_MOVING_MEAN_COLLECTION = "BN_MOVING_MEAN_COLLECTION"
BN_MOVING_VAR_COLLECTION = "BN_MOVING_VAR_COLLECTION"


def _two_element_tuple(int_or_tuple):
    """Converts `int_or_tuple` to height, width.
    Several of the functions that follow accept arguments as either
    a tuple of 2 integers or a single integer.  A single integer
    indicates that the 2 values of the tuple are the same.
    This functions normalizes the input value by always returning a tuple.
    Args:
        int_or_tuple: A list of 2 ints, a single int or a tf.TensorShape.
    Returns:
        A tuple with 2 values.
    Raises:
        ValueError: If `int_or_tuple` it not well formed.
    """
    if isinstance(int_or_tuple, (list, tuple)):
        if len(int_or_tuple) != 2:
            raise ValueError('Must be a list with 2 elements: %s' % int_or_tuple)
        return int(int_or_tuple[0]), int(int_or_tuple[1])
    if isinstance(int_or_tuple, int):
        return int(int_or_tuple), int(int_or_tuple)
    if isinstance(int_or_tuple, tf.TensorShape):
        if len(int_or_tuple) == 2:
            return int_or_tuple[0], int_or_tuple[1]
    raise ValueError('Must be an int, a list with 2 elements or a TensorShape of length 2')


def unison_shuffle(a, b):
    assert len(a) == len(b)
    perm = np.arange(len(a))
    np.random.shuffle(perm)
    return a[perm], b[perm]


def batch_norm(inputs, name, decay, epsilon, is_training):
    axis = list(range(len(inputs.get_shape()) - 1))
    out_chns = inputs.get_shape()[-1]

    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
        beta = tf.get_variable('beta', [out_chns], initializer=tf.zeros_initializer)
        gamma = tf.get_variable('gamma', [out_chns], initializer=tf.ones_initializer())

        moving_mean = tf.get_variable('moving_mean', [out_chns],
                                      initializer=tf.zeros_initializer,
                                      trainable=False,
                                      collections=[tf.GraphKeys.GLOBAL_VARIABLES, BN_MOVING_MEAN_COLLECTION])
        moving_variance = tf.get_variable('moving_variance', [out_chns],
                                          initializer=tf.ones_initializer(),
                                          trainable=False,
                                          collections=[tf.GraphKeys.GLOBAL_VARIABLES, BN_MOVING_VAR_COLLECTION])

        def batch_norm_train():
            mean, variance = tf.nn.moments(inputs, axis)
            update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, decay, False)
            update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, decay, False)
            # update_moving_mean = tf.assign(moving_mean, moving_mean * decay + mean * (1 - decay))
            # update_moving_variance = tf.assign(moving_variance, moving_variance * decay + variance * (1 - decay))
            with tf.control_dependencies([update_moving_mean, update_moving_variance]):
                return tf.nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon)

        def batch_norm_none_train():
            return tf.nn.batch_normalization(inputs, moving_mean, moving_variance, beta, gamma, epsilon)

        return tf.cond(is_training, batch_norm_train, batch_norm_none_train)


def batch_norm_rnn(inputs, name, rnn_step, use_beta, decay, epsilon, is_training):
    axis = list(range(len(inputs.get_shape()) - 1))
    out_chns = inputs.get_shape()[-1]

    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
        if (use_beta):
            beta = tf.get_variable('beta', [out_chns], initializer=tf.zeros_initializer)
        else:
            beta = tf.constant(0.0, shape=[out_chns], name='beta')
        # gamma = tf.get_variable('gamma', [out_chns], initializer=tf.ones_initializer())
        gamma = tf.get_variable('gamma', [out_chns], initializer=tf.constant_initializer(0.1))

        moving_mean = tf.get_variable('moving_mean_step{}'.format(rnn_step), [out_chns],
                                      initializer=tf.zeros_initializer,
                                      trainable=False,
                                      collections=[tf.GraphKeys.GLOBAL_VARIABLES, BN_MOVING_MEAN_COLLECTION])
        moving_variance = tf.get_variable('moving_variance_step{}'.format(rnn_step), [out_chns],
                                          initializer=tf.ones_initializer(),
                                          trainable=False,
                                          collections=[tf.GraphKeys.GLOBAL_VARIABLES, BN_MOVING_VAR_COLLECTION])

        def batch_norm_train():
            mean, variance = tf.nn.moments(inputs, axis)
            update_moving_mean = moving_averages.assign_moving_average(moving_mean, mean, decay, False)
            update_moving_variance = moving_averages.assign_moving_average(moving_variance, variance, decay, False)
            # update_moving_mean = tf.assign(moving_mean, moving_mean * decay + mean * (1 - decay))
            # update_moving_variance = tf.assign(moving_variance, moving_variance * decay + variance * (1 - decay))
            with tf.control_dependencies([update_moving_mean, update_moving_variance]):
                return tf.nn.batch_normalization(inputs, mean, variance, beta, gamma, epsilon)

        def batch_norm_none_train():
            return tf.nn.batch_normalization(inputs, moving_mean, moving_variance, beta, gamma, epsilon)

        return tf.cond(is_training, batch_norm_train, batch_norm_none_train)


def conv2d(inputs,
           name,
           kernel_sz,
           out_chns,
           strides=1,
           bn_params=None,
           weight_initializer=tf.contrib.layers.xavier_initializer_conv2d(),
           bias_initializer=tf.zeros_initializer,
           weight_regularizer=None,
           bias_regularizer=None,
           padding='SAME',
           non_linear_func=tf.nn.relu,
           use_bias=True):
    inp_chns = inputs.get_shape()[-1]
    kernel_h, kernel_w = _two_element_tuple(kernel_sz)
    stride_h, stride_w = _two_element_tuple(strides)

    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
        kernel = tf.get_variable('weights',
                                 shape=[kernel_h, kernel_w, inp_chns, out_chns],
                                 initializer=weight_initializer,
                                 regularizer=weight_regularizer)

        conv = tf.nn.conv2d(inputs, kernel, [1, stride_h, stride_w, 1], padding=padding)
        if use_bias:
            biases = tf.get_variable('biases', [out_chns],
                                     initializer=bias_initializer,
                                     regularizer=bias_regularizer)
            conv = tf.nn.bias_add(conv, biases)

        if bn_params is not None:
            conv = batch_norm(conv, "BN", **bn_params)

        if non_linear_func is not None:
            conv = non_linear_func(conv, name=name)
    return conv


def conv2d_transpose(inputs,
                     name,
                     kernel_sz,
                     out_shape=None,
                     strides=1,
                     inp_chns=None,
                     out_chns=None,
                     bn_params=None,
                     weight_initializer=tf.contrib.layers.xavier_initializer_conv2d(),
                     bias_initializer=tf.zeros_initializer,
                     weight_regularizer=None,
                     bias_regularizer=None,
                     non_linear_func=tf.nn.relu,
                     use_bias=True):
    if inp_chns is None:
        inp_shape = tf.shape(inputs)
        inp_chns = inputs.get_shape()[-1]

    kernel_h, kernel_w = _two_element_tuple(kernel_sz)
    stride_h, stride_w = _two_element_tuple(strides)

    if (out_shape is None):
        out_shape = tf.stack([inp_shape[0], inp_shape[1] * stride_h, inp_shape[2] * stride_w, out_chns])
    elif (out_chns is None):
        out_chns = out_shape[-1]
    else:
        out_shape = tf.stack([out_shape[0], out_shape[1], out_shape[2], out_chns])

    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
        kernel = tf.get_variable('weights',
                                 shape=[kernel_h, kernel_w, out_chns, inp_chns],
                                 initializer=weight_initializer,
                                 regularizer=weight_regularizer)

        deconv = tf.nn.conv2d_transpose(inputs, kernel, out_shape, [1, stride_h, stride_w, 1])

        if use_bias:
            biases = tf.get_variable('biases', [out_chns],
                                     initializer=bias_initializer,
                                     regularizer=bias_regularizer)
            deconv = tf.nn.bias_add(deconv, biases)

        if bn_params is not None:
            deconv = batch_norm(deconv, "BN", **bn_params)

        if non_linear_func is not None:
            deconv = non_linear_func(deconv, name=name)
    return deconv


def flatten(inputs):
    if len(inputs.get_shape()) < 2:
        raise ValueError('Inputs must be have a least 2 dimensions')
    dims = inputs.get_shape()[1:]
    k = dims.num_elements()
    return tf.reshape(inputs, [-1, k])


def fc(inputs,
       name,
       out_chns,
       bn_params=None,
       weight_initializer=tf.contrib.layers.xavier_initializer(),
       bias_initializer=tf.zeros_initializer,
       weight_regularizer=None,
       bias_regularizer=None,
       padding='SAME',
       non_linear_func=tf.nn.relu,
       use_bias=True):
    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
        weights = tf.get_variable('weights', shape=[inputs.get_shape()[1], out_chns],
                                  initializer=weight_initializer,
                                  regularizer=weight_regularizer)
        fc = tf.matmul(inputs, weights)

        if (use_bias):
            biases = tf.get_variable('biases', shape=[out_chns],
                                     initializer=bias_initializer,
                                     regularizer=bias_regularizer)
            fc = tf.nn.bias_add(fc, biases)

        if bn_params is not None:
            fc = batch_norm(fc, "BN", **bn_params)

        if non_linear_func is not None:
            fc = non_linear_func(fc, name=name)

    return fc


def dropout(x, name, keep_prob, is_training):
    with tf.variable_scope(name, reuse=tf.AUTO_REUSE):
        def dropout_train():
            return tf.nn.dropout(x, keep_prob)

        def dropout_test():
            return tf.nn.dropout(x, 1)

        return tf.cond(is_training, dropout_train, dropout_test)


def equalize_histogram(channel):
    values_range = tf.constant([0., 1.], dtype=tf.float32)
    histogram = tf.histogram_fixed_width(channel, values_range, 256)
    cdf = tf.cumsum(histogram)
    cdf_min = cdf[tf.reduce_min(tf.where(tf.greater(cdf, 0)))]
    ch_shape = tf.shape(channel)
    px_map = tf.to_float(cdf - cdf_min) / tf.to_float(ch_shape[-3] * ch_shape[-2] - 1)
    eq_hist = tf.expand_dims(tf.gather_nd(px_map, tf.cast(channel * 255, tf.int32)), 3)
    return eq_hist


def equalize_histogram_bgr(image_bgr):
    assert image_bgr.shape.as_list()[0] == 1, "only batches with size 1 supported by tf equalize histogram"
    image_rgb = tf.reverse(image_bgr, [-1])
    image_hsv = tf.image.rgb_to_hsv(image_rgb)
    h, s, v = tf.split(image_hsv, 3, 3)
    v = equalize_histogram(v)
    image_hsv = tf.squeeze(tf.stack([h, s, v], 3), 4)
    image_rgb_eq_hist = tf.image.hsv_to_rgb(image_hsv)
    return tf.reverse(image_rgb_eq_hist, [-1])


def average_gradients(tower_grads):
    # for grad_and_vars in zip(*tower_grads):
    #    for g, v in grad_and_vars:
    #        print(g)
    #        print(v)
    #        print("-----")

    average_grads = []
    for grad_and_vars in zip(*tower_grads):
        # Note that each grad_and_vars looks like the following:
        #   ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
        grads = []
        for g, _ in grad_and_vars:
            expanded_g = tf.expand_dims(g, 0)
            grads.append(expanded_g)

        grad = tf.concat(axis=0, values=grads)
        grad = tf.reduce_mean(grad, 0)

        v = grad_and_vars[0][1]
        grad_and_var = (grad, v)
        average_grads.append(grad_and_var)
    return average_grads


def print_def_graph_vars():
    graph_def = tf.get_default_graph().as_graph_def()
    for node in graph_def.node:
        if "Variable" in node.op:
            print(node.name, node.op, [input for input in node.input])
