Anomaly Detection with Fastestimator¶
[Paper] [Notebook] [TF Implementation] [Torch Implementation]
In this notebook we will demonstrate how to do anomaly detection using one class classifier as described in Adversarially Learned One-Class Classifier for Novelty Detection. In real world, outliers or novelty class is often absent from the training dataset. Such problems can be efficiently modeled using one class classifiers. In the algorithm demonstrated below, two networks are trained to compete with each other where one network acts as a novelty detector and other enhances the inliers and distorts the outliers. We use images of digit "1" from MNIST dataset for training and images of other digits as outliers.
import tempfile
import fastestimator as fe
import numpy as np
import tensorflow as tf
from fastestimator.backend import binary_crossentropy
from fastestimator.op.numpyop import LambdaOp
from fastestimator.op.numpyop.univariate import ExpandDims, Normalize
from fastestimator.op.tensorop import TensorOp
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
from fastestimator.trace import Trace
from fastestimator.trace.io import BestModelSaver
from fastestimator.util import ImageDisplay, GridDisplay, to_number
from sklearn.metrics import auc, f1_score, roc_curve
from tensorflow.keras import layers
# Parameters
epochs=20
batch_size=128
train_steps_per_epoch=None
save_dir=tempfile.mkdtemp()
(x_train, y_train), (x_eval, y_eval) = tf.keras.datasets.mnist.load_data()
# Create Training Dataset
x_train, y_train = x_train[np.where((y_train == 1))], np.zeros(y_train[np.where((y_train == 1))].shape)
train_data = fe.dataset.NumpyDataset({"x": x_train, "y": y_train})
# Create Validation Dataset
x_eval0, y_eval0 = x_eval[np.where((y_eval == 1))], np.ones(y_eval[np.where((y_eval == 1))].shape)
x_eval1, y_eval1 = x_eval[np.where((y_eval != 1))], y_eval[np.where((y_eval != 1))]
# Ensuring outliers comprise 50% of the dataset
index = np.random.choice(x_eval1.shape[0], int(x_eval0.shape[0]), replace=False)
x_eval1, y_eval1 = x_eval1[index], np.zeros(y_eval1[index].shape)
x_eval, y_eval = np.concatenate([x_eval0, x_eval1]), np.concatenate([y_eval0, y_eval1])
eval_data = fe.dataset.NumpyDataset({"x": x_eval, "y": y_eval})
Step 1: Create Pipeline
¶
We will use the LambdaOp
to add noise to the images during training.
pipeline = fe.Pipeline(
train_data=train_data,
eval_data=eval_data,
batch_size=batch_size,
ops=[
ExpandDims(inputs="x", outputs="x"),
Normalize(inputs="x", outputs="x", mean=1.0, std=1.0, max_pixel_value=127.5),
LambdaOp(fn=lambda x: x + np.random.normal(loc=0.0, scale=0.155, size=(28, 28, 1)),
inputs="x",
outputs="x_w_noise",
mode="train")
])
We can visualize sample images from our Pipeline
using the 'get_results' method.
sample_batch = pipeline.get_results()
GridDisplay([ImageDisplay(image=sample_batch["x"][0], color_map='greys', title="Image"),
ImageDisplay(image=sample_batch["x_w_noise"][0], color_map='greys', title="Noisy Image")
]).show()
Step 2: Create Network
¶
The architecture of our model consists of an Autoencoder (ecoder-decoder) network and a Discriminator network.
[Credit: https://arxiv.org/pdf/1802.09088.pdf]
def reconstructor(input_shape=(28, 28, 1)):
model = tf.keras.Sequential()
# Encoder Block
model.add(
layers.Conv2D(32, (5, 5),
strides=(2, 2),
padding='same',
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
input_shape=input_shape))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU(0.2))
model.add(
layers.Conv2D(64, (5, 5),
strides=(2, 2),
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
padding='same'))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU(0.2))
model.add(
layers.Conv2D(128, (5, 5),
strides=(2, 2),
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
padding='same'))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU(0.2))
# Decoder Block
model.add(
layers.Conv2DTranspose(32, (5, 5),
strides=(2, 2),
output_padding=(0, 0),
padding='same',
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.02)))
model.add(layers.BatchNormalization())
model.add(layers.ReLU())
model.add(
layers.Conv2DTranspose(16, (5, 5),
strides=(2, 2),
padding='same',
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.02)))
model.add(layers.BatchNormalization())
model.add(layers.ReLU())
model.add(
layers.Conv2DTranspose(1, (5, 5),
strides=(2, 2),
padding='same',
kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.02),
activation='tanh'))
return model
def discriminator(input_shape=(28, 28, 1)):
model = tf.keras.Sequential()
model.add(
layers.Conv2D(16, (5, 5),
strides=(2, 2),
padding='same',
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
input_shape=input_shape))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU(0.2))
model.add(
layers.Conv2D(32, (5, 5),
strides=(2, 2),
padding='same',
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02)))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU(0.2))
model.add(
layers.Conv2D(64, (5, 5),
strides=(2, 2),
padding='same',
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02)))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU(0.2))
model.add(
layers.Conv2D(128, (5, 5),
strides=(2, 2),
padding='same',
kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02)))
model.add(layers.LeakyReLU(0.2))
model.add(layers.Flatten())
model.add(layers.Dense(1, activation="sigmoid"))
return model
recon_model = fe.build(model_fn=reconstructor, optimizer_fn=lambda: tf.optimizers.RMSprop(2e-4), model_name="reconstructor")
disc_model = fe.build(model_fn=discriminator,
optimizer_fn=lambda: tf.optimizers.RMSprop(1e-4),
model_name="discriminator")
2022-05-20 22:51:09.938098: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. 2022-05-20 22:51:10.451389: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1525] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38420 MB memory: -> device: 0, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:90:00.0, compute capability: 8.0
Defining Loss¶
The losses of both the networks are smilar to a standarad GAN network with the exception of the autoencoder having and additional reconstruction loss term to enforce similarity between the input and the reconstructed image.
We first define custom TensorOp
s to calculate the losses of both the networks.
class RLoss(TensorOp):
def __init__(self, alpha=0.2, inputs=None, outputs=None, mode=None):
super().__init__(inputs, outputs, mode)
self.alpha = alpha
def forward(self, data, state):
fake_score, x_fake, x = data
recon_loss = binary_crossentropy(y_true=x, y_pred=x_fake, from_logits=True)
adv_loss = binary_crossentropy(y_pred=fake_score, y_true=tf.ones_like(fake_score), from_logits=True)
return adv_loss + self.alpha * recon_loss
class DLoss(TensorOp):
def forward(self, data, state):
true_score, fake_score = data
real_loss = binary_crossentropy(y_pred=true_score, y_true=tf.ones_like(true_score), from_logits=True)
fake_loss = binary_crossentropy(y_pred=fake_score, y_true=tf.zeros_like(fake_score), from_logits=True)
total_loss = real_loss + fake_loss
return total_loss
We now define the Network
object:
network = fe.Network(ops=[
ModelOp(model=recon_model, inputs="x_w_noise", outputs="x_fake", mode="train"),
ModelOp(model=recon_model, inputs="x", outputs="x_fake", mode="eval"),
ModelOp(model=disc_model, inputs="x_fake", outputs="fake_score"),
ModelOp(model=disc_model, inputs="x", outputs="true_score"),
RLoss(inputs=("fake_score", "x_fake", "x"), outputs="rloss"),
UpdateOp(model=recon_model, loss_name="rloss"),
DLoss(inputs=("true_score", "fake_score"), outputs="dloss"),
UpdateOp(model=disc_model, loss_name="dloss")
])
2022-05-20 22:51:11.153783: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
In this example we will also use the following traces:
- BestModelSaver for saving the best model. For illustration purpose, we will save these models in a temporary directory.
- A custom trace to calculate Area Under the Curve and F1-Score.
class F1AUCScores(Trace):
"""Computes F1-Score and AUC Score for a classification task and reports it back to the logger.
"""
def __init__(self, true_key, pred_key, mode=("eval", "test"), output_name=["auc_score", "f1_score"]):
super().__init__(inputs=(true_key, pred_key), outputs=output_name, mode=mode)
self.y_true = []
self.y_pred = []
@property
def true_key(self):
return self.inputs[0]
@property
def pred_key(self):
return self.inputs[1]
def on_epoch_begin(self, data):
self.y_true = []
self.y_pred = []
def on_batch_end(self, data):
y_true, y_pred = to_number(data[self.true_key]), to_number(data[self.pred_key])
assert y_pred.size == y_true.size
self.y_pred.extend(y_pred.ravel())
self.y_true.extend(y_true.ravel())
def on_epoch_end(self, data):
fpr, tpr, thresholds = roc_curve(self.y_true, self.y_pred, pos_label=1)
roc_auc = auc(fpr, tpr)
eer_threshold = thresholds[np.nanargmin(np.absolute((1 - tpr - fpr)))]
y_pred_class = np.copy(self.y_pred)
y_pred_class[y_pred_class >= eer_threshold] = 1
y_pred_class[y_pred_class < eer_threshold] = 0
f_score = f1_score(self.y_true, y_pred_class)
data.write_with_log(self.outputs[0], roc_auc)
data.write_with_log(self.outputs[1], f_score)
traces = [
F1AUCScores(true_key="y", pred_key="fake_score", mode="eval", output_name=["auc_score", "f1_score"]),
BestModelSaver(model=recon_model, save_dir=save_dir, metric='f1_score', save_best_mode='max', load_best_final=True),
BestModelSaver(model=disc_model, save_dir=save_dir, metric='f1_score', save_best_mode='max', load_best_final=True)
]
Step 3: Create Estimator
¶
estimator = fe.Estimator(pipeline=pipeline,
network=network,
epochs=epochs,
traces=traces,
train_steps_per_epoch=train_steps_per_epoch)
Training¶
estimator.fit()
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/
/usr/local/lib/python3.8/dist-packages/tensorflow/python/util/dispatch.py:1082: UserWarning: "`binary_crossentropy` received `from_logits=True`, but the `output` argument was produced by a sigmoid or softmax activation and thus does not represent logits. Was this intended?" 2022-05-20 22:51:17.273882: I tensorflow/stream_executor/cuda/cuda_dnn.cc:368] Loaded cuDNN version 8100
FastEstimator-Warn: the key 'y' is being pruned since it is unused outside of the Pipeline. To prevent this, you can declare the key as an input of a Trace or TensorOp. FastEstimator-Start: step: 1; logging_interval: 100; num_device: 1;
2022-05-20 22:51:29.792069: I tensorflow/stream_executor/cuda/cuda_blas.cc:1786] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
FastEstimator-Train: step: 1; dloss: 1.4055667; rloss: 0.9982634; FastEstimator-Train: step: 53; epoch: 1; epoch_time: 8.74 sec; Eval Progress: 1/17; Eval Progress: 5/17; steps/sec: 155.95; Eval Progress: 11/17; steps/sec: 185.13; Eval Progress: 17/17; steps/sec: 172.52; FastEstimator-BestModelSaver: Saved model to /tmp/tmp2w1fltl6/reconstructor_best_f1_score.h5 FastEstimator-BestModelSaver: Saved model to /tmp/tmp2w1fltl6/discriminator_best_f1_score.h5 FastEstimator-Eval: step: 53; epoch: 1; auc_score: 0.4215160395117313; dloss: 1.4071065; f1_score: 0.4431718061674009; max_f1_score: 0.4431718061674009; rloss: 1.0221571; since_best_f1_score: 0; FastEstimator-Train: step: 100; dloss: 0.18782794; rloss: 3.4576418; steps/sec: 12.69; FastEstimator-Train: step: 106; epoch: 2; epoch_time: 5.69 sec; Eval Progress: 1/17; Eval Progress: 5/17; steps/sec: 158.05; Eval Progress: 11/17; steps/sec: 237.7; Eval Progress: 17/17; steps/sec: 212.97; FastEstimator-Eval: step: 106; epoch: 2; auc_score: 0.0593130082089697; dloss: 1.7725676; f1_score: 0.12945838837516513; max_f1_score: 0.4431718061674009; rloss: 1.739687; since_best_f1_score: 1; FastEstimator-Train: step: 159; epoch: 3; epoch_time: 5.57 sec; Eval Progress: 1/17; Eval Progress: 5/17; steps/sec: 203.15; Eval Progress: 11/17; steps/sec: 257.38; Eval Progress: 17/17; steps/sec: 267.37; FastEstimator-Eval: step: 159; epoch: 3; auc_score: 0.3312002949795261; dloss: 3.8447435; f1_score: 0.3762114537444934; max_f1_score: 0.4431718061674009; rloss: 3.5073256; since_best_f1_score: 2; FastEstimator-Train: step: 200; dloss: 0.0074149473; rloss: 5.2848363; steps/sec: 8.9; FastEstimator-Train: step: 212; epoch: 4; epoch_time: 5.7 sec; Eval Progress: 1/17; Eval Progress: 5/17; steps/sec: 133.68; Eval Progress: 11/17; steps/sec: 182.39; Eval Progress: 17/17; steps/sec: 172.01; FastEstimator-Eval: step: 212; epoch: 4; auc_score: 0.12731859729472725; dloss: 10.529196; f1_score: 0.20000000000000004; max_f1_score: 0.4431718061674009; rloss: 6.2233853; since_best_f1_score: 3; FastEstimator-Train: step: 265; epoch: 5; epoch_time: 5.84 sec; Eval Progress: 1/17; Eval Progress: 5/17; steps/sec: 113.41; Eval Progress: 11/17; steps/sec: 165.91; Eval Progress: 17/17; steps/sec: 268.33; FastEstimator-BestModelSaver: Saved model to /tmp/tmp2w1fltl6/reconstructor_best_f1_score.h5 FastEstimator-BestModelSaver: Saved model to /tmp/tmp2w1fltl6/discriminator_best_f1_score.h5 FastEstimator-Eval: step: 265; epoch: 5; auc_score: 0.9694777697995305; dloss: 16.567978; f1_score: 0.9053280493174812; max_f1_score: 0.9053280493174812; rloss: 9.16089; since_best_f1_score: 0; FastEstimator-Train: step: 300; dloss: 9.0785594e-05; rloss: 9.5743; steps/sec: 8.69; FastEstimator-Train: step: 318; epoch: 6; epoch_time: 5.87 sec; Eval Progress: 1/17; Eval Progress: 5/17; steps/sec: 153.15; Eval Progress: 11/17; steps/sec: 223.88; Eval Progress: 17/17; steps/sec: 217.8; FastEstimator-BestModelSaver: Saved model to /tmp/tmp2w1fltl6/reconstructor_best_f1_score.h5 FastEstimator-BestModelSaver: Saved model to /tmp/tmp2w1fltl6/discriminator_best_f1_score.h5 FastEstimator-Eval: step: 318; epoch: 6; auc_score: 0.9818909740146323; dloss: 20.46954; f1_score: 0.9263992948435434; max_f1_score: 0.9263992948435434; rloss: 10.067438; since_best_f1_score: 0; FastEstimator-Train: step: 371; epoch: 7; epoch_time: 5.85 sec; Eval Progress: 1/17; Eval Progress: 5/17; steps/sec: 202.2; Eval Progress: 11/17; steps/sec: 238.39; Eval Progress: 17/17; steps/sec: 224.66; FastEstimator-Eval: step: 371; epoch: 7; auc_score: 0.6487768052941063; dloss: 17.003485; f1_score: 0.6144366197183099; max_f1_score: 0.9263992948435434; rloss: 13.551922; since_best_f1_score: 1; FastEstimator-Train: step: 400; dloss: 2.1421674e-06; rloss: 13.66077; steps/sec: 8.66; FastEstimator-Train: step: 424; epoch: 8; epoch_time: 5.64 sec; Eval Progress: 1/17; Eval Progress: 5/17; steps/sec: 139.75; Eval Progress: 11/17; steps/sec: 249.63; Eval Progress: 17/17; steps/sec: 213.55; FastEstimator-Eval: step: 424; epoch: 8; auc_score: 0.5244898988918861; dloss: 26.341074; f1_score: 0.5312775330396475; max_f1_score: 0.9263992948435434; rloss: 14.630283; since_best_f1_score: 2; FastEstimator-Train: step: 477; epoch: 9; epoch_time: 5.71 sec; Eval Progress: 1/17; Eval Progress: 5/17; steps/sec: 118.19; Eval Progress: 11/17; steps/sec: 188.76; Eval Progress: 17/17; steps/sec: 175.64; FastEstimator-Eval: step: 477; epoch: 9; auc_score: 0.8110687185856508; dloss: 21.545362; f1_score: 0.7548500881834214; max_f1_score: 0.9263992948435434; rloss: 15.638036; since_best_f1_score: 3; FastEstimator-Train: step: 500; dloss: 4.0664025e-05; rloss: 13.955134; steps/sec: 8.85; FastEstimator-Train: step: 530; epoch: 10; epoch_time: 5.67 sec; Eval Progress: 1/17; Eval Progress: 5/17; steps/sec: 169.3; Eval Progress: 11/17; steps/sec: 207.95; Eval Progress: 17/17; steps/sec: 202.13; FastEstimator-BestModelSaver: Saved model to /tmp/tmp2w1fltl6/reconstructor_best_f1_score.h5 FastEstimator-BestModelSaver: Saved model to /tmp/tmp2w1fltl6/discriminator_best_f1_score.h5 FastEstimator-Eval: step: 530; epoch: 10; auc_score: 0.9777368083991539; dloss: 17.341671; f1_score: 0.9463028169014085; max_f1_score: 0.9463028169014085; rloss: 12.420723; since_best_f1_score: 0; FastEstimator-Train: step: 583; epoch: 11; epoch_time: 5.64 sec; Eval Progress: 1/17; Eval Progress: 5/17; steps/sec: 128.68; Eval Progress: 11/17; steps/sec: 247.49; Eval Progress: 17/17; steps/sec: 240.88; FastEstimator-Eval: step: 583; epoch: 11; auc_score: 0.9810429078771178; dloss: 12.695894; f1_score: 0.9219920669898635; max_f1_score: 0.9463028169014085; rloss: 6.138731; since_best_f1_score: 1; FastEstimator-Train: step: 600; dloss: 0.1742793; rloss: 2.5519087; steps/sec: 8.71; FastEstimator-Train: step: 636; epoch: 12; epoch_time: 5.91 sec; Eval Progress: 1/17; Eval Progress: 5/17; steps/sec: 69.28; Eval Progress: 11/17; steps/sec: 187.24; Eval Progress: 17/17; steps/sec: 212.75; FastEstimator-BestModelSaver: Saved model to /tmp/tmp2w1fltl6/reconstructor_best_f1_score.h5 FastEstimator-BestModelSaver: Saved model to /tmp/tmp2w1fltl6/discriminator_best_f1_score.h5 FastEstimator-Eval: step: 636; epoch: 12; auc_score: 0.99651613654447; dloss: 3.250877; f1_score: 0.9788546255506608; max_f1_score: 0.9788546255506608; rloss: 0.61733353; since_best_f1_score: 0; FastEstimator-Train: step: 689; epoch: 13; epoch_time: 5.54 sec; Eval Progress: 1/17; Eval Progress: 5/17; steps/sec: 170.37; Eval Progress: 11/17; steps/sec: 251.05; Eval Progress: 17/17; steps/sec: 223.27; FastEstimator-Eval: step: 689; epoch: 13; auc_score: 0.9530811775893187; dloss: 3.0862265; f1_score: 0.8974020255394101; max_f1_score: 0.9788546255506608; rloss: 2.3505096; since_best_f1_score: 1; FastEstimator-Train: step: 700; dloss: 0.09519278; rloss: 5.07388; steps/sec: 8.95; FastEstimator-Train: step: 742; epoch: 14; epoch_time: 5.68 sec; Eval Progress: 1/17; Eval Progress: 5/17; steps/sec: 189.04; Eval Progress: 11/17; steps/sec: 235.52; Eval Progress: 17/17; steps/sec: 271.66; FastEstimator-Eval: step: 742; epoch: 14; auc_score: 0.9910233072638709; dloss: 2.9711509; f1_score: 0.9616571176729837; max_f1_score: 0.9788546255506608; rloss: 0.5347926; since_best_f1_score: 2; FastEstimator-Train: step: 795; epoch: 15; epoch_time: 5.67 sec; Eval Progress: 1/17; Eval Progress: 5/17; steps/sec: 153.04; Eval Progress: 11/17; steps/sec: 247.51; Eval Progress: 17/17; steps/sec: 232.68; FastEstimator-Eval: step: 795; epoch: 15; auc_score: 0.9850724834559181; dloss: 1.6044965; f1_score: 0.9494060712714474; max_f1_score: 0.9788546255506608; rloss: 3.59912; since_best_f1_score: 3; FastEstimator-Train: step: 800; dloss: 0.30007884; rloss: 2.662114; steps/sec: 8.83; FastEstimator-Train: step: 848; epoch: 16; epoch_time: 5.9 sec; Eval Progress: 1/17; Eval Progress: 5/17; steps/sec: 72.03; Eval Progress: 11/17; steps/sec: 156.27; Eval Progress: 17/17; steps/sec: 209.58; FastEstimator-Eval: step: 848; epoch: 16; auc_score: 0.9890345242484814; dloss: 2.4346852; f1_score: 0.9379128137384413; max_f1_score: 0.9788546255506608; rloss: 1.1848502; since_best_f1_score: 4; FastEstimator-Train: step: 900; dloss: 2.2080877; rloss: 0.1052717; steps/sec: 15.36; FastEstimator-Train: step: 901; epoch: 17; epoch_time: 5.79 sec; Eval Progress: 1/17; Eval Progress: 5/17; steps/sec: 115.86; Eval Progress: 11/17; steps/sec: 181.74; Eval Progress: 17/17; steps/sec: 141.38; FastEstimator-Eval: step: 901; epoch: 17; auc_score: 0.9592586698752159; dloss: 1.7382531; f1_score: 0.8748898678414097; max_f1_score: 0.9788546255506608; rloss: 1.7491591; since_best_f1_score: 5; FastEstimator-Train: step: 954; epoch: 18; epoch_time: 5.87 sec; Eval Progress: 1/17; Eval Progress: 5/17; steps/sec: 88.9; Eval Progress: 11/17; steps/sec: 133.96; Eval Progress: 17/17; steps/sec: 209.24; FastEstimator-Eval: step: 954; epoch: 18; auc_score: 0.97444080032603; dloss: 2.4483051; f1_score: 0.9131776112825033; max_f1_score: 0.9788546255506608; rloss: 2.4951015; since_best_f1_score: 6; FastEstimator-Train: step: 1000; dloss: 1.7254086; rloss: 0.19289978; steps/sec: 8.57; FastEstimator-Train: step: 1007; epoch: 19; epoch_time: 5.89 sec; Eval Progress: 1/17; Eval Progress: 5/17; steps/sec: 115.95; Eval Progress: 11/17; steps/sec: 184.74; Eval Progress: 17/17; steps/sec: 213.9; FastEstimator-Eval: step: 1007; epoch: 19; auc_score: 0.9717246599002504; dloss: 1.33784; f1_score: 0.9040492957746478; max_f1_score: 0.9788546255506608; rloss: 0.684746; since_best_f1_score: 7; FastEstimator-Train: step: 1060; epoch: 20; epoch_time: 5.94 sec; Eval Progress: 1/17; Eval Progress: 5/17; steps/sec: 141.37; Eval Progress: 11/17; steps/sec: 212.39; Eval Progress: 17/17; steps/sec: 228.58; FastEstimator-Eval: step: 1060; epoch: 20; auc_score: 0.7852941062314424; dloss: 1.8164662; f1_score: 0.7163652404058226; max_f1_score: 0.9788546255506608; rloss: 1.6960378; since_best_f1_score: 8; FastEstimator-BestModelSaver: Restoring model from /tmp/tmp2w1fltl6/reconstructor_best_f1_score.h5 FastEstimator-BestModelSaver: Restoring model from /tmp/tmp2w1fltl6/discriminator_best_f1_score.h5 FastEstimator-Finish: step: 1060; discriminator_lr: 1e-04; reconstructor_lr: 0.0002; total_time: 212.81 sec;
Inferencing¶
Once the training is finished, we will apply the model to visualize the reconstructed image of the inliers and outliers.
idx0 = np.random.randint(len(x_eval0))
idx1 = np.random.randint(len(x_eval1))
data = [{"x": x_eval0[idx0]}, {"x": x_eval1[idx1]}]
result = [pipeline.transform(data[i], mode="infer") for i in range(len(data))]
network = fe.Network(ops=[
ModelOp(model=recon_model, inputs="x", outputs="x_fake"),
ModelOp(model=disc_model, inputs="x_fake", outputs="fake_score")
])
output_imgs = [network.transform(result[i], mode="infer") for i in range(len(result))]
base_image = output_imgs[0]["x"].numpy()
anomaly_image = output_imgs[1]["x"].numpy()
recon_base_image = output_imgs[0]["x_fake"].numpy()
recon_anomaly_image = output_imgs[1]["x_fake"].numpy()
GridDisplay([ImageDisplay(image=base_image[0], color_map='greys', title="Input Image"),
ImageDisplay(image=recon_base_image[0], color_map='greys', title="Reconstructed Image")
]).show()
GridDisplay([ImageDisplay(image=anomaly_image[0], color_map='greys', title="Input Image"),
ImageDisplay(image=recon_anomaly_image[0], color_map='greys', title="Reconstructed Image")
]).show()
Note that the network is trained on inliers, so it's able to properly reconstruct them but does a poor job at reconstructing the outliers, thereby making it easier for discriminator to detect the outliers.
Using your own dataset¶
This example assumes each sample contains a gray-scale image array such as:
{"x": np.ones([28, 28])}
If you would like to read image from disk, you can create a Dataset
that produces the file path and use ReadImage
Op in Pipeline
to read image from disk.
If your image is 3-channel, then you can remove the ExpandDims
and adjust the input shape in model definition.