Stable Diffusion for Chest X-Ray Images¶
[paper][pytorch code][notebook]
Introduction to Stable Diffusion¶
Building on the Denoising Diffusion Probabilistic Models(DDPM) which applies diffusion on pixel space, Stable Diffusion was developed where diffusion is applied on latent space which is faster and more reliable. Stable diffusion consists of two parts latent embedder and noise estimator network. The latent embedder is normally a variational auto encoder or GAN model or any auto encoder which is used to extract the latent features from the data set. Noise estimator network would try to learn the noise adding to latent features at various time steps. Applying DDPM on the latent features instead of directly on pixel values produces images with higher fidelity and is faster. Diffusion pipeline starts with a random noise and applying denoising or noise estimation at each time step iteratively to generate images at the end. We will implement Stable Diffusion with Pytorch backend in FastEstimator and use it to generate chest-X-Ray images.
In this tutorial, we will talk about the following topics:
Define Basic Building Blocks ¶
Lets import required libraries¶
import math
import tempfile
from tqdm import tqdm
from collections import namedtuple
from typing import Any, Dict
import torch
import torch.nn.functional as F
from torch import nn
from torchvision import models
import fastestimator as fe
from fastestimator.backend._reduce_mean import reduce_mean
from fastestimator.backend._reduce_sum import reduce_sum
from fastestimator.dataset.data import nih_chestxray
from fastestimator.op.numpyop import Delete
from fastestimator.op.numpyop.multivariate import Resize
from fastestimator.op.numpyop.univariate import ChannelTranspose, Normalize, ReadImage
from fastestimator.op.tensorop import TensorOp, LambdaOp
from fastestimator.op.tensorop.loss import L1_Loss, LossOp
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
from fastestimator.trace.io import BestModelSaver, ModelSaver
from fastestimator.util import BatchDisplay, GridDisplay, ImageDisplay
Define training parameters¶
#training parameters
epochs = 40
# auto encoder trained for epochs//4 and noise network trained for epochs.
batch_size = 8
emb_channels = 16
num_samples = 16
image_size = 256
train_steps_per_epoch = 1000
eval_steps_per_epoch = 100
log_steps = 200
timesteps = 1000
model_dir = tempfile.mkdtemp()
Step 1 Set up data loader and preprocessing Pipeline
¶
In this step, we will define a common pipeline util to load NIH chest xray images training and validation dataset and prepare FastEstimator's data Pipeline.
Let's use a FastEstimator API to load the nih_chestxray dataset. The data preprocessing steps include Reading the image, resizing the image to 256*256 images, normalizing the image pixel values to the range [-1, 1] and transpose the channels. We set up these processing steps using Ops
, while also defining the data source and batch size for the Pipeline
.
def get_pipeline(image_size=256, mean=0.5, std=0.5, split_ratio=0.05):
train_dataset = nih_chestxray.load_data()
eval_dataset = train_dataset.split(split_ratio)
pipeline = fe.Pipeline(
train_data=train_dataset,
eval_data=eval_dataset,
batch_size=batch_size,
ops=[
ReadImage(inputs="x", outputs="image", color_flag='color'),
Resize(image_in="image", width=image_size, height=image_size),
Normalize(inputs="image", outputs="image", mean=mean, std=std),
ChannelTranspose(inputs="image", outputs="image"),
Delete(keys="x")
])
return pipeline
Step 2 Validate Pipeline
¶
In order to make sure the pipeline works as expected, we need to visualize its output. Pipeline.get_results
will return a batch of data for this purpose:
pipeline = get_pipeline(image_size=image_size, mean=0.5, std=0.5, split_ratio=0.05)
data = pipeline.get_results()
data_out = data["image"]*0.5 + 0.5
print("The pipeline output data size: {}".format(data_out.numpy().shape))
The pipeline output data size: (8, 3, 256, 256)
ImageDisplay(image=data_out[0], title="Sample Image").show()
class DiagonalGaussianDistribution:
def __init__(self, parameters):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape, generator=None, device=self.parameters.device)
return x
def kl(self):
return 0.5 * torch.mean(torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, dim=[1, 2, 3])
def sample_loss(self):
return self.sample(), torch.mean(self.kl())
class Upsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, in_channels, 3, stride=1, padding=1)
def forward(self, x, emb=None):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x, emb=None):
x = torch.nn.functional.pad(x, (0, 1, 0, 1), mode="constant", value=0)
x = self.conv(x)
return x
def nonlinearity(x):
return x * torch.sigmoid(x) # swish
def get_normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class ResnetBlock(nn.Module):
def __init__(self, in_channels, out_channels=None, temb_channels=512):
super().__init__()
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.norm1 = get_normalize(in_channels)
self.conv1 = torch.nn.Conv2d(self.in_channels, self.out_channels, 3, 1, 1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, self.out_channels)
self.norm2 = get_normalize(out_channels)
self.conv2 = torch.nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1)
if self.in_channels != self.out_channels:
self.conv_shortcut = torch.nn.Conv2d(self.in_channels, self.out_channels, 3, 1, 1)
def forward(self, x, temb=None):
h = x
h = self.conv1(nonlinearity(self.norm1(h)))
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.conv2(nonlinearity(self.norm2(h)))
if self.in_channels != self.out_channels:
x = self.conv_shortcut(x)
return x + h
Training Latent Embedder ¶
Step 1 Define Data Pipeline¶
Lets load data pipeline for training our latent embedder.
pipeline = get_pipeline(image_size=image_size, mean=0.5, std=0.5, split_ratio=0.05)
Step 2 Latent Embedder Model Construction¶
Latent Embedder for a diffusion pipeline can any auto encoder model like VAE, VQGAN, VQVAE. But for simplicity purpose lets use a VAE model as a latent embedder. Both of the encoder and decoder definitions are implemented in Pytorch and instantiated by calling fe.build
(which also associates the model with specific optimizers).
class Encoder(nn.Module):
def __init__(self, in_channels, emb_channels=4, hid_chs=(64, 128, 256, 512)):
super().__init__()
self.in_channels = in_channels
self.emb_channels = emb_channels
self.temb_ch = 0
self.num_resolution = len(hid_chs)
self.inc = nn.Conv2d(self.in_channels, hid_chs[0], kernel_size=3, stride=1, padding=1)
self.down = nn.ModuleList()
block_in = hid_chs[0]
for i in range(self.num_resolution):
down = nn.Module()
block_out = hid_chs[i]
down.block = nn.ModuleList(
[ResnetBlock(block_in, block_out, self.temb_ch), ResnetBlock(block_out, block_out, self.temb_ch)])
block_in = block_out
if i != self.num_resolution - 1:
down.downsample = Downsample(block_out)
self.down.append(down)
self.mid = nn.ModuleList([
ResnetBlock(hid_chs[-1], hid_chs[-1], self.temb_ch),
ResnetBlock(hid_chs[-1], hid_chs[-1], self.temb_ch),
])
# end
self.norm_out = get_normalize(hid_chs[-1])
self.conv_out = torch.nn.Conv2d(hid_chs[-1], self.emb_channels, kernel_size=3, stride=1, padding=1)
self.quant_conv = torch.nn.Conv2d(self.emb_channels, self.emb_channels, 1)
def forward(self, x):
temb = None
# downsampling
h = self.inc(x)
for i_level in range(self.num_resolution):
for i in range(2):
h = self.down[i_level].block[i](h, temb)
if i_level != self.num_resolution - 1:
h = self.down[i_level].downsample(h)
# middle
h = self.mid[0](h, temb)
h = self.mid[1](h, temb)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior.sample_loss()
class Decoder(nn.Module):
def __init__(self, emb_channels, out_channels=3, hid_chs=(64, 128, 256, 512)):
super().__init__()
self.post_quant_conv = torch.nn.Conv2d(emb_channels, emb_channels, 1)
self.emb_channels = emb_channels
self.out_channels = out_channels
self.temb_ch = 0
self.num_resolutions = len(hid_chs)
self.inc = nn.Conv2d(self.emb_channels, hid_chs[-1], kernel_size=3, stride=1, padding=1)
# middle
self.mid = nn.ModuleList(
[ResnetBlock(hid_chs[-1], hid_chs[-1], self.temb_ch), ResnetBlock(hid_chs[-1], hid_chs[-1], self.temb_ch)])
self.up = nn.ModuleList()
block_in = hid_chs[-1]
for i_level in reversed(range(self.num_resolutions)):
up = nn.Module()
block_out = hid_chs[i_level]
up.block = nn.ModuleList([
ResnetBlock(block_in, block_out, self.temb_ch),
ResnetBlock(block_out, block_out, self.temb_ch),
ResnetBlock(block_out, block_out, self.temb_ch)
])
block_in = block_out
if i_level != 0:
up.upsample = Upsample(block_in)
self.up.append(up)
# end
self.norm_out = get_normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in, self.out_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
temb = None
x = self.post_quant_conv(x)
# downsampling
h = self.inc(x)
# middle
h = self.mid[0](h, temb)
h = self.mid[1](h, temb)
for i_level in range(self.num_resolutions):
for i_block in range(3):
h = self.up[i_level].block[i_block](h, temb)
if i_level != self.num_resolutions - 1:
h = self.up[i_level].upsample(h)
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
encoder_model = fe.build(model_fn=lambda:Encoder(in_channels=3, emb_channels=2*emb_channels), optimizer_fn=lambda x: torch.optim.Adam(x, lr=1e-4), model_name="encoder")
decoder_model = fe.build(model_fn=lambda:Decoder(emb_channels=emb_channels, out_channels=3), optimizer_fn=lambda x: torch.optim.Adam(x, lr=1e-4), model_name="decoder")
Step 3 Define custom LPIPS loss and WeightedLoss Ops
¶
LPIPS loss¶
Recently, as introduced by the paper the features of the VGG network trained on ImageNet classification has been remarkably useful for image synthesis. LPIPS essentially computes the similarity between the activations of two image patches for some pre-defined network. This measure has been shown to match human perception well. A low LPIPS score means that image patches are perceptual similar. Lets define LPIPS loss which uses the LPIPS score.
class LPIPS_Loss(LossOp):
def __init__(self, inputs, outputs, mode=None):
super().__init__(inputs=inputs, outputs=outputs, mode=mode)
self.net = VGG16(requires_grad=False)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.chns = (64,128,256,512,512)
self.L = len(self.chns)
self.net.eval()
self.shift = torch.Tensor([-.030, -.088, -.188])[None,:, None, None]
self.scale = torch.Tensor([.458, .448, .450])[None,:, None, None]
self.net = self.net.to(device)
self.shift = self.shift.to(device)
self.scale = self.scale.to(device)
def normalize_tensor(self, in_feat, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
return in_feat/(norm_factor+eps)
def spatial_average(self, in_tens, keepdim=True):
return in_tens.mean([2,3], keepdim=keepdim)
def scaling_layer(self, in_tens):
return (in_tens - self.shift) / self.scale
def forward(self, data, state: Dict[str, Any]):
in0, in1 = data
in0_input, in1_input = self.scaling_layer(in0), self.scaling_layer(in1)
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
feats0, feats1, diffs = {}, {}, {}
for kk in range(self.L):
feats0[kk], feats1[kk] = self.normalize_tensor(outs0[kk]), self.normalize_tensor(outs1[kk])
diffs[kk] = (feats0[kk]-feats1[kk])**2
res = [self.spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
val = 0
for l in range(self.L):
val += res[l]
return reduce_mean(val)
class VGG16(torch.nn.Module):
def __init__(self, requires_grad=False):
super().__init__()
vgg_pretrained_features = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1).features
self.slice1 = torch.nn.Sequential()
self.slice2 = torch.nn.Sequential()
self.slice3 = torch.nn.Sequential()
self.slice4 = torch.nn.Sequential()
self.slice5 = torch.nn.Sequential()
self.N_slices = 5
for x in range(4):
self.slice1.add_module(str(x), vgg_pretrained_features[x])
for x in range(4, 9):
self.slice2.add_module(str(x), vgg_pretrained_features[x])
for x in range(9, 16):
self.slice3.add_module(str(x), vgg_pretrained_features[x])
for x in range(16, 23):
self.slice4.add_module(str(x), vgg_pretrained_features[x])
for x in range(23, 30):
self.slice5.add_module(str(x), vgg_pretrained_features[x])
if not requires_grad:
for param in self.parameters():
param.requires_grad = False
def forward(self, X):
h = self.slice1(X)
h_relu1_2 = h
h = self.slice2(h)
h_relu2_2 = h
h = self.slice3(h)
h_relu3_3 = h
h = self.slice4(h)
h_relu4_3 = h
h = self.slice5(h)
h_relu5_3 = h
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
return out
WeightedLoss¶
class WeightedLoss(TensorOp):
def __init__(self, inputs, outputs, p_loss_weight=1.0, l1_loss_weights=1, embedding_loss_weight=1e-6):
super().__init__(inputs=inputs, outputs=outputs)
self.embedding_loss_weight = embedding_loss_weight
self.p_loss_weight = p_loss_weight
self.l1_loss_weights = l1_loss_weights
def forward(self, data, state: Dict[str, Any]):
l1_loss, p_loss, emb_loss = data
loss = l1_loss * self.l1_loss_weights + p_loss * self.p_loss_weight + emb_loss * self.embedding_loss_weight
return loss
Step 4 Latent Embedder Network
definition¶
We are going to connect encoder, decoder and Ops
together into a Network
network = fe.Network(ops=[
ModelOp(model=encoder_model, inputs="image", outputs=['sample', 'emb_loss']),
LambdaOp(fn=lambda x: reduce_sum(x), inputs="emb_loss", outputs="emb_loss"),
ModelOp(model=decoder_model, inputs="sample", outputs='pred'),
L1_Loss(inputs=['pred', 'image'], outputs='l1_loss'),
LPIPS_Loss(inputs=['pred', 'image'], outputs='p_loss'),
WeightedLoss(inputs=['l1_loss', 'p_loss', 'emb_loss'], outputs='loss'),
UpdateOp(model=decoder_model, loss_name="loss"),
UpdateOp(model=encoder_model, loss_name="loss")
])
Step 5 - Training Latent Embedder¶
In this step, we define the Estimator
to compile the Network
and Pipeline
and indicate in traces
that we want to save the best models. We can then use estimator.fit()
to start the training process. It would take roughly 1 hrs on a A100 machine.
traces = [BestModelSaver(model=encoder_model, save_dir=model_dir, metric="p_loss"),
BestModelSaver(model=decoder_model, save_dir=model_dir, metric="p_loss")]
estimator = fe.Estimator(pipeline=pipeline,
network=network,
monitor_names=["l1_loss", "p_loss", "emb_loss"],
epochs=max(epochs//4, 1),
traces=traces,
train_steps_per_epoch=train_steps_per_epoch,
eval_steps_per_epoch=eval_steps_per_epoch,
log_steps=log_steps)
estimator.fit() # start the training process
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ FastEstimator-Start: step: 1; logging_interval: 200; num_device: 1; FastEstimator-Train: step: 1; emb_loss: 0.037383616; l1_loss: 0.47173545; loss: 8.903404; p_loss: 8.431669; FastEstimator-Train: step: 200; emb_loss: 4.435126; l1_loss: 0.085861966; loss: 3.4965987; p_loss: 3.4107323; steps/sec: 3.41; FastEstimator-Train: step: 400; emb_loss: 6.0855517; l1_loss: 0.052325994; loss: 2.6048863; p_loss: 2.5525541; steps/sec: 3.39; FastEstimator-Train: step: 600; emb_loss: 7.2795267; l1_loss: 0.04208569; loss: 2.170533; p_loss: 2.12844; steps/sec: 3.39; FastEstimator-Train: step: 800; emb_loss: 7.480191; l1_loss: 0.045939397; loss: 2.2675288; p_loss: 2.221582; steps/sec: 3.39; FastEstimator-Train: step: 1000; emb_loss: 7.895735; l1_loss: 0.039577734; loss: 1.9695344; p_loss: 1.9299488; steps/sec: 3.39; FastEstimator-Train: step: 1000; epoch: 1; epoch_time(sec): 299.89; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 12.27; Eval Progress: 66/100; steps/sec: 12.54; Eval Progress: 100/100; steps/sec: 12.59; FastEstimator-BestModelSaver: Saved model to /tmp/tmpvd1u5fp6/encoder_best_p_loss.pt FastEstimator-BestModelSaver: Saved model to /tmp/tmpvd1u5fp6/decoder_best_p_loss.pt FastEstimator-Eval: step: 1000; epoch: 1; emb_loss: 7.9306865; l1_loss: 0.037546095; loss: 2.0234585; min_p_loss: 1.9859043; p_loss: 1.9859043; since_best_p_loss: 0; FastEstimator-Train: step: 1200; emb_loss: 8.49646; l1_loss: 0.03104583; loss: 1.7176442; p_loss: 1.68659; steps/sec: 3.03; FastEstimator-Train: step: 1400; emb_loss: 8.600871; l1_loss: 0.042625308; loss: 1.7329235; p_loss: 1.6902896; steps/sec: 3.31; FastEstimator-Train: step: 1600; emb_loss: 7.861264; l1_loss: 0.032471918; loss: 1.69544; p_loss: 1.6629603; steps/sec: 3.31; FastEstimator-Train: step: 1800; emb_loss: 8.58733; l1_loss: 0.03612524; loss: 1.4503835; p_loss: 1.4142498; steps/sec: 3.31; FastEstimator-Train: step: 2000; emb_loss: 9.115441; l1_loss: 0.033992134; loss: 1.36872; p_loss: 1.3347188; steps/sec: 3.31; FastEstimator-Train: step: 2000; epoch: 2; epoch_time(sec): 307.65; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 12.49; Eval Progress: 66/100; steps/sec: 12.53; Eval Progress: 100/100; steps/sec: 12.58; FastEstimator-BestModelSaver: Saved model to /tmp/tmpvd1u5fp6/encoder_best_p_loss.pt FastEstimator-BestModelSaver: Saved model to /tmp/tmpvd1u5fp6/decoder_best_p_loss.pt FastEstimator-Eval: step: 2000; epoch: 2; emb_loss: 9.05134; l1_loss: 0.0322605; loss: 1.4986274; min_p_loss: 1.4663578; p_loss: 1.4663578; since_best_p_loss: 0; FastEstimator-Train: step: 2200; emb_loss: 9.251516; l1_loss: 0.031391717; loss: 1.3930725; p_loss: 1.3616714; steps/sec: 3.02; FastEstimator-Train: step: 2400; emb_loss: 9.097033; l1_loss: 0.039184596; loss: 1.2977109; p_loss: 1.2585173; steps/sec: 3.31; FastEstimator-Train: step: 2600; emb_loss: 8.766001; l1_loss: 0.039495192; loss: 1.3219857; p_loss: 1.2824817; steps/sec: 3.31; FastEstimator-Train: step: 2800; emb_loss: 9.32091; l1_loss: 0.028865907; loss: 1.1922385; p_loss: 1.1633632; steps/sec: 3.31; FastEstimator-Train: step: 3000; emb_loss: 9.520845; l1_loss: 0.025876746; loss: 1.0360948; p_loss: 1.0102085; steps/sec: 3.31; FastEstimator-Train: step: 3000; epoch: 3; epoch_time(sec): 308.04; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 12.49; Eval Progress: 66/100; steps/sec: 12.7; Eval Progress: 100/100; steps/sec: 12.56; FastEstimator-BestModelSaver: Saved model to /tmp/tmpvd1u5fp6/encoder_best_p_loss.pt FastEstimator-BestModelSaver: Saved model to /tmp/tmpvd1u5fp6/decoder_best_p_loss.pt FastEstimator-Eval: step: 3000; epoch: 3; emb_loss: 9.29928; l1_loss: 0.026937537; loss: 1.1334379; min_p_loss: 1.1064912; p_loss: 1.1064912; since_best_p_loss: 0; FastEstimator-Train: step: 3200; emb_loss: 9.658361; l1_loss: 0.029041544; loss: 1.0870456; p_loss: 1.0579944; steps/sec: 3.01; FastEstimator-Train: step: 3400; emb_loss: 9.563161; l1_loss: 0.02888388; loss: 1.0563712; p_loss: 1.0274777; steps/sec: 3.31; FastEstimator-Train: step: 3600; emb_loss: 9.187937; l1_loss: 0.039789017; loss: 1.1583413; p_loss: 1.1185431; steps/sec: 3.31; FastEstimator-Train: step: 3800; emb_loss: 9.724596; l1_loss: 0.027038708; loss: 1.0858465; p_loss: 1.0587981; steps/sec: 3.31; FastEstimator-Train: step: 4000; emb_loss: 9.426285; l1_loss: 0.02786692; loss: 1.0704956; p_loss: 1.0426192; steps/sec: 3.31; FastEstimator-Train: step: 4000; epoch: 4; epoch_time(sec): 308.21; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 12.52; Eval Progress: 66/100; steps/sec: 12.64; Eval Progress: 100/100; steps/sec: 12.62; FastEstimator-BestModelSaver: Saved model to /tmp/tmpvd1u5fp6/encoder_best_p_loss.pt FastEstimator-BestModelSaver: Saved model to /tmp/tmpvd1u5fp6/decoder_best_p_loss.pt FastEstimator-Eval: step: 4000; epoch: 4; emb_loss: 9.803524; l1_loss: 0.025227038; loss: 1.0289197; min_p_loss: 1.0036829; p_loss: 1.0036829; since_best_p_loss: 0; FastEstimator-Train: step: 4200; emb_loss: 9.862602; l1_loss: 0.029327165; loss: 1.0476326; p_loss: 1.0182955; steps/sec: 3.0; FastEstimator-Train: step: 4400; emb_loss: 10.279909; l1_loss: 0.027822902; loss: 0.9707576; p_loss: 0.94292444; steps/sec: 3.31; FastEstimator-Train: step: 4600; emb_loss: 10.312988; l1_loss: 0.024221867; loss: 0.91694087; p_loss: 0.89270866; steps/sec: 3.31; FastEstimator-Train: step: 4800; emb_loss: 9.688666; l1_loss: 0.033004433; loss: 0.9814685; p_loss: 0.9484544; steps/sec: 3.3; FastEstimator-Train: step: 5000; emb_loss: 9.982787; l1_loss: 0.022129571; loss: 0.8866111; p_loss: 0.86447155; steps/sec: 3.31; FastEstimator-Train: step: 5000; epoch: 5; epoch_time(sec): 308.54; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 12.52; Eval Progress: 66/100; steps/sec: 12.69; Eval Progress: 100/100; steps/sec: 12.61; FastEstimator-BestModelSaver: Saved model to /tmp/tmpvd1u5fp6/encoder_best_p_loss.pt FastEstimator-BestModelSaver: Saved model to /tmp/tmpvd1u5fp6/decoder_best_p_loss.pt FastEstimator-Eval: step: 5000; epoch: 5; emb_loss: 10.134614; l1_loss: 0.02360063; loss: 0.930335; min_p_loss: 0.90672415; p_loss: 0.90672415; since_best_p_loss: 0; FastEstimator-Train: step: 5200; emb_loss: 10.170776; l1_loss: 0.02262507; loss: 0.9916533; p_loss: 0.96901804; steps/sec: 3.01; FastEstimator-Train: step: 5400; emb_loss: 10.602549; l1_loss: 0.028308954; loss: 0.9122198; p_loss: 0.8839003; steps/sec: 3.31; FastEstimator-Train: step: 5600; emb_loss: 9.950895; l1_loss: 0.02818404; loss: 0.88783056; p_loss: 0.85963655; steps/sec: 3.31; FastEstimator-Train: step: 5800; emb_loss: 10.108469; l1_loss: 0.024020797; loss: 0.8527174; p_loss: 0.8286865; steps/sec: 3.31; FastEstimator-Train: step: 6000; emb_loss: 10.923492; l1_loss: 0.020157157; loss: 0.7940448; p_loss: 0.7738767; steps/sec: 3.31; FastEstimator-Train: step: 6000; epoch: 6; epoch_time(sec): 308.18; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 12.61; Eval Progress: 66/100; steps/sec: 12.7; Eval Progress: 100/100; steps/sec: 12.62; FastEstimator-BestModelSaver: Saved model to /tmp/tmpvd1u5fp6/encoder_best_p_loss.pt FastEstimator-BestModelSaver: Saved model to /tmp/tmpvd1u5fp6/decoder_best_p_loss.pt FastEstimator-Eval: step: 6000; epoch: 6; emb_loss: 10.46007; l1_loss: 0.023044894; loss: 0.8485108; min_p_loss: 0.8254554; p_loss: 0.8254554; since_best_p_loss: 0; FastEstimator-Train: step: 6200; emb_loss: 10.328095; l1_loss: 0.02312286; loss: 0.7586726; p_loss: 0.73553944; steps/sec: 3.02; FastEstimator-Train: step: 6400; emb_loss: 10.669849; l1_loss: 0.026237838; loss: 0.789012; p_loss: 0.7627635; steps/sec: 3.31; FastEstimator-Train: step: 6600; emb_loss: 10.574244; l1_loss: 0.021361798; loss: 0.76153404; p_loss: 0.74016166; steps/sec: 3.31; FastEstimator-Train: step: 6800; emb_loss: 10.43454; l1_loss: 0.02774733; loss: 0.7345116; p_loss: 0.70675385; steps/sec: 3.31; FastEstimator-Train: step: 7000; emb_loss: 9.856495; l1_loss: 0.024537547; loss: 0.8580492; p_loss: 0.8335018; steps/sec: 3.31; FastEstimator-Train: step: 7000; epoch: 7; epoch_time(sec): 308.24; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 12.45; Eval Progress: 66/100; steps/sec: 12.61; Eval Progress: 100/100; steps/sec: 12.62; FastEstimator-BestModelSaver: Saved model to /tmp/tmpvd1u5fp6/encoder_best_p_loss.pt FastEstimator-BestModelSaver: Saved model to /tmp/tmpvd1u5fp6/decoder_best_p_loss.pt FastEstimator-Eval: step: 7000; epoch: 7; emb_loss: 10.30711; l1_loss: 0.022338117; loss: 0.79276145; min_p_loss: 0.77041316; p_loss: 0.77041316; since_best_p_loss: 0; FastEstimator-Train: step: 7200; emb_loss: 10.826598; l1_loss: 0.022604145; loss: 0.7352582; p_loss: 0.7126432; steps/sec: 2.99; FastEstimator-Train: step: 7400; emb_loss: 10.573719; l1_loss: 0.029680258; loss: 0.7455427; p_loss: 0.7158519; steps/sec: 3.3; FastEstimator-Train: step: 7600; emb_loss: 10.697506; l1_loss: 0.021316916; loss: 0.7430857; p_loss: 0.7217581; steps/sec: 3.3; FastEstimator-Train: step: 7800; emb_loss: 10.563108; l1_loss: 0.040196143; loss: 0.8338117; p_loss: 0.793605; steps/sec: 3.3; FastEstimator-Train: step: 8000; emb_loss: 10.853371; l1_loss: 0.021914681; loss: 0.72414404; p_loss: 0.70221853; steps/sec: 3.3; FastEstimator-Train: step: 8000; epoch: 8; epoch_time(sec): 309.04; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 12.52; Eval Progress: 66/100; steps/sec: 12.55; Eval Progress: 100/100; steps/sec: 12.53; FastEstimator-BestModelSaver: Saved model to /tmp/tmpvd1u5fp6/encoder_best_p_loss.pt FastEstimator-BestModelSaver: Saved model to /tmp/tmpvd1u5fp6/decoder_best_p_loss.pt FastEstimator-Eval: step: 8000; epoch: 8; emb_loss: 10.474078; l1_loss: 0.021606512; loss: 0.7366684; min_p_loss: 0.7150514; p_loss: 0.7150514; since_best_p_loss: 0; FastEstimator-Train: step: 8200; emb_loss: 10.831636; l1_loss: 0.024379624; loss: 0.7284498; p_loss: 0.70405936; steps/sec: 3.01; FastEstimator-Train: step: 8400; emb_loss: 10.666431; l1_loss: 0.022530034; loss: 0.7008182; p_loss: 0.6782775; steps/sec: 3.31; FastEstimator-Train: step: 8600; emb_loss: 11.519823; l1_loss: 0.018397681; loss: 0.6034164; p_loss: 0.5850072; steps/sec: 3.31; FastEstimator-Train: step: 8800; emb_loss: 10.949921; l1_loss: 0.018667629; loss: 0.6600202; p_loss: 0.64134157; steps/sec: 3.31; FastEstimator-Train: step: 9000; emb_loss: 10.836271; l1_loss: 0.019294726; loss: 0.65506434; p_loss: 0.63575876; steps/sec: 3.31; FastEstimator-Train: step: 9000; epoch: 9; epoch_time(sec): 308.26; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 12.52; Eval Progress: 66/100; steps/sec: 12.65; Eval Progress: 100/100; steps/sec: 12.61; FastEstimator-BestModelSaver: Saved model to /tmp/tmpvd1u5fp6/encoder_best_p_loss.pt FastEstimator-BestModelSaver: Saved model to /tmp/tmpvd1u5fp6/decoder_best_p_loss.pt FastEstimator-Eval: step: 9000; epoch: 9; emb_loss: 10.5072975; l1_loss: 0.022909844; loss: 0.72326744; min_p_loss: 0.7003471; p_loss: 0.7003471; since_best_p_loss: 0; FastEstimator-Train: step: 9200; emb_loss: 10.656288; l1_loss: 0.018141937; loss: 0.7049835; p_loss: 0.6868309; steps/sec: 3.01; FastEstimator-Train: step: 9400; emb_loss: 10.148081; l1_loss: 0.02395904; loss: 0.7796004; p_loss: 0.7556312; steps/sec: 3.31; FastEstimator-Train: step: 9600; emb_loss: 11.042529; l1_loss: 0.023092013; loss: 0.6431704; p_loss: 0.62006736; steps/sec: 3.31; FastEstimator-Train: step: 9800; emb_loss: 10.607915; l1_loss: 0.020616498; loss: 0.6487969; p_loss: 0.62816983; steps/sec: 3.3; FastEstimator-Train: step: 10000; emb_loss: 10.100873; l1_loss: 0.02332376; loss: 0.6772168; p_loss: 0.653883; steps/sec: 3.31; FastEstimator-Train: step: 10000; epoch: 10; epoch_time(sec): 308.52; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 12.47; Eval Progress: 66/100; steps/sec: 12.64; Eval Progress: 100/100; steps/sec: 12.52; FastEstimator-BestModelSaver: Saved model to /tmp/tmpvd1u5fp6/encoder_best_p_loss.pt FastEstimator-BestModelSaver: Saved model to /tmp/tmpvd1u5fp6/decoder_best_p_loss.pt FastEstimator-Eval: step: 10000; epoch: 10; emb_loss: 10.388235; l1_loss: 0.021197056; loss: 0.68105394; min_p_loss: 0.6598464; p_loss: 0.6598464; since_best_p_loss: 0; FastEstimator-Finish: step: 10000; decoder_lr: 0.0001; encoder_lr: 0.0001; total_time(sec): 3216.53;
Step 6 Running inference on Latent Embedder¶
Great we are half way through. Since the latent embedder is trained, lets try to run encoder and decoder to view some of the generated images:
network = fe.Network(ops=[ModelOp(model=encoder_model, inputs="image", outputs=['sample', 'loss'], mode='infer', trainable=False, gradients=False),
ModelOp(model=decoder_model, inputs="sample", outputs='pred', mode='infer', trainable=False, gradients=False)])
infer_data = network.transform(data, mode="infer")
# lets denormalize the input images and prediction output
pred_data = (infer_data['pred'] * 0.5 + 0.5).clip(0, 1)
input_data = (infer_data['image'] * 0.5 + 0.5).clip(0, 1)
GridDisplay(
[BatchDisplay(image=input_data[:1], title="VAE Input"), BatchDisplay(image=pred_data[:1],
title="VAE output")]).show()
Train Noise Network ¶
Step 1 Define Data Pipeline¶
Lets load data pipeline for training our noise network.
pipeline = get_pipeline(image_size=image_size, mean=0.5, std=0.5, split_ratio=0.05)
Step 2 Noise Network Model Construction¶
Let's build noise network which is similar to Unet with additional time embeddings model to estimate the noise in the latent space at any given time step.
class SequentialEmb(nn.Sequential):
def forward(self, input_data, emb):
for layer in self:
if isinstance(layer, ResnetBlock):
input_data = layer(input_data, emb)
else:
input_data = layer(input_data)
return input_data
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
class NoiseNetwork(nn.Module):
def __init__(self, in_ch=1, out_ch=1, hid_chs=(256, 256, 512, 1024), emb_dim=4):
super().__init__()
self.depth = len(hid_chs)
self.in_blocks = nn.ModuleList()
# ----------- In-Convolution ------------
self.in_blocks.append(nn.Conv2d(in_ch, hid_chs[0], kernel_size=3, stride=1, padding=1))
# -------------- Encoder ----------------
for i in range(1, self.depth):
self.in_blocks.append(SequentialEmb(ResnetBlock(hid_chs[i-1], hid_chs[i], emb_dim)))
self.in_blocks.append(SequentialEmb(ResnetBlock(hid_chs[i], hid_chs[i], emb_dim)))
if i < self.depth - 1:
self.in_blocks.append(SequentialEmb(Downsample(hid_chs[i])))
# ----------- Middle ------------
self.middle_block = SequentialEmb(
ResnetBlock(in_channels=hid_chs[-1], out_channels=hid_chs[-1], temb_channels=emb_dim),
ResnetBlock(in_channels=hid_chs[-1], out_channels=hid_chs[-1], temb_channels=emb_dim))
# ------------ Decoder ----------
out_blocks = []
for i in range(1, self.depth):
for k in range(3):
seq_list = []
hidd_channels = hid_chs[i-1 if k==0 else i]
seq_list.append(ResnetBlock(hid_chs[i] + hidd_channels, out_channels=hidd_channels, temb_channels=emb_dim))
if i > 1 and k == 0:
seq_list.append(Upsample(hidd_channels))
out_blocks.append(SequentialEmb(*seq_list))
self.out_blocks = nn.ModuleList(out_blocks)
# --------------- Out-Convolution ----------------
self.outc = zero_module(nn.Sequential(nn.Conv2d(hid_chs[0], out_ch, 3, padding=1)))
self.emb_dim = emb_dim
self.pos_emb_dim = emb_dim // 4
self.time_emb = nn.Sequential(
nn.Linear(self.pos_emb_dim, self.emb_dim),
nn.SiLU(),
nn.Linear(self.emb_dim, self.emb_dim))
def get_sinusoidal_pos_emb(self, time, max_period=10000, downscale_freq_shift=1):
device = time.device
half_dim = self.pos_emb_dim // 2
emb = math.log(max_period) / (half_dim - downscale_freq_shift)
emb = torch.exp(-emb * torch.arange(half_dim, device=device))
emb = time[:, None] * emb[None, :]
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
def forward(self, x, t):
emb = None
if t is not None:
sinusoidal_emd = self.get_sinusoidal_pos_emb(t)
emb = self.time_emb(sinusoidal_emd)
x_input_block = []
# --------- Encoder --------------
for i, module in enumerate(self.in_blocks):
if i == 0:
x = module(x)
else:
x = module(x, emb)
x_input_block.append(x)
# ---------- Middle --------------
x = self.middle_block(x, emb)
# -------- Decoder -----------
for i in range(len(self.out_blocks), 0, -1):
if isinstance(self.out_blocks[i-1], SequentialEmb):
x = torch.cat([x, x_input_block.pop()], dim=1)
x = self.out_blocks[i-1](x, emb)
else:
x = self.out_blocks[i-1](x)
return self.outc(x)
Noise Model¶
Lets build a noise network model to estimate noise at a given point of time
noise_model = fe.build(
model_fn=lambda: NoiseNetwork(in_ch=emb_channels, out_ch=emb_channels, hid_chs=[256, 256, 512, 1024], emb_dim=1024),
optimizer_fn=lambda x: torch.optim.Adam(x, lr=1e-5),
model_name="noise_model")
Step 3 Define Noise Sampler Op¶
Instead of iteratively adding the noise to the given latent embedding based on the time step, noise scheduler adds gaussian noise to the input image at any given time step in a single step. For more information about the math behind the how this is applied, please refer to this paper.
class GaussianNoiseScheduler():
def __init__(
self,
timesteps=1000,
T=None,
beta_start=0.002, # default 1e-4, stable-diffusion ~ 1e-3
beta_end=0.02,
betas=None,
device='cuda' if torch.cuda.is_available() else 'cpu'):
super().__init__()
self.device = device
self.timesteps = timesteps
self.T = timesteps if T is None else T
self.timesteps_array = torch.linspace(0, self.T - 1, self.timesteps, dtype=torch.long, device=self.device)
# NOTE: End is inclusive therefore use -1 to get [0, T-1]
# scaled_linear # proposed as "quadratic" in https://arxiv.org/abs/2006.11239, used in stable-diffusion
betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float64, device=self.device)**2
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.)
self.betas = betas.to(torch.float32) # (0 , 1)
self.alphas = alphas.to(torch.float32)
self.alphas_cumprod = alphas_cumprod.to(torch.float32)
self.alphas_cumprod_prev = alphas_cumprod_prev
self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)
self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod)
self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod - 1)
self.posterior_mean_coef1 = betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)
self.posterior_mean_coef2 = (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)
self.posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
def estimate_x_t(self, x_0, t, x_T=None):
# NOTE: t == 0 means diffused for 1 step (https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils.py#L108)
x_T = self.x_final(x_0) if x_T is None else x_T
def clipper(b):
tb = t[b]
if tb<0:
return x_0[b]
elif tb>=self.T:
return x_T[b]
else:
return self.sqrt_alphas_cumprod[tb]*x_0[b]+self.sqrt_one_minus_alphas_cumprod[tb]*x_T[b]
x_t = torch.stack([clipper(b) for b in range(t.shape[0])])
return x_t
def estimate_x_t_prior_from_x_T(self, x_t, t, x_T, use_log=True, var_scale=0):
x_0 = self.estimate_x_0(x_t, x_T, t)
return self.estimate_x_t_prior_from_x_0(x_t, t, x_0, use_log, var_scale)
def estimate_x_t_prior_from_x_0(self, x_t, t, x_0, use_log=True, var_scale=0):
mean = self.estimate_mean_t(x_t, x_0, t)
variance = self.estimate_variance_t(t, x_t.ndim, use_log, var_scale)
std = torch.exp(0.5*variance) if use_log else torch.sqrt(variance)
std[t==0] = 0.0
x_T = self.x_final(x_t)
x_t_prior = mean+std*x_T
return x_t_prior, x_0
def estimate_mean_t(self, x_t, x_0, t):
ndim = x_t.ndim
return (self.extract(self.posterior_mean_coef1, t, ndim)*x_0 + self.extract(self.posterior_mean_coef2, t, ndim)*x_t)
def sample(self, x_0):
"""Randomly sample t from [0,T] and return x_t and x_T based on x_0"""
t = torch.randint(0, self.T, (x_0.shape[0],), dtype=torch.long, device=x_0.device) # NOTE: High is exclusive, therefore [0, T-1]
x_T = self.x_final(x_0)
return self.estimate_x_t(x_0, t, x_T), x_T, t
def estimate_variance_t(self, t, ndim, log=True, var_scale=0, eps=1e-20):
min_variance = self.extract(self.posterior_variance, t, ndim)
max_variance = self.extract(self.betas, t, ndim)
if log:
min_variance = torch.log(min_variance.clamp(min=eps))
max_variance = torch.log(max_variance.clamp(min=eps))
return var_scale * max_variance + (1 - var_scale) * min_variance
def estimate_x_0(self, x_t, x_T, t):
ndim = x_t.ndim
x_0 = (self.extract(self.sqrt_recip_alphas_cumprod, t, ndim)*x_t -
self.extract(self.sqrt_recipm1_alphas_cumprod, t, ndim)*x_T)
return x_0
def estimate_x_T(self, x_t, x_0, t):
ndim = x_t.ndim
return ((self.extract(self.sqrt_recip_alphas_cumprod, t, ndim)*x_t - x_0)/
self.extract(self.sqrt_recipm1_alphas_cumprod, t, ndim))
@classmethod
def x_final(cls, x):
return torch.randn_like(x)
@classmethod
def _clip_x_0(cls, x_0):
# See "static/dynamic thresholding" in Imagen https://arxiv.org/abs/2205.11487
m = 1 # Set this to about 4*sigma = 4 if latent diffusion is used
x_0 = x_0.clamp(-m, m)
return x_0
@staticmethod
def extract(x, t, ndim):
"""Extract values from x at t and reshape them to n-dim tensor"""
x = x.to(t.device)
return x.gather(0, t).reshape(-1, *((1,)*(ndim-1)))
Noise Sampler¶
We need to train a model which can predict noise any a given point of time in the given time steps range. Lets define a noise sampler which will generate random time stamps with the provided time steps and generate encoded image along with image at previous step.
class NoiseSampler(TensorOp):
def __init__(self, inputs, outputs, timesteps=1000):
super().__init__(inputs=inputs, outputs=outputs)
self.gausian_noise_scheduler = GaussianNoiseScheduler(timesteps=timesteps)
def forward(self, data, state: Dict[str, Any]):
encoded_image = data
with torch.no_grad():
encoded_image_t, encoded_image_T, t = self.gausian_noise_scheduler.sample(encoded_image)
return encoded_image_t, encoded_image_T, t
Step 4 Noise Network
definition¶
We are going to connect encoder with noise model and Ops
together into a Network
network = fe.Network(ops=[
ModelOp(model=encoder_model, inputs="image", outputs=['sample', 'embedded_loss'], trainable=False, gradients=False),
NoiseSampler(inputs='sample', outputs=['encoded_image_t', 'encoded_image_T', 't'], timesteps=timesteps),
ModelOp(model=noise_model, inputs=["encoded_image_t", 't'], outputs='pred'),
L1_Loss(inputs=['pred', 'encoded_image_T'], outputs='l1_loss'),
UpdateOp(model=noise_model, loss_name="l1_loss")
])
Step 5 Training noise network¶
In this step, we define the noise Estimator to compile the Network and Pipeline and indicate in traces that we want to save model at regular frequency. We can then use estimator.fit() to start the training process. It would take roughly 1 hrs on a A100 machine.
traces = [ModelSaver(model=noise_model, save_dir=model_dir, frequency=4)]
noise_estimator = fe.Estimator(pipeline=pipeline,
network=network,
epochs=epochs,
traces=traces,
train_steps_per_epoch=train_steps_per_epoch,
eval_steps_per_epoch=eval_steps_per_epoch,
log_steps=log_steps)
noise_estimator.fit()
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ FastEstimator-Start: step: 1; logging_interval: 200; num_device: 1; FastEstimator-Train: step: 1; l1_loss: 0.7976331; FastEstimator-Train: step: 200; l1_loss: 0.4933886; steps/sec: 12.31; FastEstimator-Train: step: 400; l1_loss: 0.30204284; steps/sec: 12.21; FastEstimator-Train: step: 600; l1_loss: 0.43382323; steps/sec: 12.26; FastEstimator-Train: step: 800; l1_loss: 0.2838242; steps/sec: 12.22; FastEstimator-Train: step: 1000; l1_loss: 0.38111776; steps/sec: 12.31; FastEstimator-Train: step: 1000; epoch: 1; epoch_time(sec): 88.67; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 25.55; Eval Progress: 66/100; steps/sec: 26.68; Eval Progress: 100/100; steps/sec: 26.04; FastEstimator-Eval: step: 1000; epoch: 1; l1_loss: 0.32913; FastEstimator-Train: step: 1200; l1_loss: 0.34653032; steps/sec: 7.86; FastEstimator-Train: step: 1400; l1_loss: 0.30205977; steps/sec: 11.37; FastEstimator-Train: step: 1600; l1_loss: 0.18366942; steps/sec: 11.38; FastEstimator-Train: step: 1800; l1_loss: 0.30938062; steps/sec: 11.37; FastEstimator-Train: step: 2000; l1_loss: 0.22511674; steps/sec: 11.36; FastEstimator-Train: step: 2000; epoch: 2; epoch_time(sec): 95.74; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 24.04; Eval Progress: 66/100; steps/sec: 26.66; Eval Progress: 100/100; steps/sec: 26.29; FastEstimator-Eval: step: 2000; epoch: 2; l1_loss: 0.31088215; FastEstimator-Train: step: 2200; l1_loss: 0.2525025; steps/sec: 7.84; FastEstimator-Train: step: 2400; l1_loss: 0.22132012; steps/sec: 11.46; FastEstimator-Train: step: 2600; l1_loss: 0.34543973; steps/sec: 11.37; FastEstimator-Train: step: 2800; l1_loss: 0.21044382; steps/sec: 11.44; FastEstimator-Train: step: 3000; l1_loss: 0.29380128; steps/sec: 11.37; FastEstimator-Train: step: 3000; epoch: 3; epoch_time(sec): 95.91; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 25.89; Eval Progress: 66/100; steps/sec: 26.88; Eval Progress: 100/100; steps/sec: 26.37; FastEstimator-Eval: step: 3000; epoch: 3; l1_loss: 0.30542818; FastEstimator-Train: step: 3200; l1_loss: 0.2643189; steps/sec: 7.72; FastEstimator-Train: step: 3400; l1_loss: 0.28451386; steps/sec: 11.47; FastEstimator-Train: step: 3600; l1_loss: 0.20400438; steps/sec: 11.43; FastEstimator-Train: step: 3800; l1_loss: 0.3432409; steps/sec: 11.41; FastEstimator-Train: step: 4000; l1_loss: 0.2634462; steps/sec: 11.46; FastEstimator-ModelSaver: Saved model to /tmp/tmpvd1u5fp6/noise_model_epoch_4.pt FastEstimator-Train: step: 4000; epoch: 4; epoch_time(sec): 95.55; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 24.99; Eval Progress: 66/100; steps/sec: 26.35; Eval Progress: 100/100; steps/sec: 26.36; FastEstimator-Eval: step: 4000; epoch: 4; l1_loss: 0.28709236; FastEstimator-Train: step: 4200; l1_loss: 0.24100962; steps/sec: 7.73; FastEstimator-Train: step: 4400; l1_loss: 0.36472592; steps/sec: 11.41; FastEstimator-Train: step: 4600; l1_loss: 0.4051692; steps/sec: 11.47; FastEstimator-Train: step: 4800; l1_loss: 0.24432875; steps/sec: 11.41; FastEstimator-Train: step: 5000; l1_loss: 0.35177416; steps/sec: 11.44; FastEstimator-Train: step: 5000; epoch: 5; epoch_time(sec): 95.79; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 22.0; Eval Progress: 66/100; steps/sec: 26.01; Eval Progress: 100/100; steps/sec: 26.06; FastEstimator-Eval: step: 5000; epoch: 5; l1_loss: 0.2833061; FastEstimator-Train: step: 5200; l1_loss: 0.26877302; steps/sec: 7.83; FastEstimator-Train: step: 5400; l1_loss: 0.3780157; steps/sec: 11.37; FastEstimator-Train: step: 5600; l1_loss: 0.29162753; steps/sec: 11.44; FastEstimator-Train: step: 5800; l1_loss: 0.21758595; steps/sec: 11.36; FastEstimator-Train: step: 6000; l1_loss: 0.21974173; steps/sec: 11.41; FastEstimator-Train: step: 6000; epoch: 6; epoch_time(sec): 95.9; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 24.95; Eval Progress: 66/100; steps/sec: 26.3; Eval Progress: 100/100; steps/sec: 26.3; FastEstimator-Eval: step: 6000; epoch: 6; l1_loss: 0.2748123; FastEstimator-Train: step: 6200; l1_loss: 0.22006854; steps/sec: 7.76; FastEstimator-Train: step: 6400; l1_loss: 0.2977705; steps/sec: 11.45; FastEstimator-Train: step: 6600; l1_loss: 0.2318846; steps/sec: 11.52; FastEstimator-Train: step: 6800; l1_loss: 0.32937667; steps/sec: 11.42; FastEstimator-Train: step: 7000; l1_loss: 0.26467934; steps/sec: 11.4; FastEstimator-Train: step: 7000; epoch: 7; epoch_time(sec): 95.47; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 25.23; Eval Progress: 66/100; steps/sec: 26.35; Eval Progress: 100/100; steps/sec: 26.48; FastEstimator-Eval: step: 7000; epoch: 7; l1_loss: 0.27676448; FastEstimator-Train: step: 7200; l1_loss: 0.31056893; steps/sec: 7.78; FastEstimator-Train: step: 7400; l1_loss: 0.2424231; steps/sec: 11.5; FastEstimator-Train: step: 7600; l1_loss: 0.24082819; steps/sec: 11.37; FastEstimator-Train: step: 7800; l1_loss: 0.12187676; steps/sec: 11.47; FastEstimator-Train: step: 8000; l1_loss: 0.2669567; steps/sec: 11.4; FastEstimator-ModelSaver: Saved model to /tmp/tmpvd1u5fp6/noise_model_epoch_8.pt FastEstimator-Train: step: 8000; epoch: 8; epoch_time(sec): 95.68; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 25.01; Eval Progress: 66/100; steps/sec: 26.31; Eval Progress: 100/100; steps/sec: 26.46; FastEstimator-Eval: step: 8000; epoch: 8; l1_loss: 0.28559822; FastEstimator-Train: step: 8200; l1_loss: 0.29148996; steps/sec: 7.78; FastEstimator-Train: step: 8400; l1_loss: 0.33090982; steps/sec: 11.46; FastEstimator-Train: step: 8600; l1_loss: 0.15168832; steps/sec: 11.45; FastEstimator-Train: step: 8800; l1_loss: 0.30693084; steps/sec: 11.39; FastEstimator-Train: step: 9000; l1_loss: 0.26546237; steps/sec: 11.45; FastEstimator-Train: step: 9000; epoch: 9; epoch_time(sec): 95.68; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 25.57; Eval Progress: 66/100; steps/sec: 26.74; Eval Progress: 100/100; steps/sec: 26.44; FastEstimator-Eval: step: 9000; epoch: 9; l1_loss: 0.2673853; FastEstimator-Train: step: 9200; l1_loss: 0.26033974; steps/sec: 7.7; FastEstimator-Train: step: 9400; l1_loss: 0.33055288; steps/sec: 11.35; FastEstimator-Train: step: 9600; l1_loss: 0.40346888; steps/sec: 11.42; FastEstimator-Train: step: 9800; l1_loss: 0.19011918; steps/sec: 11.36; FastEstimator-Train: step: 10000; l1_loss: 0.27358967; steps/sec: 11.39; FastEstimator-Train: step: 10000; epoch: 10; epoch_time(sec): 96.4; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 24.04; Eval Progress: 66/100; steps/sec: 26.48; Eval Progress: 100/100; steps/sec: 26.55; FastEstimator-Eval: step: 10000; epoch: 10; l1_loss: 0.26088646; FastEstimator-Train: step: 10200; l1_loss: 0.27276704; steps/sec: 7.73; FastEstimator-Train: step: 10400; l1_loss: 0.25251192; steps/sec: 11.36; FastEstimator-Train: step: 10600; l1_loss: 0.28084236; steps/sec: 11.43; FastEstimator-Train: step: 10800; l1_loss: 0.3099733; steps/sec: 11.39; FastEstimator-Train: step: 11000; l1_loss: 0.30634335; steps/sec: 11.46; FastEstimator-Train: step: 11000; epoch: 11; epoch_time(sec): 96.02; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 25.61; Eval Progress: 66/100; steps/sec: 26.46; Eval Progress: 100/100; steps/sec: 26.12; FastEstimator-Eval: step: 11000; epoch: 11; l1_loss: 0.2759275; FastEstimator-Train: step: 11200; l1_loss: 0.32498676; steps/sec: 7.41; FastEstimator-Train: step: 11400; l1_loss: 0.32771417; steps/sec: 11.41; FastEstimator-Train: step: 11600; l1_loss: 0.26384234; steps/sec: 11.31; FastEstimator-Train: step: 11800; l1_loss: 0.25880238; steps/sec: 11.38; FastEstimator-Train: step: 12000; l1_loss: 0.2735661; steps/sec: 11.37; FastEstimator-ModelSaver: Saved model to /tmp/tmpvd1u5fp6/noise_model_epoch_12.pt FastEstimator-Train: step: 12000; epoch: 12; epoch_time(sec): 97.17; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 22.75; Eval Progress: 66/100; steps/sec: 26.31; Eval Progress: 100/100; steps/sec: 26.27; FastEstimator-Eval: step: 12000; epoch: 12; l1_loss: 0.26069826; FastEstimator-Train: step: 12200; l1_loss: 0.17161965; steps/sec: 7.59; FastEstimator-Train: step: 12400; l1_loss: 0.23038688; steps/sec: 11.41; FastEstimator-Train: step: 12600; l1_loss: 0.3070624; steps/sec: 11.36; FastEstimator-Train: step: 12800; l1_loss: 0.24158505; steps/sec: 11.35; FastEstimator-Train: step: 13000; l1_loss: 0.2275491; steps/sec: 11.44; FastEstimator-Train: step: 13000; epoch: 13; epoch_time(sec): 96.64; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 25.02; Eval Progress: 66/100; steps/sec: 26.56; Eval Progress: 100/100; steps/sec: 26.28; FastEstimator-Eval: step: 13000; epoch: 13; l1_loss: 0.2601124; FastEstimator-Train: step: 13200; l1_loss: 0.18746276; steps/sec: 7.57; FastEstimator-Train: step: 13400; l1_loss: 0.1519129; steps/sec: 11.36; FastEstimator-Train: step: 13600; l1_loss: 0.37590352; steps/sec: 11.44; FastEstimator-Train: step: 13800; l1_loss: 0.34013107; steps/sec: 11.33; FastEstimator-Train: step: 14000; l1_loss: 0.29082727; steps/sec: 11.38; FastEstimator-Train: step: 14000; epoch: 14; epoch_time(sec): 96.73; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 24.71; Eval Progress: 66/100; steps/sec: 26.34; Eval Progress: 100/100; steps/sec: 26.42; FastEstimator-Eval: step: 14000; epoch: 14; l1_loss: 0.2643661; FastEstimator-Train: step: 14200; l1_loss: 0.2927752; steps/sec: 7.6; FastEstimator-Train: step: 14400; l1_loss: 0.28371787; steps/sec: 11.35; FastEstimator-Train: step: 14600; l1_loss: 0.24411258; steps/sec: 11.4; FastEstimator-Train: step: 14800; l1_loss: 0.22488245; steps/sec: 11.32; FastEstimator-Train: step: 15000; l1_loss: 0.15199602; steps/sec: 11.42; FastEstimator-Train: step: 15000; epoch: 15; epoch_time(sec): 96.65; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 24.83; Eval Progress: 66/100; steps/sec: 26.56; Eval Progress: 100/100; steps/sec: 26.49; FastEstimator-Eval: step: 15000; epoch: 15; l1_loss: 0.266911; FastEstimator-Train: step: 15200; l1_loss: 0.1783285; steps/sec: 7.61; FastEstimator-Train: step: 15400; l1_loss: 0.15478465; steps/sec: 11.48; FastEstimator-Train: step: 15600; l1_loss: 0.36113173; steps/sec: 11.47; FastEstimator-Train: step: 15800; l1_loss: 0.22298545; steps/sec: 11.39; FastEstimator-Train: step: 16000; l1_loss: 0.3179836; steps/sec: 11.48; FastEstimator-ModelSaver: Saved model to /tmp/tmpvd1u5fp6/noise_model_epoch_16.pt FastEstimator-Train: step: 16000; epoch: 16; epoch_time(sec): 96.36; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 24.72; Eval Progress: 66/100; steps/sec: 26.42; Eval Progress: 100/100; steps/sec: 26.22; FastEstimator-Eval: step: 16000; epoch: 16; l1_loss: 0.26159385; FastEstimator-Train: step: 16200; l1_loss: 0.17443973; steps/sec: 7.5; FastEstimator-Train: step: 16400; l1_loss: 0.2684257; steps/sec: 11.39; FastEstimator-Train: step: 16600; l1_loss: 0.19085012; steps/sec: 11.31; FastEstimator-Train: step: 16800; l1_loss: 0.33221033; steps/sec: 11.38; FastEstimator-Train: step: 17000; l1_loss: 0.16028506; steps/sec: 11.35; FastEstimator-Train: step: 17000; epoch: 17; epoch_time(sec): 97.02; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 25.1; Eval Progress: 66/100; steps/sec: 26.41; Eval Progress: 100/100; steps/sec: 26.07; FastEstimator-Eval: step: 17000; epoch: 17; l1_loss: 0.2573425; FastEstimator-Train: step: 17200; l1_loss: 0.3579504; steps/sec: 7.56; FastEstimator-Train: step: 17400; l1_loss: 0.28131068; steps/sec: 11.44; FastEstimator-Train: step: 17600; l1_loss: 0.3653924; steps/sec: 11.52; FastEstimator-Train: step: 17800; l1_loss: 0.29413274; steps/sec: 11.42; FastEstimator-Train: step: 18000; l1_loss: 0.38791138; steps/sec: 11.47; FastEstimator-Train: step: 18000; epoch: 18; epoch_time(sec): 96.27; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 25.26; Eval Progress: 66/100; steps/sec: 26.64; Eval Progress: 100/100; steps/sec: 24.54; FastEstimator-Eval: step: 18000; epoch: 18; l1_loss: 0.2617805; FastEstimator-Train: step: 18200; l1_loss: 0.24427862; steps/sec: 7.69; FastEstimator-Train: step: 18400; l1_loss: 0.24796224; steps/sec: 11.41; FastEstimator-Train: step: 18600; l1_loss: 0.26004028; steps/sec: 11.45; FastEstimator-Train: step: 18800; l1_loss: 0.2745578; steps/sec: 11.33; FastEstimator-Train: step: 19000; l1_loss: 0.28821558; steps/sec: 11.45; FastEstimator-Train: step: 19000; epoch: 19; epoch_time(sec): 95.95; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 25.09; Eval Progress: 66/100; steps/sec: 26.65; Eval Progress: 100/100; steps/sec: 26.51; FastEstimator-Eval: step: 19000; epoch: 19; l1_loss: 0.25696224; FastEstimator-Train: step: 19200; l1_loss: 0.2739766; steps/sec: 7.75; FastEstimator-Train: step: 19400; l1_loss: 0.1355182; steps/sec: 11.43; FastEstimator-Train: step: 19600; l1_loss: 0.21476263; steps/sec: 11.45; FastEstimator-Train: step: 19800; l1_loss: 0.1606172; steps/sec: 11.41; FastEstimator-Train: step: 20000; l1_loss: 0.32306737; steps/sec: 11.48; FastEstimator-ModelSaver: Saved model to /tmp/tmpvd1u5fp6/noise_model_epoch_20.pt FastEstimator-Train: step: 20000; epoch: 20; epoch_time(sec): 95.87; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 25.62; Eval Progress: 66/100; steps/sec: 25.92; Eval Progress: 100/100; steps/sec: 26.36; FastEstimator-Eval: step: 20000; epoch: 20; l1_loss: 0.25400603; FastEstimator-Train: step: 20200; l1_loss: 0.16806644; steps/sec: 7.6; FastEstimator-Train: step: 20400; l1_loss: 0.36320537; steps/sec: 11.41; FastEstimator-Train: step: 20600; l1_loss: 0.19868326; steps/sec: 11.35; FastEstimator-Train: step: 20800; l1_loss: 0.41526514; steps/sec: 11.42; FastEstimator-Train: step: 21000; l1_loss: 0.19650456; steps/sec: 11.39; FastEstimator-Train: step: 21000; epoch: 21; epoch_time(sec): 96.31; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 24.73; Eval Progress: 66/100; steps/sec: 26.63; Eval Progress: 100/100; steps/sec: 26.52; FastEstimator-Eval: step: 21000; epoch: 21; l1_loss: 0.26024672; FastEstimator-Train: step: 21200; l1_loss: 0.15510589; steps/sec: 7.77; FastEstimator-Train: step: 21400; l1_loss: 0.24414665; steps/sec: 11.42; FastEstimator-Train: step: 21600; l1_loss: 0.2518507; steps/sec: 11.36; FastEstimator-Train: step: 21800; l1_loss: 0.3406285; steps/sec: 11.52; FastEstimator-Train: step: 22000; l1_loss: 0.24972007; steps/sec: 11.39; FastEstimator-Train: step: 22000; epoch: 22; epoch_time(sec): 95.83; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 24.96; Eval Progress: 66/100; steps/sec: 26.4; Eval Progress: 100/100; steps/sec: 26.18; FastEstimator-Eval: step: 22000; epoch: 22; l1_loss: 0.2508301; FastEstimator-Train: step: 22200; l1_loss: 0.19649792; steps/sec: 7.74; FastEstimator-Train: step: 22400; l1_loss: 0.27779052; steps/sec: 11.42; FastEstimator-Train: step: 22600; l1_loss: 0.19619058; steps/sec: 11.42; FastEstimator-Train: step: 22800; l1_loss: 0.31617478; steps/sec: 11.43; FastEstimator-Train: step: 23000; l1_loss: 0.20097603; steps/sec: 11.5; FastEstimator-Train: step: 23000; epoch: 23; epoch_time(sec): 95.69; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 25.15; Eval Progress: 66/100; steps/sec: 26.18; Eval Progress: 100/100; steps/sec: 25.16; FastEstimator-Eval: step: 23000; epoch: 23; l1_loss: 0.25543055; FastEstimator-Train: step: 23200; l1_loss: 0.25547707; steps/sec: 7.73; FastEstimator-Train: step: 23400; l1_loss: 0.29235286; steps/sec: 11.4; FastEstimator-Train: step: 23600; l1_loss: 0.16141641; steps/sec: 11.47; FastEstimator-Train: step: 23800; l1_loss: 0.37722278; steps/sec: 11.37; FastEstimator-Train: step: 24000; l1_loss: 0.2310235; steps/sec: 11.47; FastEstimator-ModelSaver: Saved model to /tmp/tmpvd1u5fp6/noise_model_epoch_24.pt FastEstimator-Train: step: 24000; epoch: 24; epoch_time(sec): 95.9; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 25.49; Eval Progress: 66/100; steps/sec: 26.84; Eval Progress: 100/100; steps/sec: 26.63; FastEstimator-Eval: step: 24000; epoch: 24; l1_loss: 0.26347294; FastEstimator-Train: step: 24200; l1_loss: 0.19425388; steps/sec: 7.75; FastEstimator-Train: step: 24400; l1_loss: 0.22600873; steps/sec: 11.39; FastEstimator-Train: step: 24600; l1_loss: 0.3548479; steps/sec: 11.4; FastEstimator-Train: step: 24800; l1_loss: 0.26855838; steps/sec: 11.5; FastEstimator-Train: step: 25000; l1_loss: 0.35399187; steps/sec: 11.43; FastEstimator-Train: step: 25000; epoch: 25; epoch_time(sec): 95.79; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 25.06; Eval Progress: 66/100; steps/sec: 26.69; Eval Progress: 100/100; steps/sec: 26.7; FastEstimator-Eval: step: 25000; epoch: 25; l1_loss: 0.256749; FastEstimator-Train: step: 25200; l1_loss: 0.20960468; steps/sec: 7.71; FastEstimator-Train: step: 25400; l1_loss: 0.36978525; steps/sec: 11.39; FastEstimator-Train: step: 25600; l1_loss: 0.29932252; steps/sec: 11.34; FastEstimator-Train: step: 25800; l1_loss: 0.20496017; steps/sec: 11.41; FastEstimator-Train: step: 26000; l1_loss: 0.15937817; steps/sec: 11.39; FastEstimator-Train: step: 26000; epoch: 26; epoch_time(sec): 96.29; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 25.43; Eval Progress: 66/100; steps/sec: 26.5; Eval Progress: 100/100; steps/sec: 26.64; FastEstimator-Eval: step: 26000; epoch: 26; l1_loss: 0.2580045; FastEstimator-Train: step: 26200; l1_loss: 0.16689032; steps/sec: 7.68; FastEstimator-Train: step: 26400; l1_loss: 0.21982989; steps/sec: 11.39; FastEstimator-Train: step: 26600; l1_loss: 0.27728337; steps/sec: 11.32; FastEstimator-Train: step: 26800; l1_loss: 0.4027999; steps/sec: 11.33; FastEstimator-Train: step: 27000; l1_loss: 0.19043693; steps/sec: 11.36; FastEstimator-Train: step: 27000; epoch: 27; epoch_time(sec): 96.57; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 25.24; Eval Progress: 66/100; steps/sec: 26.46; Eval Progress: 100/100; steps/sec: 26.23; FastEstimator-Eval: step: 27000; epoch: 27; l1_loss: 0.25776404; FastEstimator-Train: step: 27200; l1_loss: 0.20112318; steps/sec: 7.65; FastEstimator-Train: step: 27400; l1_loss: 0.24933486; steps/sec: 11.38; FastEstimator-Train: step: 27600; l1_loss: 0.2382636; steps/sec: 11.45; FastEstimator-Train: step: 27800; l1_loss: 0.26056936; steps/sec: 11.37; FastEstimator-Train: step: 28000; l1_loss: 0.24781737; steps/sec: 11.43; FastEstimator-ModelSaver: Saved model to /tmp/tmpvd1u5fp6/noise_model_epoch_28.pt FastEstimator-Train: step: 28000; epoch: 28; epoch_time(sec): 96.21; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 25.11; Eval Progress: 66/100; steps/sec: 26.36; Eval Progress: 100/100; steps/sec: 26.56; FastEstimator-Eval: step: 28000; epoch: 28; l1_loss: 0.25914752; FastEstimator-Train: step: 28200; l1_loss: 0.40274465; steps/sec: 7.77; FastEstimator-Train: step: 28400; l1_loss: 0.28560466; steps/sec: 11.42; FastEstimator-Train: step: 28600; l1_loss: 0.2816598; steps/sec: 11.47; FastEstimator-Train: step: 28800; l1_loss: 0.31932306; steps/sec: 11.44; FastEstimator-Train: step: 29000; l1_loss: 0.16468275; steps/sec: 11.43; FastEstimator-Train: step: 29000; epoch: 29; epoch_time(sec): 95.84; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 24.8; Eval Progress: 66/100; steps/sec: 26.58; Eval Progress: 100/100; steps/sec: 26.74; FastEstimator-Eval: step: 29000; epoch: 29; l1_loss: 0.25642276; FastEstimator-Train: step: 29200; l1_loss: 0.42530376; steps/sec: 7.67; FastEstimator-Train: step: 29400; l1_loss: 0.24223192; steps/sec: 11.43; FastEstimator-Train: step: 29600; l1_loss: 0.31856352; steps/sec: 11.33; FastEstimator-Train: step: 29800; l1_loss: 0.28411725; steps/sec: 11.44; FastEstimator-Train: step: 30000; l1_loss: 0.21703652; steps/sec: 11.36; FastEstimator-Train: step: 30000; epoch: 30; epoch_time(sec): 96.61; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 25.83; Eval Progress: 66/100; steps/sec: 26.71; Eval Progress: 100/100; steps/sec: 25.95; FastEstimator-Eval: step: 30000; epoch: 30; l1_loss: 0.25333703; FastEstimator-Train: step: 30200; l1_loss: 0.14242204; steps/sec: 7.39; FastEstimator-Train: step: 30400; l1_loss: 0.1973533; steps/sec: 11.36; FastEstimator-Train: step: 30600; l1_loss: 0.19050522; steps/sec: 11.36; FastEstimator-Train: step: 30800; l1_loss: 0.13728365; steps/sec: 11.34; FastEstimator-Train: step: 31000; l1_loss: 0.2919001; steps/sec: 11.39; FastEstimator-Train: step: 31000; epoch: 31; epoch_time(sec): 97.03; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 25.44; Eval Progress: 66/100; steps/sec: 26.29; Eval Progress: 100/100; steps/sec: 26.38; FastEstimator-Eval: step: 31000; epoch: 31; l1_loss: 0.24761303; FastEstimator-Train: step: 31200; l1_loss: 0.19710764; steps/sec: 7.56; FastEstimator-Train: step: 31400; l1_loss: 0.2218371; steps/sec: 11.36; FastEstimator-Train: step: 31600; l1_loss: 0.27173433; steps/sec: 11.4; FastEstimator-Train: step: 31800; l1_loss: 0.26842794; steps/sec: 11.33; FastEstimator-Train: step: 32000; l1_loss: 0.17139399; steps/sec: 11.39; FastEstimator-ModelSaver: Saved model to /tmp/tmpvd1u5fp6/noise_model_epoch_32.pt FastEstimator-Train: step: 32000; epoch: 32; epoch_time(sec): 97.15; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 23.37; Eval Progress: 66/100; steps/sec: 26.41; Eval Progress: 100/100; steps/sec: 26.22; FastEstimator-Eval: step: 32000; epoch: 32; l1_loss: 0.25006843; FastEstimator-Train: step: 32200; l1_loss: 0.2317583; steps/sec: 7.59; FastEstimator-Train: step: 32400; l1_loss: 0.19050783; steps/sec: 11.39; FastEstimator-Train: step: 32600; l1_loss: 0.1785226; steps/sec: 11.46; FastEstimator-Train: step: 32800; l1_loss: 0.1975793; steps/sec: 11.42; FastEstimator-Train: step: 33000; l1_loss: 0.26653224; steps/sec: 11.47; FastEstimator-Train: step: 33000; epoch: 33; epoch_time(sec): 96.04; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 25.33; Eval Progress: 66/100; steps/sec: 26.25; Eval Progress: 100/100; steps/sec: 26.51; FastEstimator-Eval: step: 33000; epoch: 33; l1_loss: 0.25212887; FastEstimator-Train: step: 33200; l1_loss: 0.23135075; steps/sec: 7.72; FastEstimator-Train: step: 33400; l1_loss: 0.1994313; steps/sec: 11.35; FastEstimator-Train: step: 33600; l1_loss: 0.31442124; steps/sec: 11.4; FastEstimator-Train: step: 33800; l1_loss: 0.18425947; steps/sec: 11.37; FastEstimator-Train: step: 34000; l1_loss: 0.22232369; steps/sec: 11.42; FastEstimator-Train: step: 34000; epoch: 34; epoch_time(sec): 96.16; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 24.75; Eval Progress: 66/100; steps/sec: 26.81; Eval Progress: 100/100; steps/sec: 26.2; FastEstimator-Eval: step: 34000; epoch: 34; l1_loss: 0.25720927; FastEstimator-Train: step: 34200; l1_loss: 0.2963785; steps/sec: 7.69; FastEstimator-Train: step: 34400; l1_loss: 0.27527443; steps/sec: 11.43; FastEstimator-Train: step: 34600; l1_loss: 0.21376503; steps/sec: 11.39; FastEstimator-Train: step: 34800; l1_loss: 0.205087; steps/sec: 11.46; FastEstimator-Train: step: 35000; l1_loss: 0.31293505; steps/sec: 11.39; FastEstimator-Train: step: 35000; epoch: 35; epoch_time(sec): 96.01; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 25.2; Eval Progress: 66/100; steps/sec: 26.43; Eval Progress: 100/100; steps/sec: 26.64; FastEstimator-Eval: step: 35000; epoch: 35; l1_loss: 0.2528195; FastEstimator-Train: step: 35200; l1_loss: 0.19908914; steps/sec: 7.73; FastEstimator-Train: step: 35400; l1_loss: 0.26568374; steps/sec: 11.45; FastEstimator-Train: step: 35600; l1_loss: 0.3396192; steps/sec: 11.4; FastEstimator-Train: step: 35800; l1_loss: 0.24528845; steps/sec: 11.45; FastEstimator-Train: step: 36000; l1_loss: 0.31224015; steps/sec: 11.38; FastEstimator-ModelSaver: Saved model to /tmp/tmpvd1u5fp6/noise_model_epoch_36.pt FastEstimator-Train: step: 36000; epoch: 36; epoch_time(sec): 95.84; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 24.71; Eval Progress: 66/100; steps/sec: 26.73; Eval Progress: 100/100; steps/sec: 26.52; FastEstimator-Eval: step: 36000; epoch: 36; l1_loss: 0.23870115; FastEstimator-Train: step: 36200; l1_loss: 0.25260645; steps/sec: 7.77; FastEstimator-Train: step: 36400; l1_loss: 0.17068475; steps/sec: 11.37; FastEstimator-Train: step: 36600; l1_loss: 0.29799157; steps/sec: 11.42; FastEstimator-Train: step: 36800; l1_loss: 0.21063522; steps/sec: 11.34; FastEstimator-Train: step: 37000; l1_loss: 0.14850116; steps/sec: 11.41; FastEstimator-Train: step: 37000; epoch: 37; epoch_time(sec): 96.08; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 25.25; Eval Progress: 66/100; steps/sec: 24.93; Eval Progress: 100/100; steps/sec: 26.38; FastEstimator-Eval: step: 37000; epoch: 37; l1_loss: 0.2534766; FastEstimator-Train: step: 37200; l1_loss: 0.16785488; steps/sec: 7.76; FastEstimator-Train: step: 37400; l1_loss: 0.22538082; steps/sec: 11.36; FastEstimator-Train: step: 37600; l1_loss: 0.23239955; steps/sec: 11.43; FastEstimator-Train: step: 37800; l1_loss: 0.1595094; steps/sec: 11.38; FastEstimator-Train: step: 38000; l1_loss: 0.2089673; steps/sec: 11.41; FastEstimator-Train: step: 38000; epoch: 38; epoch_time(sec): 96.04; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 26.02; Eval Progress: 66/100; steps/sec: 26.51; Eval Progress: 100/100; steps/sec: 26.14; FastEstimator-Eval: step: 38000; epoch: 38; l1_loss: 0.25527927; FastEstimator-Train: step: 38200; l1_loss: 0.17317447; steps/sec: 7.72; FastEstimator-Train: step: 38400; l1_loss: 0.3340689; steps/sec: 11.36; FastEstimator-Train: step: 38600; l1_loss: 0.32379156; steps/sec: 11.41; FastEstimator-Train: step: 38800; l1_loss: 0.22818284; steps/sec: 11.33; FastEstimator-Train: step: 39000; l1_loss: 0.36167914; steps/sec: 11.34; FastEstimator-Train: step: 39000; epoch: 39; epoch_time(sec): 96.27; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 26.28; Eval Progress: 66/100; steps/sec: 26.54; Eval Progress: 100/100; steps/sec: 26.29; FastEstimator-Eval: step: 39000; epoch: 39; l1_loss: 0.25252435; FastEstimator-Train: step: 39200; l1_loss: 0.39422715; steps/sec: 7.67; FastEstimator-Train: step: 39400; l1_loss: 0.28158087; steps/sec: 11.41; FastEstimator-Train: step: 39600; l1_loss: 0.27269578; steps/sec: 11.37; FastEstimator-Train: step: 39800; l1_loss: 0.17873858; steps/sec: 11.43; FastEstimator-Train: step: 40000; l1_loss: 0.28637937; steps/sec: 11.36; FastEstimator-ModelSaver: Saved model to /tmp/tmpvd1u5fp6/noise_model_epoch_40.pt FastEstimator-Train: step: 40000; epoch: 40; epoch_time(sec): 96.27; Eval Progress: 1/100; Eval Progress: 33/100; steps/sec: 24.41; Eval Progress: 66/100; steps/sec: 26.39; Eval Progress: 100/100; steps/sec: 25.94; FastEstimator-Eval: step: 40000; epoch: 40; l1_loss: 0.24847148; FastEstimator-Finish: step: 40000; encoder_lr: 0.0001; noise_model_lr: 1e-05; total_time(sec): 4341.57;
Generate Images using Diffusion Pipeline ¶
Congratulations we have successfully trained a diffusion model¶
Since the heavy lifting is already done, lets begin the fun part and generate random images using the diffusion pipeline.
Step 1: Lets Define Diffusion Pipeline¶
Diffusion pipeline takes the pretrained noise network and decoder, to generate new images from random noise. Diffusion pipeline starts with from noise sample similar to output of the latent encoder and iteratively removes noise using the noise network iteratively to produce new images.
class DiffusionInferencePipeline():
def __init__(self, noise_network, decoder_model, timesteps=1000):
super().__init__()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.noise_scheduler = GaussianNoiseScheduler(timesteps=timesteps)
self.noise_network = noise_network
self.latent_decoder = decoder_model
with torch.no_grad():
for param in self.latent_decoder.parameters():
param.requires_grad = False
for param in self.noise_network.parameters():
param.requires_grad = False
self.latent_decoder.eval()
self.noise_network.eval()
self.latent_decoder.to(device)
self.noise_network.to(device)
@torch.no_grad()
def noise_estimate(self, x_t, t, condition=None):
# Note: x_t expected to be in range ~ [-1, 1]
pred = self.noise_network(x_t, t)
x_t_prior, x_0 = self.noise_scheduler.estimate_x_t_prior_from_x_T(x_t, t, pred, var_scale=False)
x_T = pred
return x_t_prior, x_0, x_T
@torch.no_grad()
def denoise(self, x_t, **kwargs):
# ---------- run denoise loop ---------------
timesteps_array = self.noise_scheduler.timesteps_array
for i, t in tqdm(enumerate(reversed(timesteps_array))):
# noise network prediction
self.noise_network.eval()
with torch.no_grad():
x_t, _, _ = self.noise_estimate(x_t, t.expand(x_t.shape[0]))
x_t = x_t.to(torch.float32)
# ------ Eventually decode from latent space into image space--------
if self.latent_decoder is not None:
self.latent_decoder.eval()
with torch.no_grad():
x_t = self.latent_decoder(x_t)
return x_t
@torch.no_grad()
def sample(self, num_samples, img_size, **kwargs):
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
template = torch.zeros((num_samples, *img_size), device=self.device)
x_T = self.noise_scheduler.x_final(template)
x_0 = self.denoise(x_T, **kwargs)
return x_0
Step 2: Define noise sample and load Diffusion Pipeline¶
Loading the diffusion pipeline using the pretrained noise estimator and the decoder and defining the shape of noise sample similar to output of the latent encoder. Based on the results the training epochs can be increased for either training the noise network or latent embedder.
diffusion_sample = (emb_channels, image_size//8, image_size//8) # similar to the output of latent embedder
diffusion_pipeline = DiffusionInferencePipeline(noise_model, decoder_model, timesteps=timesteps)
Step 3: Generate images from noise sample¶
sample_image_batch = diffusion_pipeline.sample(num_samples, diffusion_sample)
sample_image_batch = sample_image_batch*0.5 + 0.5
sample_image_batch = sample_image_batch.clamp(0, 1)
1000it [00:19, 50.62it/s]
BatchDisplay(image=sample_image_batch[:4].cpu(), title="Diffusion Output").show()
Bravo, in around 2 hr we are able to produce decent 256256 chest xray images using stable diffusion. Comparing it with pggan which takes around 24 hours on a V100 machine to produce images of 128x128 resolution, we are able to produce 256256 images in less than 2 hours on a A100 machine using stable diffusion. Stable diffusion is 12X faster to produce 256*256 images compared to pggan.