import tensorflow as tf
import tfutils.layers as tfutils
import params

BN_DECAY   = 0.999
BN_EPSILON = 0.001
WEIGHT_DECAY = 0.0005

def model(images, train=False, reuse=True):
    """Function for creating TensorFlow model

    Args:
        images: 4-D Tensor. Created by calling something like.
                tf.placeholder(tf.float32, shape=[barch_sz, height, width, channels])

    Returns:
        A 2-D Tensor. Logits of images classes [batch_sz, classes_cnt]


    """
    bn_params = {'decay': BN_DECAY, 'epsilon': BN_EPSILON}

    #####################################################
    # conv1
    #   input   - [BS H W  IMAGE_CHNS]
    #   output  - [BS H W          32]
    # pool1
    #   input   - [BS H   W   32]
    #   output  - [BS H/2 W/2 32]
    #####################################################
    conv1 = tfutils.conv2d(images, 'conv1', 5, 32, bn_params = bn_params, loss_reg_weight = WEIGHT_DECAY, train = train, reuse = reuse)
    pool1 = tfutils.max_pool(conv1, name='pool1')
    norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.0001, beta=0.75, name='norm1')

    #####################################################
    # conv2
    #   input   - [BS H/2 W/2  32]
    #   output  - [BS H/2 W/2  64]
    # pool2
    #   input   - [BS H/2 W/2 64]
    #   output  - [BS H/4 W/4 64]
    #####################################################
    conv2_1 = tfutils.conv2d(norm1, 'conv2_1', 5, 64, bn_params = bn_params, loss_reg_weight = WEIGHT_DECAY, train = train, reuse = reuse)
    conv2_2 = tfutils.conv2d(conv2_1, 'conv2_2', 5, 64, bn_params = bn_params, loss_reg_weight = WEIGHT_DECAY, train = train, reuse = reuse)
    pool2 = tfutils.max_pool(conv2_2, name='pool2')
    norm2 = tf.nn.lrn(pool2, 4, bias=1.0, alpha=0.0001, beta=0.75, name='norm2')

    #####################################################
    # conv3
    #   input   - [BS H/4 W/4  64]
    #   output  - [BS H/4 W/4 128]
    # pool3
    #   input   - [BS H/4 W/4 128]
    #   output  - [BS H/8 W/8 128]
    #####################################################
    conv3 = tfutils.conv2d(norm2, 'conv3', 5, 128, bn_params = bn_params, loss_reg_weight = WEIGHT_DECAY, train = train, reuse = reuse)
    pool3 = tfutils.max_pool(conv3, name='pool3')
    norm3 = tf.nn.lrn(pool3, 4, bias=1.0, alpha=0.0001, beta=0.75, name='norm3')

    #####################################################
    # reshape
    #   input   - [BS  H/8  W/8  128]
    #   output  - [BS    W/8*H/8*128]
    #####################################################
    reshape = tfutils.flatten(norm3)

    #####################################################
    # fc1
    #   input   - [BS W/8*H/8*128]
    #   output  - [BS 1024]
    #####################################################
    fc1 = tfutils.fc(reshape, 'fc1', 1024, bn_params = bn_params, loss_reg_weight = WEIGHT_DECAY, train = train, reuse = reuse)

    #####################################################
    # fc2
    #   input   - [BS      1024]
    #   output  - [BS CLASS_CNT]
    #####################################################
    fc2 = tfutils.fc(fc1, 'fc2', params.OUT_CLASSES_CNT, dropout_keep_prob = -0.1, bn_params = None, non_linear_func = None, loss_reg_weight = WEIGHT_DECAY, train = train, reuse = reuse)

    return fc2
