We'll start by getting the imports out of the way:
import tempfile
import os
import fastestimator as fe
from fastestimator.architecture.tensorflow import LeNet
from fastestimator.dataset.data import cifair10
from fastestimator.op.numpyop.meta import Sometimes
from fastestimator.op.numpyop.multivariate import HorizontalFlip, PadIfNeeded, RandomCrop
from fastestimator.op.numpyop.univariate import CoarseDropout, Normalize, Calibrate
from fastestimator.op.tensorop.loss import CrossEntropy
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
from fastestimator.summary.logs import visualize_logs
from fastestimator.trace.adapt import PBMCalibrator
from fastestimator.trace.io import BestModelSaver
from fastestimator.trace.metric import CalibrationError, MCC
from fastestimator.util import to_list
label_mapping = {
'airplane': 0,
'automobile': 1,
'bird': 2,
'cat': 3,
'deer': 4,
'dog': 5,
'frog': 6,
'horse': 7,
'ship': 8,
'truck': 9
}
And let's define a function to build a generic ciFAIR10 estimator. We will show how to use combinations of extra traces and post-processing ops to enhance this estimator throughout the tutorial.
def build_estimator(extra_traces = None, postprocessing_ops = None):
batch_size=128
save_dir = tempfile.mkdtemp()
extra_traces = to_list(extra_traces)
postprocessing_ops = to_list(postprocessing_ops)
train_data, eval_data = cifair10.load_data()
test_data = eval_data.split(range(len(eval_data) // 2))
pipeline = fe.Pipeline(
train_data=train_data,
eval_data=eval_data,
test_data=test_data,
batch_size=batch_size,
ops=[Normalize(inputs="x", outputs="x", mean=(0.4914, 0.4822, 0.4465), std=(0.2471, 0.2435, 0.2616)),
PadIfNeeded(min_height=40, min_width=40, image_in="x", image_out="x", mode="train"),
RandomCrop(32, 32, image_in="x", image_out="x", mode="train"),
Sometimes(HorizontalFlip(image_in="x", image_out="x", mode="train")),
CoarseDropout(inputs="x", outputs="x", mode="train", max_holes=1),
],
num_process=0)
model = fe.build(model_fn=lambda: LeNet(input_shape=(32, 32, 3)), optimizer_fn="adam")
network = fe.Network(
ops=[
ModelOp(model=model, inputs="x", outputs="y_pred"),
CrossEntropy(inputs=("y_pred", "y"), outputs="ce"),
UpdateOp(model=model, loss_name="ce")
],
pops=postprocessing_ops) # <---- Some of the secret sauce will go here
traces = [
MCC(true_key="y", pred_key="y_pred"),
BestModelSaver(model=model, save_dir=save_dir, metric="mcc", save_best_mode="max", load_best_final=True),
]
traces = traces + extra_traces # <---- Most of the secret sauce will go here
estimator = fe.Estimator(pipeline=pipeline,
network=network,
epochs=21,
traces=traces,
log_steps=300)
return estimator
Calculating Calibration Error¶
Suppose you have a neural network that is performing image classification. For the sake of argument, let's imagine that the classification problem is to look at x-ray images and determine whether or not a patient has cancer. Let's further suppose that your model is very accurate: when it assigns a higher probability to 'cancer' the patient is almost always sick, and when it assigns a higher probability to 'healthy' the patient is almost always fine. It could be tempting to think that the job is done, but there is still a potential problem for real-world deployments of your model. Suppose a physician using your model runs an image and gets a report saying that it is 51% likely that the patient is healthy, and 49% likely that there is a cancerous tumor. In reality the patient is indeed healthy. From an accuracy point of view, your model is doing just fine. However, if the doctor sees that it is 49% likely that there is a tumor, they are likely to order a biopsy in order to be on the safe side. Taken to an extreme, suppose that your model always predicts a 49% probability of a tumor whenever it sees a healthy patient. Even though the model might have perfect accuracy, in practice it would always result in extra surgical procedures being performed. Ideally, if the model says that there is a 49% probability of a tumor, you would expect there to actually be a tumor in 49% of those cases. The discrepancy between a models predicted probability of a class and the true probability of that class conditioned on the prediction is measured as the calibration error. Calibration error is notoriously difficult to estimate correctly, but FE provides a Trace
for this based on a 2019 NeurIPS spotlight paper titled "Verified Uncertainty Calibration".
The CalibrationError
trace can be used just like any other metric trace, though it also optionally can compute confidence intervals around the estimated error. Keep in mind that to measure calibration error you would want your validation dataset to have a reasonable real-world class distribution (only a small percentage of people in the population actually have cancer, for example). For the purpose of easy illustration we will be using the ciFAIR10 dataset, and computing a 95% confidence interval for the estimated calibration error of the model:
estimator = build_estimator(extra_traces=CalibrationError(true_key="y", pred_key="y_pred", confidence_interval=95))
summary = estimator.fit("experiment1")
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ FastEstimator-Start: step: 1; logging_interval: 300; num_device: 0; FastEstimator-Train: step: 1; ce: 2.364285; FastEstimator-Train: step: 300; ce: 1.5092063; steps/sec: 13.04; FastEstimator-Train: step: 391; epoch: 1; epoch_time: 31.73 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpeksb1mr2/model_best_mcc.h5 FastEstimator-Eval: step: 391; epoch: 1; calibration_error: (0.0349, 0.0394, 0.045); ce: 1.3210957; max_mcc: 0.4815202560109812; mcc: 0.4815202560109812; since_best_mcc: 0; FastEstimator-Train: step: 600; ce: 1.2845359; steps/sec: 13.74; FastEstimator-Train: step: 782; epoch: 2; epoch_time: 26.76 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpeksb1mr2/model_best_mcc.h5 FastEstimator-Eval: step: 782; epoch: 2; calibration_error: (0.0277, 0.0323, 0.0372); ce: 1.1139303; max_mcc: 0.5592140800930085; mcc: 0.5592140800930085; since_best_mcc: 0; FastEstimator-Train: step: 900; ce: 1.2871141; steps/sec: 14.45; FastEstimator-Train: step: 1173; epoch: 3; epoch_time: 27.92 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpeksb1mr2/model_best_mcc.h5 FastEstimator-Eval: step: 1173; epoch: 3; calibration_error: (0.0318, 0.0358, 0.0402); ce: 1.0209823; max_mcc: 0.5947684007258542; mcc: 0.5947684007258542; since_best_mcc: 0; FastEstimator-Train: step: 1200; ce: 1.1260216; steps/sec: 13.86; FastEstimator-Train: step: 1500; ce: 1.174921; steps/sec: 14.19; FastEstimator-Train: step: 1564; epoch: 4; epoch_time: 28.79 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpeksb1mr2/model_best_mcc.h5 FastEstimator-Eval: step: 1564; epoch: 4; calibration_error: (0.0256, 0.0288, 0.0334); ce: 0.97435844; max_mcc: 0.6225507938118597; mcc: 0.6225507938118597; since_best_mcc: 0; FastEstimator-Train: step: 1800; ce: 0.9238555; steps/sec: 11.79; FastEstimator-Train: step: 1955; epoch: 5; epoch_time: 31.89 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpeksb1mr2/model_best_mcc.h5 FastEstimator-Eval: step: 1955; epoch: 5; calibration_error: (0.0208, 0.026, 0.0295); ce: 0.8851633; max_mcc: 0.6524498253308791; mcc: 0.6524498253308791; since_best_mcc: 0; FastEstimator-Train: step: 2100; ce: 1.1505089; steps/sec: 13.71; FastEstimator-Train: step: 2346; epoch: 6; epoch_time: 27.85 sec; FastEstimator-Eval: step: 2346; epoch: 6; calibration_error: (0.0399, 0.0454, 0.0513); ce: 0.8992597; max_mcc: 0.6524498253308791; mcc: 0.6422697701389847; since_best_mcc: 1; FastEstimator-Train: step: 2400; ce: 0.88649476; steps/sec: 13.89; FastEstimator-Train: step: 2700; ce: 0.83883905; steps/sec: 14.12; FastEstimator-Train: step: 2737; epoch: 7; epoch_time: 27.7 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpeksb1mr2/model_best_mcc.h5 FastEstimator-Eval: step: 2737; epoch: 7; calibration_error: (0.0311, 0.0374, 0.0427); ce: 0.8641518; max_mcc: 0.6573990501670419; mcc: 0.6573990501670419; since_best_mcc: 0; FastEstimator-Train: step: 3000; ce: 0.91415524; steps/sec: 14.43; FastEstimator-Train: step: 3128; epoch: 8; epoch_time: 27.35 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpeksb1mr2/model_best_mcc.h5 FastEstimator-Eval: step: 3128; epoch: 8; calibration_error: (0.0419, 0.0467, 0.0505); ce: 0.86572677; max_mcc: 0.6622589033779989; mcc: 0.6622589033779989; since_best_mcc: 0; FastEstimator-Train: step: 3300; ce: 0.84909713; steps/sec: 14.06; FastEstimator-Train: step: 3519; epoch: 9; epoch_time: 29.29 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpeksb1mr2/model_best_mcc.h5 FastEstimator-Eval: step: 3519; epoch: 9; calibration_error: (0.0227, 0.0284, 0.0349); ce: 0.7914774; max_mcc: 0.6860117005983222; mcc: 0.6860117005983222; since_best_mcc: 0; FastEstimator-Train: step: 3600; ce: 0.90386593; steps/sec: 13.1; FastEstimator-Train: step: 3900; ce: 1.0241306; steps/sec: 13.93; FastEstimator-Train: step: 3910; epoch: 10; epoch_time: 28.02 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpeksb1mr2/model_best_mcc.h5 FastEstimator-Eval: step: 3910; epoch: 10; calibration_error: (0.0276, 0.0317, 0.0365); ce: 0.78860706; max_mcc: 0.6993861770313364; mcc: 0.6993861770313364; since_best_mcc: 0; FastEstimator-Train: step: 4200; ce: 0.8331722; steps/sec: 14.46; FastEstimator-Train: step: 4301; epoch: 11; epoch_time: 27.1 sec; FastEstimator-Eval: step: 4301; epoch: 11; calibration_error: (0.0298, 0.0346, 0.0399); ce: 0.7831341; max_mcc: 0.6993861770313364; mcc: 0.6894049337786957; since_best_mcc: 1; FastEstimator-Train: step: 4500; ce: 0.8837549; steps/sec: 14.11; FastEstimator-Train: step: 4692; epoch: 12; epoch_time: 28.27 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpeksb1mr2/model_best_mcc.h5 FastEstimator-Eval: step: 4692; epoch: 12; calibration_error: (0.0166, 0.0218, 0.0265); ce: 0.7365888; max_mcc: 0.7091949717856824; mcc: 0.7091949717856824; since_best_mcc: 0; FastEstimator-Train: step: 4800; ce: 0.9474237; steps/sec: 13.8; FastEstimator-Train: step: 5083; epoch: 13; epoch_time: 28.08 sec; FastEstimator-Eval: step: 5083; epoch: 13; calibration_error: (0.0397, 0.045, 0.0491); ce: 0.7789904; max_mcc: 0.7091949717856824; mcc: 0.7014650315495937; since_best_mcc: 1; FastEstimator-Train: step: 5100; ce: 1.0256269; steps/sec: 13.7; FastEstimator-Train: step: 5400; ce: 0.9247025; steps/sec: 11.6; FastEstimator-Train: step: 5474; epoch: 14; epoch_time: 33.01 sec; FastEstimator-Eval: step: 5474; epoch: 14; calibration_error: (0.0362, 0.0399, 0.0442); ce: 0.755077; max_mcc: 0.7091949717856824; mcc: 0.7070422768702824; since_best_mcc: 2; FastEstimator-Train: step: 5700; ce: 0.70286965; steps/sec: 13.75; FastEstimator-Train: step: 5865; epoch: 15; epoch_time: 29.03 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpeksb1mr2/model_best_mcc.h5 FastEstimator-Eval: step: 5865; epoch: 15; calibration_error: (0.0232, 0.0276, 0.0324); ce: 0.7251368; max_mcc: 0.7163199141358709; mcc: 0.7163199141358709; since_best_mcc: 0; FastEstimator-Train: step: 6000; ce: 0.86999273; steps/sec: 13.33; FastEstimator-Train: step: 6256; epoch: 16; epoch_time: 28.46 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpeksb1mr2/model_best_mcc.h5 FastEstimator-Eval: step: 6256; epoch: 16; calibration_error: (0.0139, 0.0188, 0.0246); ce: 0.6940484; max_mcc: 0.7354250490842272; mcc: 0.7354250490842272; since_best_mcc: 0; FastEstimator-Train: step: 6300; ce: 0.7672814; steps/sec: 13.6; FastEstimator-Train: step: 6600; ce: 0.73419267; steps/sec: 13.72; FastEstimator-Train: step: 6647; epoch: 17; epoch_time: 28.82 sec; FastEstimator-Eval: step: 6647; epoch: 17; calibration_error: (0.0114, 0.0151, 0.0198); ce: 0.70021456; max_mcc: 0.7354250490842272; mcc: 0.7238136191672119; since_best_mcc: 1; FastEstimator-Train: step: 6900; ce: 0.72660446; steps/sec: 12.8; FastEstimator-Train: step: 7038; epoch: 18; epoch_time: 30.39 sec; FastEstimator-Eval: step: 7038; epoch: 18; calibration_error: (0.0347, 0.0385, 0.0451); ce: 0.71250015; max_mcc: 0.7354250490842272; mcc: 0.7180530704737343; since_best_mcc: 2; FastEstimator-Train: step: 7200; ce: 0.8266677; steps/sec: 13.58; FastEstimator-Train: step: 7429; epoch: 19; epoch_time: 28.03 sec; FastEstimator-Eval: step: 7429; epoch: 19; calibration_error: (0.0181, 0.0244, 0.0309); ce: 0.68665934; max_mcc: 0.7354250490842272; mcc: 0.7327478409220392; since_best_mcc: 3; FastEstimator-Train: step: 7500; ce: 0.71262574; steps/sec: 13.92; FastEstimator-Train: step: 7800; ce: 0.7178256; steps/sec: 14.22; FastEstimator-Train: step: 7820; epoch: 20; epoch_time: 27.46 sec; FastEstimator-Eval: step: 7820; epoch: 20; calibration_error: (0.0226, 0.0274, 0.0326); ce: 0.68733853; max_mcc: 0.7354250490842272; mcc: 0.7334799441801962; since_best_mcc: 4; FastEstimator-Train: step: 8100; ce: 0.5832919; steps/sec: 14.46; FastEstimator-Train: step: 8211; epoch: 21; epoch_time: 27.44 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpeksb1mr2/model_best_mcc.h5 FastEstimator-Eval: step: 8211; epoch: 21; calibration_error: (0.0285, 0.0319, 0.0374); ce: 0.67022; max_mcc: 0.7361255970251015; mcc: 0.7361255970251015; since_best_mcc: 0; FastEstimator-BestModelSaver: Restoring model from /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpeksb1mr2/model_best_mcc.h5 FastEstimator-Finish: step: 8211; model_lr: 0.001; total_time: 793.06 sec;
estimator.test()
FastEstimator-Test: step: 8211; epoch: 21; calibration_error: (0.0208, 0.0258, 0.0301); ce: 0.7097068; mcc: 0.7399028376036656;
<fastestimator.summary.summary.Summary at 0x108064d30>
Let's take a look at how the calibration error changed over training:
visualize_logs([summary], include_metrics={'calibration_error', 'mcc', 'ce'})
As we can see from the graph above, calibration error is significantly more noisy than classical metrics like mcc or accuracy. In this case it does seem to have improved somewhat with training, though the correlation isn't strong enough to expect to be able to eliminate your calibration error just by training longer. Instead, we will see how you can effectively calibrate a model after-the-fact:
Generating and Applying a Model Calibrator¶
While there have been many proposed approaches for model calibration, we will again be leveraging the Verified Uncertainty Calibration paper mentioned above to achieve highly sample-efficient model re-calibration. There are two steps involved here. The first step is that we will use the PBMCalibrator
trace to generate a 'platt binner marginal calibrator'. This calibrator is separate from the neural network, but will take neural network outputs and return calibrated outputs. A consequence of performing this calibration is that the output vector for a prediction will no longer sum to 1, since each class is calibrated independently.
Of course, simply having such a calibration object is not useful if we don't use it. To make use of our calibrator object we will use the Calibrate
numpyOp, which can load any calibrator object from disk and then apply it during Network
post-processing. Since we are using a best model saver, we will only save the calibrator object when our since_best is 0 so that when we re-load the best model we will also be loading the correct calibrator for that model.
save_path = os.path.join(tempfile.mkdtemp(), 'calibrator.pkl')
estimator = build_estimator(extra_traces=[CalibrationError(true_key="y", pred_key="y_pred", confidence_interval=95),
PBMCalibrator(true_key="y", pred_key="y_pred", save_path=save_path, save_if_key="since_best_mcc", mode="eval"),
# We will also compare the MCC and calibration error between the original and calibrated samples:
MCC(true_key="y", pred_key="y_pred_calibrated", output_name="mcc (calibrated)", mode="test"),
CalibrationError(true_key="y", pred_key="y_pred_calibrated", output_name="calibration_error (calibrated)", confidence_interval=95, mode="test"),
],
postprocessing_ops = Calibrate(inputs="y_pred", outputs="y_pred_calibrated", calibration_fn=save_path, mode="test"))
summary = estimator.fit("experiment2")
______ __ ______ __ _ __ / ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____ / /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/ / __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / / /_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/ FastEstimator-Start: step: 1; logging_interval: 300; num_device: 0; WARNING:tensorflow:5 out of the last 43 calls to <function TFNetwork._forward_step_static at 0x17a27d670> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details. FastEstimator-Train: step: 1; ce: 2.2869081; FastEstimator-Train: step: 300; ce: 1.4376379; steps/sec: 15.34; FastEstimator-Train: step: 391; epoch: 1; epoch_time: 26.92 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmp8silmrs5/model1_best_mcc.h5 FastEstimator-PBMCalibrator: Calibrator written to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpimlnddtj/calibrator.pkl FastEstimator-Eval: step: 391; epoch: 1; calibration_error: (0.0282, 0.033, 0.0391); ce: 1.2963489; max_mcc: 0.47887144783984326; mcc: 0.47887144783984326; since_best_mcc: 0; FastEstimator-Train: step: 600; ce: 1.2960017; steps/sec: 14.03; FastEstimator-Train: step: 782; epoch: 2; epoch_time: 27.64 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmp8silmrs5/model1_best_mcc.h5 FastEstimator-PBMCalibrator: Calibrator written to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpimlnddtj/calibrator.pkl FastEstimator-Eval: step: 782; epoch: 2; calibration_error: (0.033, 0.0385, 0.0417); ce: 1.1607816; max_mcc: 0.5413508452419946; mcc: 0.5413508452419946; since_best_mcc: 0; FastEstimator-Train: step: 900; ce: 1.3636127; steps/sec: 13.55; FastEstimator-Train: step: 1173; epoch: 3; epoch_time: 30.64 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmp8silmrs5/model1_best_mcc.h5 FastEstimator-PBMCalibrator: Calibrator written to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpimlnddtj/calibrator.pkl FastEstimator-Eval: step: 1173; epoch: 3; calibration_error: (0.0224, 0.0271, 0.0339); ce: 1.0179937; max_mcc: 0.6101771269094974; mcc: 0.6101771269094974; since_best_mcc: 0; FastEstimator-Train: step: 1200; ce: 1.084051; steps/sec: 12.39; FastEstimator-Train: step: 1500; ce: 1.0849717; steps/sec: 11.78; FastEstimator-Train: step: 1564; epoch: 4; epoch_time: 33.79 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmp8silmrs5/model1_best_mcc.h5 FastEstimator-PBMCalibrator: Calibrator written to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpimlnddtj/calibrator.pkl FastEstimator-Eval: step: 1564; epoch: 4; calibration_error: (0.0199, 0.0254, 0.0301); ce: 0.9683536; max_mcc: 0.6222706205728137; mcc: 0.6222706205728137; since_best_mcc: 0; FastEstimator-Train: step: 1800; ce: 1.002698; steps/sec: 11.22; FastEstimator-Train: step: 1955; epoch: 5; epoch_time: 35.22 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmp8silmrs5/model1_best_mcc.h5 FastEstimator-PBMCalibrator: Calibrator written to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpimlnddtj/calibrator.pkl FastEstimator-Eval: step: 1955; epoch: 5; calibration_error: (0.033, 0.037, 0.0426); ce: 0.9345478; max_mcc: 0.6267778416655716; mcc: 0.6267778416655716; since_best_mcc: 0; FastEstimator-Train: step: 2100; ce: 1.0518067; steps/sec: 10.68; FastEstimator-Train: step: 2346; epoch: 6; epoch_time: 35.07 sec; FastEstimator-Eval: step: 2346; epoch: 6; calibration_error: (0.044, 0.049, 0.0534); ce: 0.94733584; max_mcc: 0.6267778416655716; mcc: 0.6233546799342925; since_best_mcc: 1; FastEstimator-Train: step: 2400; ce: 0.9290854; steps/sec: 11.98; FastEstimator-Train: step: 2700; ce: 0.8167516; steps/sec: 13.0; FastEstimator-Train: step: 2737; epoch: 7; epoch_time: 29.83 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmp8silmrs5/model1_best_mcc.h5 FastEstimator-PBMCalibrator: Calibrator written to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpimlnddtj/calibrator.pkl FastEstimator-Eval: step: 2737; epoch: 7; calibration_error: (0.0268, 0.0312, 0.0355); ce: 0.882226; max_mcc: 0.6533327147622581; mcc: 0.6533327147622581; since_best_mcc: 0; FastEstimator-Train: step: 3000; ce: 0.8312993; steps/sec: 13.76; FastEstimator-Train: step: 3128; epoch: 8; epoch_time: 28.22 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmp8silmrs5/model1_best_mcc.h5 FastEstimator-PBMCalibrator: Calibrator written to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpimlnddtj/calibrator.pkl FastEstimator-Eval: step: 3128; epoch: 8; calibration_error: (0.0315, 0.0353, 0.0407); ce: 0.86187744; max_mcc: 0.6627038830783283; mcc: 0.6627038830783283; since_best_mcc: 0; FastEstimator-Train: step: 3300; ce: 1.00288; steps/sec: 14.13; FastEstimator-Train: step: 3519; epoch: 9; epoch_time: 27.96 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmp8silmrs5/model1_best_mcc.h5 FastEstimator-PBMCalibrator: Calibrator written to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpimlnddtj/calibrator.pkl FastEstimator-Eval: step: 3519; epoch: 9; calibration_error: (0.0194, 0.023, 0.0279); ce: 0.80819094; max_mcc: 0.6796538303957567; mcc: 0.6796538303957567; since_best_mcc: 0; FastEstimator-Train: step: 3600; ce: 0.86661005; steps/sec: 13.73; FastEstimator-Train: step: 3900; ce: 0.9506703; steps/sec: 14.18; FastEstimator-Train: step: 3910; epoch: 10; epoch_time: 27.83 sec; FastEstimator-Eval: step: 3910; epoch: 10; calibration_error: (0.0382, 0.0426, 0.0465); ce: 0.8241831; max_mcc: 0.6796538303957567; mcc: 0.6746589510196896; since_best_mcc: 1; FastEstimator-Train: step: 4200; ce: 0.66708326; steps/sec: 14.2; FastEstimator-Train: step: 4301; epoch: 11; epoch_time: 27.53 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmp8silmrs5/model1_best_mcc.h5 FastEstimator-PBMCalibrator: Calibrator written to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpimlnddtj/calibrator.pkl FastEstimator-Eval: step: 4301; epoch: 11; calibration_error: (0.026, 0.031, 0.0372); ce: 0.7791259; max_mcc: 0.6915740268068222; mcc: 0.6915740268068222; since_best_mcc: 0; FastEstimator-Train: step: 4500; ce: 0.78781855; steps/sec: 14.01; FastEstimator-Train: step: 4692; epoch: 12; epoch_time: 31.77 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmp8silmrs5/model1_best_mcc.h5 FastEstimator-PBMCalibrator: Calibrator written to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpimlnddtj/calibrator.pkl FastEstimator-Eval: step: 4692; epoch: 12; calibration_error: (0.0317, 0.0363, 0.0405); ce: 0.7906082; max_mcc: 0.694238202567274; mcc: 0.694238202567274; since_best_mcc: 0; FastEstimator-Train: step: 4800; ce: 0.87651783; steps/sec: 10.46; FastEstimator-Train: step: 5083; epoch: 13; epoch_time: 37.82 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmp8silmrs5/model1_best_mcc.h5 FastEstimator-PBMCalibrator: Calibrator written to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpimlnddtj/calibrator.pkl FastEstimator-Eval: step: 5083; epoch: 13; calibration_error: (0.021, 0.0252, 0.03); ce: 0.7387682; max_mcc: 0.7108036127220436; mcc: 0.7108036127220436; since_best_mcc: 0; FastEstimator-Train: step: 5100; ce: 0.8433088; steps/sec: 10.66; FastEstimator-Train: step: 5400; ce: 0.5917939; steps/sec: 13.79; FastEstimator-Train: step: 5474; epoch: 14; epoch_time: 28.78 sec; FastEstimator-Eval: step: 5474; epoch: 14; calibration_error: (0.0295, 0.0331, 0.0366); ce: 0.77765036; max_mcc: 0.7108036127220436; mcc: 0.6947198040176703; since_best_mcc: 1; FastEstimator-Train: step: 5700; ce: 0.8003416; steps/sec: 14.15; FastEstimator-Train: step: 5865; epoch: 15; epoch_time: 27.74 sec; FastEstimator-Eval: step: 5865; epoch: 15; calibration_error: (0.0376, 0.0426, 0.0477); ce: 0.7729173; max_mcc: 0.7108036127220436; mcc: 0.6948574421893091; since_best_mcc: 2; FastEstimator-Train: step: 6000; ce: 0.79758346; steps/sec: 14.07; FastEstimator-Train: step: 6256; epoch: 16; epoch_time: 27.36 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmp8silmrs5/model1_best_mcc.h5 FastEstimator-PBMCalibrator: Calibrator written to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpimlnddtj/calibrator.pkl FastEstimator-Eval: step: 6256; epoch: 16; calibration_error: (0.0202, 0.024, 0.03); ce: 0.7178384; max_mcc: 0.7197751708113554; mcc: 0.7197751708113554; since_best_mcc: 0; FastEstimator-Train: step: 6300; ce: 0.8341557; steps/sec: 14.05; FastEstimator-Train: step: 6600; ce: 0.73111624; steps/sec: 14.46; FastEstimator-Train: step: 6647; epoch: 17; epoch_time: 27.31 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmp8silmrs5/model1_best_mcc.h5 FastEstimator-PBMCalibrator: Calibrator written to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpimlnddtj/calibrator.pkl FastEstimator-Eval: step: 6647; epoch: 17; calibration_error: (0.0241, 0.0306, 0.0353); ce: 0.7144198; max_mcc: 0.7279127104193787; mcc: 0.7279127104193787; since_best_mcc: 0; FastEstimator-Train: step: 6900; ce: 0.8111193; steps/sec: 14.57; FastEstimator-Train: step: 7038; epoch: 18; epoch_time: 27.24 sec; FastEstimator-Eval: step: 7038; epoch: 18; calibration_error: (0.0292, 0.0348, 0.0407); ce: 0.71747106; max_mcc: 0.7279127104193787; mcc: 0.7237059364078412; since_best_mcc: 1; FastEstimator-Train: step: 7200; ce: 0.71593326; steps/sec: 14.33; FastEstimator-Train: step: 7429; epoch: 19; epoch_time: 27.45 sec; FastEstimator-Eval: step: 7429; epoch: 19; calibration_error: (0.0228, 0.0282, 0.0336); ce: 0.69454527; max_mcc: 0.7279127104193787; mcc: 0.7254248272218604; since_best_mcc: 2; FastEstimator-Train: step: 7500; ce: 0.92269725; steps/sec: 13.95; FastEstimator-Train: step: 7800; ce: 0.7913792; steps/sec: 11.81; FastEstimator-Train: step: 7820; epoch: 20; epoch_time: 32.34 sec; FastEstimator-BestModelSaver: Saved model to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmp8silmrs5/model1_best_mcc.h5 FastEstimator-PBMCalibrator: Calibrator written to /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpimlnddtj/calibrator.pkl FastEstimator-Eval: step: 7820; epoch: 20; calibration_error: (0.019, 0.0248, 0.0301); ce: 0.66156435; max_mcc: 0.7342423316872604; mcc: 0.7342423316872604; since_best_mcc: 0; FastEstimator-Train: step: 8100; ce: 0.72583616; steps/sec: 12.59; FastEstimator-Train: step: 8211; epoch: 21; epoch_time: 30.54 sec; FastEstimator-Eval: step: 8211; epoch: 21; calibration_error: (0.0253, 0.033, 0.0379); ce: 0.69342196; max_mcc: 0.7342423316872604; mcc: 0.7329696458591745; since_best_mcc: 1; FastEstimator-BestModelSaver: Restoring model from /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmp8silmrs5/model1_best_mcc.h5 FastEstimator-Finish: step: 8211; model1_lr: 0.001; total_time: 841.56 sec;
estimator.test()
FastEstimator-Calibrate: calibration function loaded from /var/folders/3r/h9kh47050gv6rbt_pgf8cl540000gn/T/tmpimlnddtj/calibrator.pkl FastEstimator-Test: step: 8211; epoch: 21; calibration_error: (0.023, 0.0275, 0.0335); calibration_error (calibrated): (0.0015, 0.0054, 0.0106); ce: 0.69831085; mcc: 0.7473948072665257; mcc (calibrated): 0.7516498941076983;
<fastestimator.summary.summary.Summary at 0x17a297370>
visualize_logs([summary], include_metrics={'calibration_error', 'mcc', 'ce', "calibration_error (calibrated)", "mcc (calibrated)"})
delta = summary.history['test']['mcc (calibrated)'][8211] - summary.history['test']['mcc'][8211]
relative_delta = delta / summary.history['test']['mcc'][8211]
print(f"mcc change after calibration: {delta} ({relative_delta*100}%)")
mcc change after calibration: 0.004255086841172595 (0.5693225052947423%)
delta = summary.history['test']['calibration_error (calibrated)'][8211].y - summary.history['test']['calibration_error'][8211].y
relative_delta = delta / summary.history['test']['calibration_error'][8211].y
print(f"calibration error change after calibration: {delta} ({relative_delta*100}%)")
calibration error change after calibration: -0.0221 (-80.36363636363637%)
As we can see from the graphs and values above, with the use of a platt binning marginal calibrator we can dramatically reduce a model's calibration error (in this case by over 80%) while sacrificing only a very small amount of model performance (in this case less than a 1% reduction in MCC).