We'll start by getting the imports out of the way:
In [1]:
Copied!
import tempfile
import fastestimator as fe
from fastestimator.architecture.tensorflow import LeNet
from fastestimator.backend import reduce_mean
from fastestimator.dataset.data import cifair10
from fastestimator.op.numpyop.meta import Sometimes
from fastestimator.op.numpyop.multivariate import HorizontalFlip, PadIfNeeded, RandomCrop
from fastestimator.op.numpyop.univariate import CoarseDropout, Normalize
from fastestimator.op.tensorop.loss import CrossEntropy
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
from fastestimator.op.tensorop import LambdaOp
from fastestimator.trace.io import BestModelSaver, ImageViewer
from fastestimator.trace.metric import MCC
from fastestimator.trace.xai import LabelTracker
label_mapping = {
'airplane': 0,
'automobile': 1,
'bird': 2,
'cat': 3,
'deer': 4,
'dog': 5,
'frog': 6,
'horse': 7,
'ship': 8,
'truck': 9
}
import tempfile
import fastestimator as fe
from fastestimator.architecture.tensorflow import LeNet
from fastestimator.backend import reduce_mean
from fastestimator.dataset.data import cifair10
from fastestimator.op.numpyop.meta import Sometimes
from fastestimator.op.numpyop.multivariate import HorizontalFlip, PadIfNeeded, RandomCrop
from fastestimator.op.numpyop.univariate import CoarseDropout, Normalize
from fastestimator.op.tensorop.loss import CrossEntropy
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
from fastestimator.op.tensorop import LambdaOp
from fastestimator.trace.io import BestModelSaver, ImageViewer
from fastestimator.trace.metric import MCC
from fastestimator.trace.xai import LabelTracker
label_mapping = {
'airplane': 0,
'automobile': 1,
'bird': 2,
'cat': 3,
'deer': 4,
'dog': 5,
'frog': 6,
'horse': 7,
'ship': 8,
'truck': 9
}
Label Tracking¶
Suppose you are doing some training, and you want to know whether a particular class is easier or harder than other classes for your network to learn. One way to investigate this is with the LabelTracker
Trace
. It takes as input any per-element metric (such as sample-wise loss), as well as any label vector (usually class labels, but it could be any grouping) and produces a visualization at the end of training:
In [2]:
Copied!
batch_size=128
save_dir = tempfile.mkdtemp()
train_data, eval_data = cifair10.load_data()
test_data = eval_data.split(range(len(eval_data) // 2))
pipeline = fe.Pipeline(
train_data=train_data,
eval_data=eval_data,
test_data=test_data,
batch_size=batch_size,
ops=[Normalize(inputs="x", outputs="x", mean=(0.4914, 0.4822, 0.4465), std=(0.2471, 0.2435, 0.2616)),
PadIfNeeded(min_height=40, min_width=40, image_in="x", image_out="x", mode="train"),
RandomCrop(32, 32, image_in="x", image_out="x", mode="train"),
Sometimes(HorizontalFlip(image_in="x", image_out="x", mode="train")),
CoarseDropout(inputs="x", outputs="x", mode="train", max_holes=1),
],
num_process=0)
model = fe.build(model_fn=lambda: LeNet(input_shape=(32, 32, 3)), optimizer_fn="adam")
network = fe.Network(ops=[
ModelOp(model=model, inputs="x", outputs="y_pred"),
CrossEntropy(inputs=("y_pred", "y"), outputs="ce", mode="train"),
CrossEntropy(inputs=("y_pred", "y"), outputs="sample_ce", mode=("eval", "test"), average_loss=False),
LambdaOp(inputs="sample_ce", outputs="ce", mode=("eval", "test"), fn=lambda x: reduce_mean(x)),
UpdateOp(model=model, loss_name="ce")
])
traces = [
MCC(true_key="y", pred_key="y_pred"),
BestModelSaver(model=model, save_dir=save_dir, metric="mcc", save_best_mode="max", load_best_final=True),
LabelTracker(label="y", metric="sample_ce", label_mapping=label_mapping, outputs="ce_vs_y", bounds=None, mode=["eval", "test"]),
ImageViewer(inputs="ce_vs_y", mode=["eval", "test"])
]
estimator = fe.Estimator(pipeline=pipeline,
network=network,
epochs=10,
traces=traces,
log_steps=300)
batch_size=128
save_dir = tempfile.mkdtemp()
train_data, eval_data = cifair10.load_data()
test_data = eval_data.split(range(len(eval_data) // 2))
pipeline = fe.Pipeline(
train_data=train_data,
eval_data=eval_data,
test_data=test_data,
batch_size=batch_size,
ops=[Normalize(inputs="x", outputs="x", mean=(0.4914, 0.4822, 0.4465), std=(0.2471, 0.2435, 0.2616)),
PadIfNeeded(min_height=40, min_width=40, image_in="x", image_out="x", mode="train"),
RandomCrop(32, 32, image_in="x", image_out="x", mode="train"),
Sometimes(HorizontalFlip(image_in="x", image_out="x", mode="train")),
CoarseDropout(inputs="x", outputs="x", mode="train", max_holes=1),
],
num_process=0)
model = fe.build(model_fn=lambda: LeNet(input_shape=(32, 32, 3)), optimizer_fn="adam")
network = fe.Network(ops=[
ModelOp(model=model, inputs="x", outputs="y_pred"),
CrossEntropy(inputs=("y_pred", "y"), outputs="ce", mode="train"),
CrossEntropy(inputs=("y_pred", "y"), outputs="sample_ce", mode=("eval", "test"), average_loss=False),
LambdaOp(inputs="sample_ce", outputs="ce", mode=("eval", "test"), fn=lambda x: reduce_mean(x)),
UpdateOp(model=model, loss_name="ce")
])
traces = [
MCC(true_key="y", pred_key="y_pred"),
BestModelSaver(model=model, save_dir=save_dir, metric="mcc", save_best_mode="max", load_best_final=True),
LabelTracker(label="y", metric="sample_ce", label_mapping=label_mapping, outputs="ce_vs_y", bounds=None, mode=["eval", "test"]),
ImageViewer(inputs="ce_vs_y", mode=["eval", "test"])
]
estimator = fe.Estimator(pipeline=pipeline,
network=network,
epochs=10,
traces=traces,
log_steps=300)
2022-04-13 16:01:46.649443: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. 2022-04-13 16:01:46.739177: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
In [3]:
Copied!
estimator.fit()
estimator.fit()
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ FastEstimator-Start: step: 1; logging_interval: 300; num_device: 0; FastEstimator-Train: step: 1; ce: 2.3176317; FastEstimator-Train: step: 300; ce: 1.5395001; steps/sec: 14.3; FastEstimator-Train: step: 391; epoch: 1; epoch_time: 29.12 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmph1ky2kgz/model_best_mcc.h5 FastEstimator-Eval: step: 391; epoch: 1; ce: 1.2877194; max_mcc: 0.4842492824031186; mcc: 0.4842492824031186; since_best_mcc: 0; FastEstimator-Train: step: 600; ce: 1.2664933; steps/sec: 13.28; FastEstimator-Train: step: 782; epoch: 2; epoch_time: 29.89 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmph1ky2kgz/model_best_mcc.h5 FastEstimator-Eval: step: 782; epoch: 2; ce: 1.1504955; max_mcc: 0.5389898387019242; mcc: 0.5389898387019242; since_best_mcc: 0; FastEstimator-Train: step: 900; ce: 1.2066176; steps/sec: 12.84; FastEstimator-Train: step: 1173; epoch: 3; epoch_time: 28.73 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmph1ky2kgz/model_best_mcc.h5 FastEstimator-Eval: step: 1173; epoch: 3; ce: 1.0454248; max_mcc: 0.5840932018667955; mcc: 0.5840932018667955; since_best_mcc: 0; FastEstimator-Train: step: 1200; ce: 1.1817951; steps/sec: 13.8; FastEstimator-Train: step: 1500; ce: 1.1124868; steps/sec: 10.77; FastEstimator-Train: step: 1564; epoch: 4; epoch_time: 35.84 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmph1ky2kgz/model_best_mcc.h5 FastEstimator-Eval: step: 1564; epoch: 4; ce: 0.9683806; max_mcc: 0.6100634215605975; mcc: 0.6100634215605975; since_best_mcc: 0; FastEstimator-Train: step: 1800; ce: 1.2120445; steps/sec: 12.0; FastEstimator-Train: step: 1955; epoch: 5; epoch_time: 32.19 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmph1ky2kgz/model_best_mcc.h5 FastEstimator-Eval: step: 1955; epoch: 5; ce: 0.89093876; max_mcc: 0.6506148663922158; mcc: 0.6506148663922158; since_best_mcc: 0; FastEstimator-Train: step: 2100; ce: 0.7867976; steps/sec: 11.92; FastEstimator-Train: step: 2346; epoch: 6; epoch_time: 30.14 sec; FastEstimator-Eval: step: 2346; epoch: 6; ce: 0.88949853; max_mcc: 0.6506148663922158; mcc: 0.6494923768152842; since_best_mcc: 1; FastEstimator-Train: step: 2400; ce: 0.8040351; steps/sec: 13.66; FastEstimator-Train: step: 2700; ce: 1.0257889; steps/sec: 13.92; FastEstimator-Train: step: 2737; epoch: 7; epoch_time: 28.51 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmph1ky2kgz/model_best_mcc.h5 FastEstimator-Eval: step: 2737; epoch: 7; ce: 0.8607704; max_mcc: 0.671126859105468; mcc: 0.671126859105468; since_best_mcc: 0; FastEstimator-Train: step: 3000; ce: 0.9415474; steps/sec: 13.78; FastEstimator-Train: step: 3128; epoch: 8; epoch_time: 27.86 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmph1ky2kgz/model_best_mcc.h5 FastEstimator-Eval: step: 3128; epoch: 8; ce: 0.79306936; max_mcc: 0.6862720645156658; mcc: 0.6862720645156658; since_best_mcc: 0; FastEstimator-Train: step: 3300; ce: 0.8957753; steps/sec: 14.19; FastEstimator-Train: step: 3519; epoch: 9; epoch_time: 27.66 sec; FastEstimator-Eval: step: 3519; epoch: 9; ce: 0.78325427; max_mcc: 0.6862720645156658; mcc: 0.6832005047098153; since_best_mcc: 1; FastEstimator-Train: step: 3600; ce: 0.8727247; steps/sec: 13.87; FastEstimator-Train: step: 3900; ce: 0.89711684; steps/sec: 11.83; FastEstimator-Train: step: 3910; epoch: 10; epoch_time: 32.38 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmph1ky2kgz/model_best_mcc.h5 FastEstimator-Eval: step: 3910; epoch: 10; ce: 0.79790723; max_mcc: 0.6892349113253013; mcc: 0.6892349113253013; since_best_mcc: 0; FastEstimator-BestModelSaver: Restoring model from /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmph1ky2kgz/model_best_mcc.h5
FastEstimator-Finish: step: 3910; model_lr: 0.001; total_time: 317.37 sec;
From the graph above it seems that cats are relatively difficult for the network to learn well, whereas automobiles are pretty easy.