import tensorflow as tf


def fire_module(inputs, squeeze_depth, expand_depth, train_config, scope, is_training):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        squeeze = tf.layers.conv2d(
            inputs,
            filters=squeeze_depth,
            kernel_size=1,
            strides=1,
            kernel_regularizer=tf.keras.regularizers.l2(train_config["weight_decay"]),
            bias_regularizer=tf.keras.regularizers.l2(train_config["weight_decay"]),
            name="squeeze"
        )
        squeeze = tf.layers.batch_normalization(squeeze, training=is_training)
        squeeze = tf.nn.relu(squeeze)
        expand1x1 = tf.layers.conv2d(
            squeeze,
            filters=expand_depth,
            kernel_size=1,
            strides=1,
            kernel_regularizer=tf.keras.regularizers.l2(train_config["weight_decay"]),
            bias_regularizer=tf.keras.regularizers.l2(train_config["weight_decay"]),
            name="expand1x1"
        )
        expand1x1 = tf.layers.batch_normalization(expand1x1, training=is_training)
        expand1x1 = tf.nn.relu(expand1x1)
        expand3x3 = tf.layers.conv2d(
            expand1x1,
            filters=expand_depth,
            kernel_size=3,
            strides=1,
            padding="same",
            kernel_regularizer=tf.keras.regularizers.l2(train_config["weight_decay"]),
            bias_regularizer=tf.keras.regularizers.l2(train_config["weight_decay"]),
            name="expand3x3"
        )
        expand3x3 = tf.layers.batch_normalization(expand3x3, training=is_training)
        expand3x3 = tf.nn.relu(expand3x3)
        return tf.concat([expand1x1, expand3x3], axis=3)


def get_model(images, train_config, is_training):
    with tf.variable_scope('squeezenet', reuse=tf.AUTO_REUSE):
        conv1 = tf.layers.conv2d(
            images,
            filters=64,
            kernel_size=3,
            strides=2,
            kernel_regularizer=tf.keras.regularizers.l2(train_config["weight_decay"]),
            bias_regularizer=tf.keras.regularizers.l2(train_config["weight_decay"]),
            name="conv1"
        )
        conv1 = tf.layers.batch_normalization(conv1, training=is_training)
        conv1 = tf.nn.relu(conv1)
        pool1 = tf.layers.max_pooling2d(
            conv1,
            pool_size=3,
            strides=2,
            name="pool1"
        )
        fire2 = fire_module(
            pool1,
            squeeze_depth=16,
            expand_depth=64,
            train_config=train_config,
            scope="fire2",
            is_training=is_training
        )
        fire3 = fire_module(
            fire2,
            squeeze_depth=16,
            expand_depth=64,
            train_config=train_config,
            scope="fire3",
            is_training=is_training
        )
        pool3 = tf.layers.max_pooling2d(
            fire3,
            pool_size=3,
            strides=2,
            name="pool3"
        )
        fire4 = fire_module(
            pool3,
            squeeze_depth=32,
            expand_depth=128,
            train_config=train_config,
            scope="fire4",
            is_training=is_training
        )
        fire5 = fire_module(
            fire4,
            squeeze_depth=32,
            expand_depth=128,
            train_config=train_config,
            scope="fire5",
            is_training=is_training
        )
        pool5 = tf.layers.max_pooling2d(
            fire5,
            pool_size=3,
            strides=2,
            name="pool5"
        )
        fire6 = fire_module(
            pool5,
            squeeze_depth=48,
            expand_depth=192,
            train_config=train_config,
            scope="fire6",
            is_training=is_training
        )
        fire7 = fire_module(
            fire6,
            squeeze_depth=48,
            expand_depth=192,
            train_config=train_config,
            scope="fire7",
            is_training=is_training
        )
        fire8 = fire_module(
            fire7,
            squeeze_depth=64,
            expand_depth=256,
            train_config=train_config,
            scope="fire8",
            is_training=is_training
        )
        fire9 = fire_module(
            fire8,
            squeeze_depth=64,
            expand_depth=256,
            train_config=train_config,
            scope="fire9",
            is_training=is_training
        )
        conv10 = tf.layers.conv2d(
            fire9,
            filters=2,
            kernel_size=1,
            strides=1,
            use_bias=False,
            kernel_regularizer=tf.keras.regularizers.l2(train_config["weight_decay"]),
            bias_regularizer=tf.keras.regularizers.l2(train_config["weight_decay"]),
            name="conv10"
        )
        pool10 = tf.layers.average_pooling2d(
            conv10,
            pool_size=13,
            strides=1,
            name="pool10"
        )
        return tf.layers.flatten(pool10, name="logits")
