"""ResNet models for Keras.

Reference paper:

  - [Deep Residual Learning for Image Recognition] (https://arxiv.org/abs/1512.03385) (CVPR 2015)
"""

from typing import Callable

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import backend
from tensorflow.keras.applications import resnet
from tensorflow.keras.models import Model
from tensorflow.python.keras.utils import data_utils


SANDBOX_WEIGHTS_PATH="http://s3.mds.yandex.net/sandbox-tmp/3165736689/resnet_weights.h5"
WEIGHTS_HASH="4d473c1dd8becc155b73f8504c6f6626"


def padd_for_aligning_pixels(inputs: tf.Tensor):
    """This padding operation is here to make the pixels of the output perfectly aligned.
    It will make the output perfectly aligned at stride 32.
    """

    chan = inputs.shape[3]
    b4_stride = 32.0
    shape2d = tf.shape(inputs)[1:3]
    new_shape2d = tf.cast(
        tf.math.ceil(tf.cast(shape2d, tf.float32) / b4_stride) * b4_stride, tf.int32)
    pad_shape2d = new_shape2d - shape2d
    inputs = tf.pad(inputs,
                    tf.stack([[0, 0],
                              [3, 2 + pad_shape2d[0]],
                              [3, 2 + pad_shape2d[1]],
                              [0, 0]]),
                    name='conv1_pad')  # yapf: disable
    inputs.set_shape([None, None, None, chan])
    return inputs


def ResNet(stack_fn: Callable,
           preprocessing_func: Callable,
           preact: bool,
           use_bias: bool,
           model_name='resnet',
           input_shape=None,
           **kwargs) -> tf.keras.Model:
    """Instantiates the ResNet, ResNetV2, and ResNeXt architecture.

    Reference paper:

    [Deep Residual Learning for Image Recognition]
        (https://arxiv.org/abs/1512.03385) (CVPR 2015)
    Optionally loads weights pre-trained on ImageNet.
    Note that the data format convention used by the model is
    the one specified in your Keras config at `~/.keras/keras.json`.
    Caution: Be sure to properly pre-process your inputs to the application.
    Please see `applications.resnet.preprocess_input` for an example.

    Arguments:
        stack_fn: a function that returns output tensor for the
            stacked residual blocks.
        preprocessing_func: a function that returns the corresponding preprocessing of the network.
        preact: whether to use pre-activation or not
            (True for ResNetV2, False for ResNet and ResNeXt).
        use_bias: whether to use biases for convolutional layers or not
            (True for ResNet and ResNetV2, False for ResNeXt).
        model_name: string, model name.
        include_top: whether to include the fully-connected
            layer at the top of the network.
        input_shape: optional shape tuple, only to be specified
            if `include_top` is False (otherwise the input shape
            has to be `(224, 224, 3)` (with `channels_last` data format)
            or `(3, 224, 224)` (with `channels_first` data format).
            It should have exactly 3 inputs channels.
        kwargs: For backwards compatibility only.

    Returns:
        A `keras.Model` instance.

    Raises:
        *ValueError*: in case of invalid argument for `weights`, or invalid input shape.
    """
    if kwargs:
        raise ValueError('Unknown argument(s): %s' % (kwargs,))

    img_input = layers.Input(shape=input_shape)

    bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1

    x = layers.Lambda(preprocessing_func, name="preprocess_input")(img_input)
    x = layers.Lambda(padd_for_aligning_pixels, name="padd_for_aligning_pixels")(x)
    x = layers.Conv2D(64, 7, strides=2, use_bias=use_bias, name='conv1_conv')(x)

    if not preact:
        x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name='conv1_bn')(x)
        x = layers.Activation('relu', name='conv1_relu')(x)

    x = layers.ZeroPadding2D(padding=((1, 0), (1, 0)), name='pool1_pad')(x)
    x = layers.MaxPooling2D(3, strides=2, name='pool1_pool')(x)

    outputs = stack_fn(x)

    if preact:
        x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name='post_bn')(outputs[-1])
        outputs[-1] = layers.Activation('relu', name='post_relu')(x)

    inputs = img_input

    # Create model.
    model = Model(inputs, outputs, name=model_name)

    # Load weights.
    weights_path = data_utils.get_file("resnet_weights.h5", SANDBOX_WEIGHTS_PATH, cache_subdir='models', file_hash=WEIGHTS_HASH)
    model.load_weights(weights_path)

    return model


def block1(x, filters, kernel_size=3, stride=1,
           conv_shortcut=True, name=None):
    """A residual block.
    # Arguments
        x: input tensor.
        filters: integer, filters of the bottleneck layer.
        kernel_size: default 3, kernel size of the bottleneck layer.
        stride: default 1, stride of the first layer.
        conv_shortcut: default True, use convolution shortcut if True,
            otherwise identity shortcut.
        name: string, block label.
    # Returns
        Output tensor for the residual block.
    """
    bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1

    if conv_shortcut is True:
        shortcut = layers.Conv2D(4 * filters, 1, strides=stride,
                                 name=name + '_0_conv')(x)
        shortcut = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
                                             name=name + '_0_bn')(shortcut)
    else:
        shortcut = x

    x = layers.Conv2D(filters, 1, strides=stride, name=name + '_1_conv')(x)
    x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
                                  name=name + '_1_bn')(x)
    x = layers.Activation('relu', name=name + '_1_relu')(x)

    x = layers.Conv2D(filters, kernel_size, padding='SAME',
                      name=name + '_2_conv')(x)
    x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
                                  name=name + '_2_bn')(x)
    x = layers.Activation('relu', name=name + '_2_relu')(x)

    x = layers.Conv2D(4 * filters, 1, name=name + '_3_conv')(x)
    x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5,
                                  name=name + '_3_bn')(x)

    x = layers.Add(name=name + '_add')([shortcut, x])
    x = layers.Activation('relu', name=name + '_out')(x)
    return x


def stack1(x, filters, blocks, stride1=2, name=None):
    """A set of stacked residual blocks.
    # Arguments
        x: input tensor.
        filters: integer, filters of the bottleneck layer in a block.
        blocks: integer, blocks in the stacked blocks.
        stride1: default 2, stride of the first layer in the first block.
        name: string, stack label.
    # Returns
        Output tensor for the stacked blocks.
    """
    x = block1(x, filters, stride=stride1, name=name + '_block1')
    for i in range(2, blocks + 1):
        x = block1(x, filters, conv_shortcut=False, name=name + '_block' + str(i))
    return x


def ResNet50(input_shape=None, **kwargs):
    """Instantiates the ResNet50 architecture."""
    # Or set to None
    def stack_fn(x):
        c2 = stack1(x, 64, 3, stride1=1, name='conv2')
        c3 = stack1(c2, 128, 4, name='conv3')
        c4 = stack1(c3, 256, 6, name='conv4')
        c5 = stack1(c4, 512, 3, name='conv5')
        return [c2, c3, c4, c5]

    return ResNet(stack_fn,
                  resnet.preprocess_input,
                  False,
                  True,
                  'resnet50',
                  input_shape,
                  **kwargs)
