We'll start by getting the imports out of the way:
import tempfile
import tensorflow as tf
import fastestimator as fe
from fastestimator.architecture.tensorflow import LeNet
from fastestimator.backend import reduce_mean
from fastestimator.dataset.data import cifair10
from fastestimator.op.numpyop.meta import Sometimes
from fastestimator.op.numpyop.multivariate import HorizontalFlip, PadIfNeeded, RandomCrop
from fastestimator.op.numpyop.univariate import CoarseDropout, Normalize
from fastestimator.op.tensorop.loss import CrossEntropy
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
from fastestimator.op.tensorop import LambdaOp
from fastestimator.trace.io import BestModelSaver, ImageViewer
from fastestimator.trace.metric import MCC
from fastestimator.trace.xai import InstanceTracker
from fastestimator.util import to_number, GridDisplay, ImageDisplay
import numpy as np
label_mapping = {
'airplane': 0,
'automobile': 1,
'bird': 2,
'cat': 3,
'deer': 4,
'dog': 5,
'frog': 6,
'horse': 7,
'ship': 8,
'truck': 9
}
Instance TrackingĀ¶
Suppose you are doing some training, and you want to know which samples from your dataset are the most difficult to learn. Perhaps they were mislabeled, for example. Let's suppose you're also very curious about how well sample 10 and sample 18 from your training data do over time. One way to investigate this is with the InstanceTracker
Trace
. It takes as input any per-element metric (such as sample-wise loss), as well as an index vector and produces a visualization at the end of training:
batch_size=128
save_dir = tempfile.mkdtemp()
train_data, eval_data = cifair10.load_data()
test_data = eval_data.split(range(len(eval_data) // 2))
train_data['index'] = np.array([i for i in range(len(train_data))], dtype=np.int32).reshape((len(train_data), 1))
pipeline = fe.Pipeline(
train_data=train_data,
eval_data=eval_data,
test_data=test_data,
batch_size=batch_size,
ops=[Normalize(inputs="x", outputs="x", mean=(0.4914, 0.4822, 0.4465), std=(0.2471, 0.2435, 0.2616)),
PadIfNeeded(min_height=40, min_width=40, image_in="x", image_out="x", mode="train"),
RandomCrop(32, 32, image_in="x", image_out="x", mode="train"),
Sometimes(HorizontalFlip(image_in="x", image_out="x", mode="train")),
CoarseDropout(inputs="x", outputs="x", mode="train", max_holes=1),
],
num_process=0)
model = fe.build(model_fn=lambda: LeNet(input_shape=(32, 32, 3)), optimizer_fn="adam")
network = fe.Network(ops=[
ModelOp(model=model, inputs="x", outputs="y_pred"),
CrossEntropy(inputs=("y_pred", "y"), outputs="sample_ce", average_loss=False),
LambdaOp(inputs="sample_ce", outputs="ce", fn=lambda x: reduce_mean(x)),
UpdateOp(model=model, loss_name="ce")
])
traces = [
MCC(true_key="y", pred_key="y_pred"),
BestModelSaver(model=model, save_dir=save_dir, metric="mcc", save_best_mode="max", load_best_final=True),
InstanceTracker(index="index", metric="sample_ce", n_max_to_keep=4, n_min_to_keep=0, list_to_keep=[10, 18, 10380], outputs="ce_vs_idx", mode="train"),
ImageViewer(inputs="ce_vs_idx", mode="train")
]
estimator = fe.Estimator(pipeline=pipeline,
network=network,
epochs=10,
traces=traces,
log_steps=300)
estimator.fit()
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ FastEstimator-Start: step: 1; logging_interval: 300; num_device: 0; FastEstimator-Train: step: 1; ce: 2.2912202; FastEstimator-Train: step: 300; ce: 1.5846689; steps/sec: 13.43; FastEstimator-Train: step: 391; epoch: 1; epoch_time: 31.35 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 30.96; Eval Progress: 26/39; steps/sec: 33.3; Eval Progress: 39/39; steps/sec: 29.91; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpsp9wc_hw/model_best_mcc.h5 FastEstimator-Eval: step: 391; epoch: 1; ce: 1.3090442; max_mcc: 0.47626476759536035; mcc: 0.47626476759536035; since_best_mcc: 0; FastEstimator-Train: step: 600; ce: 1.2921344; steps/sec: 11.57; FastEstimator-Train: step: 782; epoch: 2; epoch_time: 36.13 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 24.74; Eval Progress: 26/39; steps/sec: 24.1; Eval Progress: 39/39; steps/sec: 23.63; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpsp9wc_hw/model_best_mcc.h5 FastEstimator-Eval: step: 782; epoch: 2; ce: 1.115688; max_mcc: 0.5713074370763542; mcc: 0.5713074370763542; since_best_mcc: 0; FastEstimator-Train: step: 900; ce: 1.114191; steps/sec: 10.21; FastEstimator-Train: step: 1173; epoch: 3; epoch_time: 37.56 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 23.43; Eval Progress: 26/39; steps/sec: 22.79; Eval Progress: 39/39; steps/sec: 24.15; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpsp9wc_hw/model_best_mcc.h5 FastEstimator-Eval: step: 1173; epoch: 3; ce: 1.0538898; max_mcc: 0.5788361476143156; mcc: 0.5788361476143156; since_best_mcc: 0; FastEstimator-Train: step: 1200; ce: 1.3872262; steps/sec: 10.4; FastEstimator-Train: step: 1500; ce: 1.2397084; steps/sec: 9.96; FastEstimator-Train: step: 1564; epoch: 4; epoch_time: 39.27 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 23.4; Eval Progress: 26/39; steps/sec: 25.77; Eval Progress: 39/39; steps/sec: 24.88; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpsp9wc_hw/model_best_mcc.h5 FastEstimator-Eval: step: 1564; epoch: 4; ce: 0.9873681; max_mcc: 0.6094672773776373; mcc: 0.6094672773776373; since_best_mcc: 0; FastEstimator-Train: step: 1800; ce: 1.0803001; steps/sec: 9.81; FastEstimator-Train: step: 1955; epoch: 5; epoch_time: 39.35 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 24.01; Eval Progress: 26/39; steps/sec: 25.06; Eval Progress: 39/39; steps/sec: 24.64; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpsp9wc_hw/model_best_mcc.h5 FastEstimator-Eval: step: 1955; epoch: 5; ce: 0.985664; max_mcc: 0.6117049529662967; mcc: 0.6117049529662967; since_best_mcc: 0; FastEstimator-Train: step: 2100; ce: 1.0776561; steps/sec: 10.5; FastEstimator-Train: step: 2346; epoch: 6; epoch_time: 36.1 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 25.68; Eval Progress: 26/39; steps/sec: 25.82; Eval Progress: 39/39; steps/sec: 24.76; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpsp9wc_hw/model_best_mcc.h5 FastEstimator-Eval: step: 2346; epoch: 6; ce: 0.91793406; max_mcc: 0.6461908231587581; mcc: 0.6461908231587581; since_best_mcc: 0; FastEstimator-Train: step: 2400; ce: 1.0986989; steps/sec: 10.94; FastEstimator-Train: step: 2700; ce: 0.9455121; steps/sec: 11.03; FastEstimator-Train: step: 2737; epoch: 7; epoch_time: 35.52 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 24.21; Eval Progress: 26/39; steps/sec: 25.1; Eval Progress: 39/39; steps/sec: 23.46; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpsp9wc_hw/model_best_mcc.h5 FastEstimator-Eval: step: 2737; epoch: 7; ce: 0.85772955; max_mcc: 0.6636882958141539; mcc: 0.6636882958141539; since_best_mcc: 0; FastEstimator-Train: step: 3000; ce: 0.9812018; steps/sec: 10.88; FastEstimator-Train: step: 3128; epoch: 8; epoch_time: 35.98 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 24.45; Eval Progress: 26/39; steps/sec: 26.3; Eval Progress: 39/39; steps/sec: 26.25; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpsp9wc_hw/model_best_mcc.h5 FastEstimator-Eval: step: 3128; epoch: 8; ce: 0.8402723; max_mcc: 0.6733243035167943; mcc: 0.6733243035167943; since_best_mcc: 0; FastEstimator-Train: step: 3300; ce: 1.0235794; steps/sec: 10.95; FastEstimator-Train: step: 3519; epoch: 9; epoch_time: 36.29 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 24.68; Eval Progress: 26/39; steps/sec: 26.31; Eval Progress: 39/39; steps/sec: 25.95; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpsp9wc_hw/model_best_mcc.h5 FastEstimator-Eval: step: 3519; epoch: 9; ce: 0.8024753; max_mcc: 0.6875346994702772; mcc: 0.6875346994702772; since_best_mcc: 0; FastEstimator-Train: step: 3600; ce: 1.112258; steps/sec: 10.68; FastEstimator-Train: step: 3900; ce: 0.9471506; steps/sec: 10.55; FastEstimator-Train: step: 3910; epoch: 10; epoch_time: 36.79 sec; Eval Progress: 1/39; Eval Progress: 13/39; steps/sec: 24.96; Eval Progress: 26/39; steps/sec: 25.01; Eval Progress: 39/39; steps/sec: 25.77; FastEstimator-Eval: step: 3910; epoch: 10; ce: 0.82467973; max_mcc: 0.6875346994702772; mcc: 0.6747237819814983; since_best_mcc: 1; FastEstimator-BestModelSaver: Restoring model from /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpsp9wc_hw/model_best_mcc.h5
FastEstimator-Finish: step: 3910; model_lr: 0.001; total_time: 385.42 sec;
From the graph above it looks like datapoint number 10 is pretty easy, whereas 18 is somewhat difficult. Performance on some of the hardest points actually seems to get worse over time, so perhaps it would be worth visualizing them to see if there's a reason the network is having a hard time. Let's take a look at 10380, for example:
data_idx = 10380
class_map = {v: k for k, v in label_mapping.items()}
true_key = np.array([class_map[train_data[data_idx]['y'].item()]])
data = pipeline.transform(train_data[data_idx], mode='test', target_type='tf')
y_pred = tf.argmax(model(data['x']), axis=-1).numpy().item()
pred_key = np.array([class_map[y_pred]])
fig = GridDisplay([ImageDisplay(text=true_key, title="y"),
ImageDisplay(text=pred_key, title="y_pred"),
ImageDisplay(image=train_data[data_idx]["x"], title="x")
])
fig.show()
So we've got a (sideways) image of a car, but the network is probably looking at blue/green tint of the image and deciding that the image is a ship. It might also be confused by the angle/rotation of the image. If you're trying to expand your dataset this could provide some useful information about what sort of images you might need to collect in order to get a more robust network. You could also try some hue-shift and rotation data augmentation to correct for this.