Advanced Tutorial 15: Finetuning Tutorial¶
Overview¶
In this tutorial we are going to cover finetuning using FastEstimator. This tutorial is structured as follows:
Setting Things Up ¶
First let's get some imports out of the way:¶
import os
import tempfile
import tensorflow as tf
# Since we will be mixing TF and Torch in the tutorial, we need to stop TF from taking all of the GPU memory.
# Normally you would pick either TF or Torch, so you don't need to worry about this.
physical_devices = tf.config.list_physical_devices('GPU')
for device in physical_devices:
try:
tf.config.experimental.set_memory_growth(device, True)
except:
pass
import fastestimator as fe
from fastestimator.trace.metric import Accuracy
from fastestimator.op.numpyop.univariate import ChannelTranspose, CoarseDropout, Normalize, Onehot
from fastestimator.op.numpyop.meta import Sometimes
from fastestimator.op.numpyop.multivariate import HorizontalFlip, PadIfNeeded, RandomCrop
from fastestimator.schedule.schedule import EpochScheduler
from fastestimator.op.tensorop.loss import CrossEntropy
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
from fastestimator.dataset.data import cifair100, cifair10
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
from fastestimator.architecture.tensorflow import LeNet as lenet_tf
from tensorflow.keras import Sequential, layers
from tensorflow.keras import Model
from fastestimator.architecture.pytorch import LeNet as lenet_torch
import torch.nn as nn
from torch import load, Tensor, cuda, save
import torch.nn.functional as fn
Define Reusable Methods ¶
def get_pipeline(dataset, num_classes, batch_size, mode='tf', min_height=40, min_width=40):
train_data, eval_data = dataset.load_data()
mean_value = (0.4914, 0.4822, 0.4465)
std_value = (0.2471, 0.2435, 0.2616)
ops = [ Normalize(inputs="x", outputs="x", mean=mean_value, std=std_value),
PadIfNeeded(min_height=min_height, min_width=min_width, image_in="x", image_out="x", mode="train"),
RandomCrop(32, 32, image_in="x", image_out="x", mode="train"),
Sometimes(HorizontalFlip(image_in="x", image_out="x", mode="train")),
CoarseDropout(inputs="x", outputs="x", mode="train", max_holes=1),
Onehot(inputs="y", outputs="y", mode="train", num_classes=num_classes, label_smoothing=0.2)]
if mode == 'torch':
ops.append(ChannelTranspose(inputs="x", outputs="x"))
return fe.Pipeline(
train_data=train_data,
eval_data=eval_data,
batch_size=batch_size,
ops=ops)
def get_network(model):
return fe.Network(ops=[
ModelOp(model=model, inputs="x", outputs="y_pred"),
CrossEntropy(inputs=("y_pred", "y"), outputs="ce"),
UpdateOp(model=model, loss_name="ce")])
def get_estimator(pipeline, network, epochs):
traces = [Accuracy(true_key="y", pred_key="y_pred")]
return fe.Estimator(pipeline=pipeline,
network=network,
epochs=epochs,
traces=traces,
log_steps=0)
Let's load some default training parameters as well¶
#training parameters
epochs_pretrain = 10
epochs_finetune = 5
batch_size = 64
base_num_classes = 100
finetune_num_classes = 10
model_dir = tempfile.gettempdir()
Tensorflow Workflow ¶
Train Base Model ¶
Now that boring stuff is done, let's train our first base model. We are using tensorflow LeNet to train on cifar100 with 100 classes. We are training for 10 epochs and saving the model at the end of the training job.
tf_input_shape = (32, 32, 3)
model_tf_pretrain = fe.build(model_fn=lambda: lenet_tf(input_shape=tf_input_shape, classes=base_num_classes), optimizer_fn="adam")
pipeline_tf_pretrain = get_pipeline(cifair100, base_num_classes, batch_size)
network_tf_pretrain = get_network(model_tf_pretrain)
estimator_tf_pretrain = get_estimator(pipeline_tf_pretrain, network_tf_pretrain, epochs_pretrain)
estimator_tf_pretrain.fit(warmup=False)
fe.backend.save_model(model_tf_pretrain, save_dir=model_dir, model_name= "lenet_tf")
2022-04-28 17:29:24.469363: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. 2022-04-28 17:29:26.182234: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38420 MB memory: -> device: 0, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:bd:00.0, compute capability: 8.0 2022-04-28 17:29:28.562053: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ FastEstimator-Warn: No ModelSaver Trace detected. Models will not be saved. FastEstimator-Start: step: 1; logging_interval: 0; num_device: 1;
2022-04-28 17:29:35.097002: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8100 2022-04-28 17:29:37.270702: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
FastEstimator-Train: step: 782; epoch: 1; FastEstimator-Eval: step: 782; epoch: 1; accuracy: 0.1402; ce: 3.6746857; FastEstimator-Train: step: 1564; epoch: 2; FastEstimator-Eval: step: 1564; epoch: 2; accuracy: 0.2204; ce: 3.2619572; FastEstimator-Train: step: 2346; epoch: 3; FastEstimator-Eval: step: 2346; epoch: 3; accuracy: 0.2469; ce: 3.1025422; FastEstimator-Train: step: 3128; epoch: 4; FastEstimator-Eval: step: 3128; epoch: 4; accuracy: 0.2879; ce: 2.9410963; FastEstimator-Train: step: 3910; epoch: 5; FastEstimator-Eval: step: 3910; epoch: 5; accuracy: 0.2944; ce: 2.8627439; FastEstimator-Train: step: 4692; epoch: 6; FastEstimator-Eval: step: 4692; epoch: 6; accuracy: 0.3167; ce: 2.7871962; FastEstimator-Train: step: 5474; epoch: 7; FastEstimator-Eval: step: 5474; epoch: 7; accuracy: 0.324; ce: 2.7451925; FastEstimator-Train: step: 6256; epoch: 8; FastEstimator-Eval: step: 6256; epoch: 8; accuracy: 0.3267; ce: 2.7209747; FastEstimator-Train: step: 7038; epoch: 9; FastEstimator-Eval: step: 7038; epoch: 9; accuracy: 0.3453; ce: 2.6295624; FastEstimator-Train: step: 7820; epoch: 10; FastEstimator-Eval: step: 7820; epoch: 10; accuracy: 0.3513; ce: 2.601092; FastEstimator-Finish: step: 7820; model_lr: 0.001; total_time: 108.44 sec;
'/tmp/lenet_tf.h5'
Load a new dataset for finetuning¶
For finetuning, We use FastEstimator API to load the ciFAIR-10 dataset. You can use your own dataset by updating get_pipeline
method.
pipeline_tf_finetune = get_pipeline(cifair10, finetune_num_classes, batch_size)
Extending Base Model for Finetuning ¶
Import Pretrained Model ¶
Now we are ready to extend our base model with finetuning task.
Let's load our pretrained weights saved in previous setup. The weights files are saved with h5
extension, since we have given lenet_tf
as model_name to the save_model
function the model weights are saved as lenet_tf.h5
.
weights_path = os.path.join(model_dir, "lenet_tf.h5")
pretrained_lenet_tf = fe.build(model_fn=lambda: lenet_tf(input_shape=tf_input_shape, classes=base_num_classes), optimizer_fn="adam", weights_path=weights_path)
Extending Base Model ¶
Let's remove the classification head of pretrained model and build a backbone. We will be using fe.build
to build a new fe model.
def get_tf_backbone(pretrained_model):
model = Model(inputs=pretrained_model.inputs, outputs=pretrained_model.layers[-3].output)
return model
backbone_tf = fe.build(model_fn=lambda: get_tf_backbone(pretrained_lenet_tf), optimizer_fn="adam")
Next, we will define a classification head that can be used for the finetuning task. This is simply two Dense
layers.
def get_class_head(finetune_num_classes):
return Sequential([layers.Dense(64, activation='relu', input_shape=(1024,)),
layers.Dense(finetune_num_classes, activation='softmax')])
cls_head_tf_finetune = fe.build(model_fn=lambda: get_class_head(finetune_num_classes), optimizer_fn="adam")
Combine Base Model and Finetune Model ¶
If you want to save the finetune model, we can combine the Backbone Model
and the Class Head Model
and provide it to ModelSaver later.
def combined_tf_model(backbone_model, cls_head_finetune):
backbone_output = backbone_model.layers[-1].output
x = cls_head_finetune.layers[0](backbone_output)
x = cls_head_finetune.layers[1](x)
model = Model(inputs=backbone_model.inputs, outputs=x)
return model
final_model_tf = fe.build(model_fn=lambda: combined_tf_model(backbone_tf, cls_head_tf_finetune), optimizer_fn="adam")
Start Finetuning ¶
For Finetuning, we want to train different part of the network in the following manner:
- epoch 1-3:
freeze
backbone,train
classification head only - epoch 4-end:
train
backbone and classification headtogether
Let's use EpochScheduler to define when backbone and class head weights are updated. UpdateOp is responsible for weight updating.
network_tf_finetune = fe.Network(ops=[
ModelOp(model=backbone_tf, inputs="x", outputs="feature"),
ModelOp(model=cls_head_tf_finetune, inputs="feature", outputs="y_pred"),
CrossEntropy(inputs=("y_pred", "y"), outputs="ce", from_logits=True),
EpochScheduler({1: None, 4: UpdateOp(model=backbone_tf, loss_name="ce")}),
EpochScheduler({1: UpdateOp(model=cls_head_tf_finetune, loss_name="ce")})])
estimator_tf_finetune = get_estimator(pipeline_tf_finetune, network_tf_finetune, epochs_finetune)
Let's train our finetune model using pretrained weights on our new dataset.
estimator_tf_finetune.fit(warmup=False)
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ FastEstimator-Warn: No ModelSaver Trace detected. Models will not be saved. FastEstimator-Start: step: 1; logging_interval: 0; num_device: 1;
/usr/local/lib/python3.8/dist-packages/tensorflow/python/util/dispatch.py:1082: UserWarning: "`categorical_crossentropy` received `from_logits=True`, but the `output` argument was produced by a sigmoid or softmax activation and thus does not represent logits. Was this intended?"
FastEstimator-Train: step: 782; epoch: 1;
/usr/local/lib/python3.8/dist-packages/tensorflow/python/util/dispatch.py:1082: UserWarning: "`sparse_categorical_crossentropy` received `from_logits=True`, but the `output` argument was produced by a sigmoid or softmax activation and thus does not represent logits. Was this intended?"
FastEstimator-Eval: step: 782; epoch: 1; accuracy: 0.6009; ce: 1.1969987; FastEstimator-Train: step: 1564; epoch: 2; FastEstimator-Eval: step: 1564; epoch: 2; accuracy: 0.6269; ce: 1.1396555; FastEstimator-Train: step: 2346; epoch: 3; FastEstimator-Eval: step: 2346; epoch: 3; accuracy: 0.6305; ce: 1.1114084; FastEstimator-Train: step: 3128; epoch: 4; FastEstimator-Eval: step: 3128; epoch: 4; accuracy: 0.666; ce: 1.0342808; FastEstimator-Train: step: 3910; epoch: 5; FastEstimator-Eval: step: 3910; epoch: 5; accuracy: 0.674; ce: 1.0168855; FastEstimator-Finish: step: 3910; model2_lr: 0.001; model3_lr: 0.001; total_time: 50.76 sec;
Finally, let's save our finetuned model.
fe.backend.save_model(final_model_tf, save_dir=model_dir, model_name="final_tf_finetune")
'/tmp/final_tf_finetune.h5'
Pytorch Workflow ¶
Train Base Model ¶
Let's train our first pytorch base model. We are using pytorch LeNet to train on cifar100 with 100 classes. We are training for 10 epochs and saving the model at the end of the training job.
torch_input_shape = (3, 32, 32)
model_torch_pretrain = fe.build(model_fn=lambda: lenet_torch(input_shape=torch_input_shape, classes=base_num_classes), optimizer_fn="adam")
pipeline_torch_pretrain = get_pipeline(cifair100, base_num_classes, batch_size, 'torch')
network_torch_pretrain = get_network(model_torch_pretrain)
estimator_torch_pretrain = get_estimator(pipeline_torch_pretrain, network_torch_pretrain, epochs_pretrain)
estimator_torch_pretrain.fit()
fe.backend.save_model(model_torch_pretrain, save_dir=model_dir, model_name="lenet_torch")
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ FastEstimator-Warn: No ModelSaver Trace detected. Models will not be saved. FastEstimator-Start: step: 1; logging_interval: 0; num_device: 1; FastEstimator-Train: step: 782; epoch: 1; FastEstimator-Eval: step: 782; epoch: 1; accuracy: 0.1401; ce: 3.658039; FastEstimator-Train: step: 1564; epoch: 2; FastEstimator-Eval: step: 1564; epoch: 2; accuracy: 0.1941; ce: 3.3732774; FastEstimator-Train: step: 2346; epoch: 3; FastEstimator-Eval: step: 2346; epoch: 3; accuracy: 0.2451; ce: 3.1511996; FastEstimator-Train: step: 3128; epoch: 4; FastEstimator-Eval: step: 3128; epoch: 4; accuracy: 0.2643; ce: 3.025363; FastEstimator-Train: step: 3910; epoch: 5; FastEstimator-Eval: step: 3910; epoch: 5; accuracy: 0.2903; ce: 2.8853397; FastEstimator-Train: step: 4692; epoch: 6; FastEstimator-Eval: step: 4692; epoch: 6; accuracy: 0.2999; ce: 2.8033638; FastEstimator-Train: step: 5474; epoch: 7; FastEstimator-Eval: step: 5474; epoch: 7; accuracy: 0.3158; ce: 2.7701957; FastEstimator-Train: step: 6256; epoch: 8; FastEstimator-Eval: step: 6256; epoch: 8; accuracy: 0.3213; ce: 2.7396107; FastEstimator-Train: step: 7038; epoch: 9; FastEstimator-Eval: step: 7038; epoch: 9; accuracy: 0.334; ce: 2.7355165; FastEstimator-Train: step: 7820; epoch: 10; FastEstimator-Eval: step: 7820; epoch: 10; accuracy: 0.3326; ce: 2.655726; FastEstimator-Finish: step: 7820; model5_lr: 0.001; total_time: 115.39 sec;
'/tmp/lenet_torch.pt'
Load a new dataset for finetuning¶
For finetuning, We use FastEstimator API to load the ciFAIR-10 dataset. You can use your own dataset by changing get_pipeline
method.
pipeline_torch_finetune = get_pipeline(cifair10, finetune_num_classes, batch_size, 'torch')
Extending Base Model for Finetuning ¶
Import Pretrained Model¶
Now we are ready to extend our base model with finetuning task.
Let's load our pretrained weights saved in our previous setup. The weights files are saved with h5 extension, since we have given lenet_torch
as model_name to the save_model
function the model weights are saved as lenet_torch.pt
. Replace it if you used different model_name in save_model
method.
weights_path=os.path.join(model_dir, 'lenet_torch.pt')
model_torch_pretrained = fe.build(model_fn=lambda: lenet_torch(input_shape=torch_input_shape, classes=base_num_classes), optimizer_fn="adam", weights_path=weights_path)
Extending Base Model ¶
Let's remove the last layer of pretrained model and build a new backbone. We will be using fe.build to build a new fe model.
class BackboneTorch(nn.Module):
def __init__(self, model_torch_pretrained) -> None:
super().__init__()
self.pool_kernel = 2
if isinstance(model_torch_pretrained, nn.DataParallel):
self.backbone_layers = nn.Sequential(*(list(model_torch_pretrained.module.children())[:-2]))
else:
self.backbone_layers = nn.Sequential(*(list(model_torch_pretrained.children())[:-2]))
def forward(self, x: Tensor) -> Tensor:
x = fn.relu(self.backbone_layers[0](x))
x = fn.max_pool2d(x, self.pool_kernel)
x = fn.relu(self.backbone_layers[1](x))
x = fn.max_pool2d(x, self.pool_kernel)
x = fn.relu(self.backbone_layers[2](x))
return x
backbone_torch = fe.build(model_fn=lambda: BackboneTorch(model_torch_pretrained), optimizer_fn="adam")
Next, we will define a classification head that can be used for the finetuning task. This is simply two nn.Linear
layers.
class ClassifierHead(nn.Module):
def __init__(self, classes=10):
super().__init__()
self.fc1 = nn.Linear(1024, 64)
self.fc2 = nn.Linear(64, classes)
def forward(self, x):
x = x.view(x.size(0), -1)
x = fn.relu(self.fc1(x))
x = fn.softmax(self.fc2(x), dim=-1)
return x
cls_head_torch_finetune = fe.build(model_fn=lambda: ClassifierHead(classes=finetune_num_classes), optimizer_fn="adam")
Combine Base Model and Finetune Model ¶
If you want to save the finetune model, we can combine the Backbone Model
and the Class Head Model
and provide it to ModelSaver later.
class CombinedTorchModel(nn.Module):
def __init__(self, backbone, cls_head):
super().__init__()
self.backbone = backbone
self.cls_head = cls_head
def forward(self, x):
x = self.backbone(x)
x = self.cls_head(x)
return x
final_torch_model = fe.build(model_fn=lambda: CombinedTorchModel(backbone_torch, cls_head_torch_finetune), optimizer_fn=None)
Start Finetuning ¶
For Finetuning, we want to train different part of the network in the following manner:
- epoch 1-3:
freeze
backbone,train
classification head only - epoch 4-end:
train
backbone and classification headtogether
Let's use EpochScheduler to define when backbone and class head weights are updated. UpdateOp is responsible for weight updating.
network_torch_finetune = fe.Network(ops=[
ModelOp(model=backbone_torch, inputs="x", outputs="feature"),
ModelOp(model=cls_head_torch_finetune, inputs="feature", outputs="y_pred"),
CrossEntropy(inputs=("y_pred", "y"), outputs="ce", from_logits=True),
EpochScheduler({1: None, 4: UpdateOp(model=backbone_torch, loss_name="ce")}),
EpochScheduler({1: UpdateOp(model=cls_head_torch_finetune, loss_name="ce")})])
estimator_torch_finetune = get_estimator(pipeline_torch_finetune, network_torch_finetune, epochs_finetune)
Let's train our finetune model using pretrained weights on our new dataset.
estimator_torch_finetune.fit()
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ FastEstimator-Warn: No ModelSaver Trace detected. Models will not be saved. FastEstimator-Start: step: 1; logging_interval: 0; num_device: 1; FastEstimator-Train: step: 782; epoch: 1; FastEstimator-Eval: step: 782; epoch: 1; accuracy: 0.5877; ce: 1.8731252; FastEstimator-Train: step: 1564; epoch: 2; FastEstimator-Eval: step: 1564; epoch: 2; accuracy: 0.6221; ce: 1.838632; FastEstimator-Train: step: 2346; epoch: 3; FastEstimator-Eval: step: 2346; epoch: 3; accuracy: 0.6047; ce: 1.8550365; FastEstimator-Train: step: 3128; epoch: 4; FastEstimator-Eval: step: 3128; epoch: 4; accuracy: 0.6269; ce: 1.8336332; FastEstimator-Train: step: 3910; epoch: 5; FastEstimator-Eval: step: 3910; epoch: 5; accuracy: 0.6301; ce: 1.8286426; FastEstimator-Finish: step: 3910; model7_lr: 0.001; model8_lr: 0.001; total_time: 60.43 sec;
Finally, let's save our finetuned model.
fe.backend.save_model(final_torch_model, save_dir=model_dir, model_name="final_torch_finetune")
'/tmp/final_torch_finetune.pt'