Skip to content

resnet9

ResNet9

A small 9-layer ResNet Tensorflow model for cifar10 image classification. The model architecture is from https://github.com/davidcpage/cifar10-fast

Parameters:

Name Type Description Default
input_size Tuple[int, int, int]

The size of the input tensor (height, width, channels).

(32, 32, 3)
classes int

The number of outputs the model should generate.

10

Raises:

Type Description
ValueError

Length of input_size is not 3.

ValueError

input_size[0] or input_size[1] is not a multiple of 16.

Returns:

Type Description
Model

A TensorFlow ResNet9 model.

Source code in fastestimator/fastestimator/architecture/tensorflow/resnet9.py
def ResNet9(input_size: Tuple[int, int, int] = (32, 32, 3), classes: int = 10) -> tf.keras.Model:
    """A small 9-layer ResNet Tensorflow model for cifar10 image classification.
    The model architecture is from https://github.com/davidcpage/cifar10-fast

    Args:
        input_size: The size of the input tensor (height, width, channels).
        classes: The number of outputs the model should generate.

    Raises:
        ValueError: Length of `input_size` is not 3.
        ValueError: `input_size`[0] or `input_size`[1] is not a multiple of 16.

    Returns:
        A TensorFlow ResNet9 model.
    """
    _check_input_size(input_size)

    # prep layers
    inp = layers.Input(shape=input_size)
    x = layers.Conv2D(64, 3, padding='same')(inp)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.LeakyReLU(alpha=0.1)(x)
    # layer1
    x = layers.Conv2D(128, 3, padding='same')(x)
    x = layers.MaxPool2D()(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.LeakyReLU(alpha=0.1)(x)
    x = layers.Add()([x, residual(x, 128)])
    # layer2
    x = layers.Conv2D(256, 3, padding='same')(x)
    x = layers.MaxPool2D()(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.LeakyReLU(alpha=0.1)(x)
    # layer3
    x = layers.Conv2D(512, 3, padding='same')(x)
    x = layers.MaxPool2D()(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.LeakyReLU(alpha=0.1)(x)
    x = layers.Add()([x, residual(x, 512)])
    # layers4
    x = layers.GlobalMaxPool2D()(x)
    x = layers.Flatten()(x)
    x = layers.Dense(classes)(x)
    x = layers.Activation('softmax', dtype='float32')(x)
    model = tf.keras.Model(inputs=inp, outputs=x)

    return model

residual

A ResNet unit for ResNet9.

Parameters:

Name Type Description Default
x Tensor

Input Keras tensor.

required
num_channel int

The number of layer channel.

required
Return

Output Keras tensor.

Source code in fastestimator/fastestimator/architecture/tensorflow/resnet9.py
def residual(x: tf.Tensor, num_channel: int) -> tf.Tensor:
    """A ResNet unit for ResNet9.

    Args:
        x: Input Keras tensor.
        num_channel: The number of layer channel.

    Return:
        Output Keras tensor.
    """
    x = layers.Conv2D(num_channel, 3, padding='same')(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.LeakyReLU(alpha=0.1)(x)
    x = layers.Conv2D(num_channel, 3, padding='same')(x)
    x = layers.BatchNormalization(momentum=0.8)(x)
    x = layers.LeakyReLU(alpha=0.1)(x)
    return x