Multi-Dataset Concept and API¶
When you are training a deep learning model, you may sometimes want to train/evaluate on multiple datasets. For example, we might be interested in knowing evaluation metrics separately for two datasets. In this section we show how to do that conveniently in FastEstimator.
Adding multiple datasets to a Pipeline¶
If you have multiple datasets, then in Pipeline
you can simply provide a dictionary to the train_data
, eval_data
, and/or test_data
arguments, with keys being the names of the datasets and values being the data instances.
For example:
pipeline = fe.Pipeline(eval_data={"ds1": eval_data1, "ds2": eval_data2}, ...)
In the above example, ds1
and ds2
are the names of those datasets. These can be any other arbitrary names.
Dataset-specific Ops¶
Sometimes different datasets might require specific NumpyOps or TensorOps. For example, when we train a gray-scale model and have both gray-scale and colored evaluation sets, we only need to apply the gray-scale conversion to the colored dataset.
In FastEstimator, To indicate that an Op is only applied for a specific dataset (say ds1
), one only needs to do:
Op(..., ds_id="ds1") # run the op on ds1
ds_id
works similarly to the mode
argument in Ops. The operator will only execute if ds_id
matches the specific dataset. If ds_id
is None (default), then it will execute on all datasets.
The ds_id
argument works in conjunction with mode
. For example, Op(mode="train", ds_id="myds1")
would only run during training for dataset named "myds1".
Dataset-specific Traces¶
To only execute a Trace for a specific dataset, simply pass:
Trace(..., ds_id="ds1") # run the trace on ds1
When using multiple datasets, the built-in FastEstimator metric traces will automatically report their outputs for each dataset individually, as well as the overall metric aggregated over all datasets. We will demonstrate this behavior in detail in the example section.
Specifying multiple datasets in Ops or Traces¶
When an Op or Trace needs to execute on multiple datasets, simply provide a list, tuple, or set of dataset names to the ds_id
argument.
For example:
Op(..., ds_id=["ds1", "ds2"]) # run on both ds1 and ds2
When there are many datasets such that listing every name becomes undesirable, you can use !
in front of the dataset name to indicate all except
.
For example:
Op(..., ds_id=["!ds1", "!ds2"]) # run on all datasets except ds1 and ds2
Multi-dataset Example¶
In this example, we will train on the MNIST dataset but evaluate on both the MNIST test set and the SVHN-Cropped test set. Here are the dataset-specific items we will do in this example:
- Resize images in the SVHN-Cropped dataset from 32x32 to 28x28 to match the MNIST data
- Convert the SVHN-Cropped dataseet to gray-scale
- Measure dataset-specific Accuracy as well as combined Accuracy
- Customize an AUC metric that works on a per-dataset level
- Save the best model based on evaluation AUC of a specific dataset
Prepare Dataset¶
from fastestimator.dataset.data import mnist, svhn_cropped
train_mnist, eval_mnist = mnist.load_data()
_, eval_svhn = svhn_cropped.load_data()
print("mnist evaluation dataset summary:")
print(eval_mnist.summary())
print("svhn_cropped evaluation dataset summary:")
print(eval_svhn.summary())
mnist evaluation dataset summary: {"num_instances": 10000, "keys": {"x": {"shape": [28, 28], "dtype": "uint8"}, "y": {"num_unique_values": 10, "shape": [], "dtype": "uint8"}}} svhn_cropped evaluation dataset summary: {"num_instances": 26032, "keys": {"x": {"shape": [32, 32, 3], "dtype": "uint8"}, "y": {"shape": [1], "dtype": "uint8"}}}
Preprocessing¶
import fastestimator as fe
from fastestimator.op.numpyop.univariate import ExpandDims, Minmax, ToGray
from fastestimator.op.numpyop.multivariate import Resize
from fastestimator.op.numpyop import LambdaOp
pipeline = fe.Pipeline(train_data={"mnist": train_mnist},
eval_data={"mnist": eval_mnist, "svhn": eval_svhn},
batch_size=32,
ops= [
Resize(image_in="x", image_out="x", height=28, width=28, ds_id="svhn"),
ToGray(inputs="x", outputs="x", ds_id="svhn"), # after ToGray, the output is still 3 channel
LambdaOp(fn=lambda x: x[..., 0:1], inputs="x", outputs="x", ds_id="svhn"), # select the first channel for svhn
ExpandDims(inputs="x", outputs="x", ds_id="mnist"), # (28, 28) -> (28, 28, 1) for mnist
LambdaOp(fn=fe.backend.squeeze, inputs="y", outputs="y", ds_id="svhn"), # Match the mnist y shape
Minmax(inputs="x", outputs="x")])
Visualize MNIST preprocessing results¶
from fastestimator.util import BatchDisplay, GridDisplay
mnist_eval_data = pipeline.get_results(mode="eval", ds_id="mnist")
fig = GridDisplay([BatchDisplay(image=mnist_eval_data["x"][:2], title="image"),
BatchDisplay(text=mnist_eval_data["y"][:2], title="label")
])
fig.show()
Visualize SVHN_cropped preprocessing results¶
svhn_eval_data = pipeline.get_results(mode="eval", ds_id="svhn")
fig = GridDisplay([BatchDisplay(image=svhn_eval_data["x"][:2], title="image"),
BatchDisplay(text=svhn_eval_data["y"][:2], title="label")
])
fig.show()
Define Model and Networks¶
from fastestimator.architecture.tensorflow import LeNet
from fastestimator.op.tensorop.loss import CrossEntropy
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
model = fe.build(model_fn=LeNet, optimizer_fn="adam")
network = fe.Network(ops=[
ModelOp(model=model, inputs="x", outputs="y_pred"),
CrossEntropy(inputs=("y_pred", "y"), outputs="ce"),
UpdateOp(model=model, loss_name="ce")
])
2022-06-01 12:09:21.725055: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Start Training with Only Accuracy Trace¶
from fastestimator.trace.io import BestModelSaver
from fastestimator.trace.metric import Accuracy
estimator = fe.Estimator(pipeline=pipeline,
network=network,
epochs=4,
traces=Accuracy(true_key="y", pred_key="y_pred"),
train_steps_per_epoch=200)
estimator.fit()
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ FastEstimator-Warn: No ModelSaver Trace detected. Models will not be saved. FastEstimator-Start: step: 1; logging_interval: 100; num_device: 0; FastEstimator-Train: step: 1; ce|mnist: 2.297002; FastEstimator-Train: step: 100; ce|mnist: 0.62006235; steps/sec|mnist: 68.14; FastEstimator-Train: step: 200; ce|mnist: 0.100964725; steps/sec|mnist: 72.17; FastEstimator-Train: step: 200; epoch: 1; epoch_time: 4.28 sec; Eval Progress: 1/312; Eval Progress: 104/312; steps/sec: 195.59; Eval Progress: 208/312; steps/sec: 194.24; Eval Progress: 312/312; steps/sec: 206.18; FastEstimator-Eval: step: 200; epoch: 1; accuracy: 0.38574045293072823; accuracy|mnist: 0.9399; accuracy|svhn: 0.1728641671788568; ce: 2.479175; ce|mnist: 0.20342466; ce|svhn: 3.3542488; FastEstimator-Train: step: 300; ce|mnist: 0.14703166; steps/sec|mnist: 54.67; FastEstimator-Train: step: 400; ce|mnist: 0.03327415; steps/sec|mnist: 71.45; FastEstimator-Train: step: 400; epoch: 2; epoch_time: 3.24 sec; Eval Progress: 1/312; Eval Progress: 104/312; steps/sec: 208.84; Eval Progress: 208/312; steps/sec: 208.1; Eval Progress: 312/312; steps/sec: 211.65; FastEstimator-Eval: step: 400; epoch: 2; accuracy: 0.4563721136767318; accuracy|mnist: 0.9722; accuracy|svhn: 0.25822065150583895; ce: 1.9564626; ce|mnist: 0.09087765; ce|svhn: 2.6738186; FastEstimator-Train: step: 500; ce|mnist: 0.26206714; steps/sec|mnist: 52.52; FastEstimator-Train: step: 600; ce|mnist: 0.20257875; steps/sec|mnist: 69.8; FastEstimator-Train: step: 600; epoch: 3; epoch_time: 3.33 sec; Eval Progress: 1/312; Eval Progress: 104/312; steps/sec: 209.52; Eval Progress: 208/312; steps/sec: 216.23; Eval Progress: 312/312; steps/sec: 209.66; FastEstimator-Eval: step: 600; epoch: 3; accuracy: 0.44879551509769094; accuracy|mnist: 0.9719; accuracy|svhn: 0.24784880147510757; ce: 1.7997541; ce|mnist: 0.09163932; ce|svhn: 2.4565597; FastEstimator-Train: step: 700; ce|mnist: 0.025095956; steps/sec|mnist: 52.47; FastEstimator-Train: step: 800; ce|mnist: 0.12707141; steps/sec|mnist: 70.46; FastEstimator-Train: step: 800; epoch: 4; epoch_time: 3.32 sec; Eval Progress: 1/312; Eval Progress: 104/312; steps/sec: 207.01; Eval Progress: 208/312; steps/sec: 207.07; Eval Progress: 312/312; steps/sec: 209.36; FastEstimator-Eval: step: 800; epoch: 4; accuracy: 0.4620892539964476; accuracy|mnist: 0.9741; accuracy|svhn: 0.2654041180086048; ce: 1.7328861; ce|mnist: 0.083745226; ce|svhn: 2.3670154; FastEstimator-Finish: step: 800; model_lr: 0.001; total_time: 41.48 sec;
As you can see in the training log, the Accuracy
Trace created 3 keys: accuracy|mnist
, accuracy|svhn
, and accuracy
. The accuracy|mnist
and accuracy|svhn
are measured on individual datasets, and accuracy
is measured on the overall combined evaluation set.
Customize an AUC metric that works for every dataset, then save model based on AUC of a specific dataset¶
Since this is a 10-class classification task, to simplify AUC calculation, we will count any label < 5 as 0 and the rest of labels as 1. When a trace is initializing the data during on_epoch_begin
and outputting the data during on_epoch_end
, we only need a per_ds
decorator to enable multi-dataset support as shown below.
from fastestimator.trace.meta import per_ds
from sklearn import metrics
import numpy as np
@per_ds # Without this annotation the trace would only compute the aggregate metric
class AUC(fe.trace.Trace):
def on_epoch_begin(self, data):
self.y_true = []
self.y_pred = []
def on_batch_end(self, data):
y_pred, y_true = np.argmax(data["y_pred"].numpy(), axis=-1), data["y"].numpy()
y_pred, y_true = np.where(y_pred < 5, 0, 1), np.where(y_true < 5, 0, 1)
self.y_pred.extend(y_pred.ravel())
self.y_true.extend(y_true.ravel())
def on_epoch_end(self, data):
fpr, tpr, _ = metrics.roc_curve(self.y_true, self.y_pred)
auc = metrics.auc(fpr, tpr)
data.write_with_log("auc", auc)
import tempfile
from fastestimator.trace.io import BestModelSaver
estimator = fe.Estimator(pipeline=pipeline,
network=network,
epochs=4,
traces=[Accuracy(true_key="y", pred_key="y_pred"),
AUC(inputs=("y", "y_pred"), outputs="auc", mode="eval"),
BestModelSaver(model=model, save_dir=tempfile.mkdtemp(), metric="auc|svhn", save_best_mode="max")],
train_steps_per_epoch=200)
estimator.fit()
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ FastEstimator-Start: step: 1; logging_interval: 100; num_device: 0; FastEstimator-Train: step: 1; ce|mnist: 0.062433735; FastEstimator-Train: step: 100; ce|mnist: 0.2802775; steps/sec|mnist: 63.92; FastEstimator-Train: step: 200; ce|mnist: 0.23590814; steps/sec|mnist: 64.12; FastEstimator-Train: step: 200; epoch: 1; epoch_time: 3.7 sec; Eval Progress: 1/312; Eval Progress: 104/312; steps/sec: 180.97; Eval Progress: 208/312; steps/sec: 193.39; Eval Progress: 312/312; steps/sec: 179.7; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmp95pnqff2/model_best_auc|svhn.h5 FastEstimator-Eval: step: 200; epoch: 1; accuracy: 0.4888154973357016; accuracy|mnist: 0.9837; accuracy|svhn: 0.29870928088506454; auc: 0.7571002319848853; auc|mnist: 0.9892388633631616; auc|svhn: 0.6624670574144985; ce: 1.6440561; ce|mnist: 0.055526573; ce|svhn: 2.254879; max_auc|svhn: 0.6624670574144985; since_best_auc|svhn: 0; FastEstimator-Train: step: 300; ce|mnist: 0.037645765; steps/sec|mnist: 44.2; FastEstimator-Train: step: 400; ce|mnist: 0.05922611; steps/sec|mnist: 63.98; FastEstimator-Train: step: 400; epoch: 2; epoch_time: 3.79 sec; Eval Progress: 1/312; Eval Progress: 104/312; steps/sec: 150.98; Eval Progress: 208/312; steps/sec: 166.17; Eval Progress: 312/312; steps/sec: 187.0; FastEstimator-Eval: step: 400; epoch: 2; accuracy: 0.5021369893428064; accuracy|mnist: 0.9801; accuracy|svhn: 0.31853103872157346; auc: 0.7618485829159409; auc|mnist: 0.9859111115434452; auc|svhn: 0.6552186362897082; ce: 1.6789757; ce|mnist: 0.061825372; ce|svhn: 2.3008037; max_auc|svhn: 0.6624670574144985; since_best_auc|svhn: 1; FastEstimator-Train: step: 500; ce|mnist: 0.037580382; steps/sec|mnist: 45.06; FastEstimator-Train: step: 600; ce|mnist: 0.042693608; steps/sec|mnist: 58.93; FastEstimator-Train: step: 600; epoch: 3; epoch_time: 3.93 sec; Eval Progress: 1/312; Eval Progress: 104/312; steps/sec: 172.32; Eval Progress: 208/312; steps/sec: 185.12; Eval Progress: 312/312; steps/sec: 172.22; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmp95pnqff2/model_best_auc|svhn.h5 FastEstimator-Eval: step: 600; epoch: 3; accuracy: 0.5132937388987566; accuracy|mnist: 0.98; accuracy|svhn: 0.3340119852489244; auc: 0.7671458888899824; auc|mnist: 0.9866311280009643; auc|svhn: 0.6683979301971492; ce: 1.530564; ce|mnist: 0.06489386; ce|svhn: 2.0941448; max_auc|svhn: 0.6683979301971492; since_best_auc|svhn: 0; FastEstimator-Train: step: 700; ce|mnist: 0.03424043; steps/sec|mnist: 43.93; FastEstimator-Train: step: 800; ce|mnist: 0.12813392; steps/sec|mnist: 63.89; FastEstimator-Train: step: 800; epoch: 4; epoch_time: 3.81 sec; Eval Progress: 1/312; Eval Progress: 104/312; steps/sec: 168.38; Eval Progress: 208/312; steps/sec: 172.82; Eval Progress: 312/312; steps/sec: 151.1; FastEstimator-Eval: step: 800; epoch: 4; accuracy: 0.48159968916518653; accuracy|mnist: 0.9787; accuracy|svhn: 0.2906422864167179; auc: 0.7624943008308487; auc|mnist: 0.9883770373095143; auc|svhn: 0.6621539867415974; ce: 1.7956299; ce|mnist: 0.06795667; ce|svhn: 2.4599562; max_auc|svhn: 0.6683979301971492; since_best_auc|svhn: 1; FastEstimator-Finish: step: 800; model_lr: 0.001; total_time: 47.55 sec;
Now during evaluation we can see auc|svhn
, auc|mnist
, and auc
printing in the log. Moreover, our model saving is based on the best evaluation auc on the svhn dataset.