We'll start by getting the imports out of the way:
import tempfile
import os
import fastestimator as fe
from fastestimator.architecture.tensorflow import LeNet
from fastestimator.backend import squeeze
from fastestimator.dataset.data import cifair10
from fastestimator.op.numpyop.meta import Sometimes
from fastestimator.op.numpyop.multivariate import HorizontalFlip, PadIfNeeded, RandomCrop
from fastestimator.op.numpyop.univariate import CoarseDropout, Normalize
from fastestimator.op.tensorop.loss import CrossEntropy
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
from fastestimator.schedule import RepeatScheduler
from fastestimator.trace.io import BestModelSaver, ImageViewer
from fastestimator.trace.metric import MCC
from fastestimator.trace.xai import Saliency
from fastestimator.util import to_number, BatchDisplay, GridDisplay
import numpy as np
label_mapping = {
'airplane': 0,
'automobile': 1,
'bird': 2,
'cat': 3,
'deer': 4,
'dog': 5,
'frog': 6,
'horse': 7,
'ship': 8,
'truck': 9
}
Saliency Maps¶
Suppose you have a neural network that is performing image classification. The network tells you that the image it is looking at is an airplane, but you want to know whether it is really detecting an airplane, or if it is 'cheating' by noticing the blue sky in the image background. To answer this question, all you need to do is add the Saliency
Trace
to your list of traces, and pass its output to one of either the ImageSaver
, ImageViewer
, or TensorBoard
Traces
.
batch_size=128
save_dir = tempfile.mkdtemp()
train_data, eval_data = cifair10.load_data()
test_data = eval_data.split(range(len(eval_data) // 2))
pipeline = fe.Pipeline(
train_data=train_data,
eval_data=eval_data,
test_data=test_data,
batch_size=batch_size,
ops=[Normalize(inputs="x", outputs="x", mean=(0.4914, 0.4822, 0.4465), std=(0.2471, 0.2435, 0.2616)),
PadIfNeeded(min_height=40, min_width=40, image_in="x", image_out="x", mode="train"),
RandomCrop(32, 32, image_in="x", image_out="x", mode="train"),
Sometimes(HorizontalFlip(image_in="x", image_out="x", mode="train")),
CoarseDropout(inputs="x", outputs="x", mode="train", max_holes=1),
],
num_process=0)
model = fe.build(model_fn=lambda: LeNet(input_shape=(32, 32, 3)), optimizer_fn="adam")
network = fe.Network(ops=[
ModelOp(model=model, inputs="x", outputs="y_pred"),
CrossEntropy(inputs=("y_pred", "y"), outputs="ce"),
UpdateOp(model=model, loss_name="ce")
])
traces = [
MCC(true_key="y", pred_key="y_pred"),
BestModelSaver(model=model, save_dir=save_dir, metric="mcc", save_best_mode="max", load_best_final=True),
RepeatScheduler([Saliency(model=model,
model_inputs="x",
class_key="y",
model_outputs="y_pred",
samples=5,
label_mapping=label_mapping),
None, None, None, None]), # Only compute Saliency every 5 epochs for cleaner logs
RepeatScheduler([ImageViewer(inputs="saliency"), None, None, None, None]) # Only display the images every 5 epochs for cleaner logs
]
estimator = fe.Estimator(pipeline=pipeline,
network=network,
epochs=21,
traces=traces,
log_steps=300)
In this example we will be using the ImageViewer
Trace
, since it will allow us to visualize the outputs within this Notebook. If you wanted your images to appear in TensorBoard, simply construct a TensorBoard
Trace
with the "write_images" argument set to "saliency".
estimator.fit()
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ FastEstimator-Start: step: 1; logging_interval: 300; num_device: 0; FastEstimator-Train: step: 1; ce: 2.388823; FastEstimator-Train: step: 300; ce: 1.3864682; steps/sec: 14.24; FastEstimator-Train: step: 391; epoch: 1; epoch_time: 30.73 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 38.24; Eval Progress: 26/39; steps/sec: 36.03; Eval Progress: 39/39; steps/sec: 33.58; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmppqk4i0uq/model_best_mcc.h5
FastEstimator-Eval: step: 391; epoch: 1; ce: 1.3359833; max_mcc: 0.4555820671855027; mcc: 0.4555820671855027; since_best_mcc: 0; FastEstimator-Train: step: 600; ce: 1.3321579; steps/sec: 13.15; FastEstimator-Train: step: 782; epoch: 2; epoch_time: 27.53 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 37.91; Eval Progress: 26/39; steps/sec: 34.6; Eval Progress: 39/39; steps/sec: 33.72; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmppqk4i0uq/model_best_mcc.h5 FastEstimator-Eval: step: 782; epoch: 2; ce: 1.1216551; max_mcc: 0.549275060620413; mcc: 0.549275060620413; since_best_mcc: 0; FastEstimator-Train: step: 900; ce: 1.1421483; steps/sec: 14.22; FastEstimator-Train: step: 1173; epoch: 3; epoch_time: 27.34 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 38.63; Eval Progress: 26/39; steps/sec: 35.39; Eval Progress: 39/39; steps/sec: 34.63; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmppqk4i0uq/model_best_mcc.h5 FastEstimator-Eval: step: 1173; epoch: 3; ce: 1.0429373; max_mcc: 0.5941480962039674; mcc: 0.5941480962039674; since_best_mcc: 0; FastEstimator-Train: step: 1200; ce: 1.0719717; steps/sec: 14.11; FastEstimator-Train: step: 1500; ce: 1.1470697; steps/sec: 11.58; FastEstimator-Train: step: 1564; epoch: 4; epoch_time: 35.64 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 21.38; Eval Progress: 26/39; steps/sec: 19.8; Eval Progress: 39/39; steps/sec: 21.4; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmppqk4i0uq/model_best_mcc.h5 FastEstimator-Eval: step: 1564; epoch: 4; ce: 0.94491786; max_mcc: 0.6282373140880735; mcc: 0.6282373140880735; since_best_mcc: 0; FastEstimator-Train: step: 1800; ce: 1.0663648; steps/sec: 8.0; FastEstimator-Train: step: 1955; epoch: 5; epoch_time: 49.97 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 22.19; Eval Progress: 26/39; steps/sec: 19.47; Eval Progress: 39/39; steps/sec: 20.88; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmppqk4i0uq/model_best_mcc.h5 FastEstimator-Eval: step: 1955; epoch: 5; ce: 0.93076646; max_mcc: 0.6366079684298591; mcc: 0.6366079684298591; since_best_mcc: 0; FastEstimator-Train: step: 2100; ce: 0.9752668; steps/sec: 7.75; FastEstimator-Train: step: 2346; epoch: 6; epoch_time: 45.59 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 19.19; Eval Progress: 26/39; steps/sec: 20.76; Eval Progress: 39/39; steps/sec: 18.41;
FastEstimator-Eval: step: 2346; epoch: 6; ce: 0.94940007; max_mcc: 0.6366079684298591; mcc: 0.6282007111891783; since_best_mcc: 1; FastEstimator-Train: step: 2400; ce: 0.8638771; steps/sec: 9.27; FastEstimator-Train: step: 2700; ce: 0.9501883; steps/sec: 8.57; FastEstimator-Train: step: 2737; epoch: 7; epoch_time: 44.78 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 21.97; Eval Progress: 26/39; steps/sec: 23.74; Eval Progress: 39/39; steps/sec: 24.25; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmppqk4i0uq/model_best_mcc.h5 FastEstimator-Eval: step: 2737; epoch: 7; ce: 0.86018026; max_mcc: 0.6611946478733883; mcc: 0.6611946478733883; since_best_mcc: 0; FastEstimator-Train: step: 3000; ce: 1.0164135; steps/sec: 9.76; FastEstimator-Train: step: 3128; epoch: 8; epoch_time: 37.88 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 25.69; Eval Progress: 26/39; steps/sec: 25.81; Eval Progress: 39/39; steps/sec: 25.49; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmppqk4i0uq/model_best_mcc.h5 FastEstimator-Eval: step: 3128; epoch: 8; ce: 0.82946527; max_mcc: 0.6749589983887951; mcc: 0.6749589983887951; since_best_mcc: 0; FastEstimator-Train: step: 3300; ce: 1.04843; steps/sec: 11.26; FastEstimator-Train: step: 3519; epoch: 9; epoch_time: 37.24 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 22.83; Eval Progress: 26/39; steps/sec: 23.87; Eval Progress: 39/39; steps/sec: 24.11; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmppqk4i0uq/model_best_mcc.h5 FastEstimator-Eval: step: 3519; epoch: 9; ce: 0.8264248; max_mcc: 0.6798102916434279; mcc: 0.6798102916434279; since_best_mcc: 0; FastEstimator-Train: step: 3600; ce: 0.9739647; steps/sec: 10.03; FastEstimator-Train: step: 3900; ce: 0.8769522; steps/sec: 10.1; FastEstimator-Train: step: 3910; epoch: 10; epoch_time: 38.44 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 24.93; Eval Progress: 26/39; steps/sec: 24.91; Eval Progress: 39/39; steps/sec: 25.16; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmppqk4i0uq/model_best_mcc.h5 FastEstimator-Eval: step: 3910; epoch: 10; ce: 0.7981984; max_mcc: 0.6902392572292428; mcc: 0.6902392572292428; since_best_mcc: 0; FastEstimator-Train: step: 4200; ce: 0.918536; steps/sec: 10.6; FastEstimator-Train: step: 4301; epoch: 11; epoch_time: 36.43 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 24.89; Eval Progress: 26/39; steps/sec: 25.18; Eval Progress: 39/39; steps/sec: 26.18; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmppqk4i0uq/model_best_mcc.h5
FastEstimator-Eval: step: 4301; epoch: 11; ce: 0.76242864; max_mcc: 0.6969021893369869; mcc: 0.6969021893369869; since_best_mcc: 0; FastEstimator-Train: step: 4500; ce: 0.8292513; steps/sec: 12.17; FastEstimator-Train: step: 4692; epoch: 12; epoch_time: 31.82 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 27.31; Eval Progress: 26/39; steps/sec: 27.0; Eval Progress: 39/39; steps/sec: 27.18; FastEstimator-Eval: step: 4692; epoch: 12; ce: 0.7939056; max_mcc: 0.6969021893369869; mcc: 0.6816917188407557; since_best_mcc: 1; FastEstimator-Train: step: 4800; ce: 0.932951; steps/sec: 12.07; FastEstimator-Train: step: 5083; epoch: 13; epoch_time: 34.38 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 27.32; Eval Progress: 26/39; steps/sec: 27.66; Eval Progress: 39/39; steps/sec: 28.06; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmppqk4i0uq/model_best_mcc.h5 FastEstimator-Eval: step: 5083; epoch: 13; ce: 0.7424109; max_mcc: 0.7060542503045709; mcc: 0.7060542503045709; since_best_mcc: 0; FastEstimator-Train: step: 5100; ce: 0.72088; steps/sec: 10.88; FastEstimator-Train: step: 5400; ce: 0.91354585; steps/sec: 10.09; FastEstimator-Train: step: 5474; epoch: 14; epoch_time: 39.62 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 20.92; Eval Progress: 26/39; steps/sec: 24.26; Eval Progress: 39/39; steps/sec: 21.41; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmppqk4i0uq/model_best_mcc.h5 FastEstimator-Eval: step: 5474; epoch: 14; ce: 0.7411622; max_mcc: 0.7093458570160348; mcc: 0.7093458570160348; since_best_mcc: 0; FastEstimator-Train: step: 5700; ce: 0.8136387; steps/sec: 9.16; FastEstimator-Train: step: 5865; epoch: 15; epoch_time: 42.75 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 19.89; Eval Progress: 26/39; steps/sec: 21.78; Eval Progress: 39/39; steps/sec: 21.96; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmppqk4i0uq/model_best_mcc.h5 FastEstimator-Eval: step: 5865; epoch: 15; ce: 0.7313029; max_mcc: 0.7149928925831751; mcc: 0.7149928925831751; since_best_mcc: 0; FastEstimator-Train: step: 6000; ce: 0.90510035; steps/sec: 8.93; FastEstimator-Train: step: 6256; epoch: 16; epoch_time: 44.4 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 25.51; Eval Progress: 26/39; steps/sec: 25.13; Eval Progress: 39/39; steps/sec: 25.51;
FastEstimator-Eval: step: 6256; epoch: 16; ce: 0.73286784; max_mcc: 0.7149928925831751; mcc: 0.7099151095875645; since_best_mcc: 1; FastEstimator-Train: step: 6300; ce: 0.67238724; steps/sec: 9.33; FastEstimator-Train: step: 6600; ce: 0.8542966; steps/sec: 11.31; FastEstimator-Train: step: 6647; epoch: 17; epoch_time: 33.8 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 26.94; Eval Progress: 26/39; steps/sec: 26.69; Eval Progress: 39/39; steps/sec: 27.06; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmppqk4i0uq/model_best_mcc.h5 FastEstimator-Eval: step: 6647; epoch: 17; ce: 0.7084915; max_mcc: 0.7236824568143699; mcc: 0.7236824568143699; since_best_mcc: 0; FastEstimator-Train: step: 6900; ce: 0.6791389; steps/sec: 11.57; FastEstimator-Train: step: 7038; epoch: 18; epoch_time: 34.54 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 25.54; Eval Progress: 26/39; steps/sec: 26.45; Eval Progress: 39/39; steps/sec: 26.78; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmppqk4i0uq/model_best_mcc.h5 FastEstimator-Eval: step: 7038; epoch: 18; ce: 0.69131714; max_mcc: 0.7360980191985622; mcc: 0.7360980191985622; since_best_mcc: 0; FastEstimator-Train: step: 7200; ce: 0.9284673; steps/sec: 11.07; FastEstimator-Train: step: 7429; epoch: 19; epoch_time: 37.6 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 23.2; Eval Progress: 26/39; steps/sec: 23.58; Eval Progress: 39/39; steps/sec: 25.82; FastEstimator-Eval: step: 7429; epoch: 19; ce: 0.71362627; max_mcc: 0.7360980191985622; mcc: 0.7194020040317212; since_best_mcc: 1; FastEstimator-Train: step: 7500; ce: 0.72266275; steps/sec: 10.25; FastEstimator-Train: step: 7800; ce: 0.8315146; steps/sec: 9.93; FastEstimator-Train: step: 7820; epoch: 20; epoch_time: 38.39 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 23.6; Eval Progress: 26/39; steps/sec: 23.85; Eval Progress: 39/39; steps/sec: 21.65; FastEstimator-Eval: step: 7820; epoch: 20; ce: 0.71525294; max_mcc: 0.7360980191985622; mcc: 0.7239562712561208; since_best_mcc: 2; FastEstimator-Train: step: 8100; ce: 0.70906603; steps/sec: 9.13; FastEstimator-Train: step: 8211; epoch: 21; epoch_time: 40.53 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 26.01; Eval Progress: 26/39; steps/sec: 27.51; Eval Progress: 39/39; steps/sec: 28.27;
FastEstimator-Eval: step: 8211; epoch: 21; ce: 0.7312853; max_mcc: 0.7360980191985622; mcc: 0.7147634884721064; since_best_mcc: 3; FastEstimator-BestModelSaver: Restoring model from /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmppqk4i0uq/model_best_mcc.h5 FastEstimator-Finish: step: 8211; model_lr: 0.001; total_time: 910.36 sec;
estimator.test()
FastEstimator-Test: step: 8211; epoch: 21; ce: 0.7301065; mcc: 0.7327857485184789;
In the images above, the 'saliency' column corresponds to a raw saliency mask generated by back-propagating a model's output prediction onto the input image. 'Smoothed saliency' combines multiple saliency masks for each image 'x', where each mask is generated by slightly perturbing the input 'x' before running the forward and backward gradient passes. The number of samples to be combined is controlled by the "smoothing" argument in the Saliency
Trace
constructor. 'Integrated saliency' is a saliency mask generated by starting from a baseline noise image and linearly interpolating the image towards 'x' over a number of steps defined by the "integrating" argument in the Saliency constructor. The resulting masks are then combined together. The 'SmInt Saliency' (Smoothed-Integrated) column combines smoothing and integration together. SmInt is generally considered to give the most reliable indication of the important features in an image, but it also takes the longest to compute. It is possible to disable the more complex columns by setting the 'smoothing' and 'integrating' parameters to 0. The 'x saliency' column shows the input image overlaid with whatever saliency column is furthest to the right (SmInt, unless that has been disabled).
Saliency Maps without Traces¶
Suppose that you want to generate Saliency masks without using a Trace
. This can be done through the fe.xai package:
pipeline.batch_size = 6
batch = pipeline.get_results(num_steps=2, mode='eval')[1] # Get some images we didn't see already above
batch = fe.backend.to_tensor(batch, "tf") # Convert the batch to TensorFlow
saliency_generator = fe.xai.SaliencyNet(model=model, model_inputs="x", model_outputs="y_pred")
images = saliency_generator.get_smoothed_masks(batch=batch, nsamples=6, nintegration=100)
# Let's convert 'y' and 'y_pred' from numeric values to strings for readability:
val_to_label = {val: key for key, val in label_mapping.items()}
y = np.array([val_to_label[clazz] for clazz in to_number(squeeze(batch["y"]))])
y_pred = np.array([val_to_label[clazz] for clazz in to_number(squeeze(images["y_pred"]))])
# Now simply load up a Display object and let it handle laying out the final result for you
save_dir = tempfile.mkdtemp()
fig = GridDisplay([BatchDisplay(text=y, title='y'),
BatchDisplay(text=y_pred, title='y_pred'),
BatchDisplay(image=batch['x'], title='x'),
BatchDisplay(image=images['saliency'], title='saliency', color_map='inferno')
])
fig.show(save_path=os.path.join(save_dir, "t08a_saliency.png")) # save_path is optional, but a useful feature to know about
fig.show()
Saving to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpu8vwrtta/t08a_saliency.png
The SaliencyNet
class also provides 'get_masks' and 'get_integrated_masks' methods for generating other versions of saliency masks. For a detailed overview of various saliency algorithms and their benefits / drawbacks, see https://distill.pub/2020/attribution-baselines/