Curriculum Learning with SuperLoss (Tensorflow Backend)¶
[Paper] [Notebook] [TF Implementation] [Torch Implementation]
In this example, we are going to demonstrate how to easily add curriculum learning to any project using SuperLoss. When humans learn something in school, we are first taught how to do easy versions of the task before graduating to more difficult problems. Curriculum learning seeks to emulate that process with neural networks. One way to do this would be to try and modify a data pipeline to change the order in which it presents examples, but an easier way is to simply modify your loss term to reduce the contribution of difficult examples until later on during training. Curriculum learning has been shown to be especially useful when you have label noise in your dataset, since noisy samples are essentially 'hard' and you want to put off trying to learn them.
Import the required libraries¶
import math
import tempfile
import numpy as np
from tensorflow.keras.layers import BatchNormalization, Conv2D, Dense, Flatten, MaxPooling2D
from tensorflow.keras.models import Sequential
import fastestimator as fe
from fastestimator.dataset.data import cifair100
from fastestimator.op.numpyop.meta import Sometimes
from fastestimator.op.numpyop.multivariate import HorizontalFlip, PadIfNeeded, RandomCrop
from fastestimator.op.numpyop.univariate import CoarseDropout, Normalize
from fastestimator.op.tensorop.loss import CrossEntropy, SuperLoss
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
from fastestimator.trace.io import BestModelSaver
from fastestimator.trace.metric import MCC
from fastestimator.trace.xai import LabelTracker
#training parameters
epochs = 50
batch_size = 128
train_steps_per_epoch = None
eval_steps_per_epoch = None
save_dir = tempfile.mkdtemp()
Step 1 - Data preparation¶
In this step, we will load the ciFAIR100 training and validation datasets. We use a FastEstimator API to load the dataset and then get a test set by splitting 50% of the data off of the evaluation set. We are also going to corrupt the training data by adding 40% label noise, to simulate the fact that many real-world datasets may have low quality annotations.
from fastestimator.dataset.data import cifair100
train_data, eval_data = cifair100.load_data()
test_data = eval_data.split(0.5)
def corrupt_dataset(dataset, n_classes=100, corruption_fraction=0.4):
# Keep track of which samples were corrupted for visualization later
corrupted = [0 for _ in range(len(dataset))]
# Perform the actual label corruption
n_samples_per_class = len(dataset) // n_classes
n_to_corrupt_per_class = math.floor(corruption_fraction * n_samples_per_class)
n_corrupted = [0] * n_classes
i = 0
while any([elem < n_to_corrupt_per_class for elem in n_corrupted]):
current_class = dataset[i]['y'].item()
if n_corrupted[current_class] < n_to_corrupt_per_class:
dataset[i]['y'] = (dataset[i]['y'] + np.random.randint(1, n_classes)) % n_classes
n_corrupted[current_class] += 1
corrupted[i] = 1
i += 1
# Put the corruption labels into the dataset for visualization
dataset['data_labels'] = np.array(corrupted, dtype=int).reshape((len(dataset), 1))
corrupt_dataset(train_data)
Step 2 - Build some Estimators¶
We will define a function that builds relatively simple estimators given only a particular loss function as an input. We can then compare the effects of using a regular loss versus a SuperLoss on our artificially corrupted dataset.
def big_lenet(classes=100, input_shape=(32, 32, 3)):
# Like a LeNet model, but bigger.
model = Sequential()
model.add(Conv2D(32, (3, 3), activation='swish', input_shape=input_shape))
model.add(BatchNormalization())
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(64, (3, 3), activation='swish'))
model.add(BatchNormalization())
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(128, (3, 3), activation='swish'))
model.add(BatchNormalization())
model.add(MaxPooling2D((2, 2)))
model.add(Flatten())
model.add(Dense(128, activation='swish'))
model.add(BatchNormalization())
model.add(Dense(classes, activation='softmax'))
return model
def build_estimator(loss_op):
pipeline = fe.Pipeline(train_data=train_data,
eval_data=eval_data,
test_data=test_data,
batch_size=batch_size,
ops=[Normalize(inputs="x", outputs="x", mean=(0.4914, 0.4822, 0.4465), std=(0.2471, 0.2435, 0.2616)),
PadIfNeeded(min_height=40, min_width=40, image_in="x", image_out="x", mode="train"),
RandomCrop(32, 32, image_in="x", image_out="x", mode="train"),
Sometimes(HorizontalFlip(image_in="x", image_out="x", mode="train")),
CoarseDropout(inputs="x", outputs="x", max_holes=1, mode="train"),
])
model = fe.build(model_fn=big_lenet, optimizer_fn='adam')
network = fe.Network(ops=[
ModelOp(model=model, inputs="x", outputs="y_pred"),
loss_op, # <<<----------------------------- This is where the secret sauce will go
UpdateOp(model=model, loss_name="ce")
])
traces = [
MCC(true_key="y", pred_key="y_pred"),
BestModelSaver(model=model, save_dir=save_dir, metric="mcc", save_best_mode="max", load_best_final=True),
# We will also visualize the difference between the normal and corrupted image confidence scores. You could follow this with an
# ImageViewer trace, but we will get the data out of the system summary instead later for viewing.
LabelTracker(metric="confidence", label="data_labels", label_mapping={"Normal": 0, "Corrupted": 1}, mode="train", outputs="label_confidence"),
]
estimator = fe.Estimator(pipeline=pipeline,
network=network,
epochs=epochs,
traces=traces,
train_steps_per_epoch=train_steps_per_epoch,
eval_steps_per_epoch=eval_steps_per_epoch,
log_steps=300)
return estimator
Step 3 - Train a baseline model with a regular loss¶
Let's start by training a regular model using standard CrossEntropy and see what we get. We will also define a fake SuperLoss wrapper to get sample confidence estimates in order to visualize the differences between clean and corrupted data performance.
class FakeSuperLoss(SuperLoss):
def forward(self, data, state):
superloss, confidence = super().forward(data, state)
regularloss = fe.backend.reduce_mean(self.loss.forward(data, state))
return [regularloss, confidence]
loss = FakeSuperLoss(CrossEntropy(inputs=("y_pred", "y"), outputs="ce"), output_confidence="confidence")
estimator_regular = build_estimator(loss)
regular = estimator_regular.fit("RegularLoss")
Metal device set to: Apple M1 Max
2022-04-14 08:04:52.500693: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support. 2022-04-14 08:04:52.500861: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>) 2022-04-14 08:04:52.968498: W tensorflow/python/util/util.cc:368] Sets are not currently considered sequences, but this may change in the future, so consider avoiding using them.
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass. WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass. FastEstimator-Start: step: 1; logging_interval: 300; num_device: 0; WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass. FastEstimator-Train: step: 1; ce: 5.0845623; FastEstimator-Train: step: 300; ce: 4.4716263; steps/sec: 37.57; WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass. FastEstimator-Train: step: 391; epoch: 1; epoch_time: 15.59 sec; WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass. WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass. FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model_best_mcc.h5 FastEstimator-Eval: step: 391; epoch: 1; ce: 3.9165292; max_mcc: 0.11637703081989584; mcc: 0.11637703081989584; since_best_mcc: 0; FastEstimator-Train: step: 600; ce: 4.143276; steps/sec: 26.6; FastEstimator-Train: step: 782; epoch: 2; epoch_time: 11.49 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model_best_mcc.h5 FastEstimator-Eval: step: 782; epoch: 2; ce: 3.4811482; max_mcc: 0.19463705463257847; mcc: 0.19463705463257847; since_best_mcc: 0; FastEstimator-Train: step: 900; ce: 4.118492; steps/sec: 33.88; FastEstimator-Train: step: 1173; epoch: 3; epoch_time: 12.09 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model_best_mcc.h5 FastEstimator-Eval: step: 1173; epoch: 3; ce: 3.2860541; max_mcc: 0.22608677082453127; mcc: 0.22608677082453127; since_best_mcc: 0; FastEstimator-Train: step: 1200; ce: 3.9007201; steps/sec: 31.42; FastEstimator-Train: step: 1500; ce: 4.1510983; steps/sec: 36.11; FastEstimator-Train: step: 1564; epoch: 4; epoch_time: 11.54 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model_best_mcc.h5 FastEstimator-Eval: step: 1564; epoch: 4; ce: 3.2074647; max_mcc: 0.24498756546447986; mcc: 0.24498756546447986; since_best_mcc: 0; FastEstimator-Train: step: 1800; ce: 3.9860764; steps/sec: 33.85; FastEstimator-Train: step: 1955; epoch: 5; epoch_time: 11.7 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model_best_mcc.h5 FastEstimator-Eval: step: 1955; epoch: 5; ce: 3.0909023; max_mcc: 0.2673344783186453; mcc: 0.2673344783186453; since_best_mcc: 0; FastEstimator-Train: step: 2100; ce: 3.7885048; steps/sec: 32.45; FastEstimator-Train: step: 2346; epoch: 6; epoch_time: 11.26 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model_best_mcc.h5 FastEstimator-Eval: step: 2346; epoch: 6; ce: 3.0296268; max_mcc: 0.28257658668430546; mcc: 0.28257658668430546; since_best_mcc: 0; FastEstimator-Train: step: 2400; ce: 3.7728481; steps/sec: 34.22; FastEstimator-Train: step: 2700; ce: 3.9727435; steps/sec: 37.17; FastEstimator-Train: step: 2737; epoch: 7; epoch_time: 11.3 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model_best_mcc.h5 FastEstimator-Eval: step: 2737; epoch: 7; ce: 2.9461927; max_mcc: 0.3041788741339668; mcc: 0.3041788741339668; since_best_mcc: 0; FastEstimator-Train: step: 3000; ce: 3.6648512; steps/sec: 33.29; FastEstimator-Train: step: 3128; epoch: 8; epoch_time: 11.41 sec; FastEstimator-Eval: step: 3128; epoch: 8; ce: 2.9334276; max_mcc: 0.3041788741339668; mcc: 0.3011113825398861; since_best_mcc: 1; FastEstimator-Train: step: 3300; ce: 3.9563556; steps/sec: 32.21; FastEstimator-Train: step: 3519; epoch: 9; epoch_time: 12.17 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model_best_mcc.h5 FastEstimator-Eval: step: 3519; epoch: 9; ce: 2.9065385; max_mcc: 0.30455805637425065; mcc: 0.30455805637425065; since_best_mcc: 0; FastEstimator-Train: step: 3600; ce: 3.9232101; steps/sec: 32.61; FastEstimator-Train: step: 3900; ce: 3.7361774; steps/sec: 36.74; FastEstimator-Train: step: 3910; epoch: 10; epoch_time: 11.33 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model_best_mcc.h5 FastEstimator-Eval: step: 3910; epoch: 10; ce: 2.8042293; max_mcc: 0.3314968221510782; mcc: 0.3314968221510782; since_best_mcc: 0; FastEstimator-Train: step: 4200; ce: 3.647646; steps/sec: 32.71; FastEstimator-Train: step: 4301; epoch: 11; epoch_time: 11.76 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model_best_mcc.h5 FastEstimator-Eval: step: 4301; epoch: 11; ce: 2.7552197; max_mcc: 0.3357567841576629; mcc: 0.3357567841576629; since_best_mcc: 0; FastEstimator-Train: step: 4500; ce: 3.4892626; steps/sec: 32.86; FastEstimator-Train: step: 4692; epoch: 12; epoch_time: 11.36 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model_best_mcc.h5 FastEstimator-Eval: step: 4692; epoch: 12; ce: 2.736461; max_mcc: 0.3510061360991489; mcc: 0.3510061360991489; since_best_mcc: 0; FastEstimator-Train: step: 4800; ce: 3.8051057; steps/sec: 34.65; FastEstimator-Train: step: 5083; epoch: 13; epoch_time: 11.26 sec; FastEstimator-Eval: step: 5083; epoch: 13; ce: 2.75954; max_mcc: 0.3510061360991489; mcc: 0.33251882862713855; since_best_mcc: 1; FastEstimator-Train: step: 5100; ce: 3.652556; steps/sec: 33.73; FastEstimator-Train: step: 5400; ce: 3.8262262; steps/sec: 37.64; FastEstimator-Train: step: 5474; epoch: 14; epoch_time: 11.13 sec; FastEstimator-Eval: step: 5474; epoch: 14; ce: 2.7278647; max_mcc: 0.3510061360991489; mcc: 0.3468396191191313; since_best_mcc: 2; FastEstimator-Train: step: 5700; ce: 3.76147; steps/sec: 33.38; FastEstimator-Train: step: 5865; epoch: 15; epoch_time: 11.52 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model_best_mcc.h5 FastEstimator-Eval: step: 5865; epoch: 15; ce: 2.698529; max_mcc: 0.35409691305967284; mcc: 0.35409691305967284; since_best_mcc: 0; FastEstimator-Train: step: 6000; ce: 3.7462912; steps/sec: 32.96; FastEstimator-Train: step: 6256; epoch: 16; epoch_time: 11.82 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model_best_mcc.h5 FastEstimator-Eval: step: 6256; epoch: 16; ce: 2.6763847; max_mcc: 0.36294897145164334; mcc: 0.36294897145164334; since_best_mcc: 0; FastEstimator-Train: step: 6300; ce: 3.6404994; steps/sec: 32.52; FastEstimator-Train: step: 6600; ce: 3.8307505; steps/sec: 35.45; FastEstimator-Train: step: 6647; epoch: 17; epoch_time: 11.86 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model_best_mcc.h5 FastEstimator-Eval: step: 6647; epoch: 17; ce: 2.6446831; max_mcc: 0.37233087727784103; mcc: 0.37233087727784103; since_best_mcc: 0; FastEstimator-Train: step: 6900; ce: 3.7684755; steps/sec: 32.51; FastEstimator-Train: step: 7038; epoch: 18; epoch_time: 11.69 sec; FastEstimator-Eval: step: 7038; epoch: 18; ce: 2.6735303; max_mcc: 0.37233087727784103; mcc: 0.3589859924598546; since_best_mcc: 1; FastEstimator-Train: step: 7200; ce: 3.9045172; steps/sec: 33.63; FastEstimator-Train: step: 7429; epoch: 19; epoch_time: 11.27 sec; FastEstimator-Eval: step: 7429; epoch: 19; ce: 2.677819; max_mcc: 0.37233087727784103; mcc: 0.3722223425197881; since_best_mcc: 2; FastEstimator-Train: step: 7500; ce: 3.513299; steps/sec: 32.48; FastEstimator-Train: step: 7800; ce: 3.6698349; steps/sec: 36.47; FastEstimator-Train: step: 7820; epoch: 20; epoch_time: 11.83 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model_best_mcc.h5 FastEstimator-Eval: step: 7820; epoch: 20; ce: 2.6317315; max_mcc: 0.38029253313466144; mcc: 0.38029253313466144; since_best_mcc: 0; FastEstimator-Train: step: 8100; ce: 3.586587; steps/sec: 33.85; FastEstimator-Train: step: 8211; epoch: 21; epoch_time: 11.52 sec; FastEstimator-Eval: step: 8211; epoch: 21; ce: 2.6565547; max_mcc: 0.38029253313466144; mcc: 0.36699826561040333; since_best_mcc: 1; FastEstimator-Train: step: 8400; ce: 3.7047858; steps/sec: 33.03; FastEstimator-Train: step: 8602; epoch: 22; epoch_time: 11.46 sec; FastEstimator-Eval: step: 8602; epoch: 22; ce: 2.6064296; max_mcc: 0.38029253313466144; mcc: 0.3740086404189197; since_best_mcc: 2; FastEstimator-Train: step: 8700; ce: 3.3841472; steps/sec: 32.07; FastEstimator-Train: step: 8993; epoch: 23; epoch_time: 12.04 sec; FastEstimator-Eval: step: 8993; epoch: 23; ce: 2.6085975; max_mcc: 0.38029253313466144; mcc: 0.36779678534767773; since_best_mcc: 3; FastEstimator-Train: step: 9000; ce: 3.6447806; steps/sec: 32.41; FastEstimator-Train: step: 9300; ce: 3.3836112; steps/sec: 37.36; FastEstimator-Train: step: 9384; epoch: 24; epoch_time: 11.19 sec; FastEstimator-Eval: step: 9384; epoch: 24; ce: 2.6242566; max_mcc: 0.38029253313466144; mcc: 0.37316928622026396; since_best_mcc: 4; FastEstimator-Train: step: 9600; ce: 3.5648074; steps/sec: 33.78; FastEstimator-Train: step: 9775; epoch: 25; epoch_time: 11.38 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model_best_mcc.h5 FastEstimator-Eval: step: 9775; epoch: 25; ce: 2.5628254; max_mcc: 0.3911783657392671; mcc: 0.3911783657392671; since_best_mcc: 0; FastEstimator-Train: step: 9900; ce: 3.4466887; steps/sec: 34.02; FastEstimator-Train: step: 10166; epoch: 26; epoch_time: 11.1 sec; FastEstimator-Eval: step: 10166; epoch: 26; ce: 2.5749164; max_mcc: 0.3911783657392671; mcc: 0.3824618003063329; since_best_mcc: 1; FastEstimator-Train: step: 10200; ce: 3.2873046; steps/sec: 34.09; FastEstimator-Train: step: 10500; ce: 3.365408; steps/sec: 37.09; FastEstimator-Train: step: 10557; epoch: 27; epoch_time: 11.45 sec; FastEstimator-Eval: step: 10557; epoch: 27; ce: 2.579395; max_mcc: 0.3911783657392671; mcc: 0.38492637219709647; since_best_mcc: 2; FastEstimator-Train: step: 10800; ce: 3.4366512; steps/sec: 32.48; FastEstimator-Train: step: 10948; epoch: 28; epoch_time: 11.89 sec; FastEstimator-Eval: step: 10948; epoch: 28; ce: 2.557443; max_mcc: 0.3911783657392671; mcc: 0.38992180085047434; since_best_mcc: 3; FastEstimator-Train: step: 11100; ce: 3.2275214; steps/sec: 32.12; FastEstimator-Train: step: 11339; epoch: 29; epoch_time: 11.59 sec; FastEstimator-Eval: step: 11339; epoch: 29; ce: 2.555817; max_mcc: 0.3911783657392671; mcc: 0.3834724048810534; since_best_mcc: 4; FastEstimator-Train: step: 11400; ce: 3.8311727; steps/sec: 33.09; FastEstimator-Train: step: 11700; ce: 3.437436; steps/sec: 36.89; FastEstimator-Train: step: 11730; epoch: 30; epoch_time: 11.57 sec; FastEstimator-Eval: step: 11730; epoch: 30; ce: 2.6084797; max_mcc: 0.3911783657392671; mcc: 0.38230879829021064; since_best_mcc: 5; FastEstimator-Train: step: 12000; ce: 3.2445602; steps/sec: 32.12; FastEstimator-Train: step: 12121; epoch: 31; epoch_time: 11.81 sec; FastEstimator-Eval: step: 12121; epoch: 31; ce: 2.5444205; max_mcc: 0.3911783657392671; mcc: 0.3903882701332637; since_best_mcc: 6; FastEstimator-Train: step: 12300; ce: 3.177442; steps/sec: 32.86; FastEstimator-Train: step: 12512; epoch: 32; epoch_time: 11.59 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model_best_mcc.h5 FastEstimator-Eval: step: 12512; epoch: 32; ce: 2.555135; max_mcc: 0.39336794619867993; mcc: 0.39336794619867993; since_best_mcc: 0; FastEstimator-Train: step: 12600; ce: 3.6959205; steps/sec: 32.51; FastEstimator-Train: step: 12900; ce: 3.7039824; steps/sec: 37.48; FastEstimator-Train: step: 12903; epoch: 33; epoch_time: 11.6 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model_best_mcc.h5 FastEstimator-Eval: step: 12903; epoch: 33; ce: 2.5638745; max_mcc: 0.3997956449593042; mcc: 0.3997956449593042; since_best_mcc: 0; FastEstimator-Train: step: 13200; ce: 3.453362; steps/sec: 31.32; FastEstimator-Train: step: 13294; epoch: 34; epoch_time: 12.01 sec; FastEstimator-Eval: step: 13294; epoch: 34; ce: 2.581514; max_mcc: 0.3997956449593042; mcc: 0.3840599318109791; since_best_mcc: 1; FastEstimator-Train: step: 13500; ce: 3.4583755; steps/sec: 32.65; FastEstimator-Train: step: 13685; epoch: 35; epoch_time: 11.63 sec; FastEstimator-Eval: step: 13685; epoch: 35; ce: 2.5393128; max_mcc: 0.3997956449593042; mcc: 0.39358100358310194; since_best_mcc: 2; FastEstimator-Train: step: 13800; ce: 3.3511138; steps/sec: 32.94; FastEstimator-Train: step: 14076; epoch: 36; epoch_time: 11.58 sec; FastEstimator-Eval: step: 14076; epoch: 36; ce: 2.5451062; max_mcc: 0.3997956449593042; mcc: 0.3909940656566139; since_best_mcc: 3; FastEstimator-Train: step: 14100; ce: 3.7378292; steps/sec: 32.85; FastEstimator-Train: step: 14400; ce: 3.579475; steps/sec: 35.62; FastEstimator-Train: step: 14467; epoch: 37; epoch_time: 11.87 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model_best_mcc.h5 FastEstimator-Eval: step: 14467; epoch: 37; ce: 2.4956336; max_mcc: 0.4018185133565588; mcc: 0.4018185133565588; since_best_mcc: 0; FastEstimator-Train: step: 14700; ce: 3.309771; steps/sec: 31.69; FastEstimator-Train: step: 14858; epoch: 38; epoch_time: 11.94 sec; FastEstimator-Eval: step: 14858; epoch: 38; ce: 2.5473502; max_mcc: 0.4018185133565588; mcc: 0.39070522900122134; since_best_mcc: 1; FastEstimator-Train: step: 15000; ce: 3.3463337; steps/sec: 31.3; FastEstimator-Train: step: 15249; epoch: 39; epoch_time: 12.21 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model_best_mcc.h5 FastEstimator-Eval: step: 15249; epoch: 39; ce: 2.5087152; max_mcc: 0.40282440628039246; mcc: 0.40282440628039246; since_best_mcc: 0; FastEstimator-Train: step: 15300; ce: 3.443113; steps/sec: 31.58; FastEstimator-Train: step: 15600; ce: 3.3580484; steps/sec: 36.07; FastEstimator-Train: step: 15640; epoch: 40; epoch_time: 12.05 sec; FastEstimator-Eval: step: 15640; epoch: 40; ce: 2.5329254; max_mcc: 0.40282440628039246; mcc: 0.4009165593189717; since_best_mcc: 1; FastEstimator-Train: step: 15900; ce: 3.464141; steps/sec: 32.94; FastEstimator-Train: step: 16031; epoch: 41; epoch_time: 11.49 sec; FastEstimator-Eval: step: 16031; epoch: 41; ce: 2.5511973; max_mcc: 0.40282440628039246; mcc: 0.38809165892740816; since_best_mcc: 2; FastEstimator-Train: step: 16200; ce: 3.129107; steps/sec: 33.28; FastEstimator-Train: step: 16422; epoch: 42; epoch_time: 11.4 sec; FastEstimator-Eval: step: 16422; epoch: 42; ce: 2.5401368; max_mcc: 0.40282440628039246; mcc: 0.4002505965558874; since_best_mcc: 3; FastEstimator-Train: step: 16500; ce: 3.323578; steps/sec: 33.3; FastEstimator-Train: step: 16800; ce: 3.6085215; steps/sec: 37.19; FastEstimator-Train: step: 16813; epoch: 43; epoch_time: 11.56 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model_best_mcc.h5 FastEstimator-Eval: step: 16813; epoch: 43; ce: 2.5123787; max_mcc: 0.403635709044651; mcc: 0.403635709044651; since_best_mcc: 0; FastEstimator-Train: step: 17100; ce: 3.4821346; steps/sec: 32.94; FastEstimator-Train: step: 17204; epoch: 44; epoch_time: 11.51 sec; FastEstimator-Eval: step: 17204; epoch: 44; ce: 2.5150542; max_mcc: 0.403635709044651; mcc: 0.3967572301803061; since_best_mcc: 1; FastEstimator-Train: step: 17400; ce: 3.422583; steps/sec: 32.87; FastEstimator-Train: step: 17595; epoch: 45; epoch_time: 11.48 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model_best_mcc.h5 FastEstimator-Eval: step: 17595; epoch: 45; ce: 2.5106926; max_mcc: 0.4070381050333296; mcc: 0.4070381050333296; since_best_mcc: 0; FastEstimator-Train: step: 17700; ce: 3.7156749; steps/sec: 33.6; FastEstimator-Train: step: 17986; epoch: 46; epoch_time: 11.4 sec; FastEstimator-Eval: step: 17986; epoch: 46; ce: 2.5253594; max_mcc: 0.4070381050333296; mcc: 0.399264786449083; since_best_mcc: 1; FastEstimator-Train: step: 18000; ce: 3.2805777; steps/sec: 33.8; FastEstimator-Train: step: 18300; ce: 3.4379218; steps/sec: 37.31; FastEstimator-Train: step: 18377; epoch: 47; epoch_time: 11.43 sec; FastEstimator-Eval: step: 18377; epoch: 47; ce: 2.5333138; max_mcc: 0.4070381050333296; mcc: 0.39595400336978015; since_best_mcc: 2; FastEstimator-Train: step: 18600; ce: 3.5096176; steps/sec: 32.47; FastEstimator-Train: step: 18768; epoch: 48; epoch_time: 11.75 sec; FastEstimator-Eval: step: 18768; epoch: 48; ce: 2.5043597; max_mcc: 0.4070381050333296; mcc: 0.39856495181853063; since_best_mcc: 3; FastEstimator-Train: step: 18900; ce: 3.3612788; steps/sec: 32.64; FastEstimator-Train: step: 19159; epoch: 49; epoch_time: 11.54 sec; FastEstimator-Eval: step: 19159; epoch: 49; ce: 2.5006938; max_mcc: 0.4070381050333296; mcc: 0.4048320575822205; since_best_mcc: 4; FastEstimator-Train: step: 19200; ce: 3.2576284; steps/sec: 32.78; FastEstimator-Train: step: 19500; ce: 3.5873413; steps/sec: 37.8; FastEstimator-Train: step: 19550; epoch: 50; epoch_time: 11.47 sec; FastEstimator-Eval: step: 19550; epoch: 50; ce: 2.511425; max_mcc: 0.4070381050333296; mcc: 0.3974247087308549; since_best_mcc: 5; FastEstimator-BestModelSaver: Restoring model from /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model_best_mcc.h5 FastEstimator-Finish: step: 19550; model_lr: 0.001; total_time: 624.5 sec;
Step 4 - Train a model using SuperLoss¶
Now it's time to try using SuperLoss to see whether curriculum learning can help us overcome our label noise. Note how easy it is to add SuperLoss to any existing loss function:
loss = SuperLoss(CrossEntropy(inputs=("y_pred", "y"), outputs="ce"), output_confidence="confidence") # The output_confidence arg is only needed if you want to visualize
estimator_super = build_estimator(loss)
superL = estimator_super.fit("SuperLoss")
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass. WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass. FastEstimator-Start: step: 1; logging_interval: 300; num_device: 0; WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass. WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass. FastEstimator-Train: step: 1; ce: -0.45840234; FastEstimator-Train: step: 300; ce: -0.9035268; steps/sec: 26.71; WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass. WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass. FastEstimator-Train: step: 391; epoch: 1; epoch_time: 16.66 sec; WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass. WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass. FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model1_best_mcc.h5 FastEstimator-Eval: step: 391; epoch: 1; ce: -0.8853086; max_mcc: 0.14302881374073162; mcc: 0.14302881374073162; since_best_mcc: 0; FastEstimator-Train: step: 600; ce: -1.1198719; steps/sec: 22.74; FastEstimator-Train: step: 782; epoch: 2; epoch_time: 16.01 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model1_best_mcc.h5 FastEstimator-Eval: step: 782; epoch: 2; ce: -1.2232248; max_mcc: 0.22132539422808534; mcc: 0.22132539422808534; since_best_mcc: 0; FastEstimator-Train: step: 900; ce: -1.1911894; steps/sec: 24.11; FastEstimator-Train: step: 1173; epoch: 3; epoch_time: 15.42 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model1_best_mcc.h5 FastEstimator-Eval: step: 1173; epoch: 3; ce: -1.1744673; max_mcc: 0.26100251264473845; mcc: 0.26100251264473845; since_best_mcc: 0; FastEstimator-Train: step: 1200; ce: -1.7832377; steps/sec: 24.8; FastEstimator-Train: step: 1500; ce: -2.0550075; steps/sec: 27.36; FastEstimator-Train: step: 1564; epoch: 4; epoch_time: 15.59 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model1_best_mcc.h5 FastEstimator-Eval: step: 1564; epoch: 4; ce: -1.1474909; max_mcc: 0.27333068073148203; mcc: 0.27333068073148203; since_best_mcc: 0; FastEstimator-Train: step: 1800; ce: -1.9247782; steps/sec: 23.79; FastEstimator-Train: step: 1955; epoch: 5; epoch_time: 16.05 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model1_best_mcc.h5 FastEstimator-Eval: step: 1955; epoch: 5; ce: -1.1647666; max_mcc: 0.3067382097563979; mcc: 0.3067382097563979; since_best_mcc: 0; FastEstimator-Train: step: 2100; ce: -1.8525904; steps/sec: 24.02; FastEstimator-Train: step: 2346; epoch: 6; epoch_time: 15.9 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model1_best_mcc.h5 FastEstimator-Eval: step: 2346; epoch: 6; ce: -1.1477449; max_mcc: 0.32147236431980514; mcc: 0.32147236431980514; since_best_mcc: 0; FastEstimator-Train: step: 2400; ce: -1.9646461; steps/sec: 23.73; FastEstimator-Train: step: 2700; ce: -1.6338027; steps/sec: 27.83; FastEstimator-Train: step: 2737; epoch: 7; epoch_time: 15.7 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model1_best_mcc.h5 FastEstimator-Eval: step: 2737; epoch: 7; ce: -1.1068447; max_mcc: 0.3412255408982984; mcc: 0.3412255408982984; since_best_mcc: 0; FastEstimator-Train: step: 3000; ce: -1.8994336; steps/sec: 23.62; FastEstimator-Train: step: 3128; epoch: 8; epoch_time: 16.16 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model1_best_mcc.h5 FastEstimator-Eval: step: 3128; epoch: 8; ce: -1.1532923; max_mcc: 0.34505209454575014; mcc: 0.34505209454575014; since_best_mcc: 0; FastEstimator-Train: step: 3300; ce: -2.3463767; steps/sec: 23.92; FastEstimator-Train: step: 3519; epoch: 9; epoch_time: 15.81 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model1_best_mcc.h5 FastEstimator-Eval: step: 3519; epoch: 9; ce: -1.072229; max_mcc: 0.35470884254952145; mcc: 0.35470884254952145; since_best_mcc: 0; FastEstimator-Train: step: 3600; ce: -1.9894075; steps/sec: 23.73; FastEstimator-Train: step: 3900; ce: -2.5543215; steps/sec: 26.77; FastEstimator-Train: step: 3910; epoch: 10; epoch_time: 16.08 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model1_best_mcc.h5 FastEstimator-Eval: step: 3910; epoch: 10; ce: -1.1424695; max_mcc: 0.3699711914588184; mcc: 0.3699711914588184; since_best_mcc: 0; FastEstimator-Train: step: 4200; ce: -2.374371; steps/sec: 23.8; FastEstimator-Train: step: 4301; epoch: 11; epoch_time: 15.88 sec; FastEstimator-Eval: step: 4301; epoch: 11; ce: -1.1189286; max_mcc: 0.3699711914588184; mcc: 0.3585887867728886; since_best_mcc: 1; FastEstimator-Train: step: 4500; ce: -1.3748977; steps/sec: 24.14; FastEstimator-Train: step: 4692; epoch: 12; epoch_time: 15.62 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model1_best_mcc.h5 FastEstimator-Eval: step: 4692; epoch: 12; ce: -1.1282852; max_mcc: 0.376192949299191; mcc: 0.376192949299191; since_best_mcc: 0; FastEstimator-Train: step: 4800; ce: -2.409959; steps/sec: 24.27; FastEstimator-Train: step: 5083; epoch: 13; epoch_time: 15.92 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model1_best_mcc.h5 FastEstimator-Eval: step: 5083; epoch: 13; ce: -1.1448267; max_mcc: 0.3781592646816613; mcc: 0.3781592646816613; since_best_mcc: 0; FastEstimator-Train: step: 5100; ce: -2.205565; steps/sec: 23.81; FastEstimator-Train: step: 5400; ce: -2.0702982; steps/sec: 26.67; FastEstimator-Train: step: 5474; epoch: 14; epoch_time: 16.16 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model1_best_mcc.h5 FastEstimator-Eval: step: 5474; epoch: 14; ce: -1.1674703; max_mcc: 0.39153773775904055; mcc: 0.39153773775904055; since_best_mcc: 0; FastEstimator-Train: step: 5700; ce: -1.7408903; steps/sec: 23.54; FastEstimator-Train: step: 5865; epoch: 15; epoch_time: 16.0 sec; FastEstimator-Eval: step: 5865; epoch: 15; ce: -1.0727899; max_mcc: 0.39153773775904055; mcc: 0.3706045023403391; since_best_mcc: 1; FastEstimator-Train: step: 6000; ce: -2.7495584; steps/sec: 23.83; FastEstimator-Train: step: 6256; epoch: 16; epoch_time: 15.86 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model1_best_mcc.h5 FastEstimator-Eval: step: 6256; epoch: 16; ce: -1.1405761; max_mcc: 0.4041517220685679; mcc: 0.4041517220685679; since_best_mcc: 0; FastEstimator-Train: step: 6300; ce: -1.6117443; steps/sec: 24.29; FastEstimator-Train: step: 6600; ce: -2.2560148; steps/sec: 27.6; FastEstimator-Train: step: 6647; epoch: 17; epoch_time: 15.71 sec; FastEstimator-Eval: step: 6647; epoch: 17; ce: -1.1009356; max_mcc: 0.4041517220685679; mcc: 0.39731992137483313; since_best_mcc: 1; FastEstimator-Train: step: 6900; ce: -1.6276731; steps/sec: 21.85; FastEstimator-Train: step: 7038; epoch: 18; epoch_time: 17.07 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model1_best_mcc.h5 FastEstimator-Eval: step: 7038; epoch: 18; ce: -1.1205976; max_mcc: 0.40548032886265584; mcc: 0.40548032886265584; since_best_mcc: 0; FastEstimator-Train: step: 7200; ce: -2.260793; steps/sec: 23.97; FastEstimator-Train: step: 7429; epoch: 19; epoch_time: 15.7 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model1_best_mcc.h5 FastEstimator-Eval: step: 7429; epoch: 19; ce: -1.1357507; max_mcc: 0.40837250526187713; mcc: 0.40837250526187713; since_best_mcc: 0; FastEstimator-Train: step: 7500; ce: -2.2957263; steps/sec: 24.12; FastEstimator-Train: step: 7800; ce: -1.9242468; steps/sec: 26.72; FastEstimator-Train: step: 7820; epoch: 20; epoch_time: 16.15 sec; FastEstimator-Eval: step: 7820; epoch: 20; ce: -1.1415228; max_mcc: 0.40837250526187713; mcc: 0.4074934800213416; since_best_mcc: 1; FastEstimator-Train: step: 8100; ce: -2.5596945; steps/sec: 23.75; FastEstimator-Train: step: 8211; epoch: 21; epoch_time: 16.05 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model1_best_mcc.h5 FastEstimator-Eval: step: 8211; epoch: 21; ce: -1.113405; max_mcc: 0.42303159513114147; mcc: 0.42303159513114147; since_best_mcc: 0; FastEstimator-Train: step: 8400; ce: -2.0298104; steps/sec: 23.88; FastEstimator-Train: step: 8602; epoch: 22; epoch_time: 15.93 sec; FastEstimator-Eval: step: 8602; epoch: 22; ce: -1.1111972; max_mcc: 0.42303159513114147; mcc: 0.41537172896026864; since_best_mcc: 1; FastEstimator-Train: step: 8700; ce: -1.9526964; steps/sec: 23.83; FastEstimator-Train: step: 8993; epoch: 23; epoch_time: 15.82 sec; FastEstimator-Eval: step: 8993; epoch: 23; ce: -1.1185064; max_mcc: 0.42303159513114147; mcc: 0.41804243248667144; since_best_mcc: 2; FastEstimator-Train: step: 9000; ce: -2.109986; steps/sec: 24.24; FastEstimator-Train: step: 9300; ce: -2.568654; steps/sec: 26.65; FastEstimator-Train: step: 9384; epoch: 24; epoch_time: 16.01 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model1_best_mcc.h5 FastEstimator-Eval: step: 9384; epoch: 24; ce: -1.135175; max_mcc: 0.4369457266318141; mcc: 0.4369457266318141; since_best_mcc: 0; FastEstimator-Train: step: 9600; ce: -1.6442108; steps/sec: 24.03; FastEstimator-Train: step: 9775; epoch: 25; epoch_time: 15.84 sec; FastEstimator-Eval: step: 9775; epoch: 25; ce: -1.107126; max_mcc: 0.4369457266318141; mcc: 0.42386157171628613; since_best_mcc: 1; FastEstimator-Train: step: 9900; ce: -2.8970966; steps/sec: 23.59; FastEstimator-Train: step: 10166; epoch: 26; epoch_time: 16.39 sec; FastEstimator-Eval: step: 10166; epoch: 26; ce: -1.1384557; max_mcc: 0.4369457266318141; mcc: 0.4180638344673729; since_best_mcc: 2; FastEstimator-Train: step: 10200; ce: -2.1301792; steps/sec: 23.13; FastEstimator-Train: step: 10500; ce: -1.8132625; steps/sec: 27.19; FastEstimator-Train: step: 10557; epoch: 27; epoch_time: 15.96 sec; FastEstimator-Eval: step: 10557; epoch: 27; ce: -1.1131722; max_mcc: 0.4369457266318141; mcc: 0.42625128157245296; since_best_mcc: 3; FastEstimator-Train: step: 10800; ce: -1.998105; steps/sec: 23.81; FastEstimator-Train: step: 10948; epoch: 28; epoch_time: 15.93 sec; FastEstimator-Eval: step: 10948; epoch: 28; ce: -1.1155487; max_mcc: 0.4369457266318141; mcc: 0.42908067568626684; since_best_mcc: 4; FastEstimator-Train: step: 11100; ce: -2.3713639; steps/sec: 23.7; FastEstimator-Train: step: 11339; epoch: 29; epoch_time: 16.33 sec; FastEstimator-Eval: step: 11339; epoch: 29; ce: -1.1220049; max_mcc: 0.4369457266318141; mcc: 0.4260630601839669; since_best_mcc: 5; FastEstimator-Train: step: 11400; ce: -2.169163; steps/sec: 23.17; FastEstimator-Train: step: 11700; ce: -2.1599045; steps/sec: 27.62; FastEstimator-Train: step: 11730; epoch: 30; epoch_time: 15.79 sec; FastEstimator-Eval: step: 11730; epoch: 30; ce: -1.0984906; max_mcc: 0.4369457266318141; mcc: 0.4285067929184998; since_best_mcc: 6; FastEstimator-Train: step: 12000; ce: -1.8843775; steps/sec: 24.1; FastEstimator-Train: step: 12121; epoch: 31; epoch_time: 15.78 sec; FastEstimator-Eval: step: 12121; epoch: 31; ce: -1.1562611; max_mcc: 0.4369457266318141; mcc: 0.42826860921000354; since_best_mcc: 7; FastEstimator-Train: step: 12300; ce: -2.6584306; steps/sec: 24.0; FastEstimator-Train: step: 12512; epoch: 32; epoch_time: 15.73 sec; FastEstimator-Eval: step: 12512; epoch: 32; ce: -1.1435204; max_mcc: 0.4369457266318141; mcc: 0.4333150458296645; since_best_mcc: 8; FastEstimator-Train: step: 12600; ce: -2.306051; steps/sec: 24.09; FastEstimator-Train: step: 12900; ce: -1.3515851; steps/sec: 27.23; FastEstimator-Train: step: 12903; epoch: 33; epoch_time: 15.95 sec; FastEstimator-Eval: step: 12903; epoch: 33; ce: -1.1314979; max_mcc: 0.4369457266318141; mcc: 0.43645579540783064; since_best_mcc: 9; FastEstimator-Train: step: 13200; ce: -2.8421557; steps/sec: 23.81; FastEstimator-Train: step: 13294; epoch: 34; epoch_time: 15.81 sec; FastEstimator-Eval: step: 13294; epoch: 34; ce: -1.0980072; max_mcc: 0.4369457266318141; mcc: 0.43075809854762503; since_best_mcc: 10; FastEstimator-Train: step: 13500; ce: -2.8916705; steps/sec: 24.25; FastEstimator-Train: step: 13685; epoch: 35; epoch_time: 16.25 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model1_best_mcc.h5 FastEstimator-Eval: step: 13685; epoch: 35; ce: -1.1472872; max_mcc: 0.44365012520788033; mcc: 0.44365012520788033; since_best_mcc: 0; FastEstimator-Train: step: 13800; ce: -2.8401277; steps/sec: 23.06; FastEstimator-Train: step: 14076; epoch: 36; epoch_time: 15.82 sec; FastEstimator-Eval: step: 14076; epoch: 36; ce: -1.1562262; max_mcc: 0.44365012520788033; mcc: 0.4313645449516555; since_best_mcc: 1; FastEstimator-Train: step: 14100; ce: -2.5230148; steps/sec: 24.17; FastEstimator-Train: step: 14400; ce: -2.8517623; steps/sec: 26.92; FastEstimator-Train: step: 14467; epoch: 37; epoch_time: 15.98 sec; FastEstimator-Eval: step: 14467; epoch: 37; ce: -1.2001235; max_mcc: 0.44365012520788033; mcc: 0.430704291123061; since_best_mcc: 2; FastEstimator-Train: step: 14700; ce: -2.1312857; steps/sec: 24.05; FastEstimator-Train: step: 14858; epoch: 38; epoch_time: 15.73 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model1_best_mcc.h5 FastEstimator-Eval: step: 14858; epoch: 38; ce: -1.1431301; max_mcc: 0.4472147715296702; mcc: 0.4472147715296702; since_best_mcc: 0; FastEstimator-Train: step: 15000; ce: -2.7680423; steps/sec: 23.84; FastEstimator-Train: step: 15249; epoch: 39; epoch_time: 15.99 sec; FastEstimator-Eval: step: 15249; epoch: 39; ce: -1.1730366; max_mcc: 0.4472147715296702; mcc: 0.43998571300322525; since_best_mcc: 1; FastEstimator-Train: step: 15300; ce: -2.3842983; steps/sec: 23.82; FastEstimator-Train: step: 15600; ce: -2.61199; steps/sec: 27.26; FastEstimator-Train: step: 15640; epoch: 40; epoch_time: 16.01 sec; FastEstimator-Eval: step: 15640; epoch: 40; ce: -1.1539762; max_mcc: 0.4472147715296702; mcc: 0.4367261360226644; since_best_mcc: 2; FastEstimator-Train: step: 15900; ce: -2.4703374; steps/sec: 22.97; FastEstimator-Train: step: 16031; epoch: 41; epoch_time: 16.48 sec; FastEstimator-Eval: step: 16031; epoch: 41; ce: -1.1659281; max_mcc: 0.4472147715296702; mcc: 0.4353277252173622; since_best_mcc: 3; FastEstimator-Train: step: 16200; ce: -2.339718; steps/sec: 23.3; FastEstimator-Train: step: 16422; epoch: 42; epoch_time: 15.97 sec; FastEstimator-Eval: step: 16422; epoch: 42; ce: -1.1800143; max_mcc: 0.4472147715296702; mcc: 0.4397506284062948; since_best_mcc: 4; FastEstimator-Train: step: 16500; ce: -2.7436943; steps/sec: 23.55; FastEstimator-Train: step: 16800; ce: -2.3053098; steps/sec: 27.47; FastEstimator-Train: step: 16813; epoch: 43; epoch_time: 16.11 sec; FastEstimator-Eval: step: 16813; epoch: 43; ce: -1.1810557; max_mcc: 0.4472147715296702; mcc: 0.4442676528937523; since_best_mcc: 5; FastEstimator-Train: step: 17100; ce: -2.8315272; steps/sec: 24.14; FastEstimator-Train: step: 17204; epoch: 44; epoch_time: 15.73 sec; FastEstimator-Eval: step: 17204; epoch: 44; ce: -1.1615522; max_mcc: 0.4472147715296702; mcc: 0.43377874758841084; since_best_mcc: 6; FastEstimator-Train: step: 17400; ce: -2.6666136; steps/sec: 23.62; FastEstimator-Train: step: 17595; epoch: 45; epoch_time: 16.26 sec; FastEstimator-Eval: step: 17595; epoch: 45; ce: -1.1887122; max_mcc: 0.4472147715296702; mcc: 0.43760807300504523; since_best_mcc: 7; FastEstimator-Train: step: 17700; ce: -2.892805; steps/sec: 23.27; FastEstimator-Train: step: 17986; epoch: 46; epoch_time: 16.19 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model1_best_mcc.h5 FastEstimator-Eval: step: 17986; epoch: 46; ce: -1.1846745; max_mcc: 0.4516928294301088; mcc: 0.4516928294301088; since_best_mcc: 0; FastEstimator-Train: step: 18000; ce: -1.9225727; steps/sec: 23.57; FastEstimator-Train: step: 18300; ce: -2.5122561; steps/sec: 26.8; FastEstimator-Train: step: 18377; epoch: 47; epoch_time: 16.16 sec; FastEstimator-Eval: step: 18377; epoch: 47; ce: -1.190774; max_mcc: 0.4516928294301088; mcc: 0.44820341989774154; since_best_mcc: 1; FastEstimator-Train: step: 18600; ce: -2.5353403; steps/sec: 23.46; FastEstimator-Train: step: 18768; epoch: 48; epoch_time: 16.36 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model1_best_mcc.h5 FastEstimator-Eval: step: 18768; epoch: 48; ce: -1.1954826; max_mcc: 0.4535162464296696; mcc: 0.4535162464296696; since_best_mcc: 0; FastEstimator-Train: step: 18900; ce: -1.7238238; steps/sec: 22.65; FastEstimator-Train: step: 19159; epoch: 49; epoch_time: 16.18 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model1_best_mcc.h5 FastEstimator-Eval: step: 19159; epoch: 49; ce: -1.1699214; max_mcc: 0.456133498289478; mcc: 0.456133498289478; since_best_mcc: 0; FastEstimator-Train: step: 19200; ce: -2.1122396; steps/sec: 23.94; FastEstimator-Train: step: 19500; ce: -2.927199; steps/sec: 27.23; FastEstimator-Train: step: 19550; epoch: 50; epoch_time: 16.02 sec; FastEstimator-Eval: step: 19550; epoch: 50; ce: -1.175966; max_mcc: 0.456133498289478; mcc: 0.43898991915944996; since_best_mcc: 1; FastEstimator-BestModelSaver: Restoring model from /var/folders/lx/drkxftt117gblvgsp1p39rlc0000gn/T/tmpzn39e8qh/model1_best_mcc.h5 FastEstimator-Finish: step: 19550; model1_lr: 0.001; total_time: 890.63 sec;
Step 5 - Performance Comparison¶
Let's take a look at how each of the final models compare:
estimator_regular.test()
WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass. WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass. FastEstimator-Test: step: 19550; epoch: 50; ce: 2.5485935; mcc: 0.3998134494844641;
<fastestimator.summary.summary.Summary at 0x2ff5397f0>
estimator_super.test()
WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass. WARNING:tensorflow:@custom_gradient grad_fn has 'variables' in signature, but no ResourceVariables were used on the forward pass. FastEstimator-Test: step: 19550; epoch: 50; ce: -1.0826309; mcc: 0.44526758544276457;
<fastestimator.summary.summary.Summary at 0x176d02250>
fe.summary.logs.visualize_logs([regular, superL], include_metrics={'mcc', 'ce', 'max_mcc'})
As we can see from the results above, a simple 1 line change to add SuperLoss into the training procedure can raise our model's mcc by a full 4 or 5 points in the presence of noisy input labels. Let's also take a look at the confidence scores generated by SuperLoss on the noisy vs clean data:
fe.summary.logs.visualize_logs(estimator_super.system.custom_graphs['label_confidence'])
As the graph above demonstrates, the corrupted samples have significantly lower average confidence scores than the clean samples. This is also true when we analyze the confidence scores during regular training, but the separation is not as strong:
fe.summary.logs.visualize_logs(estimator_regular.system.custom_graphs['label_confidence'])