Image Classification Using Vision Transformer¶
[Paper] [Notebook] [TF Implementation] [Torch Implementation]
Vision Transformer (ViT) is a new alternative to Convolution Neural Networks (CNNs) in the field of computer vision. The idea of ViT was inspired from the success of the Transformer and BERT architectures in NLP applications. In this example, we will implement a ViT in PyTorch and showcase how to pre-train a ViT and then fine-tune it on a downstream task for good results with minimal downstream training time.
ViT Model¶
The ViT model is almost the same as the original Transformer except for the following differences:
- Input image is broken down into small patches, which are used as sequences similar to language. The patching and embedding are implemented by a Convolution2D operation in the
patch_embedding
. - Different from original Transformer, the positional embedding is now a trainable parameter.
- Similar to BERT, a
CLS
token is added before the patch sequence. But in contrast to BERT, the value of theCLS
token is trainable. - After the Transformer encoding, only the embedding corresponding to the
CLS
token will be used as feature for the classification layer.
import torch
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
class ViTEmbeddings(nn.Module):
def __init__(self, image_size=224, patch_size=16, num_channels=3, em_dim=768, drop=0.1) -> None:
super().__init__()
assert image_size % patch_size == 0, "image size must be an integer multiply of patch size"
self.patch_embedding = nn.Conv2d(num_channels, em_dim, kernel_size=patch_size, stride=patch_size, bias=False)
self.position_embedding = nn.Parameter(torch.zeros(1, (image_size // patch_size)**2 + 1, em_dim))
self.cls_token = nn.Parameter(torch.zeros(1, 1, em_dim))
self.dropout = nn.Dropout(drop)
def forward(self, x):
x = self.patch_embedding(x).flatten(2).transpose(1, 2) # [B,C, H, W] -> [B, num_patches, em_dim]
x = torch.cat([self.cls_token.expand(x.size(0), -1, -1), x], dim=1) # [B, num_patches+1, em_dim]
x = x + self.position_embedding
x = self.dropout(x)
return x
class ViTEncoder(nn.Module):
def __init__(self, num_layers, image_size, patch_size, num_channels, em_dim, drop, num_heads, ff_dim):
super().__init__()
self.embedding = ViTEmbeddings(image_size, patch_size, num_channels, em_dim, drop)
encoder_layer = TransformerEncoderLayer(em_dim,
nhead=num_heads,
dim_feedforward=ff_dim,
activation='gelu',
dropout=drop)
self.encoder = TransformerEncoder(encoder_layer=encoder_layer, num_layers=num_layers)
self.layernorm = nn.LayerNorm(em_dim, eps=1e-6)
def forward(self, x):
x = self.embedding(x)
x = x.transpose(0, 1) # Switch batch and sequence length dimension for pytorch convention
x = self.encoder(x)
x = self.layernorm(x[0])
return x
class ViTModel(nn.Module):
def __init__(self,
num_classes,
num_layers=12,
image_size=224,
patch_size=16,
num_channels=3,
em_dim=768,
drop=0.1,
num_heads=12,
ff_dim=3072):
super().__init__()
self.vit_encoder = ViTEncoder(num_layers=num_layers,
image_size=image_size,
patch_size=patch_size,
num_channels=num_channels,
em_dim=em_dim,
drop=drop,
num_heads=num_heads,
ff_dim=ff_dim)
self.linear_classifier = nn.Linear(em_dim, num_classes)
def forward(self, x):
x = self.vit_encoder(x)
x = self.linear_classifier(x)
return x
Now let's define some parameters that will be used later:
batch_size=128
pretrain_epochs=100
finetune_epochs=1
train_steps_per_epoch=None
eval_steps_per_epoch=None
Upstream Pre-training¶
We will use CIFAIR 100 as our upstream dataset. The data preprocessing and augmentation is the standard Padded Crop + Dropout used in this example.
import tempfile
import fastestimator as fe
from fastestimator.dataset.data import cifair10, cifair100
from fastestimator.op.numpyop.meta import Sometimes
from fastestimator.op.numpyop.multivariate import HorizontalFlip, PadIfNeeded, RandomCrop
from fastestimator.op.numpyop.univariate import ChannelTranspose, CoarseDropout, Normalize
from fastestimator.op.tensorop.loss import CrossEntropy
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
from fastestimator.trace.metric import Accuracy
def pretrain(batch_size,
epochs,
model_dir=tempfile.mkdtemp(),
train_steps_per_epoch=None,
eval_steps_per_epoch=None):
train_data, eval_data = cifair100.load_data()
pipeline = fe.Pipeline(
train_data=train_data,
eval_data=eval_data,
batch_size=batch_size,
ops=[
Normalize(inputs="x", outputs="x", mean=(0.4914, 0.4822, 0.4465), std=(0.2471, 0.2435, 0.2616)),
PadIfNeeded(min_height=40, min_width=40, image_in="x", image_out="x", mode="train"),
RandomCrop(32, 32, image_in="x", image_out="x", mode="train"),
Sometimes(HorizontalFlip(image_in="x", image_out="x", mode="train")),
CoarseDropout(inputs="x", outputs="x", mode="train", max_holes=1),
ChannelTranspose(inputs="x", outputs="x")
])
model = fe.build(
model_fn=lambda: ViTModel(num_classes=100,
image_size=32,
patch_size=4,
num_layers=6,
num_channels=3,
em_dim=256,
num_heads=8,
ff_dim=512),
optimizer_fn=lambda x: torch.optim.SGD(x, lr=0.01, momentum=0.9, weight_decay=1e-4))
network = fe.Network(ops=[
ModelOp(model=model, inputs="x", outputs="y_pred"),
CrossEntropy(inputs=("y_pred", "y"), outputs="ce", from_logits=True),
UpdateOp(model=model, loss_name="ce")
])
traces = [
Accuracy(true_key="y", pred_key="y_pred")
]
estimator = fe.Estimator(pipeline=pipeline,
network=network,
epochs=epochs,
traces=traces,
train_steps_per_epoch=train_steps_per_epoch,
eval_steps_per_epoch=eval_steps_per_epoch,
log_steps=0)
estimator.fit(warmup=False)
return model
Start Pre-training¶
Let's train the ViT model for 100 epochs, and get the pre-trained weight. This would take ~40 minutes on single GTX 1080 TI GPU.
Here we are only training a mini version of the actual ViT model, and the CIFAR100 performance after 100 epochs is similar to the 55% top-1 performance reported in the community. However, training the official ViTModel
model with its original parameters on the JFT-300M dataset would produce much better encoder weights at the cost of a much longer training time. The paper used this strategy to reach near 81% ImageNet downstream top-1 accuracy.
pretrained_model = pretrain(batch_size=batch_size,
epochs=pretrain_epochs,
train_steps_per_epoch=train_steps_per_epoch,
eval_steps_per_epoch=eval_steps_per_epoch)
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ FastEstimator-Warn: No ModelSaver Trace detected. Models will not be saved. FastEstimator-Start: step: 1; logging_interval: 0; num_device: 1; FastEstimator-Eval: step: 391; epoch: 1; accuracy: 0.1088; ce: 3.8288884; FastEstimator-Eval: step: 782; epoch: 2; accuracy: 0.158; ce: 3.5546496; FastEstimator-Eval: step: 1173; epoch: 3; accuracy: 0.1876; ce: 3.3546138; FastEstimator-Eval: step: 1564; epoch: 4; accuracy: 0.225; ce: 3.1463547; FastEstimator-Eval: step: 1955; epoch: 5; accuracy: 0.2515; ce: 3.0236564; FastEstimator-Eval: step: 2346; epoch: 6; accuracy: 0.2781; ce: 2.8541021; FastEstimator-Eval: step: 2737; epoch: 7; accuracy: 0.2922; ce: 2.7996583; FastEstimator-Eval: step: 3128; epoch: 8; accuracy: 0.3128; ce: 2.6991034; FastEstimator-Eval: step: 3519; epoch: 9; accuracy: 0.3314; ce: 2.57633; FastEstimator-Eval: step: 3910; epoch: 10; accuracy: 0.3394; ce: 2.5583541; FastEstimator-Eval: step: 4301; epoch: 11; accuracy: 0.3635; ce: 2.4394403; FastEstimator-Eval: step: 4692; epoch: 12; accuracy: 0.3717; ce: 2.4280012; FastEstimator-Eval: step: 5083; epoch: 13; accuracy: 0.3845; ce: 2.3532598; FastEstimator-Eval: step: 5474; epoch: 14; accuracy: 0.3756; ce: 2.3746123; FastEstimator-Eval: step: 5865; epoch: 15; accuracy: 0.4079; ce: 2.2628024; FastEstimator-Eval: step: 6256; epoch: 16; accuracy: 0.4045; ce: 2.2397344; FastEstimator-Eval: step: 6647; epoch: 17; accuracy: 0.4175; ce: 2.183634; FastEstimator-Eval: step: 7038; epoch: 18; accuracy: 0.4167; ce: 2.209709; FastEstimator-Eval: step: 7429; epoch: 19; accuracy: 0.4339; ce: 2.1296408; FastEstimator-Eval: step: 7820; epoch: 20; accuracy: 0.4182; ce: 2.1953375; FastEstimator-Eval: step: 8211; epoch: 21; accuracy: 0.438; ce: 2.1236746; FastEstimator-Eval: step: 8602; epoch: 22; accuracy: 0.4438; ce: 2.092245; FastEstimator-Eval: step: 8993; epoch: 23; accuracy: 0.4559; ce: 2.0420241; FastEstimator-Eval: step: 9384; epoch: 24; accuracy: 0.461; ce: 2.021573; FastEstimator-Eval: step: 9775; epoch: 25; accuracy: 0.4577; ce: 2.0449996; FastEstimator-Eval: step: 10166; epoch: 26; accuracy: 0.4648; ce: 2.0265305; FastEstimator-Eval: step: 10557; epoch: 27; accuracy: 0.4609; ce: 2.0219545; FastEstimator-Eval: step: 10948; epoch: 28; accuracy: 0.4599; ce: 2.0249476; FastEstimator-Eval: step: 11339; epoch: 29; accuracy: 0.4799; ce: 1.958858; FastEstimator-Eval: step: 11730; epoch: 30; accuracy: 0.4651; ce: 2.0040777; FastEstimator-Eval: step: 12121; epoch: 31; accuracy: 0.4759; ce: 1.9787812; FastEstimator-Eval: step: 12512; epoch: 32; accuracy: 0.4815; ce: 1.9677575; FastEstimator-Eval: step: 12903; epoch: 33; accuracy: 0.4836; ce: 1.9488634; FastEstimator-Eval: step: 13294; epoch: 34; accuracy: 0.4698; ce: 2.0040216; FastEstimator-Eval: step: 13685; epoch: 35; accuracy: 0.4854; ce: 1.933885; FastEstimator-Eval: step: 14076; epoch: 36; accuracy: 0.4915; ce: 1.9364777; FastEstimator-Eval: step: 14467; epoch: 37; accuracy: 0.4872; ce: 1.9454862; FastEstimator-Eval: step: 14858; epoch: 38; accuracy: 0.4953; ce: 1.9281081; FastEstimator-Eval: step: 15249; epoch: 39; accuracy: 0.4987; ce: 1.8994861; FastEstimator-Eval: step: 15640; epoch: 40; accuracy: 0.4972; ce: 1.9311935; FastEstimator-Eval: step: 16031; epoch: 41; accuracy: 0.4999; ce: 1.9120353; FastEstimator-Eval: step: 16422; epoch: 42; accuracy: 0.4999; ce: 1.9262657; FastEstimator-Eval: step: 16813; epoch: 43; accuracy: 0.5003; ce: 1.9173524; FastEstimator-Eval: step: 17204; epoch: 44; accuracy: 0.5099; ce: 1.9153186; FastEstimator-Eval: step: 17595; epoch: 45; accuracy: 0.5064; ce: 1.9490457; FastEstimator-Eval: step: 17986; epoch: 46; accuracy: 0.4941; ce: 1.9536077; FastEstimator-Eval: step: 18377; epoch: 47; accuracy: 0.5109; ce: 1.9044245; FastEstimator-Eval: step: 18768; epoch: 48; accuracy: 0.5015; ce: 1.9598173; FastEstimator-Eval: step: 19159; epoch: 49; accuracy: 0.5036; ce: 1.9729359; FastEstimator-Eval: step: 19550; epoch: 50; accuracy: 0.5087; ce: 1.9391878; FastEstimator-Eval: step: 19941; epoch: 51; accuracy: 0.509; ce: 1.9359056; FastEstimator-Eval: step: 20332; epoch: 52; accuracy: 0.5055; ce: 1.9588828; FastEstimator-Eval: step: 20723; epoch: 53; accuracy: 0.5104; ce: 1.9606155; FastEstimator-Eval: step: 21114; epoch: 54; accuracy: 0.502; ce: 2.0006518; FastEstimator-Eval: step: 21505; epoch: 55; accuracy: 0.5102; ce: 1.9584115; FastEstimator-Eval: step: 21896; epoch: 56; accuracy: 0.5031; ce: 2.0014715; FastEstimator-Eval: step: 22287; epoch: 57; accuracy: 0.5135; ce: 1.9787788; FastEstimator-Eval: step: 22678; epoch: 58; accuracy: 0.5041; ce: 2.000715; FastEstimator-Eval: step: 23069; epoch: 59; accuracy: 0.509; ce: 2.0128548; FastEstimator-Eval: step: 23460; epoch: 60; accuracy: 0.5124; ce: 2.0124855; FastEstimator-Eval: step: 23851; epoch: 61; accuracy: 0.5126; ce: 2.0069928; FastEstimator-Eval: step: 24242; epoch: 62; accuracy: 0.5109; ce: 2.0440252; FastEstimator-Eval: step: 24633; epoch: 63; accuracy: 0.5176; ce: 2.0516953; FastEstimator-Eval: step: 25024; epoch: 64; accuracy: 0.513; ce: 2.0567915; FastEstimator-Eval: step: 25415; epoch: 65; accuracy: 0.5103; ce: 2.0795443; FastEstimator-Eval: step: 25806; epoch: 66; accuracy: 0.5038; ce: 2.1041098; FastEstimator-Eval: step: 26197; epoch: 67; accuracy: 0.5087; ce: 2.1095006; FastEstimator-Eval: step: 26588; epoch: 68; accuracy: 0.5121; ce: 2.1082811; FastEstimator-Eval: step: 26979; epoch: 69; accuracy: 0.5071; ce: 2.1289942; FastEstimator-Eval: step: 27370; epoch: 70; accuracy: 0.5192; ce: 2.1182418; FastEstimator-Eval: step: 27761; epoch: 71; accuracy: 0.5175; ce: 2.1154375; FastEstimator-Eval: step: 28152; epoch: 72; accuracy: 0.5139; ce: 2.1458533; FastEstimator-Eval: step: 28543; epoch: 73; accuracy: 0.5152; ce: 2.15533; FastEstimator-Eval: step: 28934; epoch: 74; accuracy: 0.5079; ce: 2.199765; FastEstimator-Eval: step: 29325; epoch: 75; accuracy: 0.5053; ce: 2.1914499; FastEstimator-Eval: step: 29716; epoch: 76; accuracy: 0.5072; ce: 2.2124186; FastEstimator-Eval: step: 30107; epoch: 77; accuracy: 0.5102; ce: 2.1962357; FastEstimator-Eval: step: 30498; epoch: 78; accuracy: 0.5134; ce: 2.2328248; FastEstimator-Eval: step: 30889; epoch: 79; accuracy: 0.5078; ce: 2.2428932; FastEstimator-Eval: step: 31280; epoch: 80; accuracy: 0.5101; ce: 2.278882; FastEstimator-Eval: step: 31671; epoch: 81; accuracy: 0.5121; ce: 2.2327974; FastEstimator-Eval: step: 32062; epoch: 82; accuracy: 0.5165; ce: 2.2351596; FastEstimator-Eval: step: 32453; epoch: 83; accuracy: 0.5146; ce: 2.252345; FastEstimator-Eval: step: 32844; epoch: 84; accuracy: 0.5176; ce: 2.265459; FastEstimator-Eval: step: 33235; epoch: 85; accuracy: 0.5035; ce: 2.3590689; FastEstimator-Eval: step: 33626; epoch: 86; accuracy: 0.5101; ce: 2.3089907; FastEstimator-Eval: step: 34017; epoch: 87; accuracy: 0.5125; ce: 2.3316317; FastEstimator-Eval: step: 34408; epoch: 88; accuracy: 0.5072; ce: 2.3564215; FastEstimator-Eval: step: 34799; epoch: 89; accuracy: 0.5128; ce: 2.310105; FastEstimator-Eval: step: 35190; epoch: 90; accuracy: 0.517; ce: 2.294423; FastEstimator-Eval: step: 35581; epoch: 91; accuracy: 0.5169; ce: 2.303369; FastEstimator-Eval: step: 35972; epoch: 92; accuracy: 0.5125; ce: 2.355053; FastEstimator-Eval: step: 36363; epoch: 93; accuracy: 0.5154; ce: 2.3520496; FastEstimator-Eval: step: 36754; epoch: 94; accuracy: 0.5081; ce: 2.377033; FastEstimator-Eval: step: 37145; epoch: 95; accuracy: 0.515; ce: 2.405619; FastEstimator-Eval: step: 37536; epoch: 96; accuracy: 0.5193; ce: 2.3753698; FastEstimator-Eval: step: 37927; epoch: 97; accuracy: 0.5184; ce: 2.3919399; FastEstimator-Eval: step: 38318; epoch: 98; accuracy: 0.5048; ce: 2.4691393; FastEstimator-Eval: step: 38709; epoch: 99; accuracy: 0.5226; ce: 2.3937373; FastEstimator-Eval: step: 39100; epoch: 100; accuracy: 0.5154; ce: 2.383983; FastEstimator-Finish: step: 39100; model_lr: 0.01; total_time: 2576.35 sec;
Downstream Fine-tuning¶
A general rule-of-thumb to ensure successful downstream fine-tuning is to choose a downstream task with less variety and complexity than the upstream training. In this example, given that we used CIFAIR100 as our upstream task, a good candidate for the downstream dataset is CIFAIR10. The official implementation mapped this practice to a larger scale, using JFT-300M as their upstream task and then ImageNet as their downstream task.
Given the similarity between our downstream and upstream datasets, the fine-tuning configuration is almost the same as before.
def finetune(pretrained_model,
batch_size,
epochs,
model_dir=tempfile.mkdtemp(),
train_steps_per_epoch=None,
eval_steps_per_epoch=None):
train_data, eval_data = cifair10.load_data()
pipeline = fe.Pipeline(
train_data=train_data,
eval_data=eval_data,
batch_size=batch_size,
ops=[
Normalize(inputs="x", outputs="x", mean=(0.4914, 0.4822, 0.4465), std=(0.2471, 0.2435, 0.2616)),
PadIfNeeded(min_height=40, min_width=40, image_in="x", image_out="x", mode="train"),
RandomCrop(32, 32, image_in="x", image_out="x", mode="train"),
Sometimes(HorizontalFlip(image_in="x", image_out="x", mode="train")),
CoarseDropout(inputs="x", outputs="x", mode="train", max_holes=1),
ChannelTranspose(inputs="x", outputs="x")
])
model = fe.build(
model_fn=lambda: ViTModel(num_classes=100,
image_size=32,
patch_size=4,
num_layers=6,
num_channels=3,
em_dim=256,
num_heads=8,
ff_dim=512),
optimizer_fn=lambda x: torch.optim.SGD(x, lr=0.01, momentum=0.9, weight_decay=1e-4))
# load the encoder's weight
if hasattr(model, "module"):
model.module.vit_encoder.load_state_dict(pretrained_model.module.vit_encoder.state_dict())
else:
model.vit_encoder.load_state_dict(pretrained_model.vit_encoder.state_dict())
network = fe.Network(ops=[
ModelOp(model=model, inputs="x", outputs="y_pred"),
CrossEntropy(inputs=("y_pred", "y"), outputs="ce", from_logits=True),
UpdateOp(model=model, loss_name="ce")
])
traces = [
Accuracy(true_key="y", pred_key="y_pred")
]
estimator = fe.Estimator(pipeline=pipeline,
network=network,
epochs=epochs,
traces=traces,
train_steps_per_epoch=train_steps_per_epoch,
eval_steps_per_epoch=eval_steps_per_epoch)
estimator.fit(warmup=False)
Start the Fine-tuning¶
The downstream ViT is re-using the ViT encoder pre-trained on the CIFAR100 dataset. To illustrate the effect of using the pre-trained encoder, we will only train the downstream task for a single epoch.
finetune(pretrained_model,
batch_size=batch_size,
epochs=finetune_epochs,
train_steps_per_epoch=train_steps_per_epoch,
eval_steps_per_epoch=eval_steps_per_epoch)
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ FastEstimator-Warn: No ModelSaver Trace detected. Models will not be saved. FastEstimator-Start: step: 1; logging_interval: 100; num_device: 1; FastEstimator-Train: step: 1; ce: 4.801615; FastEstimator-Train: step: 100; ce: 1.0262994; steps/sec: 17.72; FastEstimator-Train: step: 200; ce: 0.74568576; steps/sec: 17.57; FastEstimator-Train: step: 300; ce: 0.7660386; steps/sec: 17.54; FastEstimator-Train: step: 391; epoch: 1; epoch_time: 22.4 sec; FastEstimator-Eval: step: 391; epoch: 1; accuracy: 0.7426; ce: 0.7396317; FastEstimator-Finish: step: 391; model2_lr: 0.01; total_time: 25.4 sec;
With only one epoch of training, we are able to get 74% top-1 accuracy on the CIFAIR 10 test set. Not bad huh?