CIFAR-10 Image Classification Using ResNet (PyTorch Backend)¶
[Paper] [Notebook] [TF Implementation] [Torch Implementation]
In this example we are going to demonstrate how to train a CIFAR-10 image classification model using a ResNet architecture on the PyTorch backend. All training details including model structure, data preprocessing, learning rate control, etc. come from https://github.com/davidcpage/cifar10-fast. Note that we will, however, be using the ciFAIR-10 dataset which fixes train/test duplicates found in the original CIFAR-10 dataset (https://cvjena.github.io/cifair/)
Import the required libraries¶
import fastestimator as fe
import numpy as np
import matplotlib.pyplot as plt
import tempfile
from fastestimator.util import BatchDisplay, GridDisplay
#training parameters
epochs = 24
batch_size = 512
train_steps_per_epoch = None
eval_steps_per_epoch = None
save_dir = tempfile.mkdtemp()
class_names = ["airplanes", "cars", "birds", "cats", "deer", "dogs", "frogs", "horses", "ships", "trucks"]
from fastestimator.dataset.data import cifair10
train_data, eval_data = cifair10.load_data()
test_data = eval_data.split(0.5)
Set up a pre-processing Pipeline
¶
Here we set up the data pipeline. This will involve a variety of data augmentation including: random cropping, horizontal flipping, image obscuration, and smoothed one-hot label encoding. Beside all of this, the image channels need to be transposed from HWC to CHW format due to PyTorch conventions. We set up these processing steps using Ops
and also bundle the data sources and batch_size together into our Pipeline
.
from fastestimator.op.numpyop.univariate import ChannelTranspose, CoarseDropout, Normalize, Onehot
from fastestimator.op.numpyop.meta import Sometimes
from fastestimator.op.numpyop.multivariate import HorizontalFlip, PadIfNeeded, RandomCrop
from fastestimator.op.numpyop import Delete
pipeline = fe.Pipeline(
train_data=train_data,
eval_data=eval_data,
test_data=test_data,
batch_size=batch_size,
ops=[
Normalize(inputs="x", outputs="x", mean=(0.4914, 0.4822, 0.4465), std=(0.2471, 0.2435, 0.2616)),
PadIfNeeded(min_height=40, min_width=40, image_in="x", image_out="x", mode="train"),
RandomCrop(32, 32, image_in="x", image_out="x", mode="train"),
Sometimes(HorizontalFlip(image_in="x", image_out="x", mode="train")),
CoarseDropout(inputs="x", outputs="x", mode="train", max_holes=1),
ChannelTranspose(inputs="x", outputs="x"),
Onehot(inputs="y", outputs="y", mode="train", num_classes=10, label_smoothing=0.2)
])
Validate Pipeline
¶
In order to make sure the Pipeline
works as expected, let's visualize the output and check its size. Pipeline.get_results
will return a batch data of pipeline output for this purpose:
data = pipeline.get_results()
data_x = data["x"]
data_y = data["y"]
print("the pipeline output image size: {}".format(data_x.numpy().shape))
print("the pipeline output label size: {}".format(data_y.numpy().shape))
the pipeline output image size: (512, 3, 32, 32) the pipeline output label size: (512, 10)
sample_num = 5
fig = GridDisplay([BatchDisplay(image=data_x[0:sample_num], title="Pipeline Output"),
BatchDisplay(text=np.argmax(data_y, axis=-1)[0:sample_num], title="Label")])
fig.show()
Step 2 - Network
construction¶
FastEstimator supports both PyTorch and TensorFlow, so this section could use either backend.
We are going to only demonstrate the PyTorch way in this example.
Model construction¶
The model definitions are implemented in PyTorch and instantiated by calling fe.build
which also associates the model with a specific optimizer. Here we are going to directly import the model architecture from FastEstimator.
from fastestimator.architecture.pytorch import ResNet9
model = fe.build(model_fn=ResNet9, optimizer_fn="adam")
Network
definition¶
Ops
are the basic components of a network that include models, loss calculation units, and post-processing units. In this step we are going to combine those pieces together into a Network
:
from fastestimator.op.tensorop.loss import CrossEntropy
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
network = fe.Network(ops=[
ModelOp(model=model, inputs="x", outputs="y_pred"),
CrossEntropy(inputs=("y_pred", "y"), outputs="ce"),
UpdateOp(model=model, loss_name="ce", mode="train")
])
Step 3 - Estimator
definition and training¶
In this step, we define an Estimator
to connect our Network
with our Pipeline
and set the traces
which will compute accuracy (Accuracy
), save our best model (BestModelSaver
), and change the learning rate (LRScheduler
) of our optimizer over time. We will then use Estimator.fit
to trigger the training.
from fastestimator.trace.adapt import LRScheduler
from fastestimator.trace.io import BestModelSaver
from fastestimator.trace.metric import Accuracy
def lr_schedule(step):
if step <= 490:
lr = step / 490 * 0.4
else:
lr = (2352 - step) / 1862 * 0.4
return lr * 0.1
traces = [
Accuracy(true_key="y", pred_key="y_pred"),
BestModelSaver(model=model, save_dir=save_dir, metric="accuracy", save_best_mode="max"),
LRScheduler(model=model, lr_fn=lr_schedule)
]
estimator = fe.Estimator(pipeline=pipeline,
network=network,
epochs=epochs,
traces=traces,
train_steps_per_epoch=train_steps_per_epoch,
eval_steps_per_epoch=eval_steps_per_epoch)
estimator.fit() # start the training
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ FastEstimator-Start: step: 1; logging_interval: 100; num_device: 1; FastEstimator-Train: step: 1; ce: 2.8280644; model_lr: 8.163265e-05; FastEstimator-Train: step: 98; epoch: 1; epoch_time: 9.81 sec; Eval Progress: 1/9; Eval Progress: 3/9; steps/sec: 56.61; Eval Progress: 6/9; steps/sec: 61.74; Eval Progress: 9/9; steps/sec: 68.46; FastEstimator-BestModelSaver: Saved model to /tmp/tmpq5st5yqv/model_best_accuracy.pt FastEstimator-Eval: step: 98; epoch: 1; accuracy: 0.5464; ce: 1.3504643; max_accuracy: 0.5464; since_best_accuracy: 0; FastEstimator-Train: step: 100; ce: 1.5739734; model_lr: 0.008163265; steps/sec: 10.35; FastEstimator-Train: step: 196; epoch: 2; epoch_time: 9.54 sec; Eval Progress: 1/9; Eval Progress: 3/9; steps/sec: 60.31; Eval Progress: 6/9; steps/sec: 63.37; Eval Progress: 9/9; steps/sec: 68.77; FastEstimator-BestModelSaver: Saved model to /tmp/tmpq5st5yqv/model_best_accuracy.pt FastEstimator-Eval: step: 196; epoch: 2; accuracy: 0.6372; ce: 1.1597185; max_accuracy: 0.6372; since_best_accuracy: 0; FastEstimator-Train: step: 200; ce: 1.4725939; model_lr: 0.01632653; steps/sec: 10.56; FastEstimator-Train: step: 294; epoch: 3; epoch_time: 9.38 sec; Eval Progress: 1/9; Eval Progress: 3/9; steps/sec: 53.16; Eval Progress: 6/9; steps/sec: 61.63; Eval Progress: 9/9; steps/sec: 66.2; FastEstimator-BestModelSaver: Saved model to /tmp/tmpq5st5yqv/model_best_accuracy.pt FastEstimator-Eval: step: 294; epoch: 3; accuracy: 0.7268; ce: 0.8682629; max_accuracy: 0.7268; since_best_accuracy: 0; FastEstimator-Train: step: 300; ce: 1.348976; model_lr: 0.024489796; steps/sec: 10.51; FastEstimator-Train: step: 392; epoch: 4; epoch_time: 9.4 sec; Eval Progress: 1/9; Eval Progress: 3/9; steps/sec: 58.85; Eval Progress: 6/9; steps/sec: 65.5; Eval Progress: 9/9; steps/sec: 65.74; FastEstimator-BestModelSaver: Saved model to /tmp/tmpq5st5yqv/model_best_accuracy.pt FastEstimator-Eval: step: 392; epoch: 4; accuracy: 0.7662; ce: 0.7805661; max_accuracy: 0.7662; since_best_accuracy: 0; FastEstimator-Train: step: 400; ce: 1.3233813; model_lr: 0.03265306; steps/sec: 10.51; FastEstimator-Train: step: 490; epoch: 5; epoch_time: 9.48 sec; Eval Progress: 1/9; Eval Progress: 3/9; steps/sec: 42.64; Eval Progress: 6/9; steps/sec: 65.71; Eval Progress: 9/9; steps/sec: 67.01; FastEstimator-BestModelSaver: Saved model to /tmp/tmpq5st5yqv/model_best_accuracy.pt FastEstimator-Eval: step: 490; epoch: 5; accuracy: 0.7768; ce: 0.75958633; max_accuracy: 0.7768; since_best_accuracy: 0; FastEstimator-Train: step: 500; ce: 1.2187204; model_lr: 0.039785177; steps/sec: 10.62; FastEstimator-Train: step: 588; epoch: 6; epoch_time: 9.33 sec; Eval Progress: 1/9; Eval Progress: 3/9; steps/sec: 57.03; Eval Progress: 6/9; steps/sec: 63.87; Eval Progress: 9/9; steps/sec: 64.77; FastEstimator-BestModelSaver: Saved model to /tmp/tmpq5st5yqv/model_best_accuracy.pt FastEstimator-Eval: step: 588; epoch: 6; accuracy: 0.8296; ce: 0.6205404; max_accuracy: 0.8296; since_best_accuracy: 0; FastEstimator-Train: step: 600; ce: 1.1162705; model_lr: 0.03763695; steps/sec: 10.58; FastEstimator-Train: step: 686; epoch: 7; epoch_time: 9.42 sec; Eval Progress: 1/9; Eval Progress: 3/9; steps/sec: 55.44; Eval Progress: 6/9; steps/sec: 64.01; Eval Progress: 9/9; steps/sec: 64.72; FastEstimator-BestModelSaver: Saved model to /tmp/tmpq5st5yqv/model_best_accuracy.pt FastEstimator-Eval: step: 686; epoch: 7; accuracy: 0.8496; ce: 0.59812355; max_accuracy: 0.8496; since_best_accuracy: 0; FastEstimator-Train: step: 700; ce: 1.1050198; model_lr: 0.03548872; steps/sec: 10.58; FastEstimator-Train: step: 784; epoch: 8; epoch_time: 9.39 sec; Eval Progress: 1/9; Eval Progress: 3/9; steps/sec: 47.2; Eval Progress: 6/9; steps/sec: 65.16; Eval Progress: 9/9; steps/sec: 64.75; FastEstimator-BestModelSaver: Saved model to /tmp/tmpq5st5yqv/model_best_accuracy.pt FastEstimator-Eval: step: 784; epoch: 8; accuracy: 0.858; ce: 0.5833092; max_accuracy: 0.858; since_best_accuracy: 0; FastEstimator-Train: step: 800; ce: 1.1079816; model_lr: 0.033340495; steps/sec: 10.59; FastEstimator-Train: step: 882; epoch: 9; epoch_time: 9.34 sec; Eval Progress: 1/9; Eval Progress: 3/9; steps/sec: 47.24; Eval Progress: 6/9; steps/sec: 61.69; Eval Progress: 9/9; steps/sec: 67.45; FastEstimator-BestModelSaver: Saved model to /tmp/tmpq5st5yqv/model_best_accuracy.pt FastEstimator-Eval: step: 882; epoch: 9; accuracy: 0.8608; ce: 0.5679447; max_accuracy: 0.8608; since_best_accuracy: 0; FastEstimator-Train: step: 900; ce: 1.1163852; model_lr: 0.031192265; steps/sec: 10.33; FastEstimator-Train: step: 980; epoch: 10; epoch_time: 9.58 sec; Eval Progress: 1/9; Eval Progress: 3/9; steps/sec: 38.38; Eval Progress: 6/9; steps/sec: 59.85; Eval Progress: 9/9; steps/sec: 63.56; FastEstimator-BestModelSaver: Saved model to /tmp/tmpq5st5yqv/model_best_accuracy.pt FastEstimator-Eval: step: 980; epoch: 10; accuracy: 0.8716; ce: 0.55340326; max_accuracy: 0.8716; since_best_accuracy: 0; FastEstimator-Train: step: 1000; ce: 1.0618074; model_lr: 0.02904404; steps/sec: 10.63; FastEstimator-Train: step: 1078; epoch: 11; epoch_time: 9.39 sec; Eval Progress: 1/9; Eval Progress: 3/9; steps/sec: 49.09; Eval Progress: 6/9; steps/sec: 68.61; Eval Progress: 9/9; steps/sec: 65.94; FastEstimator-BestModelSaver: Saved model to /tmp/tmpq5st5yqv/model_best_accuracy.pt FastEstimator-Eval: step: 1078; epoch: 11; accuracy: 0.8834; ce: 0.49868757; max_accuracy: 0.8834; since_best_accuracy: 0; FastEstimator-Train: step: 1100; ce: 1.0652387; model_lr: 0.026895812; steps/sec: 10.55; FastEstimator-Train: step: 1176; epoch: 12; epoch_time: 9.4 sec; Eval Progress: 1/9; Eval Progress: 3/9; steps/sec: 47.16; Eval Progress: 6/9; steps/sec: 59.78; Eval Progress: 9/9; steps/sec: 62.87; FastEstimator-BestModelSaver: Saved model to /tmp/tmpq5st5yqv/model_best_accuracy.pt FastEstimator-Eval: step: 1176; epoch: 12; accuracy: 0.8902; ce: 0.50505483; max_accuracy: 0.8902; since_best_accuracy: 0; FastEstimator-Train: step: 1200; ce: 1.0391171; model_lr: 0.024747584; steps/sec: 10.5; FastEstimator-Train: step: 1274; epoch: 13; epoch_time: 9.48 sec; Eval Progress: 1/9; Eval Progress: 3/9; steps/sec: 56.42; Eval Progress: 6/9; steps/sec: 59.49; Eval Progress: 9/9; steps/sec: 66.38; FastEstimator-BestModelSaver: Saved model to /tmp/tmpq5st5yqv/model_best_accuracy.pt FastEstimator-Eval: step: 1274; epoch: 13; accuracy: 0.8958; ce: 0.4790287; max_accuracy: 0.8958; since_best_accuracy: 0; FastEstimator-Train: step: 1300; ce: 1.0158072; model_lr: 0.022599356; steps/sec: 10.57; FastEstimator-Train: step: 1372; epoch: 14; epoch_time: 9.38 sec; Eval Progress: 1/9; Eval Progress: 3/9; steps/sec: 32.5; Eval Progress: 6/9; steps/sec: 63.85; Eval Progress: 9/9; steps/sec: 63.61; FastEstimator-BestModelSaver: Saved model to /tmp/tmpq5st5yqv/model_best_accuracy.pt FastEstimator-Eval: step: 1372; epoch: 14; accuracy: 0.899; ce: 0.4504097; max_accuracy: 0.899; since_best_accuracy: 0; FastEstimator-Train: step: 1400; ce: 1.005058; model_lr: 0.020451128; steps/sec: 10.6; FastEstimator-Train: step: 1470; epoch: 15; epoch_time: 9.35 sec; Eval Progress: 1/9; Eval Progress: 3/9; steps/sec: 52.57; Eval Progress: 6/9; steps/sec: 56.99; Eval Progress: 9/9; steps/sec: 68.16; FastEstimator-Eval: step: 1470; epoch: 15; accuracy: 0.8934; ce: 0.47465906; max_accuracy: 0.899; since_best_accuracy: 1; FastEstimator-Train: step: 1500; ce: 0.99793816; model_lr: 0.0183029; steps/sec: 10.55; FastEstimator-Train: step: 1568; epoch: 16; epoch_time: 9.43 sec; Eval Progress: 1/9; Eval Progress: 3/9; steps/sec: 52.53; Eval Progress: 6/9; steps/sec: 68.94; Eval Progress: 9/9; steps/sec: 63.96; FastEstimator-BestModelSaver: Saved model to /tmp/tmpq5st5yqv/model_best_accuracy.pt FastEstimator-Eval: step: 1568; epoch: 16; accuracy: 0.9104; ce: 0.4402462; max_accuracy: 0.9104; since_best_accuracy: 0; FastEstimator-Train: step: 1600; ce: 0.9909789; model_lr: 0.016154673; steps/sec: 10.62; FastEstimator-Train: step: 1666; epoch: 17; epoch_time: 9.34 sec; Eval Progress: 1/9; Eval Progress: 3/9; steps/sec: 55.53; Eval Progress: 6/9; steps/sec: 64.31; Eval Progress: 9/9; steps/sec: 64.53; FastEstimator-BestModelSaver: Saved model to /tmp/tmpq5st5yqv/model_best_accuracy.pt FastEstimator-Eval: step: 1666; epoch: 17; accuracy: 0.9142; ce: 0.42767367; max_accuracy: 0.9142; since_best_accuracy: 0; FastEstimator-Train: step: 1700; ce: 0.97114897; model_lr: 0.014006444; steps/sec: 10.7; FastEstimator-Train: step: 1764; epoch: 18; epoch_time: 9.24 sec; Eval Progress: 1/9; Eval Progress: 3/9; steps/sec: 58.98; Eval Progress: 6/9; steps/sec: 64.64; Eval Progress: 9/9; steps/sec: 59.78; FastEstimator-BestModelSaver: Saved model to /tmp/tmpq5st5yqv/model_best_accuracy.pt FastEstimator-Eval: step: 1764; epoch: 18; accuracy: 0.9158; ce: 0.41351128; max_accuracy: 0.9158; since_best_accuracy: 0; FastEstimator-Train: step: 1800; ce: 0.9536685; model_lr: 0.011858217; steps/sec: 10.56; FastEstimator-Train: step: 1862; epoch: 19; epoch_time: 9.37 sec; Eval Progress: 1/9; Eval Progress: 3/9; steps/sec: 47.6; Eval Progress: 6/9; steps/sec: 64.17; Eval Progress: 9/9; steps/sec: 64.03; FastEstimator-BestModelSaver: Saved model to /tmp/tmpq5st5yqv/model_best_accuracy.pt FastEstimator-Eval: step: 1862; epoch: 19; accuracy: 0.9194; ce: 0.4258132; max_accuracy: 0.9194; since_best_accuracy: 0; FastEstimator-Train: step: 1900; ce: 0.94474137; model_lr: 0.00970999; steps/sec: 10.59; FastEstimator-Train: step: 1960; epoch: 20; epoch_time: 9.47 sec; Eval Progress: 1/9; Eval Progress: 3/9; steps/sec: 62.46; Eval Progress: 6/9; steps/sec: 68.06; Eval Progress: 9/9; steps/sec: 65.02; FastEstimator-BestModelSaver: Saved model to /tmp/tmpq5st5yqv/model_best_accuracy.pt FastEstimator-Eval: step: 1960; epoch: 20; accuracy: 0.9234; ce: 0.39099675; max_accuracy: 0.9234; since_best_accuracy: 0; FastEstimator-Train: step: 2000; ce: 0.93317723; model_lr: 0.0075617614; steps/sec: 10.39; FastEstimator-Train: step: 2058; epoch: 21; epoch_time: 9.53 sec; Eval Progress: 1/9; Eval Progress: 3/9; steps/sec: 35.81; Eval Progress: 6/9; steps/sec: 59.58; Eval Progress: 9/9; steps/sec: 67.37; FastEstimator-BestModelSaver: Saved model to /tmp/tmpq5st5yqv/model_best_accuracy.pt FastEstimator-Eval: step: 2058; epoch: 21; accuracy: 0.924; ce: 0.3927131; max_accuracy: 0.924; since_best_accuracy: 0; FastEstimator-Train: step: 2100; ce: 0.9343517; model_lr: 0.0054135337; steps/sec: 10.41; FastEstimator-Train: step: 2156; epoch: 22; epoch_time: 9.53 sec; Eval Progress: 1/9; Eval Progress: 3/9; steps/sec: 53.58; Eval Progress: 6/9; steps/sec: 66.41; Eval Progress: 9/9; steps/sec: 63.95; FastEstimator-BestModelSaver: Saved model to /tmp/tmpq5st5yqv/model_best_accuracy.pt FastEstimator-Eval: step: 2156; epoch: 22; accuracy: 0.929; ce: 0.3977168; max_accuracy: 0.929; since_best_accuracy: 0; FastEstimator-Train: step: 2200; ce: 0.93351036; model_lr: 0.0032653061; steps/sec: 10.51; FastEstimator-Train: step: 2254; epoch: 23; epoch_time: 9.42 sec; Eval Progress: 1/9; Eval Progress: 3/9; steps/sec: 53.55; Eval Progress: 6/9; steps/sec: 63.23; Eval Progress: 9/9; steps/sec: 64.59; FastEstimator-BestModelSaver: Saved model to /tmp/tmpq5st5yqv/model_best_accuracy.pt FastEstimator-Eval: step: 2254; epoch: 23; accuracy: 0.9374; ce: 0.376044; max_accuracy: 0.9374; since_best_accuracy: 0; FastEstimator-Train: step: 2300; ce: 0.90236723; model_lr: 0.0011170784; steps/sec: 10.48; FastEstimator-Train: step: 2352; epoch: 24; epoch_time: 9.5 sec; Eval Progress: 1/9; Eval Progress: 3/9; steps/sec: 62.99; Eval Progress: 6/9; steps/sec: 63.47; Eval Progress: 9/9; steps/sec: 67.18; FastEstimator-Eval: step: 2352; epoch: 24; accuracy: 0.9348; ce: 0.37468904; max_accuracy: 0.9374; since_best_accuracy: 1; FastEstimator-Finish: step: 2352; model_lr: 0.0; total_time: 371.66 sec;
Model testing¶
Estimator.test
will trigger model testing using all of the test data defined in the Pipeline
. This will allow us to check our accuracy on previously unseen data.
estimator.test()
FastEstimator-Test: step: 2352; epoch: 24; accuracy: 0.9252; ce: 0.3898231;
Inferencing¶
In this step we run image inference directly using the model that we just trained. We randomly select 5 images from testing dataset and infer them image by image using Pipeline.transform
and Network.transform
. Please be aware that the Pipeline
is no longer the same as it was during training, because we don't want to use data augmentation during inference. This detail was already defined in the Pipeline
(mode = "!infer").
sample_num = 5
for i, j in enumerate(np.random.randint(low=0, high=batch_size-1, size=sample_num)):
data = {"x": test_data["x"][j]}
# run the pipeline
data = pipeline.transform(data, mode="infer")
# run the network
data = network.transform(data, mode="infer")
predict = data["y_pred"].numpy()
fig = GridDisplay([BatchDisplay(image=data['x'], title="Pipeline Output"),
BatchDisplay(text=[class_names[np.argmax(predict)]], title="Predicted Class")
])
fig.show()