SimCLR on CIFAIR10 Image Classification (Tensorflow Backend)¶
[Paper] [Notebook] [TF Implementation] [Torch Implementation]
Labeled datasets are much more expensive than their unlabeled counterparts. It is thus often the case that only a small fraction of total available data can be labeled. Therefore, self-supervised learning algorithms, which don't require labeled data during training, have become a huge topic in ML research recently. In 2020 SimCLR was proposed and achieved 85.8% top-5 accuracy using only 1% of the available labels on the ImageNet dataset.
The idea of SimCLR is to separate visual tasks into two parts: an encoder and a classifier. The encoder projects images to a representation space which is then used by the classifier to make decisions. The encoder doesn't need to know the image class, but it does need to project an "image group" (a group of images generated from the same image with data augmentation) to a cluster. By increasing the similarity of encoded images from the same image groups while reducing similarity between different groups, the encoder can be trained without explicit labels. The process of training the encoder is called "pretraining". Later, users can attach any classifier after the pretrained encoder and finetune the whole model for specific visual tasks. According to the paper, this can achieve good results with only a small fraction of the available data being labeled.
In this tutorial we will demonstrate the implementation of SimCLR with the ciFAIR10 dataset. Some details of this implementation will be different from the original paper. This implementation draws upon the code provided here.
import tempfile
import tensorflow as tf
from tensorflow.keras import layers
import fastestimator as fe
from fastestimator.dataset.data import cifair10
from fastestimator.op.numpyop.meta import Sometimes
from fastestimator.op.numpyop.multivariate import HorizontalFlip, PadIfNeeded, RandomCrop
from fastestimator.op.numpyop.univariate import ColorJitter, GaussianBlur, ToFloat, ToGray
from fastestimator.op.tensorop import LambdaOp, TensorOp
from fastestimator.op.tensorop.loss import CrossEntropy
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
from fastestimator.trace.io import BestModelSaver, ModelSaver
from fastestimator.trace.metric import Accuracy
#training parameters
epochs_pretrain = 50
epochs_finetune = 10
batch_size = 512
train_steps_per_epoch = None
eval_steps_per_epoch = None
save_dir = tempfile.mkdtemp()
Pre-Training Pipeline¶
In the SimCLR paper they emphasized the importance of data augmentation steps and how these can directly impact the quality of the pretrained model. The preprocessing steps include: random cropping, random color jitter, and random Gaussian blur. An image will go through the pipeline and generate two augmented images which constitute an image group (or pair to be more specific). The batch of augmented image pairs will later be used for model pretraining.
train_data, eval_data = cifair10.load_data()
pipeline_pretrain = fe.Pipeline(
train_data=train_data,
batch_size=batch_size,
ops=[
PadIfNeeded(min_height=40, min_width=40, image_in="x", image_out="x"),
# augmentation 1
RandomCrop(32, 32, image_in="x", image_out="x_aug"),
Sometimes(HorizontalFlip(image_in="x_aug", image_out="x_aug"), prob=0.5),
Sometimes(
ColorJitter(inputs="x_aug", outputs="x_aug", brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2),
prob=0.8),
Sometimes(ToGray(inputs="x_aug", outputs="x_aug"), prob=0.2),
Sometimes(GaussianBlur(inputs="x_aug", outputs="x_aug", blur_limit=(3, 3), sigma_limit=(0.1, 2.0)), prob=0.5),
ToFloat(inputs="x_aug", outputs="x_aug"),
# augmentation 2
RandomCrop(32, 32, image_in="x", image_out="x_aug2"),
Sometimes(HorizontalFlip(image_in="x_aug2", image_out="x_aug2"), prob=0.5),
Sometimes(
ColorJitter(inputs="x_aug2", outputs="x_aug2", brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2),
prob=0.8),
Sometimes(ToGray(inputs="x_aug2", outputs="x_aug2"), prob=0.2),
Sometimes(GaussianBlur(inputs="x_aug2", outputs="x_aug2", blur_limit=(3, 3), sigma_limit=(0.1, 2.0)), prob=0.5),
ToFloat(inputs="x_aug2", outputs="x_aug2")
])
Model¶
During SimCLR contrastive learning, the training can be separated into two parts: pretraining and finetuning. In the pretraining step, the encoder is attached to a series of MLPs called the "projection head". During finetuning, the encoder is attached to a classifier called the "supervision head". The paper claimed that using the projection head can help make data more clustered in the representation space.
Although in the original paper they used a ResNet50 model architecture, we will use ResNet9 for faster convergence.
def ResNet9(input_size=(32, 32, 3), dims=128, classes=10):
"""A small 9-layer ResNet Tensorflow model for cifar10 image classification.
The model architecture is from https://github.com/davidcpage/cifar10-fast
Args:
input_size: The size of the input tensor (height, width, channels).
classes: The number of outputs the model should generate.
Raises:
ValueError: Length of `input_size` is not 3.
ValueError: `input_size`[0] or `input_size`[1] is not a multiple of 16.
Returns:
A TensorFlow ResNet9 model.
"""
# prep layers
inp = layers.Input(shape=input_size)
x = layers.Conv2D(64, 3, padding='same')(inp)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.LeakyReLU(alpha=0.1)(x)
# layer1
x = layers.Conv2D(128, 3, padding='same')(x)
x = layers.MaxPool2D()(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.LeakyReLU(alpha=0.1)(x)
x = layers.Add()([x, residual(x, 128)])
# layer2
x = layers.Conv2D(256, 3, padding='same')(x)
x = layers.MaxPool2D()(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.LeakyReLU(alpha=0.1)(x)
# layer3
x = layers.Conv2D(512, 3, padding='same')(x)
x = layers.MaxPool2D()(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.LeakyReLU(alpha=0.1)(x)
x = layers.Add()([x, residual(x, 512)])
# layers4
x = layers.GlobalMaxPool2D()(x)
code = layers.Flatten()(x)
p_head = layers.Dense(dims)(code)
model_con = tf.keras.Model(inputs=inp, outputs=p_head)
s_head = layers.Dense(classes)(code)
s_head = layers.Activation('softmax', dtype='float32')(s_head)
model_finetune = tf.keras.Model(inputs=inp, outputs=s_head)
return model_con, model_finetune
def residual(x, num_channel: int):
"""A ResNet unit for ResNet9.
Args:
x: Input Keras tensor.
num_channel: The number of layer channel.
Return:
Output Keras tensor.
"""
x = layers.Conv2D(num_channel, 3, padding='same')(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.LeakyReLU(alpha=0.1)(x)
x = layers.Conv2D(num_channel, 3, padding='same')(x)
x = layers.BatchNormalization(momentum=0.8)(x)
x = layers.LeakyReLU(alpha=0.1)(x)
return x
model_con, model_finetune = fe.build(model_fn=ResNet9, optimizer_fn=["adam", "adam"])
Pre-Training Network¶
SimCLR uses NT-Xent (the normalized temperature-scaled cross entropy loss) to train the encoder. By reducing the loss it will increase the similarity of positive augemented pairs and decrease the similarity of negative pairs as the following GIF demonstrates. For a detailed formula, please refer to the orginal paper.
(source: https://ai.googleblog.com/2020/04/advancing-self-supervised-and-semi.html)
class NTXentOp(TensorOp):
def __init__(self, arg1, arg2, outputs, temperature=1.0, mode=None):
super().__init__(inputs=(arg1, arg2), outputs=outputs, mode=mode)
self.temperature = temperature
def forward(self, data, state):
arg1, arg2 = data
loss, logit, label = NTXent(arg1, arg2, self.temperature)
return loss, logit, label
def NTXent(A, B, temperature):
large_number = 1e9
batch_size = tf.shape(A)[0]
A = tf.math.l2_normalize(A, -1)
B = tf.math.l2_normalize(B, -1)
mask = tf.one_hot(tf.range(batch_size), batch_size)
labels = tf.one_hot(tf.range(batch_size), 2 * batch_size)
aa = tf.matmul(A, A, transpose_b=True) / temperature
aa = aa - mask * large_number
ab = tf.matmul(A, B, transpose_b=True) / temperature
bb = tf.matmul(B, B, transpose_b=True) / temperature
bb = bb - mask * large_number
ba = tf.matmul(B, A, transpose_b=True) / temperature
loss_a = tf.nn.softmax_cross_entropy_with_logits(labels, tf.concat([ab, aa], 1))
loss_b = tf.nn.softmax_cross_entropy_with_logits(labels, tf.concat([ba, bb], 1))
loss = tf.reduce_mean(loss_a + loss_b)
return loss, ab, labels
network_pretrain = fe.Network(ops=[
ModelOp(model=model_con, inputs="x_aug", outputs="y_pred"),
ModelOp(model=model_con, inputs="x_aug2", outputs="y_pred2"),
NTXentOp(arg1="y_pred", arg2="y_pred2", outputs=["NTXent", "logit", "label"]),
UpdateOp(model=model_con, loss_name="NTXent")
])
Pre-Training Estimator¶
Next we are going to combine the pretraining pipeline and network together in the estimator class with an Accuracy
trace to monitor the contrastive accuracy and a ModelSaver
trace to save the pretrained model. We can then start the training.
traces = [
Accuracy(true_key="label", pred_key="logit", mode="train", output_name="contrastive_accuracy"),
ModelSaver(model=model_con, save_dir=save_dir)
]
estimator_pretrain = fe.Estimator(pipeline=pipeline_pretrain,
network=network_pretrain,
epochs=epochs_pretrain,
traces=traces,
train_steps_per_epoch=train_steps_per_epoch)
estimator_pretrain.fit()
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ FastEstimator-Start: step: 1; logging_interval: 100; num_device: 1; FastEstimator-Train: step: 1; NTXent: 13.829769; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_1.h5 FastEstimator-Train: step: 98; epoch: 1; contrastive_accuracy: 0.191; epoch_time: 26.89 sec; FastEstimator-Train: step: 100; NTXent: 12.382189; steps/sec: 4.52; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_2.h5 FastEstimator-Train: step: 196; epoch: 2; contrastive_accuracy: 0.5078; epoch_time: 20.97 sec; FastEstimator-Train: step: 200; NTXent: 12.23844; steps/sec: 4.67; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_3.h5 FastEstimator-Train: step: 294; epoch: 3; contrastive_accuracy: 0.7107; epoch_time: 20.46 sec; FastEstimator-Train: step: 300; NTXent: 12.149595; steps/sec: 4.79; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_4.h5 FastEstimator-Train: step: 392; epoch: 4; contrastive_accuracy: 0.8136; epoch_time: 21.08 sec; FastEstimator-Train: step: 400; NTXent: 12.122614; steps/sec: 4.64; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_5.h5 FastEstimator-Train: step: 490; epoch: 5; contrastive_accuracy: 0.87096; epoch_time: 20.76 sec; FastEstimator-Train: step: 500; NTXent: 12.102264; steps/sec: 4.75; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_6.h5 FastEstimator-Train: step: 588; epoch: 6; contrastive_accuracy: 0.91048; epoch_time: 20.02 sec; FastEstimator-Train: step: 600; NTXent: 12.082749; steps/sec: 4.88; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_7.h5 FastEstimator-Train: step: 686; epoch: 7; contrastive_accuracy: 0.93192; epoch_time: 20.53 sec; FastEstimator-Train: step: 700; NTXent: 12.061149; steps/sec: 4.8; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_8.h5 FastEstimator-Train: step: 784; epoch: 8; contrastive_accuracy: 0.9465; epoch_time: 19.97 sec; FastEstimator-Train: step: 800; NTXent: 12.045004; steps/sec: 4.86; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_9.h5 FastEstimator-Train: step: 882; epoch: 9; contrastive_accuracy: 0.96042; epoch_time: 20.63 sec; FastEstimator-Train: step: 900; NTXent: 12.045875; steps/sec: 4.77; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_10.h5 FastEstimator-Train: step: 980; epoch: 10; contrastive_accuracy: 0.96834; epoch_time: 20.79 sec; FastEstimator-Train: step: 1000; NTXent: 12.04475; steps/sec: 4.71; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_11.h5 FastEstimator-Train: step: 1078; epoch: 11; contrastive_accuracy: 0.97196; epoch_time: 20.5 sec; FastEstimator-Train: step: 1100; NTXent: 12.030878; steps/sec: 4.78; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_12.h5 FastEstimator-Train: step: 1176; epoch: 12; contrastive_accuracy: 0.97612; epoch_time: 20.54 sec; FastEstimator-Train: step: 1200; NTXent: 12.024199; steps/sec: 4.78; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_13.h5 FastEstimator-Train: step: 1274; epoch: 13; contrastive_accuracy: 0.97904; epoch_time: 20.43 sec; FastEstimator-Train: step: 1300; NTXent: 12.031016; steps/sec: 4.78; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_14.h5 FastEstimator-Train: step: 1372; epoch: 14; contrastive_accuracy: 0.97972; epoch_time: 20.6 sec; FastEstimator-Train: step: 1400; NTXent: 12.023199; steps/sec: 4.75; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_15.h5 FastEstimator-Train: step: 1470; epoch: 15; contrastive_accuracy: 0.98242; epoch_time: 21.16 sec; FastEstimator-Train: step: 1500; NTXent: 12.017285; steps/sec: 4.63; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_16.h5 FastEstimator-Train: step: 1568; epoch: 16; contrastive_accuracy: 0.98304; epoch_time: 20.61 sec; FastEstimator-Train: step: 1600; NTXent: 12.014565; steps/sec: 4.74; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_17.h5 FastEstimator-Train: step: 1666; epoch: 17; contrastive_accuracy: 0.98406; epoch_time: 20.65 sec; FastEstimator-Train: step: 1700; NTXent: 12.009141; steps/sec: 4.83; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_18.h5 FastEstimator-Train: step: 1764; epoch: 18; contrastive_accuracy: 0.9849; epoch_time: 19.89 sec; FastEstimator-Train: step: 1800; NTXent: 12.001139; steps/sec: 4.93; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_19.h5 FastEstimator-Train: step: 1862; epoch: 19; contrastive_accuracy: 0.98566; epoch_time: 19.94 sec; FastEstimator-Train: step: 1900; NTXent: 12.003536; steps/sec: 4.91; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_20.h5 FastEstimator-Train: step: 1960; epoch: 20; contrastive_accuracy: 0.98588; epoch_time: 19.95 sec; FastEstimator-Train: step: 2000; NTXent: 11.999655; steps/sec: 4.91; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_21.h5 FastEstimator-Train: step: 2058; epoch: 21; contrastive_accuracy: 0.98662; epoch_time: 19.97 sec; FastEstimator-Train: step: 2100; NTXent: 11.992434; steps/sec: 4.91; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_22.h5 FastEstimator-Train: step: 2156; epoch: 22; contrastive_accuracy: 0.98726; epoch_time: 19.96 sec; FastEstimator-Train: step: 2200; NTXent: 11.998148; steps/sec: 4.91; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_23.h5 FastEstimator-Train: step: 2254; epoch: 23; contrastive_accuracy: 0.988; epoch_time: 19.93 sec; FastEstimator-Train: step: 2300; NTXent: 12.004744; steps/sec: 4.93; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_24.h5 FastEstimator-Train: step: 2352; epoch: 24; contrastive_accuracy: 0.98886; epoch_time: 19.88 sec; FastEstimator-Train: step: 2400; NTXent: 11.997645; steps/sec: 4.92; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_25.h5 FastEstimator-Train: step: 2450; epoch: 25; contrastive_accuracy: 0.98946; epoch_time: 19.91 sec; FastEstimator-Train: step: 2500; NTXent: 11.9974985; steps/sec: 4.92; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_26.h5 FastEstimator-Train: step: 2548; epoch: 26; contrastive_accuracy: 0.98968; epoch_time: 19.96 sec; FastEstimator-Train: step: 2600; NTXent: 11.988766; steps/sec: 4.91; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_27.h5 FastEstimator-Train: step: 2646; epoch: 27; contrastive_accuracy: 0.98896; epoch_time: 19.91 sec; FastEstimator-Train: step: 2700; NTXent: 11.992538; steps/sec: 4.93; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_28.h5 FastEstimator-Train: step: 2744; epoch: 28; contrastive_accuracy: 0.98936; epoch_time: 19.95 sec; FastEstimator-Train: step: 2800; NTXent: 11.984715; steps/sec: 4.91; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_29.h5 FastEstimator-Train: step: 2842; epoch: 29; contrastive_accuracy: 0.99066; epoch_time: 19.95 sec; FastEstimator-Train: step: 2900; NTXent: 11.989294; steps/sec: 4.92; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_30.h5 FastEstimator-Train: step: 2940; epoch: 30; contrastive_accuracy: 0.99126; epoch_time: 19.89 sec; FastEstimator-Train: step: 3000; NTXent: 11.980862; steps/sec: 4.92; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_31.h5 FastEstimator-Train: step: 3038; epoch: 31; contrastive_accuracy: 0.99174; epoch_time: 19.9 sec; FastEstimator-Train: step: 3100; NTXent: 11.988753; steps/sec: 4.92; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_32.h5 FastEstimator-Train: step: 3136; epoch: 32; contrastive_accuracy: 0.99158; epoch_time: 19.94 sec; FastEstimator-Train: step: 3200; NTXent: 11.982931; steps/sec: 4.91; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_33.h5 FastEstimator-Train: step: 3234; epoch: 33; contrastive_accuracy: 0.99192; epoch_time: 19.91 sec; FastEstimator-Train: step: 3300; NTXent: 11.983704; steps/sec: 4.93; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_34.h5 FastEstimator-Train: step: 3332; epoch: 34; contrastive_accuracy: 0.99288; epoch_time: 19.94 sec; FastEstimator-Train: step: 3400; NTXent: 11.982264; steps/sec: 4.9; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_35.h5 FastEstimator-Train: step: 3430; epoch: 35; contrastive_accuracy: 0.99274; epoch_time: 19.95 sec; FastEstimator-Train: step: 3500; NTXent: 11.976917; steps/sec: 4.91; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_36.h5 FastEstimator-Train: step: 3528; epoch: 36; contrastive_accuracy: 0.99184; epoch_time: 19.97 sec; FastEstimator-Train: step: 3600; NTXent: 11.985281; steps/sec: 4.91; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_37.h5 FastEstimator-Train: step: 3626; epoch: 37; contrastive_accuracy: 0.99332; epoch_time: 19.97 sec; FastEstimator-Train: step: 3700; NTXent: 11.973089; steps/sec: 4.91; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_38.h5 FastEstimator-Train: step: 3724; epoch: 38; contrastive_accuracy: 0.99292; epoch_time: 19.97 sec; FastEstimator-Train: step: 3800; NTXent: 11.979197; steps/sec: 4.91; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_39.h5 FastEstimator-Train: step: 3822; epoch: 39; contrastive_accuracy: 0.99344; epoch_time: 19.92 sec; FastEstimator-Train: step: 3900; NTXent: 11.972714; steps/sec: 4.92; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_40.h5 FastEstimator-Train: step: 3920; epoch: 40; contrastive_accuracy: 0.99342; epoch_time: 19.93 sec; FastEstimator-Train: step: 4000; NTXent: 11.978983; steps/sec: 4.92; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_41.h5 FastEstimator-Train: step: 4018; epoch: 41; contrastive_accuracy: 0.99362; epoch_time: 19.92 sec; FastEstimator-Train: step: 4100; NTXent: 11.970781; steps/sec: 4.93; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_42.h5 FastEstimator-Train: step: 4116; epoch: 42; contrastive_accuracy: 0.9934; epoch_time: 19.9 sec; FastEstimator-Train: step: 4200; NTXent: 11.967566; steps/sec: 4.91; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_43.h5 FastEstimator-Train: step: 4214; epoch: 43; contrastive_accuracy: 0.99374; epoch_time: 19.93 sec; FastEstimator-Train: step: 4300; NTXent: 11.967752; steps/sec: 4.91; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_44.h5 FastEstimator-Train: step: 4312; epoch: 44; contrastive_accuracy: 0.99356; epoch_time: 19.96 sec; FastEstimator-Train: step: 4400; NTXent: 11.965156; steps/sec: 4.93; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_45.h5 FastEstimator-Train: step: 4410; epoch: 45; contrastive_accuracy: 0.99422; epoch_time: 19.88 sec; FastEstimator-Train: step: 4500; NTXent: 11.964204; steps/sec: 4.92; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_46.h5 FastEstimator-Train: step: 4508; epoch: 46; contrastive_accuracy: 0.9936; epoch_time: 19.92 sec; FastEstimator-Train: step: 4600; NTXent: 11.970972; steps/sec: 4.93; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_47.h5 FastEstimator-Train: step: 4606; epoch: 47; contrastive_accuracy: 0.99444; epoch_time: 19.9 sec; FastEstimator-Train: step: 4700; NTXent: 11.972122; steps/sec: 4.92; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_48.h5 FastEstimator-Train: step: 4704; epoch: 48; contrastive_accuracy: 0.9936; epoch_time: 19.93 sec; FastEstimator-Train: step: 4800; NTXent: 11.961956; steps/sec: 4.93; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_49.h5 FastEstimator-Train: step: 4802; epoch: 49; contrastive_accuracy: 0.99414; epoch_time: 19.89 sec; FastEstimator-Train: step: 4900; NTXent: 11.124333; steps/sec: 4.91; FastEstimator-ModelSaver: Saved model to /tmp/tmp33wb3rot/model_epoch_50.h5 FastEstimator-Train: step: 4900; epoch: 50; contrastive_accuracy: 0.99446; epoch_time: 19.94 sec; FastEstimator-Finish: step: 4900; model_lr: 0.001; total_time: 1055.67 sec;
Finetune the model on an image classification task¶
Once the model is pretrained, we can finetune the model on a specific task. In this case we are going to use this pretrained model on ciFAIR10 image classification. Remember in the previous section we built both model_con
and model_finetune
. Because those two models share the same encoder object, by (pre)training the model_con
, the encoder of model_fintune
is also trained. The finetuing of the model is literally just supervised training with the pretrained encoder. In order to demonstrate the benefit of SimCLR, we are going to fine-tune the network using only 10% of the labeled training data and compare with how well a model could do trained from scratch with the same data limitation.
split_train = train_data.split(0.1)
pipeline_finetune = fe.Pipeline(
train_data=split_train,
eval_data=eval_data,
batch_size=batch_size,
ops=[
ToFloat(inputs="x", outputs="x")
])
network_finetune = fe.Network(ops=[
ModelOp(model=model_finetune, inputs="x", outputs="y_pred"),
CrossEntropy(inputs=["y_pred", "y"], outputs="ce"),
UpdateOp(model=model_finetune, loss_name="ce")
])
traces = [
Accuracy(true_key="y", pred_key="y_pred"),
BestModelSaver(model=model_finetune, save_dir=save_dir, metric="accuracy", save_best_mode="max")
]
est_finetune = fe.Estimator(pipeline=pipeline_finetune,
network=network_finetune,
epochs=epochs_finetune,
traces=traces,
train_steps_per_epoch=train_steps_per_epoch)
est_finetune.fit()
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ FastEstimator-Start: step: 1; logging_interval: 100; num_device: 1; FastEstimator-Train: step: 1; ce: 6.394948; FastEstimator-Train: step: 9; epoch: 1; epoch_time: 3.6 sec; FastEstimator-BestModelSaver: Saved model to /tmp/tmp33wb3rot/model1_best_accuracy.h5 FastEstimator-Eval: step: 9; epoch: 1; accuracy: 0.4504; ce: 1.8657482; max_accuracy: 0.4504; since_best_accuracy: 0; FastEstimator-Train: step: 18; epoch: 2; epoch_time: 0.78 sec; FastEstimator-BestModelSaver: Saved model to /tmp/tmp33wb3rot/model1_best_accuracy.h5 FastEstimator-Eval: step: 18; epoch: 2; accuracy: 0.5904; ce: 1.1950318; max_accuracy: 0.5904; since_best_accuracy: 0; FastEstimator-Train: step: 27; epoch: 3; epoch_time: 0.78 sec; FastEstimator-BestModelSaver: Saved model to /tmp/tmp33wb3rot/model1_best_accuracy.h5 FastEstimator-Eval: step: 27; epoch: 3; accuracy: 0.6299; ce: 1.0542316; max_accuracy: 0.6299; since_best_accuracy: 0; FastEstimator-Train: step: 36; epoch: 4; epoch_time: 0.8 sec; FastEstimator-BestModelSaver: Saved model to /tmp/tmp33wb3rot/model1_best_accuracy.h5 FastEstimator-Eval: step: 36; epoch: 4; accuracy: 0.6627; ce: 0.99589044; max_accuracy: 0.6627; since_best_accuracy: 0; FastEstimator-Train: step: 45; epoch: 5; epoch_time: 0.79 sec; FastEstimator-BestModelSaver: Saved model to /tmp/tmp33wb3rot/model1_best_accuracy.h5 FastEstimator-Eval: step: 45; epoch: 5; accuracy: 0.667; ce: 0.95694005; max_accuracy: 0.667; since_best_accuracy: 0; FastEstimator-Train: step: 54; epoch: 6; epoch_time: 0.78 sec; FastEstimator-BestModelSaver: Saved model to /tmp/tmp33wb3rot/model1_best_accuracy.h5 FastEstimator-Eval: step: 54; epoch: 6; accuracy: 0.6795; ce: 0.9522039; max_accuracy: 0.6795; since_best_accuracy: 0; FastEstimator-Train: step: 63; epoch: 7; epoch_time: 0.78 sec; FastEstimator-BestModelSaver: Saved model to /tmp/tmp33wb3rot/model1_best_accuracy.h5 FastEstimator-Eval: step: 63; epoch: 7; accuracy: 0.6907; ce: 0.9476255; max_accuracy: 0.6907; since_best_accuracy: 0; FastEstimator-Train: step: 72; epoch: 8; epoch_time: 0.79 sec; FastEstimator-BestModelSaver: Saved model to /tmp/tmp33wb3rot/model1_best_accuracy.h5 FastEstimator-Eval: step: 72; epoch: 8; accuracy: 0.698; ce: 0.9569526; max_accuracy: 0.698; since_best_accuracy: 0; FastEstimator-Train: step: 81; epoch: 9; epoch_time: 0.81 sec; FastEstimator-Eval: step: 81; epoch: 9; accuracy: 0.6973; ce: 0.96826965; max_accuracy: 0.698; since_best_accuracy: 1; FastEstimator-Train: step: 90; epoch: 10; epoch_time: 0.82 sec; FastEstimator-BestModelSaver: Saved model to /tmp/tmp33wb3rot/model1_best_accuracy.h5 FastEstimator-Eval: step: 90; epoch: 10; accuracy: 0.7009; ce: 0.98749244; max_accuracy: 0.7009; since_best_accuracy: 0; FastEstimator-Finish: step: 90; model1_lr: 0.001; total_time: 27.06 sec;
Results¶
We can see that SimCLR achieved 70% accuracy using only 10% of the labeled data. Without SimCLR, a vanilla ResNet9 can only achieve around 57% accuracy using the same 10% labeled data.