How robust a model is at testing?¶
Robustness is the ability of a model to estimate reliably when inputs are influenced by different conditions or when model’s assumptions are not fully satisfied. In this tutorial we are going to introduce model robustness and how we can use FastEstimator Search API and Visualization API to check model robustness.
In this tutorial, we will test model’s capability at handling rotation varieties. First, let’s design a generic get_estimator
function to be used in all our experiments in this tutorial and ensure that we use same test and train set for each experiment.
import fastestimator as fe
import os
import tempfile
from fastestimator.dataset.data import cifair10
from fastestimator.op.numpyop.univariate import ChannelTranspose, Normalize
from fastestimator.op.numpyop.multivariate import Rotate, Affine
from fastestimator.architecture.pytorch import LeNet
from fastestimator.op.tensorop.loss import CrossEntropy
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
from fastestimator.trace.io import ModelSaver
from fastestimator.trace.metric import Accuracy
def get_estimator(
save_dir,
weight_path=None,
model_name='robust',
train_rotate=None,
test_rotate=None,
train_shear=None,
epochs=24,
visualize=False):
train_data, eval_data = cifair10.load_data()
test_data = eval_data.split(0.5, seed = 0)
numpy_op = []
if train_shear is not None:
numpy_op.append(Affine(image_in="x", shear=train_shear, mode="train", border_handling='constant', fill_value=0))
if train_rotate is not None:
numpy_op.append(Affine(image_in="x", rotate=train_rotate, mode="train", border_handling='constant', fill_value=0))
pipeline = fe.Pipeline(train_data=train_data,
eval_data=eval_data,
test_data=test_data,
batch_size=32,
ops=numpy_op +[Rotate(image_in="x", limit=[test_rotate, test_rotate], mode="test"),
Normalize(inputs="x", outputs="x", mean=(0.4914, 0.4822, 0.4465), std=(0.2471, 0.2435, 0.2616)),
ChannelTranspose(inputs="x", outputs="x")
])
model = fe.build(model_fn=lambda: LeNet(input_shape=(3, 32, 32)), optimizer_fn="adam", weights_path=weight_path, model_name=model_name)
network = fe.Network(ops=[
ModelOp(model=model, inputs="x", outputs="y_pred"),
CrossEntropy(inputs=("y_pred", "y"), outputs="ce"),
UpdateOp(model=model, loss_name="ce", mode="train")
])
traces = [
Accuracy(true_key="y", pred_key="y_pred"),
ModelSaver(model=model, save_dir=save_dir, frequency=epochs)
]
estimator = fe.Estimator(pipeline=pipeline, network=network, epochs=epochs, traces=traces)
if visualize:
return estimator, pipeline
else:
return estimator
2023-07-28 00:34:24.278703: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 AVX512F AVX512_VNNI FMA To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags. 2023-07-28 00:34:24.370670: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
Let's train a model without any augmentation.
save_dir = tempfile.mkdtemp()
est, no_aug_pipe = get_estimator(save_dir, weight_path=None, model_name='Without_augmentation', epochs=3, visualize=True)
est.fit()
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ FastEstimator-Start: step: 1; logging_interval: 100; num_device: 0; FastEstimator-Train: step: 1; ce: 2.297823; FastEstimator-Train: step: 100; ce: 1.9229205; steps/sec: 20.72; FastEstimator-Train: step: 200; ce: 1.8488135; steps/sec: 37.45; FastEstimator-Train: step: 300; ce: 1.97854; steps/sec: 36.75; FastEstimator-Train: step: 400; ce: 1.4034971; steps/sec: 41.87; FastEstimator-Train: step: 500; ce: 1.634158; steps/sec: 36.72; FastEstimator-Train: step: 600; ce: 1.5278842; steps/sec: 40.42; FastEstimator-Train: step: 700; ce: 1.5716941; steps/sec: 20.76; FastEstimator-Train: step: 800; ce: 1.4656043; steps/sec: 19.85; FastEstimator-Train: step: 900; ce: 1.3644964; steps/sec: 22.87; FastEstimator-Train: step: 1000; ce: 1.2243099; steps/sec: 27.8; FastEstimator-Train: step: 1100; ce: 1.1308111; steps/sec: 30.43; FastEstimator-Train: step: 1200; ce: 1.4400374; steps/sec: 30.82; FastEstimator-Train: step: 1300; ce: 1.0056257; steps/sec: 27.23; FastEstimator-Train: step: 1400; ce: 1.337088; steps/sec: 20.05; FastEstimator-Train: step: 1500; ce: 1.532153; steps/sec: 10.55; FastEstimator-Train: step: 1563; epoch: 1; epoch_time(sec): 62.79; Eval Progress: 1/157; Eval Progress: 52/157; steps/sec: 20.32; Eval Progress: 104/157; steps/sec: 112.04; Eval Progress: 157/157; steps/sec: 105.34; FastEstimator-Eval: step: 1563; epoch: 1; accuracy: 0.5198; ce: 1.3610618; FastEstimator-Train: step: 1600; ce: 1.3166769; steps/sec: 30.71; FastEstimator-Train: step: 1700; ce: 1.1115323; steps/sec: 36.85; FastEstimator-Train: step: 1800; ce: 1.2530088; steps/sec: 23.32; FastEstimator-Train: step: 1900; ce: 1.1652532; steps/sec: 34.8; FastEstimator-Train: step: 2000; ce: 1.0679957; steps/sec: 36.68; FastEstimator-Train: step: 2100; ce: 1.1941316; steps/sec: 24.23; FastEstimator-Train: step: 2200; ce: 0.9643715; steps/sec: 34.52; FastEstimator-Train: step: 2300; ce: 1.3712162; steps/sec: 36.0; FastEstimator-Train: step: 2400; ce: 1.1851138; steps/sec: 26.4; FastEstimator-Train: step: 2500; ce: 1.0683028; steps/sec: 53.08; FastEstimator-Train: step: 2600; ce: 1.2203339; steps/sec: 26.74; FastEstimator-Train: step: 2700; ce: 0.9800033; steps/sec: 22.31; FastEstimator-Train: step: 2800; ce: 1.1035116; steps/sec: 29.24; FastEstimator-Train: step: 2900; ce: 0.93366957; steps/sec: 20.45; FastEstimator-Train: step: 3000; ce: 0.9092339; steps/sec: 33.47; FastEstimator-Train: step: 3100; ce: 1.1804193; steps/sec: 8.85; FastEstimator-Train: step: 3126; epoch: 2; epoch_time(sec): 63.83; Eval Progress: 1/157; Eval Progress: 52/157; steps/sec: 24.04; Eval Progress: 104/157; steps/sec: 16.5; Eval Progress: 157/157; steps/sec: 17.76; FastEstimator-Eval: step: 3126; epoch: 2; accuracy: 0.6108; ce: 1.1001097; FastEstimator-Train: step: 3200; ce: 1.0673318; steps/sec: 14.38; FastEstimator-Train: step: 3300; ce: 0.80209684; steps/sec: 36.37; FastEstimator-Train: step: 3400; ce: 1.0805223; steps/sec: 22.28; FastEstimator-Train: step: 3500; ce: 0.8957855; steps/sec: 22.97; FastEstimator-Train: step: 3600; ce: 0.8971278; steps/sec: 61.19; FastEstimator-Train: step: 3700; ce: 0.95779765; steps/sec: 27.5; FastEstimator-Train: step: 3800; ce: 1.0616807; steps/sec: 36.44; FastEstimator-Train: step: 3900; ce: 1.2883285; steps/sec: 31.54; FastEstimator-Train: step: 4000; ce: 1.0859238; steps/sec: 17.53; FastEstimator-Train: step: 4100; ce: 1.16119; steps/sec: 34.34; FastEstimator-Train: step: 4200; ce: 0.7867568; steps/sec: 49.79; FastEstimator-Train: step: 4300; ce: 0.8639324; steps/sec: 26.84; FastEstimator-Train: step: 4400; ce: 1.1114055; steps/sec: 10.92; FastEstimator-Train: step: 4500; ce: 0.58221227; steps/sec: 19.2; FastEstimator-Train: step: 4600; ce: 0.65876365; steps/sec: 32.65; FastEstimator-ModelSaver: Saved model to /tmp/tmp26og0a4e/Without_augmentation_epoch_3.pt FastEstimator-Train: step: 4689; epoch: 3; epoch_time(sec): 60.23; Eval Progress: 1/157; Eval Progress: 52/157; steps/sec: 115.84; Eval Progress: 104/157; steps/sec: 23.7; Eval Progress: 157/157; steps/sec: 38.77; FastEstimator-Eval: step: 4689; epoch: 3; accuracy: 0.6584; ce: 0.9637741; FastEstimator-Finish: step: 4689; total_time(sec): 204.87; Without_augmentation_lr: 0.001;
Now that we have trained our model without any augmentations, let's load the trained model and test its performance on test set while the input images are rotated at various degrees. We will use FastEstimator Search API to accomplish this.
First, let's define a generic evaluation function for Grid Search
from fastestimator.search.visualize import visualize_search
from fastestimator.search import GridSearch
def score_fn(search_idx, rotate, weight_path, save_dir, field_name):
est = get_estimator(save_dir, weight_path=weight_path, test_rotate=rotate, epochs=3, visualize=False)
hist = est.test(summary="myexp")
acc = float(hist.history["test"]["accuracy"][0])
return {field_name: acc}
We take a range of rotation from 0 to 360 degrees(at an interval of 10) and use Grid Search to pass the rotation angles the estimator function
weight_path = os.path.join(save_dir, 'Without_augmentation_epoch_3.pt')
rot = list(range(0, 360, 10))
no_aug_grid_search = GridSearch(eval_fn=lambda search_idx, rotate: score_fn(search_idx, rotate, weight_path, save_dir, field_name="Accuracy without Augmentation"), params={"rotate": rot})
no_aug_grid_search.fit()
FastEstimator-Test: step: None; epoch: 3; accuracy: 0.6664; ce: 0.9561062; FastEstimator-Search: Evaluated {'rotate': 0, 'search_idx': 1}, result: {'Accuracy without Augmentation': 0.6664} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.6422; ce: 1.028634; FastEstimator-Search: Evaluated {'rotate': 10, 'search_idx': 2}, result: {'Accuracy without Augmentation': 0.6422} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.5804; ce: 1.198358; FastEstimator-Search: Evaluated {'rotate': 20, 'search_idx': 3}, result: {'Accuracy without Augmentation': 0.5804} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.4862; ce: 1.5152924; FastEstimator-Search: Evaluated {'rotate': 30, 'search_idx': 4}, result: {'Accuracy without Augmentation': 0.4862} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3996; ce: 1.8845491; FastEstimator-Search: Evaluated {'rotate': 40, 'search_idx': 5}, result: {'Accuracy without Augmentation': 0.3996} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3414; ce: 2.1744254; FastEstimator-Search: Evaluated {'rotate': 50, 'search_idx': 6}, result: {'Accuracy without Augmentation': 0.3414} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3072; ce: 2.351854; FastEstimator-Search: Evaluated {'rotate': 60, 'search_idx': 7}, result: {'Accuracy without Augmentation': 0.3072} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2864; ce: 2.4515882; FastEstimator-Search: Evaluated {'rotate': 70, 'search_idx': 8}, result: {'Accuracy without Augmentation': 0.2864} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2818; ce: 2.5172606; FastEstimator-Search: Evaluated {'rotate': 80, 'search_idx': 9}, result: {'Accuracy without Augmentation': 0.2818} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2932; ce: 2.4945111; FastEstimator-Search: Evaluated {'rotate': 90, 'search_idx': 10}, result: {'Accuracy without Augmentation': 0.2932} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.256; ce: 2.734745; FastEstimator-Search: Evaluated {'rotate': 100, 'search_idx': 11}, result: {'Accuracy without Augmentation': 0.256} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2414; ce: 2.9066494; FastEstimator-Search: Evaluated {'rotate': 110, 'search_idx': 12}, result: {'Accuracy without Augmentation': 0.2414} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2352; ce: 3.0595553; FastEstimator-Search: Evaluated {'rotate': 120, 'search_idx': 13}, result: {'Accuracy without Augmentation': 0.2352} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2346; ce: 3.0900187; FastEstimator-Search: Evaluated {'rotate': 130, 'search_idx': 14}, result: {'Accuracy without Augmentation': 0.2346} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2496; ce: 2.9533136; FastEstimator-Search: Evaluated {'rotate': 140, 'search_idx': 15}, result: {'Accuracy without Augmentation': 0.2496} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2694; ce: 2.6867018; FastEstimator-Search: Evaluated {'rotate': 150, 'search_idx': 16}, result: {'Accuracy without Augmentation': 0.2694} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.301; ce: 2.4071147; FastEstimator-Search: Evaluated {'rotate': 160, 'search_idx': 17}, result: {'Accuracy without Augmentation': 0.301} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.316; ce: 2.240197; FastEstimator-Search: Evaluated {'rotate': 170, 'search_idx': 18}, result: {'Accuracy without Augmentation': 0.316} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3428; ce: 2.1289763; FastEstimator-Search: Evaluated {'rotate': 180, 'search_idx': 19}, result: {'Accuracy without Augmentation': 0.3428} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3158; ce: 2.2378516; FastEstimator-Search: Evaluated {'rotate': 190, 'search_idx': 20}, result: {'Accuracy without Augmentation': 0.3158} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3002; ce: 2.378881; FastEstimator-Search: Evaluated {'rotate': 200, 'search_idx': 21}, result: {'Accuracy without Augmentation': 0.3002} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2768; ce: 2.6192675; FastEstimator-Search: Evaluated {'rotate': 210, 'search_idx': 22}, result: {'Accuracy without Augmentation': 0.2768} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2486; ce: 2.8439047; FastEstimator-Search: Evaluated {'rotate': 220, 'search_idx': 23}, result: {'Accuracy without Augmentation': 0.2486} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2346; ce: 2.9615552; FastEstimator-Search: Evaluated {'rotate': 230, 'search_idx': 24}, result: {'Accuracy without Augmentation': 0.2346} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2364; ce: 2.9517527; FastEstimator-Search: Evaluated {'rotate': 240, 'search_idx': 25}, result: {'Accuracy without Augmentation': 0.2364} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2488; ce: 2.8593893; FastEstimator-Search: Evaluated {'rotate': 250, 'search_idx': 26}, result: {'Accuracy without Augmentation': 0.2488} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2634; ce: 2.761762; FastEstimator-Search: Evaluated {'rotate': 260, 'search_idx': 27}, result: {'Accuracy without Augmentation': 0.2634} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2904; ce: 2.538681; FastEstimator-Search: Evaluated {'rotate': 270, 'search_idx': 28}, result: {'Accuracy without Augmentation': 0.2904} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2768; ce: 2.636479; FastEstimator-Search: Evaluated {'rotate': 280, 'search_idx': 29}, result: {'Accuracy without Augmentation': 0.2768} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.277; ce: 2.5930767; FastEstimator-Search: Evaluated {'rotate': 290, 'search_idx': 30}, result: {'Accuracy without Augmentation': 0.277} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2798; ce: 2.5175831; FastEstimator-Search: Evaluated {'rotate': 300, 'search_idx': 31}, result: {'Accuracy without Augmentation': 0.2798} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3082; ce: 2.318865; FastEstimator-Search: Evaluated {'rotate': 310, 'search_idx': 32}, result: {'Accuracy without Augmentation': 0.3082} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3678; ce: 1.9600892; FastEstimator-Search: Evaluated {'rotate': 320, 'search_idx': 33}, result: {'Accuracy without Augmentation': 0.3678} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.478; ce: 1.527449; FastEstimator-Search: Evaluated {'rotate': 330, 'search_idx': 34}, result: {'Accuracy without Augmentation': 0.478} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.5888; ce: 1.1805097; FastEstimator-Search: Evaluated {'rotate': 340, 'search_idx': 35}, result: {'Accuracy without Augmentation': 0.5888} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.646; ce: 1.0058187; FastEstimator-Search: Evaluated {'rotate': 350, 'search_idx': 36}, result: {'Accuracy without Augmentation': 0.646}
Let's visualize how our model performs when tested on rotated images
visualize_search(search=no_aug_grid_search, title="Model Robustness Without Augmentation")
How Data augmentation can help improve rotation robustness?¶
It is clear from last experiment that as the angle of rotation increases the performance drops. This shows that the trained model is not robust to rotation while testing. Can we improve rotation robustness by augmentating training images?
The effect of applying shear augmentation¶
Shear is augmentation to move one or two sides of the image, turning a square or rectangle image into a trapezoidal image. Would training a model with Shear augmentation while training make it robust to rotation? Let's visualize the shear operation first.
est, shear_pipe = get_estimator(save_dir, weight_path=None, model_name='Shear', train_shear=45, epochs=3, visualize=True)
/usr/local/lib/python3.8/dist-packages/albumentations/imgaug/transforms.py:346: FutureWarning: This IAAAffine is deprecated. Please use Affine instead
from fastestimator.util import GridDisplay, BatchDisplay
shear_results = shear_pipe.get_results()
no_aug_results = no_aug_pipe.get_results()
sample_num = 3
fig = GridDisplay([
BatchDisplay(image=no_aug_results['x'][0:sample_num], title="Pipeline Input"),
BatchDisplay(image=shear_results['x'][0:sample_num], title="Pipeline Output")
])
fig.show()
Now that we have visualized shear operation , let us train a model while applying random shear operation in range [-45, 45] degrees.
est.fit()
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ FastEstimator-Start: step: 1; logging_interval: 100; num_device: 0; FastEstimator-Train: step: 1; ce: 2.2861838; FastEstimator-Train: step: 100; ce: 2.005511; steps/sec: 9.38; FastEstimator-Train: step: 200; ce: 1.6937112; steps/sec: 6.46; FastEstimator-Train: step: 300; ce: 1.8294314; steps/sec: 4.56; FastEstimator-Train: step: 400; ce: 1.4309876; steps/sec: 6.09; FastEstimator-Train: step: 500; ce: 1.5529174; steps/sec: 8.53; FastEstimator-Train: step: 600; ce: 1.8520718; steps/sec: 12.41; FastEstimator-Train: step: 700; ce: 1.3964385; steps/sec: 5.37; FastEstimator-Train: step: 800; ce: 1.5918491; steps/sec: 8.14; FastEstimator-Train: step: 900; ce: 1.7113352; steps/sec: 12.04; FastEstimator-Train: step: 1000; ce: 1.4144588; steps/sec: 3.41; FastEstimator-Train: step: 1100; ce: 1.6062101; steps/sec: 7.75; FastEstimator-Train: step: 1200; ce: 1.4574857; steps/sec: 10.48; FastEstimator-Train: step: 1300; ce: 1.7639947; steps/sec: 13.06; FastEstimator-Train: step: 1400; ce: 1.5409832; steps/sec: 11.45; FastEstimator-Train: step: 1500; ce: 1.4090402; steps/sec: 10.2; FastEstimator-Train: step: 1563; epoch: 1; epoch_time(sec): 208.05; Eval Progress: 1/157; Eval Progress: 52/157; steps/sec: 14.17; Eval Progress: 104/157; steps/sec: 11.21; Eval Progress: 157/157; steps/sec: 27.28; FastEstimator-Eval: step: 1563; epoch: 1; accuracy: 0.5006; ce: 1.3493171; FastEstimator-Train: step: 1600; ce: 1.5268304; steps/sec: 8.59; FastEstimator-Train: step: 1700; ce: 1.5136603; steps/sec: 8.34; FastEstimator-Train: step: 1800; ce: 1.2788031; steps/sec: 6.0; FastEstimator-Train: step: 1900; ce: 1.2211283; steps/sec: 12.23; FastEstimator-Train: step: 2000; ce: 1.2741537; steps/sec: 11.08; FastEstimator-Train: step: 2100; ce: 1.327008; steps/sec: 6.86; FastEstimator-Train: step: 2200; ce: 1.7329004; steps/sec: 8.57; FastEstimator-Train: step: 2300; ce: 1.5304513; steps/sec: 11.48; FastEstimator-Train: step: 2400; ce: 1.1817951; steps/sec: 13.21; FastEstimator-Train: step: 2500; ce: 1.4848573; steps/sec: 13.29; FastEstimator-Train: step: 2600; ce: 1.4725429; steps/sec: 14.05; FastEstimator-Train: step: 2700; ce: 1.3283855; steps/sec: 8.0; FastEstimator-Train: step: 2800; ce: 1.3344322; steps/sec: 3.3; FastEstimator-Train: step: 2900; ce: 1.3632934; steps/sec: 3.59; FastEstimator-Train: step: 3000; ce: 1.4991808; steps/sec: 13.95; FastEstimator-Train: step: 3100; ce: 1.3375795; steps/sec: 9.72; FastEstimator-Train: step: 3126; epoch: 2; epoch_time(sec): 199.82; Eval Progress: 1/157; Eval Progress: 52/157; steps/sec: 44.0; Eval Progress: 104/157; steps/sec: 77.09; Eval Progress: 157/157; steps/sec: 53.15; FastEstimator-Eval: step: 3126; epoch: 2; accuracy: 0.5242; ce: 1.3092657; FastEstimator-Train: step: 3200; ce: 1.2650018; steps/sec: 10.98; FastEstimator-Train: step: 3300; ce: 1.169062; steps/sec: 11.88; FastEstimator-Train: step: 3400; ce: 1.6375405; steps/sec: 13.0; FastEstimator-Train: step: 3500; ce: 1.2077605; steps/sec: 11.34; FastEstimator-Train: step: 3600; ce: 1.4420668; steps/sec: 6.71; FastEstimator-Train: step: 3700; ce: 1.0475363; steps/sec: 11.78; FastEstimator-Train: step: 3800; ce: 1.25505; steps/sec: 12.77; FastEstimator-Train: step: 3900; ce: 1.6666105; steps/sec: 14.69; FastEstimator-Train: step: 4000; ce: 1.1326865; steps/sec: 15.31; FastEstimator-Train: step: 4100; ce: 1.667225; steps/sec: 17.02; FastEstimator-Train: step: 4200; ce: 0.94082505; steps/sec: 14.59; FastEstimator-Train: step: 4300; ce: 1.6120632; steps/sec: 12.3; FastEstimator-Train: step: 4400; ce: 1.4536572; steps/sec: 6.06; FastEstimator-Train: step: 4500; ce: 1.0896178; steps/sec: 13.95; FastEstimator-Train: step: 4600; ce: 1.3797795; steps/sec: 15.07; FastEstimator-ModelSaver: Saved model to /tmp/tmp26og0a4e/Shear_epoch_3.pt FastEstimator-Train: step: 4689; epoch: 3; epoch_time(sec): 133.64; Eval Progress: 1/157; Eval Progress: 52/157; steps/sec: 103.06; Eval Progress: 104/157; steps/sec: 38.39; Eval Progress: 157/157; steps/sec: 108.79; FastEstimator-Eval: step: 4689; epoch: 3; accuracy: 0.5906; ce: 1.1530296; FastEstimator-Finish: step: 4689; Shear_lr: 0.001; total_time(sec): 559.21;
Now that we have trained this new model of ours, let's use Grid Search again to test the performance of the model while the input images are rotated in range 0 to 360 degrees(at an interval of 10)
weight_path = os.path.join(save_dir, 'Shear_epoch_3.pt')
rot = list(range(0, 360, 10))
shear_grid_search = GridSearch(eval_fn=lambda search_idx, rotate: score_fn(search_idx, rotate, weight_path, save_dir, field_name="Shear Accuracy"), params={"rotate": rot})
shear_grid_search.fit()
FastEstimator-Test: step: None; epoch: 3; accuracy: 0.5988; ce: 1.1352092; FastEstimator-Search: Evaluated {'rotate': 0, 'search_idx': 1}, result: {'Shear Accuracy': 0.5988} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.5908; ce: 1.1508583; FastEstimator-Search: Evaluated {'rotate': 10, 'search_idx': 2}, result: {'Shear Accuracy': 0.5908} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.5784; ce: 1.2101475; FastEstimator-Search: Evaluated {'rotate': 20, 'search_idx': 3}, result: {'Shear Accuracy': 0.5784} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.5358; ce: 1.3203115; FastEstimator-Search: Evaluated {'rotate': 30, 'search_idx': 4}, result: {'Shear Accuracy': 0.5358} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.485; ce: 1.4910775; FastEstimator-Search: Evaluated {'rotate': 40, 'search_idx': 5}, result: {'Shear Accuracy': 0.485} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.427; ce: 1.7015682; FastEstimator-Search: Evaluated {'rotate': 50, 'search_idx': 6}, result: {'Shear Accuracy': 0.427} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3684; ce: 1.9068954; FastEstimator-Search: Evaluated {'rotate': 60, 'search_idx': 7}, result: {'Shear Accuracy': 0.3684} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3346; ce: 2.0805237; FastEstimator-Search: Evaluated {'rotate': 70, 'search_idx': 8}, result: {'Shear Accuracy': 0.3346} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2974; ce: 2.262713; FastEstimator-Search: Evaluated {'rotate': 80, 'search_idx': 9}, result: {'Shear Accuracy': 0.2974} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2952; ce: 2.2669344; FastEstimator-Search: Evaluated {'rotate': 90, 'search_idx': 10}, result: {'Shear Accuracy': 0.2952} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2602; ce: 2.533539; FastEstimator-Search: Evaluated {'rotate': 100, 'search_idx': 11}, result: {'Shear Accuracy': 0.2602} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2584; ce: 2.58759; FastEstimator-Search: Evaluated {'rotate': 110, 'search_idx': 12}, result: {'Shear Accuracy': 0.2584} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2612; ce: 2.600429; FastEstimator-Search: Evaluated {'rotate': 120, 'search_idx': 13}, result: {'Shear Accuracy': 0.2612} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.27; ce: 2.5737288; FastEstimator-Search: Evaluated {'rotate': 130, 'search_idx': 14}, result: {'Shear Accuracy': 0.27} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2832; ce: 2.4878793; FastEstimator-Search: Evaluated {'rotate': 140, 'search_idx': 15}, result: {'Shear Accuracy': 0.2832} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2948; ce: 2.3791826; FastEstimator-Search: Evaluated {'rotate': 150, 'search_idx': 16}, result: {'Shear Accuracy': 0.2948} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3062; ce: 2.2879615; FastEstimator-Search: Evaluated {'rotate': 160, 'search_idx': 17}, result: {'Shear Accuracy': 0.3062} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3092; ce: 2.2508817; FastEstimator-Search: Evaluated {'rotate': 170, 'search_idx': 18}, result: {'Shear Accuracy': 0.3092} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3184; ce: 2.1892116; FastEstimator-Search: Evaluated {'rotate': 180, 'search_idx': 19}, result: {'Shear Accuracy': 0.3184} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3072; ce: 2.2258189; FastEstimator-Search: Evaluated {'rotate': 190, 'search_idx': 20}, result: {'Shear Accuracy': 0.3072} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3; ce: 2.2513013; FastEstimator-Search: Evaluated {'rotate': 200, 'search_idx': 21}, result: {'Shear Accuracy': 0.3} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2894; ce: 2.3335693; FastEstimator-Search: Evaluated {'rotate': 210, 'search_idx': 22}, result: {'Shear Accuracy': 0.2894} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2692; ce: 2.463834; FastEstimator-Search: Evaluated {'rotate': 220, 'search_idx': 23}, result: {'Shear Accuracy': 0.2692} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2524; ce: 2.5842586; FastEstimator-Search: Evaluated {'rotate': 230, 'search_idx': 24}, result: {'Shear Accuracy': 0.2524} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2506; ce: 2.62961; FastEstimator-Search: Evaluated {'rotate': 240, 'search_idx': 25}, result: {'Shear Accuracy': 0.2506} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2468; ce: 2.6211019; FastEstimator-Search: Evaluated {'rotate': 250, 'search_idx': 26}, result: {'Shear Accuracy': 0.2468} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2494; ce: 2.5996192; FastEstimator-Search: Evaluated {'rotate': 260, 'search_idx': 27}, result: {'Shear Accuracy': 0.2494} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2814; ce: 2.3607845; FastEstimator-Search: Evaluated {'rotate': 270, 'search_idx': 28}, result: {'Shear Accuracy': 0.2814} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.2798; ce: 2.3904302; FastEstimator-Search: Evaluated {'rotate': 280, 'search_idx': 29}, result: {'Shear Accuracy': 0.2798} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3108; ce: 2.2020493; FastEstimator-Search: Evaluated {'rotate': 290, 'search_idx': 30}, result: {'Shear Accuracy': 0.3108} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3652; ce: 1.9952031; FastEstimator-Search: Evaluated {'rotate': 300, 'search_idx': 31}, result: {'Shear Accuracy': 0.3652} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.4182; ce: 1.7621305; FastEstimator-Search: Evaluated {'rotate': 310, 'search_idx': 32}, result: {'Shear Accuracy': 0.4182} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.4808; ce: 1.5279862; FastEstimator-Search: Evaluated {'rotate': 320, 'search_idx': 33}, result: {'Shear Accuracy': 0.4808} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.538; ce: 1.3273461; FastEstimator-Search: Evaluated {'rotate': 330, 'search_idx': 34}, result: {'Shear Accuracy': 0.538} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.5756; ce: 1.2092421; FastEstimator-Search: Evaluated {'rotate': 340, 'search_idx': 35}, result: {'Shear Accuracy': 0.5756} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.5936; ce: 1.1571609; FastEstimator-Search: Evaluated {'rotate': 350, 'search_idx': 36}, result: {'Shear Accuracy': 0.5936}
visualize_search(search=shear_grid_search, title="Model Robustness With Shear")
The effect of applying rotation augmentation¶
To test this out let's train the model while randomly rotating images in range[-90, 90]. First, let's visalize how the rotation operation looks.
est,rotate_pipe = get_estimator(save_dir, weight_path=None, model_name='Rotation', train_rotate=90, epochs=3, visualize=True)
rotate_results = rotate_pipe.get_results()
no_aug_results = no_aug_pipe.get_results()
sample_num = 3
fig = GridDisplay([
BatchDisplay(image=no_aug_results['x'][0:sample_num], title="Pipeline Input"),
BatchDisplay(image=rotate_results['x'][0:sample_num], title="Pipeline Output")
])
fig.show()
est.fit()
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ FastEstimator-Start: step: 1; logging_interval: 100; num_device: 0; FastEstimator-Train: step: 1; ce: 2.299035; FastEstimator-Train: step: 100; ce: 2.075551; steps/sec: 6.59; FastEstimator-Train: step: 200; ce: 1.96548; steps/sec: 14.31; FastEstimator-Train: step: 300; ce: 1.6383593; steps/sec: 7.94; FastEstimator-Train: step: 400; ce: 1.8449334; steps/sec: 8.19; FastEstimator-Train: step: 500; ce: 1.7785522; steps/sec: 11.14; FastEstimator-Train: step: 600; ce: 1.8474475; steps/sec: 10.75; FastEstimator-Train: step: 700; ce: 1.697898; steps/sec: 8.71; FastEstimator-Train: step: 800; ce: 1.5075471; steps/sec: 4.63; FastEstimator-Train: step: 900; ce: 1.6469618; steps/sec: 6.64; FastEstimator-Train: step: 1000; ce: 1.5503968; steps/sec: 9.44; FastEstimator-Train: step: 1100; ce: 1.711024; steps/sec: 9.04; FastEstimator-Train: step: 1200; ce: 1.5141683; steps/sec: 12.01; FastEstimator-Train: step: 1300; ce: 1.781623; steps/sec: 13.39; FastEstimator-Train: step: 1400; ce: 1.5333574; steps/sec: 10.26; FastEstimator-Train: step: 1500; ce: 1.853748; steps/sec: 11.98; FastEstimator-Train: step: 1563; epoch: 1; epoch_time(sec): 179.68; Eval Progress: 1/157; Eval Progress: 52/157; steps/sec: 31.86; Eval Progress: 104/157; steps/sec: 61.86; Eval Progress: 157/157; steps/sec: 70.95; FastEstimator-Eval: step: 1563; epoch: 1; accuracy: 0.4368; ce: 1.5425291; FastEstimator-Train: step: 1600; ce: 1.4006021; steps/sec: 7.67; FastEstimator-Train: step: 1700; ce: 1.3395236; steps/sec: 10.88; FastEstimator-Train: step: 1800; ce: 1.498162; steps/sec: 15.62; FastEstimator-Train: step: 1900; ce: 1.5839293; steps/sec: 10.67; FastEstimator-Train: step: 2000; ce: 1.149652; steps/sec: 7.71; FastEstimator-Train: step: 2100; ce: 1.7589302; steps/sec: 10.33; FastEstimator-Train: step: 2200; ce: 1.4050286; steps/sec: 7.69; FastEstimator-Train: step: 2300; ce: 1.7004766; steps/sec: 13.79; FastEstimator-Train: step: 2400; ce: 1.3629683; steps/sec: 13.57; FastEstimator-Train: step: 2500; ce: 1.5042443; steps/sec: 7.83; FastEstimator-Train: step: 2600; ce: 1.4782196; steps/sec: 10.63; FastEstimator-Train: step: 2700; ce: 1.9164765; steps/sec: 9.35; FastEstimator-Train: step: 2800; ce: 1.4845995; steps/sec: 8.12; FastEstimator-Train: step: 2900; ce: 1.1254141; steps/sec: 7.15; FastEstimator-Train: step: 3000; ce: 1.6342312; steps/sec: 8.51; FastEstimator-Train: step: 3100; ce: 1.0823914; steps/sec: 10.25; FastEstimator-Train: step: 3126; epoch: 2; epoch_time(sec): 161.58; Eval Progress: 1/157; Eval Progress: 52/157; steps/sec: 30.03; Eval Progress: 104/157; steps/sec: 46.97; Eval Progress: 157/157; steps/sec: 35.09; FastEstimator-Eval: step: 3126; epoch: 2; accuracy: 0.493; ce: 1.4054103; FastEstimator-Train: step: 3200; ce: 1.474694; steps/sec: 8.11; FastEstimator-Train: step: 3300; ce: 1.22069; steps/sec: 3.53; FastEstimator-Train: step: 3400; ce: 1.3306724; steps/sec: 3.82; FastEstimator-Train: step: 3500; ce: 1.6650531; steps/sec: 13.52; FastEstimator-Train: step: 3600; ce: 1.6610831; steps/sec: 13.37; FastEstimator-Train: step: 3700; ce: 1.4613279; steps/sec: 8.74; FastEstimator-Train: step: 3800; ce: 1.1969187; steps/sec: 12.35; FastEstimator-Train: step: 3900; ce: 1.2569786; steps/sec: 7.64; FastEstimator-Train: step: 4000; ce: 1.191232; steps/sec: 3.82; FastEstimator-Train: step: 4100; ce: 1.3513925; steps/sec: 10.37; FastEstimator-Train: step: 4200; ce: 1.3323293; steps/sec: 10.39; FastEstimator-Train: step: 4300; ce: 1.3886821; steps/sec: 13.58; FastEstimator-Train: step: 4400; ce: 0.8490051; steps/sec: 8.1; FastEstimator-Train: step: 4500; ce: 1.0546207; steps/sec: 13.12; FastEstimator-Train: step: 4600; ce: 1.5466611; steps/sec: 13.66; FastEstimator-ModelSaver: Saved model to /tmp/tmp26og0a4e/Rotation_epoch_3.pt FastEstimator-Train: step: 4689; epoch: 3; epoch_time(sec): 199.32; Eval Progress: 1/157; Eval Progress: 52/157; steps/sec: 33.79; Eval Progress: 104/157; steps/sec: 95.27; Eval Progress: 157/157; steps/sec: 89.83; FastEstimator-Eval: step: 4689; epoch: 3; accuracy: 0.5168; ce: 1.3441926; FastEstimator-Finish: step: 4689; Rotation_lr: 0.001; total_time(sec): 553.19;
Let us use Grid Search again to test the performance of the model while the input images are rotated in range 0 to 360 degrees(at an interval of 10)
weight_path = os.path.join(save_dir, 'Rotation_epoch_3.pt')
rot = list(range(0, 360, 10))
rotation_grid_search = GridSearch(eval_fn=lambda search_idx, rotate: score_fn(search_idx, rotate, weight_path, save_dir, field_name="Rotation Accuracy"), params={"rotate": rot})
rotation_grid_search.fit()
FastEstimator-Test: step: None; epoch: 3; accuracy: 0.526; ce: 1.3442501; FastEstimator-Search: Evaluated {'rotate': 0, 'search_idx': 1}, result: {'Rotation Accuracy': 0.526} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.529; ce: 1.3476864; FastEstimator-Search: Evaluated {'rotate': 10, 'search_idx': 2}, result: {'Rotation Accuracy': 0.529} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.5274; ce: 1.3577137; FastEstimator-Search: Evaluated {'rotate': 20, 'search_idx': 3}, result: {'Rotation Accuracy': 0.5274} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.5238; ce: 1.3682574; FastEstimator-Search: Evaluated {'rotate': 30, 'search_idx': 4}, result: {'Rotation Accuracy': 0.5238} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.5208; ce: 1.3765999; FastEstimator-Search: Evaluated {'rotate': 40, 'search_idx': 5}, result: {'Rotation Accuracy': 0.5208} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.52; ce: 1.3815999; FastEstimator-Search: Evaluated {'rotate': 50, 'search_idx': 6}, result: {'Rotation Accuracy': 0.52} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.5024; ce: 1.3979217; FastEstimator-Search: Evaluated {'rotate': 60, 'search_idx': 7}, result: {'Rotation Accuracy': 0.5024} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.4924; ce: 1.4351144; FastEstimator-Search: Evaluated {'rotate': 70, 'search_idx': 8}, result: {'Rotation Accuracy': 0.4924} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.4728; ce: 1.4845449; FastEstimator-Search: Evaluated {'rotate': 80, 'search_idx': 9}, result: {'Rotation Accuracy': 0.4728} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.4506; ce: 1.5334638; FastEstimator-Search: Evaluated {'rotate': 90, 'search_idx': 10}, result: {'Rotation Accuracy': 0.4506} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.413; ce: 1.6307365; FastEstimator-Search: Evaluated {'rotate': 100, 'search_idx': 11}, result: {'Rotation Accuracy': 0.413} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.381; ce: 1.7120235; FastEstimator-Search: Evaluated {'rotate': 110, 'search_idx': 12}, result: {'Rotation Accuracy': 0.381} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3628; ce: 1.7878714; FastEstimator-Search: Evaluated {'rotate': 120, 'search_idx': 13}, result: {'Rotation Accuracy': 0.3628} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3404; ce: 1.8583707; FastEstimator-Search: Evaluated {'rotate': 130, 'search_idx': 14}, result: {'Rotation Accuracy': 0.3404} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3232; ce: 1.9118625; FastEstimator-Search: Evaluated {'rotate': 140, 'search_idx': 15}, result: {'Rotation Accuracy': 0.3232} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3132; ce: 1.9561087; FastEstimator-Search: Evaluated {'rotate': 150, 'search_idx': 16}, result: {'Rotation Accuracy': 0.3132} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3096; ce: 1.982389; FastEstimator-Search: Evaluated {'rotate': 160, 'search_idx': 17}, result: {'Rotation Accuracy': 0.3096} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3082; ce: 2.002314; FastEstimator-Search: Evaluated {'rotate': 170, 'search_idx': 18}, result: {'Rotation Accuracy': 0.3082} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.31; ce: 1.9804295; FastEstimator-Search: Evaluated {'rotate': 180, 'search_idx': 19}, result: {'Rotation Accuracy': 0.31} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3014; ce: 2.0179684; FastEstimator-Search: Evaluated {'rotate': 190, 'search_idx': 20}, result: {'Rotation Accuracy': 0.3014} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3058; ce: 2.0084887; FastEstimator-Search: Evaluated {'rotate': 200, 'search_idx': 21}, result: {'Rotation Accuracy': 0.3058} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3106; ce: 1.9818201; FastEstimator-Search: Evaluated {'rotate': 210, 'search_idx': 22}, result: {'Rotation Accuracy': 0.3106} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.322; ce: 1.9469423; FastEstimator-Search: Evaluated {'rotate': 220, 'search_idx': 23}, result: {'Rotation Accuracy': 0.322} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3334; ce: 1.891612; FastEstimator-Search: Evaluated {'rotate': 230, 'search_idx': 24}, result: {'Rotation Accuracy': 0.3334} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.35; ce: 1.8142512; FastEstimator-Search: Evaluated {'rotate': 240, 'search_idx': 25}, result: {'Rotation Accuracy': 0.35} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.3808; ce: 1.7215782; FastEstimator-Search: Evaluated {'rotate': 250, 'search_idx': 26}, result: {'Rotation Accuracy': 0.3808} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.4116; ce: 1.6249053; FastEstimator-Search: Evaluated {'rotate': 260, 'search_idx': 27}, result: {'Rotation Accuracy': 0.4116} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.444; ce: 1.5260195; FastEstimator-Search: Evaluated {'rotate': 270, 'search_idx': 28}, result: {'Rotation Accuracy': 0.444} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.4716; ce: 1.4676648; FastEstimator-Search: Evaluated {'rotate': 280, 'search_idx': 29}, result: {'Rotation Accuracy': 0.4716} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.4906; ce: 1.413264; FastEstimator-Search: Evaluated {'rotate': 290, 'search_idx': 30}, result: {'Rotation Accuracy': 0.4906} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.5042; ce: 1.3752073; FastEstimator-Search: Evaluated {'rotate': 300, 'search_idx': 31}, result: {'Rotation Accuracy': 0.5042} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.5196; ce: 1.3499327; FastEstimator-Search: Evaluated {'rotate': 310, 'search_idx': 32}, result: {'Rotation Accuracy': 0.5196} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.52; ce: 1.3427712; FastEstimator-Search: Evaluated {'rotate': 320, 'search_idx': 33}, result: {'Rotation Accuracy': 0.52} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.5126; ce: 1.3381072; FastEstimator-Search: Evaluated {'rotate': 330, 'search_idx': 34}, result: {'Rotation Accuracy': 0.5126} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.5194; ce: 1.3385159; FastEstimator-Search: Evaluated {'rotate': 340, 'search_idx': 35}, result: {'Rotation Accuracy': 0.5194} FastEstimator-Test: step: None; epoch: 3; accuracy: 0.522; ce: 1.3364942; FastEstimator-Search: Evaluated {'rotate': 350, 'search_idx': 36}, result: {'Rotation Accuracy': 0.522}
visualize_search(search=rotation_grid_search, title="Model Robustness With Rotation")
The above plot proves that if we introduce random rotation while training, we can make model more robust to rotation. The model is still not robust to the range [90, 270] while testing since it was not trained on the same.
Model Comparison¶
Let's compare all the trained model to compare model robustness to rotation
def comparison_fn(search_idx, rotate, no_aug_results, shear_results, rotation_results):
acc = no_aug_results[rotate//10]['result']
shear_acc = shear_results[rotate//10]['result']
rotation_acc = rotation_results[rotate//10]['result']
return {**acc, **shear_acc, **rotation_acc}
no_aug_results = no_aug_grid_search.get_search_summary()
shear_results = shear_grid_search.get_search_summary()
rotation_results = rotation_grid_search.get_search_summary()
rot = list(range(0, 360, 10))
model_comparison_grid_search = GridSearch(eval_fn=lambda search_idx, rotate: comparison_fn(search_idx, rotate, no_aug_results, shear_results, rotation_results), params={"rotate": rot})
model_comparison_grid_search.fit()
FastEstimator-Search: Evaluated {'rotate': 0, 'search_idx': 1}, result: {'Accuracy without Augmentation': 0.6664, 'Shear Accuracy': 0.5988, 'Rotation Accuracy': 0.526} FastEstimator-Search: Evaluated {'rotate': 10, 'search_idx': 2}, result: {'Accuracy without Augmentation': 0.6422, 'Shear Accuracy': 0.5908, 'Rotation Accuracy': 0.529} FastEstimator-Search: Evaluated {'rotate': 20, 'search_idx': 3}, result: {'Accuracy without Augmentation': 0.5804, 'Shear Accuracy': 0.5784, 'Rotation Accuracy': 0.5274} FastEstimator-Search: Evaluated {'rotate': 30, 'search_idx': 4}, result: {'Accuracy without Augmentation': 0.4862, 'Shear Accuracy': 0.5358, 'Rotation Accuracy': 0.5238} FastEstimator-Search: Evaluated {'rotate': 40, 'search_idx': 5}, result: {'Accuracy without Augmentation': 0.3996, 'Shear Accuracy': 0.485, 'Rotation Accuracy': 0.5208} FastEstimator-Search: Evaluated {'rotate': 50, 'search_idx': 6}, result: {'Accuracy without Augmentation': 0.3414, 'Shear Accuracy': 0.427, 'Rotation Accuracy': 0.52} FastEstimator-Search: Evaluated {'rotate': 60, 'search_idx': 7}, result: {'Accuracy without Augmentation': 0.3072, 'Shear Accuracy': 0.3684, 'Rotation Accuracy': 0.5024} FastEstimator-Search: Evaluated {'rotate': 70, 'search_idx': 8}, result: {'Accuracy without Augmentation': 0.2864, 'Shear Accuracy': 0.3346, 'Rotation Accuracy': 0.4924} FastEstimator-Search: Evaluated {'rotate': 80, 'search_idx': 9}, result: {'Accuracy without Augmentation': 0.2818, 'Shear Accuracy': 0.2974, 'Rotation Accuracy': 0.4728} FastEstimator-Search: Evaluated {'rotate': 90, 'search_idx': 10}, result: {'Accuracy without Augmentation': 0.2932, 'Shear Accuracy': 0.2952, 'Rotation Accuracy': 0.4506} FastEstimator-Search: Evaluated {'rotate': 100, 'search_idx': 11}, result: {'Accuracy without Augmentation': 0.256, 'Shear Accuracy': 0.2602, 'Rotation Accuracy': 0.413} FastEstimator-Search: Evaluated {'rotate': 110, 'search_idx': 12}, result: {'Accuracy without Augmentation': 0.2414, 'Shear Accuracy': 0.2584, 'Rotation Accuracy': 0.381} FastEstimator-Search: Evaluated {'rotate': 120, 'search_idx': 13}, result: {'Accuracy without Augmentation': 0.2352, 'Shear Accuracy': 0.2612, 'Rotation Accuracy': 0.3628} FastEstimator-Search: Evaluated {'rotate': 130, 'search_idx': 14}, result: {'Accuracy without Augmentation': 0.2346, 'Shear Accuracy': 0.27, 'Rotation Accuracy': 0.3404} FastEstimator-Search: Evaluated {'rotate': 140, 'search_idx': 15}, result: {'Accuracy without Augmentation': 0.2496, 'Shear Accuracy': 0.2832, 'Rotation Accuracy': 0.3232} FastEstimator-Search: Evaluated {'rotate': 150, 'search_idx': 16}, result: {'Accuracy without Augmentation': 0.2694, 'Shear Accuracy': 0.2948, 'Rotation Accuracy': 0.3132} FastEstimator-Search: Evaluated {'rotate': 160, 'search_idx': 17}, result: {'Accuracy without Augmentation': 0.301, 'Shear Accuracy': 0.3062, 'Rotation Accuracy': 0.3096} FastEstimator-Search: Evaluated {'rotate': 170, 'search_idx': 18}, result: {'Accuracy without Augmentation': 0.316, 'Shear Accuracy': 0.3092, 'Rotation Accuracy': 0.3082} FastEstimator-Search: Evaluated {'rotate': 180, 'search_idx': 19}, result: {'Accuracy without Augmentation': 0.3428, 'Shear Accuracy': 0.3184, 'Rotation Accuracy': 0.31} FastEstimator-Search: Evaluated {'rotate': 190, 'search_idx': 20}, result: {'Accuracy without Augmentation': 0.3158, 'Shear Accuracy': 0.3072, 'Rotation Accuracy': 0.3014} FastEstimator-Search: Evaluated {'rotate': 200, 'search_idx': 21}, result: {'Accuracy without Augmentation': 0.3002, 'Shear Accuracy': 0.3, 'Rotation Accuracy': 0.3058} FastEstimator-Search: Evaluated {'rotate': 210, 'search_idx': 22}, result: {'Accuracy without Augmentation': 0.2768, 'Shear Accuracy': 0.2894, 'Rotation Accuracy': 0.3106} FastEstimator-Search: Evaluated {'rotate': 220, 'search_idx': 23}, result: {'Accuracy without Augmentation': 0.2486, 'Shear Accuracy': 0.2692, 'Rotation Accuracy': 0.322} FastEstimator-Search: Evaluated {'rotate': 230, 'search_idx': 24}, result: {'Accuracy without Augmentation': 0.2346, 'Shear Accuracy': 0.2524, 'Rotation Accuracy': 0.3334} FastEstimator-Search: Evaluated {'rotate': 240, 'search_idx': 25}, result: {'Accuracy without Augmentation': 0.2364, 'Shear Accuracy': 0.2506, 'Rotation Accuracy': 0.35} FastEstimator-Search: Evaluated {'rotate': 250, 'search_idx': 26}, result: {'Accuracy without Augmentation': 0.2488, 'Shear Accuracy': 0.2468, 'Rotation Accuracy': 0.3808} FastEstimator-Search: Evaluated {'rotate': 260, 'search_idx': 27}, result: {'Accuracy without Augmentation': 0.2634, 'Shear Accuracy': 0.2494, 'Rotation Accuracy': 0.4116} FastEstimator-Search: Evaluated {'rotate': 270, 'search_idx': 28}, result: {'Accuracy without Augmentation': 0.2904, 'Shear Accuracy': 0.2814, 'Rotation Accuracy': 0.444} FastEstimator-Search: Evaluated {'rotate': 280, 'search_idx': 29}, result: {'Accuracy without Augmentation': 0.2768, 'Shear Accuracy': 0.2798, 'Rotation Accuracy': 0.4716} FastEstimator-Search: Evaluated {'rotate': 290, 'search_idx': 30}, result: {'Accuracy without Augmentation': 0.277, 'Shear Accuracy': 0.3108, 'Rotation Accuracy': 0.4906} FastEstimator-Search: Evaluated {'rotate': 300, 'search_idx': 31}, result: {'Accuracy without Augmentation': 0.2798, 'Shear Accuracy': 0.3652, 'Rotation Accuracy': 0.5042} FastEstimator-Search: Evaluated {'rotate': 310, 'search_idx': 32}, result: {'Accuracy without Augmentation': 0.3082, 'Shear Accuracy': 0.4182, 'Rotation Accuracy': 0.5196} FastEstimator-Search: Evaluated {'rotate': 320, 'search_idx': 33}, result: {'Accuracy without Augmentation': 0.3678, 'Shear Accuracy': 0.4808, 'Rotation Accuracy': 0.52} FastEstimator-Search: Evaluated {'rotate': 330, 'search_idx': 34}, result: {'Accuracy without Augmentation': 0.478, 'Shear Accuracy': 0.538, 'Rotation Accuracy': 0.5126} FastEstimator-Search: Evaluated {'rotate': 340, 'search_idx': 35}, result: {'Accuracy without Augmentation': 0.5888, 'Shear Accuracy': 0.5756, 'Rotation Accuracy': 0.5194} FastEstimator-Search: Evaluated {'rotate': 350, 'search_idx': 36}, result: {'Accuracy without Augmentation': 0.646, 'Shear Accuracy': 0.5936, 'Rotation Accuracy': 0.522}
visualize_search(search=model_comparison_grid_search, title="Model Robustness",groups=[['Accuracy without Augmentation','Shear Accuracy','Rotation Accuracy']])
It can be seen that by randomly introducing shear and rotation to the input images while training makes model more robust to rotation. So with augmentation we can make the model aware of rotation and increase robustness of the model.
NOTE : All the three models compared above were trained only for 3 epochs on LeNet. One can achieve a much better performance by training long enough and choosing a different architecture.