import tensorflow as tf


def residual_block(inputs, filters, kernel_size, strides, train_config, scope, is_training):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        branch1_1 = tf.layers.conv2d(
            inputs,
            filters=filters,
            kernel_size=kernel_size,
            strides=strides,
            padding="same",
            kernel_regularizer=tf.keras.regularizers.l2(train_config["weight_decay"]),
            bias_regularizer=tf.keras.regularizers.l2(train_config["weight_decay"]),
            name="branch1_1"
        )
        branch1_1 = tf.layers.batch_normalization(branch1_1, training=is_training)
        branch1_1 = tf.nn.relu(branch1_1)
        branch1_2 = tf.layers.conv2d(
            branch1_1,
            filters=filters,
            kernel_size=kernel_size,
            strides=1,
            padding="same",
            kernel_regularizer=tf.keras.regularizers.l2(train_config["weight_decay"]),
            bias_regularizer=tf.keras.regularizers.l2(train_config["weight_decay"]),
            name="branch1_2"
        )
        branch1_2 = tf.layers.batch_normalization(branch1_2, training=is_training)
        if strides != 1:
            branch2 = tf.layers.conv2d(
                inputs,
                filters=filters,
                kernel_size=1,
                strides=strides,
                padding="same",
                kernel_regularizer=tf.keras.regularizers.l2(train_config["weight_decay"]),
                bias_regularizer=tf.keras.regularizers.l2(train_config["weight_decay"]),
                name="branch2"
            )
            branch2 = tf.layers.batch_normalization(branch2, training=is_training)
        else:
            branch2 = tf.identity(inputs, name="branch2")
        return tf.nn.relu(branch1_2 + branch2)


def get_model(images, train_config, is_training):
    with tf.variable_scope('resnet18', reuse=tf.AUTO_REUSE):
        conv1 = tf.layers.conv2d(
            images,
            filters=64,
            kernel_size=7,
            strides=2,
            padding="same",
            kernel_regularizer=tf.keras.regularizers.l2(train_config["weight_decay"]),
            bias_regularizer=tf.keras.regularizers.l2(train_config["weight_decay"]),
            name="conv1"
        )
        conv1 = tf.layers.batch_normalization(conv1, training=is_training)
        conv1 = tf.nn.relu(conv1)
        pool1 = tf.layers.max_pooling2d(
            conv1,
            pool_size=3,
            strides=2,
            name="pool1"
        )
        conv2_1 = residual_block(
            pool1,
            filters=64,
            kernel_size=3,
            strides=1,
            train_config=train_config,
            scope="conv2_1",
            is_training=is_training
        )
        conv2_2 = residual_block(
            conv2_1,
            filters=64,
            kernel_size=3,
            strides=1,
            train_config=train_config,
            scope="conv2_2",
            is_training=is_training
        )
        conv2_3 = residual_block(
            conv2_2,
            filters=64,
            kernel_size=3,
            strides=1,
            train_config=train_config,
            scope="conv2_3",
            is_training=is_training
        )
        conv3_1 = residual_block(
            conv2_3,
            filters=128,
            kernel_size=3,
            strides=2,
            train_config=train_config,
            scope="conv3_1",
            is_training=is_training
        )
        conv3_2 = residual_block(
            conv3_1,
            filters=128,
            kernel_size=3,
            strides=1,
            train_config=train_config,
            scope="conv3_2",
            is_training=is_training
        )
        conv3_3 = residual_block(
            conv3_2,
            filters=128,
            kernel_size=3,
            strides=1,
            train_config=train_config,
            scope="conv3_3",
            is_training=is_training
        )
        conv3_4 = residual_block(
            conv3_3,
            filters=128,
            kernel_size=3,
            strides=1,
            train_config=train_config,
            scope="conv3_4",
            is_training=is_training
        )
        conv4_1 = residual_block(
            conv3_4,
            filters=256,
            kernel_size=3,
            strides=2,
            train_config=train_config,
            scope="conv4_1",
            is_training=is_training
        )
        conv4_2 = residual_block(
            conv4_1,
            filters=256,
            kernel_size=3,
            strides=1,
            train_config=train_config,
            scope="conv4_2",
            is_training=is_training
        )
        conv4_3 = residual_block(
            conv4_2,
            filters=256,
            kernel_size=3,
            strides=1,
            train_config=train_config,
            scope="conv4_3",
            is_training=is_training
        )
        conv4_4 = residual_block(
            conv4_3,
            filters=256,
            kernel_size=3,
            strides=1,
            train_config=train_config,
            scope="conv4_4",
            is_training=is_training
        )
        conv4_5 = residual_block(
            conv4_4,
            filters=256,
            kernel_size=3,
            strides=1,
            train_config=train_config,
            scope="conv4_5",
            is_training=is_training
        )
        conv4_6 = residual_block(
            conv4_5,
            filters=256,
            kernel_size=3,
            strides=1,
            train_config=train_config,
            scope="conv4_6",
            is_training=is_training
        )
        conv5_1 = residual_block(
            conv4_6,
            filters=512,
            kernel_size=3,
            strides=2,
            train_config=train_config,
            scope="conv5_1",
            is_training=is_training
        )
        conv5_2 = residual_block(
            conv5_1,
            filters=512,
            kernel_size=3,
            strides=1,
            train_config=train_config,
            scope="conv5_2",
            is_training=is_training
        )
        conv5_3 = residual_block(
            conv5_2,
            filters=512,
            kernel_size=3,
            strides=1,
            train_config=train_config,
            scope="conv5_3",
            is_training=is_training
        )
        pool5 = tf.layers.average_pooling2d(
            conv5_3,
            pool_size=7,
            strides=1,
            name="pool5"
        )
        pool5 = tf.layers.flatten(pool5)
        pool5 = tf.layers.dropout(pool5, rate=0.5, training=is_training)
        fc = tf.layers.dense(pool5, 2, name="logits")
        return fc
