Multi-Task Learning using Uncertainty Weighted Loss¶
[Paper] [Notebook] [TF Implementation] [Torch Implementation]
Multi-task learning is popular in many deep learning applications. For example, in object detection the network performs both classification and localization for each object. As a result, the final loss will be a combination of classification loss and regression loss. The most frequent way of combining two losses is by simply adding them together:
$loss_{total} = loss_1 + loss_2$
However, a problem emerges when the two losses are on different numerical scales. To resolve this issue, people usually manually design/experimentally determine the best weight, which is very time consuming and computationally expensive:
$loss_{total} = w_1loss_1 + w_2loss_2$
This paper presents an interesting idea: make the weights w1 and w2 trainable parameters based on the uncertainty of each task, such that the network can dynamically focus more on the task with higher uncertainty.
import os
import tempfile
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as fn
import numpy as np
from torch.nn.init import kaiming_normal_ as he_normal
from torchvision import models
import fastestimator as fe
from fastestimator.backend import reduce_mean
from fastestimator.op.numpyop import Delete
from fastestimator.op.numpyop.meta import Sometimes
from fastestimator.op.numpyop.multivariate import HorizontalFlip, LongestMaxSize, PadIfNeeded, ReadMat, ShiftScaleRotate
from fastestimator.op.numpyop.univariate import ChannelTranspose, Normalize, ReadImage, Reshape
from fastestimator.op.tensorop import TensorOp
from fastestimator.op.tensorop.loss import CrossEntropy
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
from fastestimator.schedule import cosine_decay
from fastestimator.trace.adapt import LRScheduler
from fastestimator.trace.io import BestModelSaver
from fastestimator.trace.metric import Accuracy, Dice
#parameters
epochs = 25
batch_size = 8
train_steps_per_epoch = None
eval_steps_per_epoch = None
save_dir = tempfile.mkdtemp()
data_dir = None
Building Components¶
Dataset¶
We will use the CUB200 2010 dataset by Caltech. It contains 6033 bird images from 200 categories, where each image also has a corresponding mask. Therefore, our task is to classify and segment the bird given the image.
We use a FastEstimator API to load the CUB200 dataset and split the dataset to get train, evaluation and test sets.
from fastestimator.dataset.data import cub200
train_data = cub200.load_data(root_dir=data_dir)
eval_data = train_data.split(0.3)
test_data = eval_data.split(0.5)
Step 1: Create Pipeline
¶
We read the images with ReadImage
, and the masks stored in a MAT file with ReadMat
. There is other information stored in the MAT file, so we specify the key seg
to retrieve the mask only.
Here the main task is to resize the images and masks into 512 by 512 pixels. We use LongestMaxSize
(to preserve the aspect ratio) and PadIfNeeded
to resize the image. We will augment both image and mask in the same way and rescale the image pixel values between -1 and 1 since we are using pre-trained ImageNet weights.
pipeline = fe.Pipeline(
batch_size=batch_size,
train_data=train_data,
eval_data=eval_data,
test_data=test_data,
ops=[
ReadImage(inputs="image", outputs="image", parent_path=train_data.parent_path),
Normalize(inputs="image", outputs="image", mean=1.0, std=1.0, max_pixel_value=127.5),
ReadMat(inputs='annotation', outputs="seg", parent_path=train_data.parent_path),
Delete(keys="annotation"),
LongestMaxSize(max_size=512, image_in="image", image_out="image", mask_in="seg", mask_out="seg"),
PadIfNeeded(min_height=512,
min_width=512,
image_in="image",
image_out="image",
mask_in="seg",
mask_out="seg",
border_mode=cv2.BORDER_CONSTANT,
value=0,
mask_value=0),
ShiftScaleRotate(image_in="image",
mask_in="seg",
image_out="image",
mask_out="seg",
mode="train",
shift_limit=0.2,
rotate_limit=15.0,
scale_limit=0.2,
border_mode=cv2.BORDER_CONSTANT,
value=0,
mask_value=0),
Sometimes(HorizontalFlip(image_in="image", mask_in="seg", image_out="image", mask_out="seg", mode="train")),
ChannelTranspose(inputs="image", outputs="image"),
Reshape(shape=(1, 512, 512), inputs="seg", outputs="seg")
])
Let's visualize our Pipeline
results¶
from fastestimator.util import ImageDisplay, GridDisplay
result = pipeline.get_results()
GridDisplay([ImageDisplay(image=result["image"][1],
title="Original Image"),
ImageDisplay(image=result["image"][1],
masks=np.squeeze(result["seg"][1].numpy()),
title="Mask Overlay"),
]).show()
Step 2: Create Network
¶
In this implementation, the network architecture is not the focus. Therefore, we are going to create something out of the blue :). How about a combination of resnet50 and Unet that can do both classification and segmentation? We can call it - ResUnet50
class Upsample2D(nn.Module):
"""Upsampling Block"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.upsample = nn.Sequential(
nn.Upsample(mode="bilinear", scale_factor=2, align_corners=True),
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True))
for l in self.upsample:
if isinstance(l, nn.Conv2d):
he_normal(l.weight.data)
def forward(self, x):
return self.upsample(x)
class DecBlock(nn.Module):
"""Decoder Block"""
def __init__(self, upsample_in_ch, conv_in_ch, out_ch):
super().__init__()
self.upsample = Upsample2D(upsample_in_ch, out_ch)
self.conv_layers = nn.Sequential(
nn.Conv2d(conv_in_ch, out_ch, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
nn.ReLU(inplace=True))
for l in self.conv_layers:
if isinstance(l, nn.Conv2d):
he_normal(l.weight.data)
def forward(self, x_up, x_down):
x = self.upsample(x_up)
x = torch.cat([x, x_down], 1)
x = self.conv_layers(x)
return x
class ResUnet50(nn.Module):
"""Network Architecture"""
def __init__(self, num_classes=200):
super().__init__()
base_model = models.resnet50(pretrained=True)
self.enc1 = nn.Sequential(*list(base_model.children())[:3])
self.input_pool = list(base_model.children())[3]
self.enc2 = nn.Sequential(*list(base_model.children())[4])
self.enc3 = nn.Sequential(*list(base_model.children())[5])
self.enc4 = nn.Sequential(*list(base_model.children())[6])
self.enc5 = nn.Sequential(*list(base_model.children())[7])
self.fc = nn.Linear(2048, num_classes)
self.dec6 = DecBlock(2048, 1536, 512)
self.dec7 = DecBlock(512, 768, 256)
self.dec8 = DecBlock(256, 384, 128)
self.dec9 = DecBlock(128, 128, 64)
self.dec10 = Upsample2D(64, 2)
self.mask = nn.Conv2d(2, 1, kernel_size=1)
def forward(self, x):
x_e1 = self.enc1(x)
x_e1_1 = self.input_pool(x_e1)
x_e2 = self.enc2(x_e1_1)
x_e3 = self.enc3(x_e2)
x_e4 = self.enc4(x_e3)
x_e5 = self.enc5(x_e4)
x_label = fn.max_pool2d(x_e5, kernel_size=x_e5.size()[2:])
x_label = x_label.view(x_label.shape[0], -1)
x_label = self.fc(x_label)
x_label = torch.softmax(x_label, dim=-1)
x_d6 = self.dec6(x_e5, x_e4)
x_d7 = self.dec7(x_d6, x_e3)
x_d8 = self.dec8(x_d7, x_e2)
x_d9 = self.dec9(x_d8, x_e1)
x_d10 = self.dec10(x_d9)
x_mask = self.mask(x_d10)
x_mask = torch.sigmoid(x_mask)
return x_label, x_mask
Other than the ResUnet50, we will have another network to contain the trainable weighted parameter in the weighted loss. We call it our uncertainty model. In the network ops
, ResUnet50 produces both a predicted label and predicted mask. These two predictions are then fed to classification loss and segmentation loss operators respectively. Finally, both losses are passed to the uncertainty model to create a final loss.
class UncertaintyLossNet(nn.Module):
"""Creates Uncertainty weighted loss model https://arxiv.org/abs/1705.07115
"""
def __init__(self):
super().__init__()
self.w1 = nn.Parameter(torch.zeros(1))
self.w2 = nn.Parameter(torch.zeros(1))
def forward(self, x):
loss = torch.exp(-self.w1) * x[0] + self.w1 + torch.exp(
-self.w2) * x[1] + self.w2
return loss
We also implement a TensorOp
to average the output of UncertaintyLossNet
for each batch:
class ReduceLoss(TensorOp):
def forward(self, data, state):
return reduce_mean(data)
resunet50 = fe.build(model_fn=ResUnet50,
model_name="resunet50",
optimizer_fn=lambda x: torch.optim.Adam(x, lr=1e-4))
uncertainty = fe.build(model_fn=UncertaintyLossNet,
model_name="uncertainty",
optimizer_fn=lambda x: torch.optim.Adam(x, lr=1e-5))
network = fe.Network(ops=[
ModelOp(inputs='image',
model=resunet50,
outputs=["label_pred", "mask_pred"]),
CrossEntropy(inputs=["label_pred", "label"],
outputs="cls_loss",
form="sparse",
average_loss=False),
CrossEntropy(inputs=["mask_pred", "seg"],
outputs="seg_loss",
form="binary",
average_loss=False),
ModelOp(inputs=["cls_loss", "seg_loss"],
model=uncertainty,
outputs="total_loss"),
ReduceLoss(inputs="total_loss", outputs="total_loss"),
UpdateOp(model=resunet50, loss_name="total_loss"),
UpdateOp(model=uncertainty, loss_name="total_loss")
])
Step 3: Create Estimator
¶
We will have four different traces to control/monitor the training: Dice
and Accuracy
will be used to measure segmentation and classification results, BestModelSaver
will save the model with best loss, and LRScheduler
will apply a cosine learning rate decay throughout the training loop.
traces = [
Accuracy(true_key="label", pred_key="label_pred"),
Dice(true_key="seg", pred_key='mask_pred'),
BestModelSaver(model=resunet50,
save_dir=save_dir,
metric="total_loss",
save_best_mode="min"),
LRScheduler(model=resunet50,
lr_fn=lambda step: cosine_decay(
step, cycle_length=13200, init_lr=1e-4))
]
estimator = fe.Estimator(network=network,
pipeline=pipeline,
traces=traces,
epochs=epochs,
train_steps_per_epoch=train_steps_per_epoch,
eval_steps_per_epoch=eval_steps_per_epoch,
log_steps=500)
Training and Testing¶
The whole training (25 epochs) will take about 1 hour 20 mins on single V100 GPU. We are going to reach ~0.87 dice and ~83% accuracy by the end of the training.
estimator.fit()
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ FastEstimator-Start: step: 1; logging_interval: 500; num_device: 1; FastEstimator-Train: step: 1; resunet50_lr: 1e-04; total_loss: 9.154388; FastEstimator-Train: step: 500; resunet50_lr: 9.964993e-05; steps/sec: 6.11; total_loss: 4.1662703; FastEstimator-Train: step: 528; epoch: 1; epoch_time: 91.81 sec; Eval Progress: 1/113; Eval Progress: 37/113; steps/sec: 12.85; Eval Progress: 75/113; steps/sec: 14.48; Eval Progress: 113/113; steps/sec: 14.36; FastEstimator-BestModelSaver: Saved model to /tmp/tmpxbzrtksw/resunet50_best_total_loss.pt FastEstimator-Eval: step: 528; epoch: 1; accuracy: 0.15248618784530388; Dice: 0.78948456; min_total_loss: 4.0875998; since_best_total_loss: 0; total_loss: 4.0875998; FastEstimator-Train: step: 1000; resunet50_lr: 9.860467e-05; steps/sec: 6.55; total_loss: 2.3772597; FastEstimator-Train: step: 1056; epoch: 2; epoch_time: 79.99 sec; Eval Progress: 1/113; Eval Progress: 37/113; steps/sec: 9.92; Eval Progress: 75/113; steps/sec: 12.16; Eval Progress: 113/113; steps/sec: 14.27; FastEstimator-BestModelSaver: Saved model to /tmp/tmpxbzrtksw/resunet50_best_total_loss.pt FastEstimator-Eval: step: 1056; epoch: 2; accuracy: 0.4276243093922652; Dice: 0.8241268; min_total_loss: 2.297407; since_best_total_loss: 0; total_loss: 2.297407; FastEstimator-Train: step: 1500; resunet50_lr: 9.687901e-05; steps/sec: 5.85; total_loss: 1.4891936; FastEstimator-Train: step: 1584; epoch: 3; epoch_time: 89.91 sec; Eval Progress: 1/113; Eval Progress: 37/113; steps/sec: 11.1; Eval Progress: 75/113; steps/sec: 13.71; Eval Progress: 113/113; steps/sec: 13.34; FastEstimator-BestModelSaver: Saved model to /tmp/tmpxbzrtksw/resunet50_best_total_loss.pt FastEstimator-Eval: step: 1584; epoch: 3; accuracy: 0.5580110497237569; Dice: 0.8163089; min_total_loss: 1.7144347; since_best_total_loss: 0; total_loss: 1.7144347; FastEstimator-Train: step: 2000; resunet50_lr: 9.449736e-05; steps/sec: 6.48; total_loss: 0.92048085; FastEstimator-Train: step: 2112; epoch: 4; epoch_time: 81.09 sec; Eval Progress: 1/113; Eval Progress: 37/113; steps/sec: 13.91; Eval Progress: 75/113; steps/sec: 15.05; Eval Progress: 113/113; steps/sec: 15.09; FastEstimator-BestModelSaver: Saved model to /tmp/tmpxbzrtksw/resunet50_best_total_loss.pt FastEstimator-Eval: step: 2112; epoch: 4; accuracy: 0.6397790055248619; Dice: 0.83269066; min_total_loss: 1.4245821; since_best_total_loss: 0; total_loss: 1.4245821; FastEstimator-Train: step: 2500; resunet50_lr: 9.149339e-05; steps/sec: 5.54; total_loss: 0.5726477; FastEstimator-Train: step: 2640; epoch: 5; epoch_time: 93.91 sec; Eval Progress: 1/113; Eval Progress: 37/113; steps/sec: 12.88; Eval Progress: 75/113; steps/sec: 13.97; Eval Progress: 113/113; steps/sec: 13.62; FastEstimator-Eval: step: 2640; epoch: 5; accuracy: 0.63646408839779; Dice: 0.8394248; min_total_loss: 1.4245821; since_best_total_loss: 1; total_loss: 1.4509172; FastEstimator-Train: step: 3000; resunet50_lr: 8.7909604e-05; steps/sec: 6.52; total_loss: 1.4370325; FastEstimator-Train: step: 3168; epoch: 6; epoch_time: 89.83 sec; Eval Progress: 1/113; Eval Progress: 37/113; steps/sec: 6.42; Eval Progress: 75/113; steps/sec: 13.17; Eval Progress: 113/113; steps/sec: 12.92; FastEstimator-BestModelSaver: Saved model to /tmp/tmpxbzrtksw/resunet50_best_total_loss.pt FastEstimator-Eval: step: 3168; epoch: 6; accuracy: 0.687292817679558; Dice: 0.8373534; min_total_loss: 1.2567004; since_best_total_loss: 0; total_loss: 1.2567004; FastEstimator-Train: step: 3500; resunet50_lr: 8.379669e-05; steps/sec: 5.61; total_loss: 0.2995667; FastEstimator-Train: step: 3696; epoch: 7; epoch_time: 84.76 sec; Eval Progress: 1/113; Eval Progress: 37/113; steps/sec: 12.1; Eval Progress: 75/113; steps/sec: 14.14; Eval Progress: 113/113; steps/sec: 13.69; FastEstimator-Eval: step: 3696; epoch: 7; accuracy: 0.6486187845303868; Dice: 0.83454293; min_total_loss: 1.2567004; since_best_total_loss: 1; total_loss: 1.5572991; FastEstimator-Train: step: 4000; resunet50_lr: 7.921282e-05; steps/sec: 5.84; total_loss: 0.44874218; FastEstimator-Train: step: 4224; epoch: 8; epoch_time: 94.05 sec; Eval Progress: 1/113; Eval Progress: 37/113; steps/sec: 12.68; Eval Progress: 75/113; steps/sec: 14.45; Eval Progress: 113/113; steps/sec: 14.19; FastEstimator-Eval: step: 4224; epoch: 8; accuracy: 0.7038674033149172; Dice: 0.8413349; min_total_loss: 1.2567004; since_best_total_loss: 2; total_loss: 1.2848965; FastEstimator-Train: step: 4500; resunet50_lr: 7.422282e-05; steps/sec: 6.16; total_loss: 0.30338854; FastEstimator-Train: step: 4752; epoch: 9; epoch_time: 80.02 sec; Eval Progress: 1/113; Eval Progress: 37/113; steps/sec: 13.72; Eval Progress: 75/113; steps/sec: 11.25; Eval Progress: 113/113; steps/sec: 14.8; FastEstimator-BestModelSaver: Saved model to /tmp/tmpxbzrtksw/resunet50_best_total_loss.pt FastEstimator-Eval: step: 4752; epoch: 9; accuracy: 0.7149171270718232; Dice: 0.8224843; min_total_loss: 1.132937; since_best_total_loss: 0; total_loss: 1.132937; FastEstimator-Train: step: 5000; resunet50_lr: 6.8897294e-05; steps/sec: 5.57; total_loss: 0.5157542; FastEstimator-Train: step: 5280; epoch: 10; epoch_time: 94.31 sec; Eval Progress: 1/113; Eval Progress: 37/113; steps/sec: 12.57; Eval Progress: 75/113; steps/sec: 14.03; Eval Progress: 113/113; steps/sec: 14.15; FastEstimator-Eval: step: 5280; epoch: 10; accuracy: 0.7116022099447514; Dice: 0.8515037; min_total_loss: 1.132937; since_best_total_loss: 1; total_loss: 1.2132615; FastEstimator-Train: step: 5500; resunet50_lr: 6.331154e-05; steps/sec: 6.52; total_loss: 0.4250332; FastEstimator-Train: step: 5808; epoch: 11; epoch_time: 80.24 sec; Eval Progress: 1/113; Eval Progress: 37/113; steps/sec: 8.37; Eval Progress: 75/113; steps/sec: 10.42; Eval Progress: 113/113; steps/sec: 5.6; FastEstimator-BestModelSaver: Saved model to /tmp/tmpxbzrtksw/resunet50_best_total_loss.pt FastEstimator-Eval: step: 5808; epoch: 11; accuracy: 0.7425414364640884; Dice: 0.85177195; min_total_loss: 1.0848694; since_best_total_loss: 0; total_loss: 1.0848694; FastEstimator-Train: step: 6000; resunet50_lr: 5.7544585e-05; steps/sec: 6.35; total_loss: -0.0076454906; FastEstimator-Train: step: 6336; epoch: 12; epoch_time: 84.4 sec; Eval Progress: 1/113; Eval Progress: 37/113; steps/sec: 12.07; Eval Progress: 75/113; steps/sec: 13.34; Eval Progress: 113/113; steps/sec: 13.27; FastEstimator-Eval: step: 6336; epoch: 12; accuracy: 0.7337016574585635; Dice: 0.8479459; min_total_loss: 1.0848694; since_best_total_loss: 1; total_loss: 1.115356; FastEstimator-Train: step: 6500; resunet50_lr: 5.1677987e-05; steps/sec: 6.16; total_loss: -0.019147485; FastEstimator-Train: step: 6864; epoch: 13; epoch_time: 89.97 sec; Eval Progress: 1/113; Eval Progress: 37/113; steps/sec: 12.67; Eval Progress: 75/113; steps/sec: 13.32; Eval Progress: 113/113; steps/sec: 12.7; FastEstimator-BestModelSaver: Saved model to /tmp/tmpxbzrtksw/resunet50_best_total_loss.pt FastEstimator-Eval: step: 6864; epoch: 13; accuracy: 0.7712707182320442; Dice: 0.8462679; min_total_loss: 0.89779496; since_best_total_loss: 0; total_loss: 0.89779496; FastEstimator-Train: step: 7000; resunet50_lr: 4.5794724e-05; steps/sec: 5.94; total_loss: -0.0025181528; FastEstimator-Train: step: 7392; epoch: 14; epoch_time: 82.91 sec; Eval Progress: 1/113; Eval Progress: 37/113; steps/sec: 11.44; Eval Progress: 75/113; steps/sec: 12.84; Eval Progress: 113/113; steps/sec: 12.13; FastEstimator-Eval: step: 7392; epoch: 14; accuracy: 0.7668508287292818; Dice: 0.850353; min_total_loss: 0.89779496; since_best_total_loss: 1; total_loss: 0.93487144; FastEstimator-Train: step: 7500; resunet50_lr: 3.997802e-05; steps/sec: 6.45; total_loss: -0.05418662; FastEstimator-Train: step: 7920; epoch: 15; epoch_time: 88.47 sec; Eval Progress: 1/113; Eval Progress: 37/113; steps/sec: 10.31; Eval Progress: 75/113; steps/sec: 11.96; Eval Progress: 113/113; steps/sec: 12.33; FastEstimator-BestModelSaver: Saved model to /tmp/tmpxbzrtksw/resunet50_best_total_loss.pt FastEstimator-Eval: step: 7920; epoch: 15; accuracy: 0.7911602209944751; Dice: 0.8533265; min_total_loss: 0.8720767; since_best_total_loss: 0; total_loss: 0.8720767; FastEstimator-Train: step: 8000; resunet50_lr: 3.4310135e-05; steps/sec: 5.69; total_loss: -0.059786916; FastEstimator-Train: step: 8448; epoch: 16; epoch_time: 85.13 sec; Eval Progress: 1/113; Eval Progress: 37/113; steps/sec: 14.1; Eval Progress: 75/113; steps/sec: 15.31; Eval Progress: 113/113; steps/sec: 15.02; FastEstimator-Eval: step: 8448; epoch: 16; accuracy: 0.7977900552486188; Dice: 0.85619587; min_total_loss: 0.8720767; since_best_total_loss: 1; total_loss: 0.89978665; FastEstimator-Train: step: 8500; resunet50_lr: 2.8871247e-05; steps/sec: 6.32; total_loss: -0.05804207; FastEstimator-Train: step: 8976; epoch: 17; epoch_time: 86.89 sec; Eval Progress: 1/113; Eval Progress: 37/113; steps/sec: 12.14; Eval Progress: 75/113; steps/sec: 12.06; Eval Progress: 113/113; steps/sec: 10.28; FastEstimator-BestModelSaver: Saved model to /tmp/tmpxbzrtksw/resunet50_best_total_loss.pt FastEstimator-Eval: step: 8976; epoch: 17; accuracy: 0.8033149171270718; Dice: 0.8537778; min_total_loss: 0.7831172; since_best_total_loss: 0; total_loss: 0.7831172; FastEstimator-Train: step: 9000; resunet50_lr: 2.373828e-05; steps/sec: 5.83; total_loss: -0.07401227; FastEstimator-Train: step: 9500; resunet50_lr: 1.8983836e-05; steps/sec: 6.64; total_loss: -0.058266915; FastEstimator-Train: step: 9504; epoch: 18; epoch_time: 88.88 sec; Eval Progress: 1/113; Eval Progress: 37/113; steps/sec: 11.93; Eval Progress: 75/113; steps/sec: 12.3; Eval Progress: 113/113; steps/sec: 11.76; FastEstimator-Eval: step: 9504; epoch: 18; accuracy: 0.7966850828729282; Dice: 0.856397; min_total_loss: 0.7831172; since_best_total_loss: 1; total_loss: 0.84377414; FastEstimator-Train: step: 10000; resunet50_lr: 1.4675165e-05; steps/sec: 5.84; total_loss: 0.5967636; FastEstimator-Train: step: 10032; epoch: 19; epoch_time: 90.66 sec; Eval Progress: 1/113; Eval Progress: 37/113; steps/sec: 11.62; Eval Progress: 75/113; steps/sec: 13.34; Eval Progress: 113/113; steps/sec: 13.43; FastEstimator-Eval: step: 10032; epoch: 19; accuracy: 0.8022099447513812; Dice: 0.8553481; min_total_loss: 0.7831172; since_best_total_loss: 2; total_loss: 0.8032281; FastEstimator-Train: step: 10500; resunet50_lr: 1.0873208e-05; steps/sec: 5.98; total_loss: -0.12161801; FastEstimator-Train: step: 10560; epoch: 20; epoch_time: 87.34 sec; Eval Progress: 1/113; Eval Progress: 37/113; steps/sec: 11.92; Eval Progress: 75/113; steps/sec: 13.64; Eval Progress: 113/113; steps/sec: 13.76; FastEstimator-Eval: step: 10560; epoch: 20; accuracy: 0.8; Dice: 0.8542475; min_total_loss: 0.7831172; since_best_total_loss: 3; total_loss: 0.8481178; FastEstimator-Train: step: 11000; resunet50_lr: 7.631743e-06; steps/sec: 6.28; total_loss: -0.13021652; FastEstimator-Train: step: 11088; epoch: 21; epoch_time: 82.8 sec; Eval Progress: 1/113; Eval Progress: 37/113; steps/sec: 11.62; Eval Progress: 75/113; steps/sec: 13.0; Eval Progress: 113/113; steps/sec: 12.55; FastEstimator-BestModelSaver: Saved model to /tmp/tmpxbzrtksw/resunet50_best_total_loss.pt FastEstimator-Eval: step: 11088; epoch: 21; accuracy: 0.8110497237569061; Dice: 0.8564908; min_total_loss: 0.7806792; since_best_total_loss: 0; total_loss: 0.7806792; FastEstimator-Train: step: 11500; resunet50_lr: 4.996615e-06; steps/sec: 5.76; total_loss: -0.14522819; FastEstimator-Train: step: 11616; epoch: 22; epoch_time: 91.08 sec; Eval Progress: 1/113; Eval Progress: 37/113; steps/sec: 12.16; Eval Progress: 75/113; steps/sec: 12.74; Eval Progress: 113/113; steps/sec: 10.87; FastEstimator-BestModelSaver: Saved model to /tmp/tmpxbzrtksw/resunet50_best_total_loss.pt FastEstimator-Eval: step: 11616; epoch: 22; accuracy: 0.8154696132596685; Dice: 0.85830724; min_total_loss: 0.7108743; since_best_total_loss: 0; total_loss: 0.7108743; FastEstimator-Train: step: 12000; resunet50_lr: 3.0050978e-06; steps/sec: 6.15; total_loss: -0.112137154; FastEstimator-Train: step: 12144; epoch: 23; epoch_time: 84.93 sec; Eval Progress: 1/113; Eval Progress: 37/113; steps/sec: 11.72; Eval Progress: 75/113; steps/sec: 14.94; Eval Progress: 113/113; steps/sec: 15.09; FastEstimator-BestModelSaver: Saved model to /tmp/tmpxbzrtksw/resunet50_best_total_loss.pt FastEstimator-Eval: step: 12144; epoch: 23; accuracy: 0.8121546961325967; Dice: 0.8582533; min_total_loss: 0.70829; since_best_total_loss: 0; total_loss: 0.70829; FastEstimator-Train: step: 12500; resunet50_lr: 1.6853595e-06; steps/sec: 6.25; total_loss: -0.06531774; FastEstimator-Train: step: 12672; epoch: 24; epoch_time: 85.24 sec; Eval Progress: 1/113; Eval Progress: 37/113; steps/sec: 12.69; Eval Progress: 75/113; steps/sec: 13.95; Eval Progress: 113/113; steps/sec: 14.0; FastEstimator-BestModelSaver: Saved model to /tmp/tmpxbzrtksw/resunet50_best_total_loss.pt FastEstimator-Eval: step: 12672; epoch: 24; accuracy: 0.8198895027624309; Dice: 0.85663784; min_total_loss: 0.70189565; since_best_total_loss: 0; total_loss: 0.70189565; FastEstimator-Train: step: 13000; resunet50_lr: 1.0560667e-06; steps/sec: 6.06; total_loss: -0.17042789; FastEstimator-Train: step: 13200; epoch: 25; epoch_time: 86.13 sec; Eval Progress: 1/113; Eval Progress: 37/113; steps/sec: 13.74; Eval Progress: 75/113; steps/sec: 15.23; Eval Progress: 113/113; steps/sec: 7.14; FastEstimator-Eval: step: 13200; epoch: 25; accuracy: 0.8209944751381215; Dice: 0.86008275; min_total_loss: 0.70189565; since_best_total_loss: 1; total_loss: 0.70731735; FastEstimator-Finish: step: 13200; resunet50_lr: 1e-06; total_time: 2554.61 sec; uncertainty_lr: 1e-05;
Let's load the model with best loss and check our performance on the test set:
fe.backend.load_model(resunet50, os.path.join(save_dir, 'resunet50_best_total_loss.pt'))
estimator.test()
FastEstimator-Test: step: 13200; epoch: 25; accuracy: 0.825414364640884; Dice: 0.8634865; total_loss: 0.6717207;
Inferencing¶
We randomly select an image from the test dataset and use pipeline.transform
to process the image. We generate the results using network.transform
and visualize the prediction.
data = test_data[np.random.randint(low=0, high=len(test_data))]
result = pipeline.transform(data, mode="infer")
img = np.squeeze(result["image"])
img = np.transpose(img, (1, 2, 0))
mask_gt = np.squeeze(result["seg"])
Visualize Ground Truth¶
GridDisplay([ImageDisplay(image=img,
title="Original Image"),
ImageDisplay(image=img,
masks=mask_gt,
title="Mask Overlay"),
]).show()
Visualize Prediction¶
network = fe.Network(ops=[
ModelOp(inputs='image', model=resunet50, outputs=["label_pred", "mask_pred"])
])
predictions = network.transform(result, mode="infer")
predicted_mask = predictions["mask_pred"].numpy()
pred_mask = np.squeeze(predicted_mask)
pred_mask = np.round(pred_mask).astype(mask_gt.dtype)
GridDisplay([ImageDisplay(image=img,
title="Original Image"),
ImageDisplay(image=img,
masks=pred_mask,
title="Mask Overlay"),
]).show()