DCGAN on the MNIST Dataset¶
[Paper] [Notebook] [TF Implementation] [Torch Implementation]
import tempfile
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from matplotlib import pyplot as plt
import fastestimator as fe
from fastestimator.backend import binary_crossentropy, feed_forward
from fastestimator.dataset.data import mnist
from fastestimator.op.numpyop import LambdaOp
from fastestimator.op.numpyop.univariate import ExpandDims, Normalize
from fastestimator.op.tensorop import TensorOp
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
from fastestimator.trace.io import ModelSaver
batch_size = 256
epochs = 50
train_steps_per_epoch = None
save_dir = tempfile.mkdtemp()
model_name = 'model_epoch_50.h5'
Building components
Step 1: Prepare training and define a Pipeline
¶
We are loading data from tf.keras.datasets.mnist and defining a series of operations to perform on the data before the training:
train_data, _ = mnist.load_data()
pipeline = fe.Pipeline(
train_data=train_data,
batch_size=batch_size,
ops=[
ExpandDims(inputs="x", outputs="x"),
Normalize(inputs="x", outputs="x", mean=1.0, std=1.0, max_pixel_value=127.5),
LambdaOp(fn=lambda: np.random.normal(size=[100]).astype('float32'), outputs="z")
])
Step 2: Create a model
and FastEstimator Network
¶
First, we have to define the network architecture for both our Generator and Discriminator. After defining the architecture, users are expected to feed the architecture definition, along with associated model names and optimizers, to fe.build.
def generator():
model = tf.keras.Sequential()
model.add(layers.Dense(7 * 7 * 256, use_bias=False, input_shape=(100, )))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Reshape((7, 7, 256)))
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
return model
def discriminator():
model = tf.keras.Sequential()
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1]))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Flatten())
model.add(layers.Dense(1))
return model
gen_model = fe.build(model_fn=generator, optimizer_fn=lambda: tf.optimizers.Adam(1e-4))
disc_model = fe.build(model_fn=discriminator, optimizer_fn=lambda: tf.optimizers.Adam(1e-4))
2022-05-17 22:24:12.679107: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. 2022-05-17 22:24:13.365903: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 32253 MB memory: -> device: 0, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:90:00.0, compute capability: 8.0
We define the generator and discriminator losses. These can have multiple inputs and outputs.
class GLoss(TensorOp):
"""Compute generator loss."""
def forward(self, data, state):
return binary_crossentropy(y_pred=data, y_true=tf.ones_like(data), from_logits=True)
class DLoss(TensorOp):
"""Compute discriminator loss."""
def forward(self, data, state):
true_score, fake_score = data
real_loss = binary_crossentropy(y_pred=true_score, y_true=tf.ones_like(true_score), from_logits=True)
fake_loss = binary_crossentropy(y_pred=fake_score, y_true=tf.zeros_like(fake_score), from_logits=True)
total_loss = real_loss + fake_loss
return total_loss
fe.Network
takes series of operators. Here we pass our models wrapped into ModelOps
along with our loss functions and some update rules:
network = fe.Network(ops=[
ModelOp(model=gen_model, inputs="z", outputs="x_fake"),
ModelOp(model=disc_model, inputs="x_fake", outputs="fake_score"),
GLoss(inputs="fake_score", outputs="gloss"),
UpdateOp(model=gen_model, loss_name="gloss"),
ModelOp(inputs="x", model=disc_model, outputs="true_score"),
DLoss(inputs=("true_score", "fake_score"), outputs="dloss"),
UpdateOp(model=disc_model, loss_name="dloss")
])
2022-05-17 22:24:13.980469: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
Step 3: Prepare Estimator
and configure the training loop¶
We will define an Estimator
that has four notable arguments: network, pipeline, epochs and traces. Our Network
and Pipeline
objects are passed here as an argument along with the number of epochs and a Trace
, in this case one designed to save our model every 5 epochs.
traces=ModelSaver(model=gen_model, save_dir=save_dir, frequency=5)
estimator = fe.Estimator(pipeline=pipeline,
network=network,
epochs=epochs,
traces=traces,
train_steps_per_epoch=train_steps_per_epoch)
Training
estimator.fit()
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ FastEstimator-Warn: the key 'y' is being pruned since it is unused outside of the Pipeline. To prevent this, you can declare the key as an input of a Trace or TensorOp.
2022-05-17 22:24:19.148391: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8100 2022-05-17 22:24:19.882754: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
FastEstimator-Start: step: 1; logging_interval: 100; num_device: 1; FastEstimator-Train: step: 1; dloss: 1.4074755; gloss: 0.6938505; FastEstimator-Train: step: 100; dloss: 1.071872; gloss: 0.6625423; steps/sec: 96.76; FastEstimator-Train: step: 200; dloss: 1.1390224; gloss: 0.7711286; steps/sec: 153.13; FastEstimator-Train: step: 235; epoch: 1; epoch_time: 6.54 sec; FastEstimator-Train: step: 300; dloss: 1.1058098; gloss: 0.8283564; steps/sec: 21.38; FastEstimator-Train: step: 400; dloss: 1.2330828; gloss: 0.8234701; steps/sec: 137.98; FastEstimator-Train: step: 470; epoch: 2; epoch_time: 4.88 sec; FastEstimator-Train: step: 500; dloss: 1.0860771; gloss: 0.8769846; steps/sec: 26.53; FastEstimator-Train: step: 600; dloss: 1.3000612; gloss: 0.76368; steps/sec: 139.42; FastEstimator-Train: step: 700; dloss: 1.3079492; gloss: 0.7894113; steps/sec: 136.25; FastEstimator-Train: step: 705; epoch: 3; epoch_time: 4.73 sec; FastEstimator-Train: step: 800; dloss: 0.95769334; gloss: 1.0632424; steps/sec: 24.9; FastEstimator-Train: step: 900; dloss: 1.0605512; gloss: 1.0276053; steps/sec: 118.47; FastEstimator-Train: step: 940; epoch: 4; epoch_time: 5.15 sec; FastEstimator-Train: step: 1000; dloss: 1.1515687; gloss: 1.0443377; steps/sec: 21.87; FastEstimator-Train: step: 1100; dloss: 1.312542; gloss: 0.8080373; steps/sec: 143.04; FastEstimator-ModelSaver: Saved model to /tmp/tmp7hg1f143/model_epoch_5.h5 FastEstimator-Train: step: 1175; epoch: 5; epoch_time: 5.49 sec; FastEstimator-Train: step: 1200; dloss: 1.5896232; gloss: 0.69061863; steps/sec: 24.5; FastEstimator-Train: step: 1300; dloss: 1.0532413; gloss: 1.063014; steps/sec: 139.64; FastEstimator-Train: step: 1400; dloss: 1.1432295; gloss: 0.9709608; steps/sec: 118.93; FastEstimator-Train: step: 1410; epoch: 6; epoch_time: 5.16 sec; FastEstimator-Train: step: 1500; dloss: 1.1589987; gloss: 0.88277376; steps/sec: 23.83; FastEstimator-Train: step: 1600; dloss: 1.4223424; gloss: 0.82864374; steps/sec: 136.68; FastEstimator-Train: step: 1645; epoch: 7; epoch_time: 5.19 sec; FastEstimator-Train: step: 1700; dloss: 1.0954866; gloss: 1.0265985; steps/sec: 24.66; FastEstimator-Train: step: 1800; dloss: 1.2621644; gloss: 0.8546214; steps/sec: 137.29; FastEstimator-Train: step: 1880; epoch: 8; epoch_time: 5.04 sec; FastEstimator-Train: step: 1900; dloss: 1.23395; gloss: 0.9362403; steps/sec: 24.59; FastEstimator-Train: step: 2000; dloss: 1.1891618; gloss: 0.93046916; steps/sec: 141.07; FastEstimator-Train: step: 2100; dloss: 1.1270559; gloss: 0.83571935; steps/sec: 135.63; FastEstimator-Train: step: 2115; epoch: 9; epoch_time: 5.04 sec; FastEstimator-Train: step: 2200; dloss: 1.0217198; gloss: 1.0431002; steps/sec: 24.59; FastEstimator-Train: step: 2300; dloss: 1.0043646; gloss: 0.987643; steps/sec: 134.45; FastEstimator-ModelSaver: Saved model to /tmp/tmp7hg1f143/model_epoch_10.h5 FastEstimator-Train: step: 2350; epoch: 10; epoch_time: 5.05 sec; FastEstimator-Train: step: 2400; dloss: 1.1894727; gloss: 0.957337; steps/sec: 26.01; FastEstimator-Train: step: 2500; dloss: 1.5908065; gloss: 0.67834544; steps/sec: 144.04; FastEstimator-Train: step: 2585; epoch: 11; epoch_time: 4.8 sec; FastEstimator-Train: step: 2600; dloss: 1.1317401; gloss: 1.0065825; steps/sec: 24.88; FastEstimator-Train: step: 2700; dloss: 1.0598496; gloss: 1.0172932; steps/sec: 131.43; FastEstimator-Train: step: 2800; dloss: 1.2017224; gloss: 0.9210052; steps/sec: 133.42; FastEstimator-Train: step: 2820; epoch: 12; epoch_time: 5.07 sec; FastEstimator-Train: step: 2900; dloss: 1.1818032; gloss: 0.8824554; steps/sec: 25.56; FastEstimator-Train: step: 3000; dloss: 1.0128953; gloss: 1.0633328; steps/sec: 123.48; FastEstimator-Train: step: 3055; epoch: 13; epoch_time: 4.97 sec; FastEstimator-Train: step: 3100; dloss: 1.1521642; gloss: 0.98313975; steps/sec: 24.99; FastEstimator-Train: step: 3200; dloss: 1.2207543; gloss: 1.0411373; steps/sec: 130.05; FastEstimator-Train: step: 3290; epoch: 14; epoch_time: 5.11 sec; FastEstimator-Train: step: 3300; dloss: 1.0686741; gloss: 1.04258; steps/sec: 24.05; FastEstimator-Train: step: 3400; dloss: 1.1607264; gloss: 0.9313967; steps/sec: 125.47; FastEstimator-Train: step: 3500; dloss: 1.108815; gloss: 1.0197322; steps/sec: 128.05; FastEstimator-ModelSaver: Saved model to /tmp/tmp7hg1f143/model_epoch_15.h5 FastEstimator-Train: step: 3525; epoch: 15; epoch_time: 5.18 sec; FastEstimator-Train: step: 3600; dloss: 0.84284836; gloss: 1.3197017; steps/sec: 25.89; FastEstimator-Train: step: 3700; dloss: 1.3472433; gloss: 0.9266864; steps/sec: 131.6; FastEstimator-Train: step: 3760; epoch: 16; epoch_time: 4.87 sec; FastEstimator-Train: step: 3800; dloss: 1.1090488; gloss: 1.148634; steps/sec: 24.4; FastEstimator-Train: step: 3900; dloss: 1.0149467; gloss: 1.171385; steps/sec: 134.44; FastEstimator-Train: step: 3995; epoch: 17; epoch_time: 5.12 sec; FastEstimator-Train: step: 4000; dloss: 1.0201095; gloss: 1.3736193; steps/sec: 24.89; FastEstimator-Train: step: 4100; dloss: 1.1828756; gloss: 1.1176487; steps/sec: 127.13; FastEstimator-Train: step: 4200; dloss: 0.9339646; gloss: 1.2892692; steps/sec: 129.5; FastEstimator-Train: step: 4230; epoch: 18; epoch_time: 5.09 sec; FastEstimator-Train: step: 4300; dloss: 1.031354; gloss: 1.1992577; steps/sec: 22.92; FastEstimator-Train: step: 4400; dloss: 0.59360105; gloss: 1.7209172; steps/sec: 131.32; FastEstimator-Train: step: 4465; epoch: 19; epoch_time: 5.35 sec; FastEstimator-Train: step: 4500; dloss: 0.8413017; gloss: 1.3761427; steps/sec: 22.88; FastEstimator-Train: step: 4600; dloss: 1.0306518; gloss: 1.3534794; steps/sec: 133.22; FastEstimator-Train: step: 4700; dloss: 1.1800865; gloss: 1.1182059; steps/sec: 128.35; FastEstimator-ModelSaver: Saved model to /tmp/tmp7hg1f143/model_epoch_20.h5 FastEstimator-Train: step: 4700; epoch: 20; epoch_time: 5.41 sec; FastEstimator-Train: step: 4800; dloss: 1.0419474; gloss: 1.090816; steps/sec: 24.75; FastEstimator-Train: step: 4900; dloss: 0.9564374; gloss: 1.2522755; steps/sec: 136.4; FastEstimator-Train: step: 4935; epoch: 21; epoch_time: 5.05 sec; FastEstimator-Train: step: 5000; dloss: 1.1178182; gloss: 1.0265534; steps/sec: 26.88; FastEstimator-Train: step: 5100; dloss: 0.9095393; gloss: 1.2775141; steps/sec: 139.1; FastEstimator-Train: step: 5170; epoch: 22; epoch_time: 4.95 sec; FastEstimator-Train: step: 5200; dloss: 1.1132319; gloss: 1.1454469; steps/sec: 20.47; FastEstimator-Train: step: 5300; dloss: 1.1913971; gloss: 1.2246889; steps/sec: 73.91; FastEstimator-Train: step: 5400; dloss: 0.94440126; gloss: 1.2502146; steps/sec: 84.12; FastEstimator-Train: step: 5405; epoch: 23; epoch_time: 6.99 sec; FastEstimator-Train: step: 5500; dloss: 1.0678749; gloss: 1.2804728; steps/sec: 17.85; FastEstimator-Train: step: 5600; dloss: 1.0255647; gloss: 1.2616677; steps/sec: 80.99; FastEstimator-Train: step: 5640; epoch: 24; epoch_time: 7.2 sec; FastEstimator-Train: step: 5700; dloss: 0.9666496; gloss: 1.2296994; steps/sec: 19.36; FastEstimator-Train: step: 5800; dloss: 1.116164; gloss: 0.96720326; steps/sec: 87.0; FastEstimator-ModelSaver: Saved model to /tmp/tmp7hg1f143/model_epoch_25.h5 FastEstimator-Train: step: 5875; epoch: 25; epoch_time: 6.74 sec; FastEstimator-Train: step: 5900; dloss: 1.034753; gloss: 1.1733758; steps/sec: 19.73; FastEstimator-Train: step: 6000; dloss: 1.0251997; gloss: 1.2077578; steps/sec: 80.12; FastEstimator-Train: step: 6100; dloss: 0.9085286; gloss: 1.4258803; steps/sec: 82.07; FastEstimator-Train: step: 6110; epoch: 26; epoch_time: 6.79 sec; FastEstimator-Train: step: 6200; dloss: 0.91629565; gloss: 1.2635107; steps/sec: 18.95; FastEstimator-Train: step: 6300; dloss: 0.9700071; gloss: 1.2341801; steps/sec: 81.38; FastEstimator-Train: step: 6345; epoch: 27; epoch_time: 6.91 sec; FastEstimator-Train: step: 6400; dloss: 1.0925424; gloss: 1.1383421; steps/sec: 19.67; FastEstimator-Train: step: 6500; dloss: 0.886595; gloss: 1.2264949; steps/sec: 88.04; FastEstimator-Train: step: 6580; epoch: 28; epoch_time: 6.63 sec; FastEstimator-Train: step: 6600; dloss: 0.9399621; gloss: 1.2295549; steps/sec: 18.91; FastEstimator-Train: step: 6700; dloss: 1.1239285; gloss: 1.0375732; steps/sec: 85.2; FastEstimator-Train: step: 6800; dloss: 1.115287; gloss: 1.074056; steps/sec: 87.68; FastEstimator-Train: step: 6815; epoch: 29; epoch_time: 6.77 sec; FastEstimator-Train: step: 6900; dloss: 0.99310243; gloss: 1.0726111; steps/sec: 19.46; FastEstimator-Train: step: 7000; dloss: 1.3016744; gloss: 1.1302423; steps/sec: 80.95; FastEstimator-ModelSaver: Saved model to /tmp/tmp7hg1f143/model_epoch_30.h5 FastEstimator-Train: step: 7050; epoch: 30; epoch_time: 6.59 sec; FastEstimator-Train: step: 7100; dloss: 1.1211019; gloss: 1.0915082; steps/sec: 22.69; FastEstimator-Train: step: 7200; dloss: 1.0433137; gloss: 1.3584124; steps/sec: 74.73; FastEstimator-Train: step: 7285; epoch: 31; epoch_time: 6.6 sec; FastEstimator-Train: step: 7300; dloss: 1.3104564; gloss: 1.133613; steps/sec: 18.21; FastEstimator-Train: step: 7400; dloss: 1.195024; gloss: 1.1311966; steps/sec: 82.59; FastEstimator-Train: step: 7500; dloss: 1.1094584; gloss: 1.1424885; steps/sec: 84.76; FastEstimator-Train: step: 7520; epoch: 32; epoch_time: 6.99 sec; FastEstimator-Train: step: 7600; dloss: 0.9371455; gloss: 1.1962738; steps/sec: 18.9; FastEstimator-Train: step: 7700; dloss: 1.039571; gloss: 1.158539; steps/sec: 84.59; FastEstimator-Train: step: 7755; epoch: 33; epoch_time: 6.9 sec; FastEstimator-Train: step: 7800; dloss: 1.2328132; gloss: 1.0920696; steps/sec: 18.49; FastEstimator-Train: step: 7900; dloss: 1.263283; gloss: 1.057182; steps/sec: 75.0; FastEstimator-Train: step: 7990; epoch: 34; epoch_time: 7.16 sec; FastEstimator-Train: step: 8000; dloss: 1.1744511; gloss: 1.0851499; steps/sec: 18.84; FastEstimator-Train: step: 8100; dloss: 1.2414798; gloss: 0.9666091; steps/sec: 68.94; FastEstimator-Train: step: 8200; dloss: 1.101659; gloss: 1.0953238; steps/sec: 84.85; FastEstimator-ModelSaver: Saved model to /tmp/tmp7hg1f143/model_epoch_35.h5 FastEstimator-Train: step: 8225; epoch: 35; epoch_time: 7.23 sec; FastEstimator-Train: step: 8300; dloss: 1.1480453; gloss: 1.0929885; steps/sec: 18.74; FastEstimator-Train: step: 8400; dloss: 1.0150359; gloss: 1.2104923; steps/sec: 75.67; FastEstimator-Train: step: 8460; epoch: 36; epoch_time: 6.93 sec; FastEstimator-Train: step: 8500; dloss: 1.2448437; gloss: 0.95777893; steps/sec: 19.26; FastEstimator-Train: step: 8600; dloss: 1.128643; gloss: 0.9374757; steps/sec: 84.76; FastEstimator-Train: step: 8695; epoch: 37; epoch_time: 6.76 sec; FastEstimator-Train: step: 8700; dloss: 1.263104; gloss: 1.0629486; steps/sec: 18.54; FastEstimator-Train: step: 8800; dloss: 1.0654749; gloss: 1.1919184; steps/sec: 81.95; FastEstimator-Train: step: 8900; dloss: 1.0990744; gloss: 1.0925766; steps/sec: 80.36; FastEstimator-Train: step: 8930; epoch: 38; epoch_time: 7.05 sec; FastEstimator-Train: step: 9000; dloss: 1.3174253; gloss: 0.9213707; steps/sec: 21.57; FastEstimator-Train: step: 9100; dloss: 1.3086524; gloss: 0.89782405; steps/sec: 87.45; FastEstimator-Train: step: 9165; epoch: 39; epoch_time: 6.25 sec; FastEstimator-Train: step: 9200; dloss: 1.2255262; gloss: 0.9562658; steps/sec: 18.78; FastEstimator-Train: step: 9300; dloss: 1.2224264; gloss: 1.0891576; steps/sec: 80.81; FastEstimator-Train: step: 9400; dloss: 1.2472556; gloss: 1.0490451; steps/sec: 76.39; FastEstimator-ModelSaver: Saved model to /tmp/tmp7hg1f143/model_epoch_40.h5 FastEstimator-Train: step: 9400; epoch: 40; epoch_time: 7.04 sec; FastEstimator-Train: step: 9500; dloss: 1.2820866; gloss: 0.8918587; steps/sec: 19.16; FastEstimator-Train: step: 9600; dloss: 1.266947; gloss: 0.9083342; steps/sec: 80.53; FastEstimator-Train: step: 9635; epoch: 41; epoch_time: 6.89 sec; FastEstimator-Train: step: 9700; dloss: 1.2590028; gloss: 0.91593623; steps/sec: 18.61; FastEstimator-Train: step: 9800; dloss: 1.2626295; gloss: 0.8595179; steps/sec: 85.58; FastEstimator-Train: step: 9870; epoch: 42; epoch_time: 6.98 sec; FastEstimator-Train: step: 9900; dloss: 1.2414654; gloss: 0.9379455; steps/sec: 19.63; FastEstimator-Train: step: 10000; dloss: 1.209197; gloss: 0.8486848; steps/sec: 80.11; FastEstimator-Train: step: 10100; dloss: 1.221647; gloss: 1.0560616; steps/sec: 82.71; FastEstimator-Train: step: 10105; epoch: 43; epoch_time: 6.79 sec; FastEstimator-Train: step: 10200; dloss: 1.2570242; gloss: 0.9568354; steps/sec: 18.62; FastEstimator-Train: step: 10300; dloss: 1.1466038; gloss: 1.1030976; steps/sec: 86.59; FastEstimator-Train: step: 10340; epoch: 44; epoch_time: 6.91 sec; FastEstimator-Train: step: 10400; dloss: 1.3359842; gloss: 0.815323; steps/sec: 18.28; FastEstimator-Train: step: 10500; dloss: 1.436327; gloss: 0.82999104; steps/sec: 79.86; FastEstimator-ModelSaver: Saved model to /tmp/tmp7hg1f143/model_epoch_45.h5 FastEstimator-Train: step: 10575; epoch: 45; epoch_time: 7.27 sec; FastEstimator-Train: step: 10600; dloss: 1.3078375; gloss: 0.9893749; steps/sec: 17.96; FastEstimator-Train: step: 10700; dloss: 1.50738; gloss: 0.88688844; steps/sec: 79.93; FastEstimator-Train: step: 10800; dloss: 1.2391019; gloss: 0.94254756; steps/sec: 82.19; FastEstimator-Train: step: 10810; epoch: 46; epoch_time: 7.2 sec; FastEstimator-Train: step: 10900; dloss: 1.2033031; gloss: 0.9017648; steps/sec: 20.95; FastEstimator-Train: step: 11000; dloss: 1.1743715; gloss: 0.9873478; steps/sec: 80.95; FastEstimator-Train: step: 11045; epoch: 47; epoch_time: 6.37 sec; FastEstimator-Train: step: 11100; dloss: 1.2350852; gloss: 0.8709105; steps/sec: 19.28; FastEstimator-Train: step: 11200; dloss: 1.3278979; gloss: 0.8681942; steps/sec: 81.55; FastEstimator-Train: step: 11280; epoch: 48; epoch_time: 7.01 sec; FastEstimator-Train: step: 11300; dloss: 1.2809682; gloss: 0.87823415; steps/sec: 18.17; FastEstimator-Train: step: 11400; dloss: 1.3148129; gloss: 0.9447537; steps/sec: 80.3; FastEstimator-Train: step: 11500; dloss: 1.2640177; gloss: 0.9070884; steps/sec: 82.19; FastEstimator-Train: step: 11515; epoch: 49; epoch_time: 7.03 sec; FastEstimator-Train: step: 11600; dloss: 1.209223; gloss: 0.99375534; steps/sec: 20.29; FastEstimator-Train: step: 11700; dloss: 1.1119881; gloss: 0.991818; steps/sec: 83.43; FastEstimator-ModelSaver: Saved model to /tmp/tmp7hg1f143/model_epoch_50.h5 FastEstimator-Train: step: 11750; epoch: 50; epoch_time: 6.52 sec; FastEstimator-Finish: step: 11750; model1_lr: 1e-04; model_lr: 1e-04; total_time: 307.39 sec;
Inferencing
For inferencing, first we have to load the trained model weights. We will load the trained generator weights using fe.build
model_path = os.path.join(save_dir, model_name)
trained_model = fe.build(model_fn=generator, weights_path=model_path, optimizer_fn=lambda: tf.optimizers.Adam(1e-4))
We will the generate some images from random noise:
images = feed_forward(trained_model, np.random.normal(size=(16, 100)), training=False)
fe.util.BatchDisplay(image=images).show()