calibration_error
CalibrationError
¶
Bases: Trace
A trace which computes the calibration error for a given set of predictions.
Unlike many common calibration error estimation algorithms, this one has actual theoretical bounds on the quality of its output: https://arxiv.org/pdf/1909.10155v1.pdf.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
true_key |
str
|
Name of the key that corresponds to ground truth in the batch dictionary. |
required |
pred_key |
str
|
Name of the key that corresponds to predicted score in the batch dictionary. |
required |
mode |
Union[None, str, Iterable[str]]
|
What mode(s) to execute this Trace in. For example, "train", "eval", "test", or "infer". To execute regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument like "!infer" or "!train". |
('eval', 'test')
|
ds_id |
Union[None, str, Iterable[str]]
|
What dataset id(s) to execute this Trace in. To execute regardless of ds_id, pass None. To execute in all ds_ids except for a particular one, you can pass an argument like "!ds1". |
None
|
output_name |
str
|
What to call the output from this trace (for example in the logger output). |
'calibration_error'
|
method |
str
|
Either 'marginal' or 'top-label'. 'marginal' calibration averages the calibration error over each class, whereas 'top-label' computes the error based on only the most confident predictions. |
'marginal'
|
confidence_interval |
Optional[int]
|
The calibration error confidence interval to be reported (estimated empirically). Should be in the range (0, 100), or else None to omit this extra calculation. |
None
|
per_ds |
bool
|
Whether to automatically compute this metric individually for every ds_id it runs on, in addition to
computing an aggregate across all ds_ids on which it runs. This is automatically False if |
True
|