import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
import numpy as np
import tf_utils

def get_weights(weights, name):
    if (weights is None):
        return tf.contrib.layers.xavier_initializer_conv2d()
    return tf.constant_initializer(weights[name])

def get_bias(weights, name):
    if (weights is None):
        return tf.zeros_initializer
    return tf.constant_initializer(weights[name])

def bilinear_upsample_weights(kernel_sz, stride, chns):
    weights = np.zeros((kernel_sz, kernel_sz, chns, chns), dtype=np.float32)
    center = stride - (1 if kernel_sz % 2 == 1 else 0.5)
    og = np.ogrid[:kernel_sz, :kernel_sz]
    upsample_kernel = (1 - abs(og[0] - center) / stride) * (1 - abs(og[1] - center) / stride)
    for i in range(chns):
        weights[:, :, i, i] = upsample_kernel
    return weights

def model_l5(X, params, is_training):
    regulizer = tf.contrib.layers.l2_regularizer(params["weight_decay"]) if params["weight_decay"] > 0. else None
    x_shape = tf.shape(X)
    out_shape = tf.stack([x_shape[0], x_shape[1], x_shape[2], 1])

    bn_params = None
    if "bn_decay" in params.keys() and "bn_epsilon" in params.keys():
        bn_params = {'decay': params["bn_decay"], 'epsilon': params["bn_epsilon"], 'is_training': is_training}

    weights = None
    if ("vgg16_weights" in params.keys() and bn_params is None):
        # if we will find retrained weights for model with BN then we will should apply BN params from model (gamma, beta etc)
        # to our model
        #assert bn_params is None
        filepath = os.path.abspath(params["vgg16_weights"])
        weights = np.load(filepath)

    conv1_1 = tf_utils.conv2d(X, "conv1_1", 3, 64,
                              bn_params = bn_params, use_bias = (bn_params == None),
                              weight_initializer = get_weights(weights, 'conv1_1_W'),
                              bias_initializer = get_bias(weights, 'conv1_1_b'),
                              weight_regularizer = regulizer,
                              bias_regularizer = regulizer)
    conv1_2 = tf_utils.conv2d(conv1_1, "conv1_2", 3, 64,
                              bn_params = bn_params, use_bias = (bn_params == None),
                              weight_regularizer = regulizer,
                              bias_regularizer = regulizer)

    #############################################################################
    pool1 = tf.nn.max_pool(conv1_2, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = "SAME", name = "pool1")

    conv2_1 = tf_utils.conv2d(pool1, "conv2_1", 3, 128,
                              bn_params = bn_params, use_bias = (bn_params == None),
                              weight_initializer = get_weights(weights, 'conv2_1_W'),
                              bias_initializer = get_bias(weights, 'conv2_1_b'),
                              weight_regularizer = regulizer,
                              bias_regularizer = regulizer)
    conv2_2 = tf_utils.conv2d(conv2_1, "conv2_2", 3, 128,
                              weight_initializer = get_weights(weights, 'conv2_2_W'),
                              bias_initializer = get_bias(weights, 'conv2_2_b'),
                              bn_params = bn_params, use_bias = (bn_params == None),
                              weight_regularizer = regulizer,
                              bias_regularizer = regulizer)

    #############################################################################
    pool2 = tf.nn.max_pool(conv2_2, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = "SAME", name = "pool2")

    conv3_1 = tf_utils.conv2d(pool2, "conv3_1", 3, 256,
                              bn_params = bn_params, use_bias = (bn_params == None),
                              weight_initializer = get_weights(weights, 'conv3_1_W'),
                              bias_initializer = get_bias(weights, 'conv3_1_b'),
                              weight_regularizer = regulizer,
                              bias_regularizer = regulizer)
    conv3_2 = tf_utils.conv2d(conv3_1, "conv3_2", 3, 256,
                              bn_params = bn_params, use_bias = (bn_params == None),
                              weight_initializer = get_weights(weights, 'conv3_2_W'),
                              bias_initializer = get_bias(weights, 'conv3_2_b'),
                              weight_regularizer = regulizer,
                              bias_regularizer = regulizer)
    conv3_3 = tf_utils.conv2d(conv3_2, "conv3_3", 3, 256,
                              bn_params = bn_params, use_bias = (bn_params == None),
                              weight_initializer = get_weights(weights, 'conv3_3_W'),
                              bias_initializer = get_bias(weights, 'conv3_3_b'),
                              weight_regularizer = regulizer,
                              bias_regularizer = regulizer)

    #############################################################################
    pool3 = tf.nn.max_pool(conv3_3, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = "SAME", name = "pool3")

    conv4_1 = tf_utils.conv2d(pool3, "conv4_1", 3, 512,
                              bn_params = bn_params, use_bias = (bn_params == None),
                              weight_initializer = get_weights(weights, 'conv4_1_W'),
                              bias_initializer = get_bias(weights, 'conv4_1_b'),
                              weight_regularizer = regulizer,
                              bias_regularizer = regulizer)
    conv4_2 = tf_utils.conv2d(conv4_1, "conv4_2", 3, 512,
                              bn_params = bn_params, use_bias = (bn_params == None),
                              weight_initializer = get_weights(weights, 'conv4_2_W'),
                              bias_initializer = get_bias(weights, 'conv4_2_b'),
                              weight_regularizer = regulizer,
                              bias_regularizer = regulizer)
    conv4_3 = tf_utils.conv2d(conv4_2, "conv4_3", 3, 512,
                              bn_params = bn_params, use_bias = (bn_params == None),
                              weight_initializer = get_weights(weights, 'conv4_3_W'),
                              bias_initializer = get_bias(weights, 'conv4_3_b'),
                              weight_regularizer = regulizer,
                              bias_regularizer = regulizer)

    #############################################################################
    pool4 = tf.nn.max_pool(conv4_3, ksize = [1, 2, 2, 1], strides = [1, 2, 2, 1], padding = "SAME", name = "pool4")

    conv5_1 = tf_utils.conv2d(pool4, "conv5_1", 3, 512,
                              bn_params = bn_params, use_bias = (bn_params == None),
                              weight_initializer = get_weights(weights, 'conv5_1_W'),
                              bias_initializer = get_bias(weights, 'conv5_1_b'),
                              weight_regularizer = regulizer,
                              bias_regularizer = regulizer)
    conv5_2 = tf_utils.conv2d(conv5_1, "conv5_2", 3, 512,
                              bn_params = bn_params, use_bias = (bn_params == None),
                              weight_initializer = get_weights(weights, 'conv5_2_W'),
                              bias_initializer = get_bias(weights, 'conv5_2_b'),
                              weight_regularizer = regulizer,
                              bias_regularizer = regulizer)
    conv5_3 = tf_utils.conv2d(conv5_2, "conv5_3", 3, 512,
                              bn_params = bn_params, use_bias = (bn_params == None),
                              weight_initializer = get_weights(weights, 'conv5_3_W'),
                              bias_initializer = get_bias(weights, 'conv5_3_b'),
                              weight_regularizer = regulizer,
                              bias_regularizer = regulizer)

    #############################################################################
    branch1 = tf_utils.conv2d(conv1_2, 'branch1', 1, 1,
                              non_linear_func = None,
                              weight_regularizer = regulizer,
                              bias_regularizer = regulizer)

    #############################################################################
    branch2 = tf_utils.conv2d(conv2_2, 'branch2', 1, 1,
                              non_linear_func = None,
                              weight_regularizer = regulizer,
                              bias_regularizer = regulizer)
    branch2 = tf_utils.conv2d_transpose(branch2, 'upscale2', 4, out_shape, 2, out_chns = 1,
                                        non_linear_func = None,
                                        weight_regularizer = regulizer,
                                        bias_regularizer = regulizer)

    #############################################################################
    branch3 = tf_utils.conv2d(conv3_3, 'branch3', 1, 1,
                              non_linear_func = None,
                              weight_regularizer = regulizer,
                              bias_regularizer = regulizer)

    branch3 = tf_utils.conv2d_transpose(branch3, 'upscale4', 8, out_shape, 4, out_chns = 1,
                                        non_linear_func = None,
                                        weight_regularizer = regulizer,
                                        bias_regularizer = regulizer)

    #############################################################################
    branch4 = tf_utils.conv2d(conv4_3, 'branch4', 1, 1,
                              non_linear_func = None,
                              weight_regularizer = regulizer,
                              bias_regularizer = regulizer)

    branch4 = tf_utils.conv2d_transpose(branch4, 'upscale8', 16, out_shape, 8, out_chns = 1,
                                        non_linear_func = None,
                                        weight_regularizer = regulizer,
                                        bias_regularizer = regulizer)

    #############################################################################
    branch5 = tf_utils.conv2d(conv5_3, 'branch5', 1, 1,
                              non_linear_func = None,
                              weight_regularizer = regulizer,
                              bias_regularizer = regulizer)

    branch5 = tf_utils.conv2d_transpose(branch5, 'upscale16', 32, out_shape, 16, out_chns = 1,
                                        non_linear_func = None,
                                        weight_regularizer = regulizer,
                                        bias_regularizer = regulizer)

    concat = tf.concat([branch1, branch2, branch3, branch4, branch5], axis = -1, name = 'concat')
    fuse = tf_utils.conv2d(concat, 'fuse', 1, 1,
                           bn_params = bn_params, use_bias = (bn_params == None),
                           non_linear_func = None,
                           weight_initializer = tf.constant_initializer(0.2),
                           weight_regularizer = regulizer,
                           bias_regularizer = regulizer)
    return [branch1, branch2, branch3, branch4, branch5, fuse]

