Horse to Zebra Unpaired Image Translation with CycleGAN in FastEstimator¶
[Paper] [Notebook] [TF Implementation] [Torch Implementation]
This notebook demonstrates how to perform an unpaired image to image translation using CycleGAN in FastEstimator.
We will specifically look at the problem of translating horse images to zebra images.
import tempfile
import numpy as np
import torch
import torch.nn as nn
from torch.nn.init import normal_
import fastestimator as fe
from fastestimator.backend import reduce_mean
from fastestimator.util import ImageDisplay, GridDisplay
#Parameters
epochs = 200
batch_size = 1
train_steps_per_epoch = None
save_dir=tempfile.mkdtemp()
weight = 10.0
data_dir=None
Building Components¶
Downloading the data¶
First, we will download the dataset of horses and zebras via our dataset API.
The images will be first downloaded from here.
As this task requires an unpaired datasets of horse and zebra images, horse2zebra dataset is implemented using BatchDataset
in FastEstimator. Hence, we need to specify the batch size while loading the dataset.
from fastestimator.dataset.data.horse2zebra import load_data
train_data, test_data = load_data(batch_size=batch_size, root_dir=data_dir)
Step 1: Create pipeline¶
Let's create the pipeline. As, batch_size must be None
when BatchDataset is being used, we will not provide the batch_size argument.
from fastestimator.op.numpyop import Delete
from fastestimator.op.numpyop.meta import Sometimes
from fastestimator.op.numpyop.multivariate import HorizontalFlip, RandomCrop, Resize
from fastestimator.op.numpyop.univariate import ChannelTranspose, Normalize, ReadImage
pipeline = fe.Pipeline(
train_data=train_data,
ops=[
ReadImage(inputs=["A", "B"], outputs=["A", "B"]),
Normalize(inputs=["A", "B"], outputs=["real_A", "real_B"], mean=1.0, std=1.0, max_pixel_value=127.5),
Resize(height=286, width=286, image_in="real_A", image_out="real_A", mode="train"),
RandomCrop(height=256, width=256, image_in="real_A", image_out="real_A", mode="train"),
Resize(height=286, width=286, image_in="real_B", image_out="real_B", mode="train"),
RandomCrop(height=256, width=256, image_in="real_B", image_out="real_B", mode="train"),
Sometimes(HorizontalFlip(image_in="real_A", image_out="real_A", mode="train")),
Sometimes(HorizontalFlip(image_in="real_B", image_out="real_B", mode="train")),
ChannelTranspose(inputs=["real_A", "real_B"], outputs=["real_A", "real_B"]),
Delete(keys=["A", "B"])
])
We can visualize sample images from the pipeline
using get_results
method.
def Minmax(img):
img_max = np.max(img)
img_min = np.min(img)
img = (img - img_min)/max((img_max - img_min), 1e-7)
img = (img*255).astype(np.uint8)
return img
sample_batch = pipeline.get_results()
horse_img = sample_batch["real_A"][0]
horse_img = np.transpose(horse_img.numpy(), (1, 2, 0))
horse_img = np.expand_dims(Minmax(horse_img), 0)
zebra_img = sample_batch["real_B"][0]
zebra_img = np.transpose(zebra_img.numpy(), (1, 2, 0))
zebra_img = np.expand_dims(Minmax(zebra_img), 0)
GridDisplay([ImageDisplay(image=horse_img[0], title="Horse"),
ImageDisplay(image=zebra_img[0], title="Zebra")
]).show()
Step 2: Create Network¶
In CycleGAN, there are 2 generators and 2 discriminators being trained.
- Generator
g_AtoB
learns to map horse images to zebra images - Generator
g_BtoA
learns to map zebra images to horse images - Discriminator
d_A
learns to differentiate between real hores images and fake horse images produced byg_BtoA
- Discriminator
d_B
learns to differentiate between image zebra and fake zebra images produced byg_AtoB
The architecture of generator is a modified resnet, and the architecture of discriminator is a PatchGAN.
class ResidualBlock(nn.Module):
"""Residual block architecture"""
def __init__(self, in_channels, out_channels, kernel_size=3):
super().__init__()
self.layers = nn.Sequential(nn.ReflectionPad2d(1),
nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size),
nn.InstanceNorm2d(out_channels),
nn.ReLU(),
nn.ReflectionPad2d(1),
nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size),
nn.InstanceNorm2d(out_channels))
for layer in self.layers:
if isinstance(layer, nn.Conv2d):
normal_(layer.weight.data, mean=0, std=0.02)
def forward(self, x):
x_out = self.layers(x)
x_out = x_out + x
return x_out
class Discriminator(nn.Module):
"""Discriminator network architecture"""
def __init__(self):
super().__init__()
self.layers = nn.Sequential(nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.InstanceNorm2d(128),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
nn.InstanceNorm2d(256),
nn.LeakyReLU(0.2),
nn.ReflectionPad2d(1),
nn.Conv2d(256, 512, kernel_size=4, stride=1),
nn.InstanceNorm2d(512),
nn.LeakyReLU(0.2),
nn.ReflectionPad2d(1),
nn.Conv2d(512, 1, kernel_size=4, stride=1))
for layer in self.layers:
if isinstance(layer, nn.Conv2d):
normal_(layer.weight.data, mean=0, std=0.02)
def forward(self, x):
x = self.layers(x)
return x
class Generator(nn.Module):
"""Generator network architecture"""
def __init__(self, num_blocks=9):
super().__init__()
self.layers1 = nn.Sequential(nn.ReflectionPad2d(3),
nn.Conv2d(3, 64, kernel_size=7, stride=1),
nn.InstanceNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.InstanceNorm2d(128),
nn.ReLU(),
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
nn.InstanceNorm2d(256),
nn.ReLU())
self.resblocks = nn.Sequential(*[ResidualBlock(256, 256) for i in range(num_blocks)])
self.layers2 = nn.Sequential(nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(128),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(64),
nn.ReLU(),
nn.ReflectionPad2d(3),
nn.Conv2d(64, 3, kernel_size=7, stride=1))
for block in [self.layers1, self.layers2]:
for layer in block:
if isinstance(layer, nn.Conv2d) or isinstance(layer, nn.ConvTranspose2d):
normal_(layer.weight.data, mean=0, std=0.02)
def forward(self, x):
x = self.layers1(x)
x = self.resblocks(x)
x = self.layers2(x)
x = torch.tanh(x)
return x
g_AtoB = fe.build(model_fn=Generator,
model_name="g_AtoB",
optimizer_fn=lambda x: torch.optim.Adam(x, lr=2e-4, betas=(0.5, 0.999)))
g_BtoA = fe.build(model_fn=Generator,
model_name="g_BtoA",
optimizer_fn=lambda x: torch.optim.Adam(x, lr=2e-4, betas=(0.5, 0.999)))
d_A = fe.build(model_fn=Discriminator,
model_name="d_A",
optimizer_fn=lambda x: torch.optim.Adam(x, lr=2e-4, betas=(0.5, 0.999)))
d_B = fe.build(model_fn=Discriminator,
model_name="d_B",
optimizer_fn=lambda x: torch.optim.Adam(x, lr=2e-4, betas=(0.5, 0.999)))
Defining Loss functions¶
For each network, we need to define associated losses.
Because horse images and zebra images are unpaired, the loss of generator is quite complex.
The generator's loss is composed of three terms:
* adversarial
* cycle-consistency
* identity.
The cycle-consistency term and identity term are weighted by a parameter LAMBDA
. In the paper the authors used 10 for LAMBDA
.
Let's consider computing the loss for g_AtoB
which translates horses to zebras.
- Adversarial term that is computed as binary cross entropy between ones and
d_A
's prediction on the translated images - Cycle consistency term is computed with mean absolute error between original horse images and the cycled horse images that are translated forward by
g_AtoB
and then backward byg_BtoA
. - Identity term that is computed with the mean absolute error between original zebra and the output of
g_AtoB
on these images.
The discriminator's loss is the standard adversarial loss that is computed as binary cross entropy between:
- Ones and real images
- Zeros and fake images
from fastestimator.op.tensorop import TensorOp
class GLoss(TensorOp):
"""TensorOp to compute generator loss"""
def __init__(self, inputs, weight, outputs=None, mode=None, average_loss=True):
super().__init__(inputs=inputs, outputs=outputs, mode=mode)
self.loss_fn = nn.MSELoss(reduction="none")
self.LAMBDA = weight
self.average_loss = average_loss
def _adversarial_loss(self, fake_img):
return torch.mean(self.loss_fn(fake_img, torch.ones_like(fake_img)), dim=(2, 3))
def _identity_loss(self, real_img, same_img):
return 0.5 * self.LAMBDA * torch.mean(torch.abs(real_img - same_img), dim=(1, 2, 3))
def _cycle_loss(self, real_img, cycled_img):
return self.LAMBDA * torch.mean(torch.abs(real_img - cycled_img), dim=(1, 2, 3))
def forward(self, data, state):
real_img, fake_img, cycled_img, same_img = data
total_loss = self._adversarial_loss(fake_img) + self._identity_loss(real_img, same_img) + self._cycle_loss(
real_img, cycled_img)
if self.average_loss:
total_loss = reduce_mean(total_loss)
return total_loss
class DLoss(TensorOp):
"""TensorOp to compute discriminator loss"""
def __init__(self, inputs, outputs=None, mode=None, average_loss=True):
super().__init__(inputs=inputs, outputs=outputs, mode=mode)
self.loss_fn = nn.MSELoss(reduction="none")
self.average_loss = average_loss
def forward(self, data, state):
real_img, fake_img = data
real_img_loss = torch.mean(self.loss_fn(real_img, torch.ones_like(real_img)), dim=(2, 3))
fake_img_loss = torch.mean(self.loss_fn(fake_img, torch.zeros_like(real_img)), dim=(2, 3))
total_loss = real_img_loss + fake_img_loss
if self.average_loss:
total_loss = reduce_mean(total_loss)
return 0.5 * total_loss
We implement an image buffer as a TensorOp
which stores the previous images produced by the generators to updated the discriminators as outlined in Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks.
class Buffer(TensorOp):
def __init__(self, image_in=None, image_out=None, mode=None, buffer_size=50):
super().__init__(inputs=image_in, outputs=image_out, mode=mode)
self.buffer_size = buffer_size
self.num_imgs = 0
self.image_buffer = []
def forward(self, data, state):
output = []
for image in data:
image = torch.unsqueeze(image.data, 0)
if self.num_imgs < self.buffer_size:
self.image_buffer.append(image)
output.append(image)
self.num_imgs += 1
else:
if np.random.uniform() > 0.5:
idx = np.random.randint(self.buffer_size)
temp = self.image_buffer[idx].clone()
self.image_buffer[idx] = image
output.append(temp)
else:
output.append(image)
output = torch.cat(output, 0)
return output
Once associated losses are defined, we can now define the Network
object.
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
network = fe.Network(ops=[
ModelOp(inputs="real_A", model=g_AtoB, outputs="fake_B"),
ModelOp(inputs="real_B", model=g_BtoA, outputs="fake_A"),
Buffer(image_in="fake_A", image_out="buffer_fake_A"),
Buffer(image_in="fake_B", image_out="buffer_fake_B"),
ModelOp(inputs="real_A", model=d_A, outputs="d_real_A"),
ModelOp(inputs="fake_A", model=d_A, outputs="d_fake_A"),
ModelOp(inputs="buffer_fake_A", model=d_A, outputs="buffer_d_fake_A"),
ModelOp(inputs="real_B", model=d_B, outputs="d_real_B"),
ModelOp(inputs="fake_B", model=d_B, outputs="d_fake_B"),
ModelOp(inputs="buffer_fake_B", model=d_B, outputs="buffer_d_fake_B"),
ModelOp(inputs="real_A", model=g_BtoA, outputs="same_A"),
ModelOp(inputs="fake_B", model=g_BtoA, outputs="cycled_A"),
ModelOp(inputs="real_B", model=g_AtoB, outputs="same_B"),
ModelOp(inputs="fake_A", model=g_AtoB, outputs="cycled_B"),
GLoss(inputs=("real_A", "d_fake_B", "cycled_A", "same_A"), weight=weight, outputs="g_AtoB_loss"),
GLoss(inputs=("real_B", "d_fake_A", "cycled_B", "same_B"), weight=weight, outputs="g_BtoA_loss"),
DLoss(inputs=("d_real_A", "buffer_d_fake_A"), outputs="d_A_loss"),
DLoss(inputs=("d_real_B", "buffer_d_fake_B"), outputs="d_B_loss"),
UpdateOp(model=g_AtoB, loss_name="g_AtoB_loss"),
UpdateOp(model=g_BtoA, loss_name="g_BtoA_loss"),
UpdateOp(model=d_A, loss_name="d_A_loss"),
UpdateOp(model=d_B, loss_name="d_B_loss")
])
Here, we use a linear learning rate decay for training.
def lr_schedule(epoch):
if epoch<=100:
lr = 2e-4
else:
lr = 2e-4*(200 - epoch)/100
return lr
In this example we will use ModelSaver
traces to save the two generators g_AtoB
and g_BtoA
throughout training and LRScheduler
traces to update the learning rate.
from fastestimator.trace.adapt import LRScheduler
from fastestimator.trace.io import ModelSaver
traces = [
ModelSaver(model=g_AtoB, save_dir=save_dir, frequency=10),
ModelSaver(model=g_BtoA, save_dir=save_dir, frequency=10),
LRScheduler(model=g_AtoB, lr_fn=lr_schedule),
LRScheduler(model=g_BtoA, lr_fn=lr_schedule),
LRScheduler(model=d_A, lr_fn=lr_schedule),
LRScheduler(model=d_B, lr_fn=lr_schedule)
]
Step 3: Estimator¶
Finally, we are ready to define Estimator
object and then call fit
method to start the training.
Just for the sake of demo purpose, we would only run 50 epochs.
estimator = fe.Estimator(network=network,
pipeline=pipeline,
epochs=epochs,
traces=traces,
log_steps=1000,
train_steps_per_epoch=train_steps_per_epoch)
Training¶
estimator.fit()
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ FastEstimator-Start: step: 1; logging_interval: 1000; num_device: 1; FastEstimator-Train: step: 1; d_A_loss: 3.2108502; d_A_lr: 0.0002; d_B_loss: 3.1976771; d_B_lr: 0.0002; g_AtoB_loss: 13.665354; g_AtoB_lr: 0.0002; g_BtoA_loss: 13.892114; g_BtoA_lr: 0.0002; FastEstimator-Train: step: 1000; d_A_loss: 0.010945682; d_A_lr: 0.0002; d_B_loss: 0.3114647; d_B_lr: 0.0002; g_AtoB_loss: 15.73657; g_AtoB_lr: 0.0002; g_BtoA_loss: 9.770332; g_BtoA_lr: 0.0002; steps/sec: 10.26; FastEstimator-Train: step: 1334; epoch: 1; epoch_time: 138.9 sec; FastEstimator-Train: step: 2000; d_A_loss: 0.21845323; d_A_lr: 0.0002; d_B_loss: 0.3194074; d_B_lr: 0.0002; g_AtoB_loss: 7.026337; g_AtoB_lr: 0.0002; g_BtoA_loss: 9.131756; g_BtoA_lr: 0.0002; steps/sec: 9.05; FastEstimator-Train: step: 2668; epoch: 2; epoch_time: 143.35 sec; FastEstimator-Train: step: 3000; d_A_loss: 0.2580701; d_A_lr: 0.0002; d_B_loss: 0.43598607; d_B_lr: 0.0002; g_AtoB_loss: 4.68789; g_AtoB_lr: 0.0002; g_BtoA_loss: 8.580699; g_BtoA_lr: 0.0002; steps/sec: 8.18; FastEstimator-Train: step: 4000; d_A_loss: 0.12693632; d_A_lr: 0.0002; d_B_loss: 0.04536821; d_B_lr: 0.0002; g_AtoB_loss: 5.3422995; g_AtoB_lr: 0.0002; g_BtoA_loss: 7.5210814; g_BtoA_lr: 0.0002; steps/sec: 8.7; FastEstimator-Train: step: 4002; epoch: 3; epoch_time: 169.88 sec; FastEstimator-Train: step: 5000; d_A_loss: 0.24568222; d_A_lr: 0.0002; d_B_loss: 0.04645998; d_B_lr: 0.0002; g_AtoB_loss: 6.6831093; g_AtoB_lr: 0.0002; g_BtoA_loss: 7.0778866; g_BtoA_lr: 0.0002; steps/sec: 8.15; FastEstimator-Train: step: 5336; epoch: 4; epoch_time: 159.61 sec; FastEstimator-Train: step: 6000; d_A_loss: 0.16322158; d_A_lr: 0.0002; d_B_loss: 0.019150853; d_B_lr: 0.0002; g_AtoB_loss: 7.7576237; g_AtoB_lr: 0.0002; g_BtoA_loss: 10.1092415; g_BtoA_lr: 0.0002; steps/sec: 9.25; FastEstimator-Train: step: 6670; epoch: 5; epoch_time: 135.78 sec; FastEstimator-Train: step: 7000; d_A_loss: 0.10134612; d_A_lr: 0.0002; d_B_loss: 0.17742005; d_B_lr: 0.0002; g_AtoB_loss: 7.486726; g_AtoB_lr: 0.0002; g_BtoA_loss: 4.1879535; g_BtoA_lr: 0.0002; steps/sec: 8.96; FastEstimator-Train: step: 8000; d_A_loss: 0.13598159; d_A_lr: 0.0002; d_B_loss: 0.06359689; d_B_lr: 0.0002; g_AtoB_loss: 5.198177; g_AtoB_lr: 0.0002; g_BtoA_loss: 5.1237717; g_BtoA_lr: 0.0002; steps/sec: 10.15; FastEstimator-Train: step: 8004; epoch: 6; epoch_time: 146.04 sec; FastEstimator-Train: step: 9000; d_A_loss: 0.24748227; d_A_lr: 0.0002; d_B_loss: 0.08985389; d_B_lr: 0.0002; g_AtoB_loss: 7.138328; g_AtoB_lr: 0.0002; g_BtoA_loss: 12.837671; g_BtoA_lr: 0.0002; steps/sec: 9.73; FastEstimator-Train: step: 9338; epoch: 7; epoch_time: 132.51 sec; FastEstimator-Train: step: 10000; d_A_loss: 0.28274047; d_A_lr: 0.0002; d_B_loss: 0.23293617; d_B_lr: 0.0002; g_AtoB_loss: 3.688202; g_AtoB_lr: 0.0002; g_BtoA_loss: 3.485837; g_BtoA_lr: 0.0002; steps/sec: 9.29; FastEstimator-Train: step: 10672; epoch: 8; epoch_time: 155.15 sec; FastEstimator-Train: step: 11000; d_A_loss: 0.03592102; d_A_lr: 0.0002; d_B_loss: 0.2748941; d_B_lr: 0.0002; g_AtoB_loss: 4.3695536; g_AtoB_lr: 0.0002; g_BtoA_loss: 3.2612844; g_BtoA_lr: 0.0002; steps/sec: 8.29; FastEstimator-Train: step: 12000; d_A_loss: 0.05218474; d_A_lr: 0.0002; d_B_loss: 0.017800268; d_B_lr: 0.0002; g_AtoB_loss: 19.132637; g_AtoB_lr: 0.0002; g_BtoA_loss: 17.503262; g_BtoA_lr: 0.0002; steps/sec: 11.89; FastEstimator-Train: step: 12006; epoch: 9; epoch_time: 126.77 sec; FastEstimator-Train: step: 13000; d_A_loss: 0.13947828; d_A_lr: 0.0002; d_B_loss: 0.17353384; d_B_lr: 0.0002; g_AtoB_loss: 3.1774468; g_AtoB_lr: 0.0002; g_BtoA_loss: 6.9188337; g_BtoA_lr: 0.0002; steps/sec: 11.05; FastEstimator-ModelSaver: Saved model to /tmp/tmpn34de7gf/g_AtoB_epoch_10.pt FastEstimator-ModelSaver: Saved model to /tmp/tmpn34de7gf/g_BtoA_epoch_10.pt FastEstimator-Train: step: 13340; epoch: 10; epoch_time: 118.57 sec; FastEstimator-Train: step: 14000; d_A_loss: 0.18080026; d_A_lr: 0.0002; d_B_loss: 0.10328345; d_B_lr: 0.0002; g_AtoB_loss: 4.9533176; g_AtoB_lr: 0.0002; g_BtoA_loss: 7.5026536; g_BtoA_lr: 0.0002; steps/sec: 11.06; FastEstimator-Train: step: 14674; epoch: 11; epoch_time: 119.07 sec; FastEstimator-Train: step: 15000; d_A_loss: 0.0076176217; d_A_lr: 0.0002; d_B_loss: 0.10207303; d_B_lr: 0.0002; g_AtoB_loss: 16.60973; g_AtoB_lr: 0.0002; g_BtoA_loss: 13.892546; g_BtoA_lr: 0.0002; steps/sec: 11.05; FastEstimator-Train: step: 16000; d_A_loss: 0.18173808; d_A_lr: 0.0002; d_B_loss: 0.052283786; d_B_lr: 0.0002; g_AtoB_loss: 6.226133; g_AtoB_lr: 0.0002; g_BtoA_loss: 4.731463; g_BtoA_lr: 0.0002; steps/sec: 12.0; FastEstimator-Train: step: 16008; epoch: 12; epoch_time: 117.27 sec; FastEstimator-Train: step: 17000; d_A_loss: 0.11598459; d_A_lr: 0.0002; d_B_loss: 0.121301845; d_B_lr: 0.0002; g_AtoB_loss: 4.3975754; g_AtoB_lr: 0.0002; g_BtoA_loss: 6.546061; g_BtoA_lr: 0.0002; steps/sec: 11.02; FastEstimator-Train: step: 17342; epoch: 13; epoch_time: 118.98 sec; FastEstimator-Train: step: 18000; d_A_loss: 0.15181227; d_A_lr: 0.0002; d_B_loss: 0.16069743; d_B_lr: 0.0002; g_AtoB_loss: 6.0471025; g_AtoB_lr: 0.0002; g_BtoA_loss: 3.4856486; g_BtoA_lr: 0.0002; steps/sec: 11.05; FastEstimator-Train: step: 18676; epoch: 14; epoch_time: 118.49 sec; FastEstimator-Train: step: 19000; d_A_loss: 0.20692636; d_A_lr: 0.0002; d_B_loss: 0.07573674; d_B_lr: 0.0002; g_AtoB_loss: 4.6534934; g_AtoB_lr: 0.0002; g_BtoA_loss: 3.594749; g_BtoA_lr: 0.0002; steps/sec: 11.02; FastEstimator-Train: step: 20000; d_A_loss: 0.21789747; d_A_lr: 0.0002; d_B_loss: 0.21549332; d_B_lr: 0.0002; g_AtoB_loss: 4.447349; g_AtoB_lr: 0.0002; g_BtoA_loss: 4.937718; g_BtoA_lr: 0.0002; steps/sec: 11.97; FastEstimator-Train: step: 20010; epoch: 15; epoch_time: 118.25 sec; FastEstimator-Train: step: 21000; d_A_loss: 0.28535625; d_A_lr: 0.0002; d_B_loss: 0.22555062; d_B_lr: 0.0002; g_AtoB_loss: 2.8892808; g_AtoB_lr: 0.0002; g_BtoA_loss: 5.4860334; g_BtoA_lr: 0.0002; steps/sec: 11.08; FastEstimator-Train: step: 21344; epoch: 16; epoch_time: 118.57 sec; FastEstimator-Train: step: 22000; d_A_loss: 0.19136402; d_A_lr: 0.0002; d_B_loss: 0.109166086; d_B_lr: 0.0002; g_AtoB_loss: 3.6108365; g_AtoB_lr: 0.0002; g_BtoA_loss: 4.6259403; g_BtoA_lr: 0.0002; steps/sec: 11.03; FastEstimator-Train: step: 22678; epoch: 17; epoch_time: 118.98 sec; FastEstimator-Train: step: 23000; d_A_loss: 0.15124202; d_A_lr: 0.0002; d_B_loss: 0.1616896; d_B_lr: 0.0002; g_AtoB_loss: 3.598625; g_AtoB_lr: 0.0002; g_BtoA_loss: 4.349188; g_BtoA_lr: 0.0002; steps/sec: 11.0; FastEstimator-Train: step: 24000; d_A_loss: 0.16977733; d_A_lr: 0.0002; d_B_loss: 0.06550271; d_B_lr: 0.0002; g_AtoB_loss: 6.7928667; g_AtoB_lr: 0.0002; g_BtoA_loss: 5.1823797; g_BtoA_lr: 0.0002; steps/sec: 11.98; FastEstimator-Train: step: 24012; epoch: 18; epoch_time: 117.98 sec; FastEstimator-Train: step: 25000; d_A_loss: 0.15378422; d_A_lr: 0.0002; d_B_loss: 0.08358267; d_B_lr: 0.0002; g_AtoB_loss: 6.564144; g_AtoB_lr: 0.0002; g_BtoA_loss: 3.2897534; g_BtoA_lr: 0.0002; steps/sec: 11.06; FastEstimator-Train: step: 25346; epoch: 19; epoch_time: 118.51 sec; FastEstimator-Train: step: 26000; d_A_loss: 0.0070417607; d_A_lr: 0.0002; d_B_loss: 0.12145958; d_B_lr: 0.0002; g_AtoB_loss: 15.752372; g_AtoB_lr: 0.0002; g_BtoA_loss: 6.0002174; g_BtoA_lr: 0.0002; steps/sec: 11.06; FastEstimator-ModelSaver: Saved model to /tmp/tmpn34de7gf/g_AtoB_epoch_20.pt FastEstimator-ModelSaver: Saved model to /tmp/tmpn34de7gf/g_BtoA_epoch_20.pt FastEstimator-Train: step: 26680; epoch: 20; epoch_time: 118.79 sec; FastEstimator-Train: step: 27000; d_A_loss: 0.0622226; d_A_lr: 0.0002; d_B_loss: 0.04732838; d_B_lr: 0.0002; g_AtoB_loss: 10.552514; g_AtoB_lr: 0.0002; g_BtoA_loss: 9.402086; g_BtoA_lr: 0.0002; steps/sec: 11.01; FastEstimator-Train: step: 28000; d_A_loss: 0.028189775; d_A_lr: 0.0002; d_B_loss: 0.13978265; d_B_lr: 0.0002; g_AtoB_loss: 10.884317; g_AtoB_lr: 0.0002; g_BtoA_loss: 4.892234; g_BtoA_lr: 0.0002; steps/sec: 11.87; FastEstimator-Train: step: 28014; epoch: 21; epoch_time: 118.8 sec; FastEstimator-Train: step: 29000; d_A_loss: 0.09849219; d_A_lr: 0.0002; d_B_loss: 0.124093324; d_B_lr: 0.0002; g_AtoB_loss: 5.1369505; g_AtoB_lr: 0.0002; g_BtoA_loss: 3.663924; g_BtoA_lr: 0.0002; steps/sec: 11.02; FastEstimator-Train: step: 29348; epoch: 22; epoch_time: 118.78 sec; FastEstimator-Train: step: 30000; d_A_loss: 0.12628208; d_A_lr: 0.0002; d_B_loss: 0.10195969; d_B_lr: 0.0002; g_AtoB_loss: 2.6797104; g_AtoB_lr: 0.0002; g_BtoA_loss: 4.2230177; g_BtoA_lr: 0.0002; steps/sec: 11.14; FastEstimator-Train: step: 30682; epoch: 23; epoch_time: 123.55 sec; FastEstimator-Train: step: 31000; d_A_loss: 0.15682167; d_A_lr: 0.0002; d_B_loss: 0.32548973; d_B_lr: 0.0002; g_AtoB_loss: 4.840435; g_AtoB_lr: 0.0002; g_BtoA_loss: 3.4905968; g_BtoA_lr: 0.0002; steps/sec: 4.6; FastEstimator-Train: step: 32000; d_A_loss: 0.11451645; d_A_lr: 0.0002; d_B_loss: 0.06042458; d_B_lr: 0.0002; g_AtoB_loss: 3.4695773; g_AtoB_lr: 0.0002; g_BtoA_loss: 4.142742; g_BtoA_lr: 0.0002; steps/sec: 3.78; FastEstimator-Train: step: 32016; epoch: 24; epoch_time: 430.36 sec; FastEstimator-Train: step: 33000; d_A_loss: 0.29871425; d_A_lr: 0.0002; d_B_loss: 0.16944641; d_B_lr: 0.0002; g_AtoB_loss: 4.16166; g_AtoB_lr: 0.0002; g_BtoA_loss: 3.8225899; g_BtoA_lr: 0.0002; steps/sec: 2.83; FastEstimator-Train: step: 33350; epoch: 25; epoch_time: 499.05 sec; FastEstimator-Train: step: 34000; d_A_loss: 0.114286; d_A_lr: 0.0002; d_B_loss: 0.109651625; d_B_lr: 0.0002; g_AtoB_loss: 4.5099387; g_AtoB_lr: 0.0002; g_BtoA_loss: 3.6966114; g_BtoA_lr: 0.0002; steps/sec: 2.77; FastEstimator-Train: step: 34684; epoch: 26; epoch_time: 396.64 sec; FastEstimator-Train: step: 35000; d_A_loss: 0.07977181; d_A_lr: 0.0002; d_B_loss: 0.21154696; d_B_lr: 0.0002; g_AtoB_loss: 4.3349495; g_AtoB_lr: 0.0002; g_BtoA_loss: 3.3471603; g_BtoA_lr: 0.0002; steps/sec: 3.37; FastEstimator-Train: step: 36000; d_A_loss: 0.16729036; d_A_lr: 0.0002; d_B_loss: 0.062212124; d_B_lr: 0.0002; g_AtoB_loss: 4.9130187; g_AtoB_lr: 0.0002; g_BtoA_loss: 4.848304; g_BtoA_lr: 0.0002; steps/sec: 3.76; FastEstimator-Train: step: 36018; epoch: 27; epoch_time: 372.0 sec; FastEstimator-Train: step: 37000; d_A_loss: 0.2021332; d_A_lr: 0.0002; d_B_loss: 0.13457865; d_B_lr: 0.0002; g_AtoB_loss: 2.7795298; g_AtoB_lr: 0.0002; g_BtoA_loss: 3.057055; g_BtoA_lr: 0.0002; steps/sec: 3.65; FastEstimator-Train: step: 37352; epoch: 28; epoch_time: 385.47 sec; FastEstimator-Train: step: 38000; d_A_loss: 0.41481462; d_A_lr: 0.0002; d_B_loss: 0.0745297; d_B_lr: 0.0002; g_AtoB_loss: 4.2751684; g_AtoB_lr: 0.0002; g_BtoA_loss: 6.6050787; g_BtoA_lr: 0.0002; steps/sec: 3.17; FastEstimator-Train: step: 38686; epoch: 29; epoch_time: 389.7 sec; FastEstimator-Train: step: 39000; d_A_loss: 0.2179724; d_A_lr: 0.0002; d_B_loss: 0.07892369; d_B_lr: 0.0002; g_AtoB_loss: 2.9368668; g_AtoB_lr: 0.0002; g_BtoA_loss: 5.7667227; g_BtoA_lr: 0.0002; steps/sec: 3.14; FastEstimator-Train: step: 40000; d_A_loss: 0.38974488; d_A_lr: 0.0002; d_B_loss: 0.10798988; d_B_lr: 0.0002; g_AtoB_loss: 4.255815; g_AtoB_lr: 0.0002; g_BtoA_loss: 3.5183225; g_BtoA_lr: 0.0002; steps/sec: 3.84; FastEstimator-ModelSaver: Saved model to /tmp/tmpn34de7gf/g_AtoB_epoch_30.pt FastEstimator-ModelSaver: Saved model to /tmp/tmpn34de7gf/g_BtoA_epoch_30.pt FastEstimator-Train: step: 40020; epoch: 30; epoch_time: 402.99 sec; FastEstimator-Train: step: 41000; d_A_loss: 0.37121627; d_A_lr: 0.0002; d_B_loss: 0.11034091; d_B_lr: 0.0002; g_AtoB_loss: 2.4258504; g_AtoB_lr: 0.0002; g_BtoA_loss: 3.009646; g_BtoA_lr: 0.0002; steps/sec: 2.78; FastEstimator-Train: step: 41354; epoch: 31; epoch_time: 414.77 sec; FastEstimator-Train: step: 42000; d_A_loss: 0.114548594; d_A_lr: 0.0002; d_B_loss: 0.248036; d_B_lr: 0.0002; g_AtoB_loss: 5.9243593; g_AtoB_lr: 0.0002; g_BtoA_loss: 4.291023; g_BtoA_lr: 0.0002; steps/sec: 3.64; FastEstimator-Train: step: 42688; epoch: 32; epoch_time: 440.97 sec; FastEstimator-Train: step: 43000; d_A_loss: 0.08480996; d_A_lr: 0.0002; d_B_loss: 0.018437251; d_B_lr: 0.0002; g_AtoB_loss: 6.847295; g_AtoB_lr: 0.0002; g_BtoA_loss: 9.564554; g_BtoA_lr: 0.0002; steps/sec: 3.18; FastEstimator-Train: step: 44000; d_A_loss: 0.118619055; d_A_lr: 0.0002; d_B_loss: 0.13077882; d_B_lr: 0.0002; g_AtoB_loss: 2.9808536; g_AtoB_lr: 0.0002; g_BtoA_loss: 3.238474; g_BtoA_lr: 0.0002; steps/sec: 2.77; FastEstimator-Train: step: 44022; epoch: 33; epoch_time: 444.05 sec; FastEstimator-Train: step: 45000; d_A_loss: 0.16026257; d_A_lr: 0.0002; d_B_loss: 0.07697069; d_B_lr: 0.0002; g_AtoB_loss: 3.919404; g_AtoB_lr: 0.0002; g_BtoA_loss: 4.2792253; g_BtoA_lr: 0.0002; steps/sec: 3.35; FastEstimator-Train: step: 45356; epoch: 34; epoch_time: 396.14 sec; FastEstimator-Train: step: 46000; d_A_loss: 0.045636296; d_A_lr: 0.0002; d_B_loss: 0.07664771; d_B_lr: 0.0002; g_AtoB_loss: 3.9265654; g_AtoB_lr: 0.0002; g_BtoA_loss: 4.7000437; g_BtoA_lr: 0.0002; steps/sec: 3.41; FastEstimator-Train: step: 46690; epoch: 35; epoch_time: 379.74 sec; FastEstimator-Train: step: 47000; d_A_loss: 0.16862245; d_A_lr: 0.0002; d_B_loss: 0.05204377; d_B_lr: 0.0002; g_AtoB_loss: 4.0016527; g_AtoB_lr: 0.0002; g_BtoA_loss: 4.7975125; g_BtoA_lr: 0.0002; steps/sec: 4.13; FastEstimator-Train: step: 48000; d_A_loss: 0.17546852; d_A_lr: 0.0002; d_B_loss: 0.27352092; d_B_lr: 0.0002; g_AtoB_loss: 2.6668086; g_AtoB_lr: 0.0002; g_BtoA_loss: 2.9398174; g_BtoA_lr: 0.0002; steps/sec: 3.32; FastEstimator-Train: step: 48024; epoch: 36; epoch_time: 359.12 sec; FastEstimator-Train: step: 49000; d_A_loss: 0.09111771; d_A_lr: 0.0002; d_B_loss: 0.09417858; d_B_lr: 0.0002; g_AtoB_loss: 3.3456988; g_AtoB_lr: 0.0002; g_BtoA_loss: 4.4651456; g_BtoA_lr: 0.0002; steps/sec: 2.64; FastEstimator-Train: step: 49358; epoch: 37; epoch_time: 447.61 sec; FastEstimator-Train: step: 50000; d_A_loss: 0.11524596; d_A_lr: 0.0002; d_B_loss: 0.12104061; d_B_lr: 0.0002; g_AtoB_loss: 9.68929; g_AtoB_lr: 0.0002; g_BtoA_loss: 4.233226; g_BtoA_lr: 0.0002; steps/sec: 2.92; FastEstimator-Train: step: 50692; epoch: 38; epoch_time: 529.96 sec; FastEstimator-Train: step: 51000; d_A_loss: 0.105448775; d_A_lr: 0.0002; d_B_loss: 0.07446474; d_B_lr: 0.0002; g_AtoB_loss: 2.8403997; g_AtoB_lr: 0.0002; g_BtoA_loss: 4.0699873; g_BtoA_lr: 0.0002; steps/sec: 3.09; FastEstimator-Train: step: 52000; d_A_loss: 0.32795072; d_A_lr: 0.0002; d_B_loss: 0.2948647; d_B_lr: 0.0002; g_AtoB_loss: 2.2968879; g_AtoB_lr: 0.0002; g_BtoA_loss: 2.3472998; g_BtoA_lr: 0.0002; steps/sec: 2.89; FastEstimator-Train: step: 52026; epoch: 39; epoch_time: 414.88 sec; FastEstimator-Train: step: 53000; d_A_loss: 0.14258908; d_A_lr: 0.0002; d_B_loss: 0.09975105; d_B_lr: 0.0002; g_AtoB_loss: 4.417849; g_AtoB_lr: 0.0002; g_BtoA_loss: 3.8894544; g_BtoA_lr: 0.0002; steps/sec: 2.64; FastEstimator-ModelSaver: Saved model to /tmp/tmpn34de7gf/g_AtoB_epoch_40.pt FastEstimator-ModelSaver: Saved model to /tmp/tmpn34de7gf/g_BtoA_epoch_40.pt FastEstimator-Train: step: 53360; epoch: 40; epoch_time: 474.55 sec; FastEstimator-Train: step: 54000; d_A_loss: 0.40623146; d_A_lr: 0.0002; d_B_loss: 0.13360925; d_B_lr: 0.0002; g_AtoB_loss: 3.5933185; g_AtoB_lr: 0.0002; g_BtoA_loss: 8.165068; g_BtoA_lr: 0.0002; steps/sec: 2.67; FastEstimator-Train: step: 54694; epoch: 41; epoch_time: 488.97 sec; FastEstimator-Train: step: 55000; d_A_loss: 0.32001668; d_A_lr: 0.0002; d_B_loss: 0.19170062; d_B_lr: 0.0002; g_AtoB_loss: 5.8548794; g_AtoB_lr: 0.0002; g_BtoA_loss: 2.9984937; g_BtoA_lr: 0.0002; steps/sec: 2.69; FastEstimator-Train: step: 56000; d_A_loss: 0.18761483; d_A_lr: 0.0002; d_B_loss: 0.1781841; d_B_lr: 0.0002; g_AtoB_loss: 6.030718; g_AtoB_lr: 0.0002; g_BtoA_loss: 4.817857; g_BtoA_lr: 0.0002; steps/sec: 3.07; FastEstimator-Train: step: 56028; epoch: 42; epoch_time: 507.62 sec; FastEstimator-Train: step: 57000; d_A_loss: 0.21374495; d_A_lr: 0.0002; d_B_loss: 0.030690651; d_B_lr: 0.0002; g_AtoB_loss: 3.0272703; g_AtoB_lr: 0.0002; g_BtoA_loss: 3.9826207; g_BtoA_lr: 0.0002; steps/sec: 2.84; FastEstimator-Train: step: 57362; epoch: 43; epoch_time: 469.08 sec; FastEstimator-Train: step: 58000; d_A_loss: 0.015637912; d_A_lr: 0.0002; d_B_loss: 0.01957808; d_B_lr: 0.0002; g_AtoB_loss: 11.956296; g_AtoB_lr: 0.0002; g_BtoA_loss: 7.9149184; g_BtoA_lr: 0.0002; steps/sec: 2.76; FastEstimator-Train: step: 58696; epoch: 44; epoch_time: 461.95 sec; FastEstimator-Train: step: 59000; d_A_loss: 0.028864238; d_A_lr: 0.0002; d_B_loss: 0.0040443847; d_B_lr: 0.0002; g_AtoB_loss: 12.938608; g_AtoB_lr: 0.0002; g_BtoA_loss: 9.135632; g_BtoA_lr: 0.0002; steps/sec: 2.78; FastEstimator-Train: step: 60000; d_A_loss: 0.05312334; d_A_lr: 0.0002; d_B_loss: 0.037270255; d_B_lr: 0.0002; g_AtoB_loss: 7.9331784; g_AtoB_lr: 0.0002; g_BtoA_loss: 5.792024; g_BtoA_lr: 0.0002; steps/sec: 2.62; FastEstimator-Train: step: 60030; epoch: 45; epoch_time: 503.9 sec; FastEstimator-Train: step: 61000; d_A_loss: 0.089201994; d_A_lr: 0.0002; d_B_loss: 0.17658553; d_B_lr: 0.0002; g_AtoB_loss: 2.7340398; g_AtoB_lr: 0.0002; g_BtoA_loss: 4.7218227; g_BtoA_lr: 0.0002; steps/sec: 3.43; FastEstimator-Train: step: 61364; epoch: 46; epoch_time: 403.93 sec; FastEstimator-Train: step: 62000; d_A_loss: 0.16056217; d_A_lr: 0.0002; d_B_loss: 0.036287144; d_B_lr: 0.0002; g_AtoB_loss: 3.8993127; g_AtoB_lr: 0.0002; g_BtoA_loss: 4.52674; g_BtoA_lr: 0.0002; steps/sec: 3.23; FastEstimator-Train: step: 62698; epoch: 47; epoch_time: 378.64 sec; FastEstimator-Train: step: 63000; d_A_loss: 0.1105344; d_A_lr: 0.0002; d_B_loss: 0.07777613; d_B_lr: 0.0002; g_AtoB_loss: 5.8097258; g_AtoB_lr: 0.0002; g_BtoA_loss: 6.6205196; g_BtoA_lr: 0.0002; steps/sec: 3.01; FastEstimator-Train: step: 64000; d_A_loss: 0.36758858; d_A_lr: 0.0002; d_B_loss: 0.15424584; d_B_lr: 0.0002; g_AtoB_loss: 3.3357177; g_AtoB_lr: 0.0002; g_BtoA_loss: 7.3640885; g_BtoA_lr: 0.0002; steps/sec: 3.56; FastEstimator-Train: step: 64032; epoch: 48; epoch_time: 439.95 sec; FastEstimator-Train: step: 65000; d_A_loss: 0.048484586; d_A_lr: 0.0002; d_B_loss: 0.20226882; d_B_lr: 0.0002; g_AtoB_loss: 4.131755; g_AtoB_lr: 0.0002; g_BtoA_loss: 3.4538093; g_BtoA_lr: 0.0002; steps/sec: 2.99; FastEstimator-Train: step: 65366; epoch: 49; epoch_time: 406.73 sec; FastEstimator-Train: step: 66000; d_A_loss: 0.054230925; d_A_lr: 0.0002; d_B_loss: 0.15945143; d_B_lr: 0.0002; g_AtoB_loss: 3.2555401; g_AtoB_lr: 0.0002; g_BtoA_loss: 3.5932798; g_BtoA_lr: 0.0002; steps/sec: 3.46; FastEstimator-ModelSaver: Saved model to /tmp/tmpn34de7gf/g_AtoB_epoch_50.pt FastEstimator-ModelSaver: Saved model to /tmp/tmpn34de7gf/g_BtoA_epoch_50.pt FastEstimator-Train: step: 66700; epoch: 50; epoch_time: 476.56 sec; FastEstimator-Finish: step: 66700; d_A_lr: 0.0002; d_B_lr: 0.0002; g_AtoB_lr: 0.0002; g_BtoA_lr: 0.0002; total_time: 14691.95 sec;
Inferencing¶
Below are infering results of the two generators.
idx = np.random.randint(len(test_data))
data = test_data[idx][0]
result = pipeline.transform(data, mode="infer")
network = fe.Network(ops=[
ModelOp(inputs="real_A", model=g_AtoB, outputs="fake_B"),
ModelOp(inputs="real_B", model=g_BtoA, outputs="fake_A"),
])
predictions = network.transform(result, mode="infer")
horse_img = np.transpose(predictions["real_A"].numpy(), (0, 2, 3, 1))
zebra_img = np.transpose(predictions["real_B"].numpy(), (0, 2, 3, 1))
fake_zebra = np.transpose(predictions["fake_B"].numpy(), (0, 2, 3, 1))
fake_horse = np.transpose(predictions["fake_A"].numpy(), (0, 2, 3, 1))
GridDisplay([ImageDisplay(image=horse_img[0], title="Real Horse"),
ImageDisplay(image=fake_zebra[0], title="Fake Zebra")
]).show()
GridDisplay([ImageDisplay(image=zebra_img[0], title="Real Zebra"),
ImageDisplay(image=fake_horse[0], title="Fake Horse")
]).show()
Note the addition of zebra-like stripe texture on top of horses when translating from horses to zebras. When translating zebras to horses, we can observe that generator removes the stripe texture from zebras.