Advanced Tutorial 17: Slicer¶
Overview¶
In this tutorial, we will talk about the following topics:
- Slicer Overview
- Example Usecase 1: 3D to 2D
- Example Usecase 2: Sliding Windows
- Building Your Own Slicer
This tutorial will demonstrate modifications to the UNet3D Apphub, so you may want to look at that first to get more context for the problem setting.
Slicer Overview¶
Suppose you have a single batch of data, but for whatever reason you can't / don't want to run the entire thing through your model at once. Slicers
allow you to cut a batch of data apart and run it through your Network
in chunks. Slicers
then re-combine the output before passing it along to Network
post-processing and Trace
functions.
In this tutorial we will take a look at this through the lens of an electronic microscopy 3D cell dataset. This dataset contains only 2 images: one of size 800x800x50 and the other 800x800x24. These would be rather large images to feed through a network in a single step.
In our UNet3D Apphub we took advantage of the FE dataset implementation to automatically convert the data volumes into around 75 training images and 25 validation images, each of size 256x256x24. That was fine for our 3D model, but what if you want to pass the data directly into a 2D network? Or what if you think side lengths of 256 are too boring and you want to try something else instead? Well, in that case you have come to the right place. Let's see how to make it happen...
Example Usecase 1: 3D data into a 2D Network¶
One reason that you might want to slice data apart is if you have 3D image volumes, but you want to inference them slice-by-slice using a 2D network architecture. This can be accommodated using an AxisSlicer
.
The Data¶
First let's load up some 3D data and have a look:
from fastestimator.dataset.data.em_3d import load_data
train_data, eval_data = load_data()
print(f"Training Samples: {len(train_data)}")
print(f"Eval Samples: {len(eval_data)}")
print(f"Sample Shape: {train_data[0]['image'].shape}")
Training Samples: 75 Eval Samples: 25 Sample Shape: (256, 256, 24)
We'll feed this data through a pretty straightforward pre-processing pipeline:
import fastestimator as fe
from fastestimator.op.numpyop.meta import Sometimes
from fastestimator.op.numpyop.multivariate import HorizontalFlip, VerticalFlip
from fastestimator.op.numpyop.univariate import Minmax
from fastestimator.op.numpyop.univariate.expand_dims import ExpandDims
pipeline = fe.Pipeline(
train_data=train_data,
eval_data=eval_data,
batch_size=1,
ops=[
Sometimes(numpy_op=HorizontalFlip(image_in="image", mask_in="label", mode='train')),
Sometimes(numpy_op=VerticalFlip(image_in="image", mask_in="label", mode='train')),
Minmax(inputs="image", outputs="image"),
ExpandDims(inputs="image", outputs="image"), # We'll add a channel dimension to the images
])
sample_batch = pipeline.get_results()
print(f"Image Batch Shape: {sample_batch['image'].shape}")
print(f"Label Batch Shape: {sample_batch['label'].shape}")
Image Batch Shape: torch.Size([1, 256, 256, 24, 1]) Label Batch Shape: torch.Size([1, 256, 256, 24, 6])
The Network¶
In our apphub we spent quite a lot of time defining a 3D model architecture that could handle this data. But what if you just want to use a basic 2D UNet Model? In that case you'll want to use an AxisSlicer
. The AxisSlicer
slices our input data along a given axis and then runs each slice through the network separately. The slices are then put back together again before being passed on to subsequent Network
post-processing Ops
or Traces
. You'll also need to specify how to un-slice (re-combine) the data which is generated during the repeated forward passes over the Network
. In this case we will use a MeanUnslicer
in order to merge our loss terms together for log printing, and re-use our AxisSlicer
to stack the network predictions back into a 3D volume:
import tensorflow as tf
from fastestimator.architecture.tensorflow import UNet
from fastestimator.op.tensorop.loss import CrossEntropy
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
from fastestimator.slicer import AxisSlicer, MeanUnslicer
model = fe.build(model_fn=lambda: UNet(input_size=(256, 256, 1), output_channel=6),
optimizer_fn=lambda: tf.optimizers.legacy.Adam(learning_rate=0.0001),)
network = fe.Network(
ops=[
ModelOp(inputs="image", model=model, outputs="pred"),
CrossEntropy(inputs=("pred", "label"), outputs="loss", form="binary"),
UpdateOp(model=model, loss_name="loss")
],
slicers=[
AxisSlicer(slice=["image", "label"], unslice=["pred"], axis=3),
MeanUnslicer(unslice="loss")
])
Metal device set to: Apple M2 Max
2023-06-27 17:45:47.879720: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support. 2023-06-27 17:45:47.879753: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
sample_prediction = network.transform(data=sample_batch, mode='eval')
print(f"Image Batch Shape: {sample_prediction['image'].shape}")
print(f"Label Batch Shape: {sample_prediction['label'].shape}")
print(f"Prediction Batch Shape: {sample_prediction['pred'].shape}")
print(f"Mean Loss Value: {sample_prediction['loss']}")
Image Batch Shape: (1, 256, 256, 24, 1) Label Batch Shape: (1, 256, 256, 24, 6) Prediction Batch Shape: (1, 256, 256, 24, 6) Mean Loss Value: 1.1275629997253418
Finally, let's add in a metric trace and see the training:
from fastestimator.trace.metric import Dice
channel_mapping={0: 'Cell', 1: 'Mitochondria', 2: 'AlphaGranule',
3: 'CanalicularVessel', 4: 'GranuleBody', 5: 'GranuleCore'}
traces=[Dice(true_key="label", pred_key="pred", channel_mapping=channel_mapping)]
estimator = fe.Estimator(
pipeline=pipeline,
network=network,
traces=traces,
log_steps=75,
eval_log_steps=[-1],
epochs=10)
# estimator.fit() # Uncomment this if you want to re-run the training yourself
______ __ ______ __ _ __
/ ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____
/ /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/
/ __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / /
/_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/
FastEstimator-Warn: No ModelSaver Trace detected. Models will not be saved.
FastEstimator-Start: step: 1; logging_interval: 75; num_device: 1;
FastEstimator-Train: step: 1; loss: 0.35076174;
FastEstimator-Train: step: 75; loss: 0.06861534; steps/sec: 0.45;
FastEstimator-Train: step: 75; epoch: 1; epoch_time(sec): 173.2;
FastEstimator-Eval: step: 75; epoch: 1; Dice: 0.1650141; Dice_AlphaGranule: 0.0; Dice_CanalicularVessel: 0.0; Dice_Cell: 0.56379896; Dice_GranuleBody: 0.0; Dice_GranuleCore: 0.41382408; Dice_Mitochondria: 0.012461557; loss: 0.22645305;
FastEstimator-Train: step: 150; loss: 0.036027152; steps/sec: 0.41;
FastEstimator-Train: step: 150; epoch: 2; epoch_time(sec): 182.15;
FastEstimator-Eval: step: 150; epoch: 2; Dice: 0.28468058; Dice_AlphaGranule: 0.22724058; Dice_CanalicularVessel: 0.0; Dice_Cell: 0.84920985; Dice_GranuleBody: 0.0; Dice_GranuleCore: 0.59404886; Dice_Mitochondria: 0.037584234; loss: 0.11351043;
FastEstimator-Train: step: 225; loss: 0.019175015; steps/sec: 0.43;
FastEstimator-Train: step: 225; epoch: 3; epoch_time(sec): 173.2;
FastEstimator-Eval: step: 225; epoch: 3; Dice: 0.27223733; Dice_AlphaGranule: 0.14862165; Dice_CanalicularVessel: 0.0; Dice_Cell: 0.87892294; Dice_GranuleBody: 0.0; Dice_GranuleCore: 0.57796216; Dice_Mitochondria: 0.027917214; loss: 0.10777818;
FastEstimator-Train: step: 300; loss: 0.038775463; steps/sec: 0.43;
FastEstimator-Train: step: 300; epoch: 4; epoch_time(sec): 175.43;
FastEstimator-Eval: step: 300; epoch: 4; Dice: 0.26352742; Dice_AlphaGranule: 0.06612307; Dice_CanalicularVessel: 0.0; Dice_Cell: 0.85590476; Dice_GranuleBody: 0.0; Dice_GranuleCore: 0.6462211; Dice_Mitochondria: 0.012915753; loss: 0.15005614;
FastEstimator-Train: step: 375; loss: 0.041875094; steps/sec: 0.42;
FastEstimator-Train: step: 375; epoch: 5; epoch_time(sec): 179.78;
FastEstimator-Eval: step: 375; epoch: 5; Dice: 0.25157312; Dice_AlphaGranule: 0.09126292; Dice_CanalicularVessel: 0.0; Dice_Cell: 0.8459115; Dice_GranuleBody: 0.0; Dice_GranuleCore: 0.5518494; Dice_Mitochondria: 0.020414837; loss: 0.14314592;
FastEstimator-Train: step: 450; loss: 0.036384713; steps/sec: 0.42;
FastEstimator-Train: step: 450; epoch: 6; epoch_time(sec): 179.49;
FastEstimator-Eval: step: 450; epoch: 6; Dice: 0.24621534; Dice_AlphaGranule: 0.07902191; Dice_CanalicularVessel: 0.0016777955; Dice_Cell: 0.81178457; Dice_GranuleBody: 0.0; Dice_GranuleCore: 0.5809567; Dice_Mitochondria: 0.0038511453; loss: 0.19332533;
FastEstimator-Train: step: 525; loss: 0.032831687; steps/sec: 0.42;
FastEstimator-Train: step: 525; epoch: 7; epoch_time(sec): 177.97;
FastEstimator-Eval: step: 525; epoch: 7; Dice: 0.36281154; Dice_AlphaGranule: 0.3231712; Dice_CanalicularVessel: 0.15168218; Dice_Cell: 0.92331016; Dice_GranuleBody: 0.010409508; Dice_GranuleCore: 0.75507385; Dice_Mitochondria: 0.013222287; loss: 0.099501915;
FastEstimator-Train: step: 600; loss: 0.039932176; steps/sec: 0.42;
FastEstimator-Train: step: 600; epoch: 8; epoch_time(sec): 178.89;
FastEstimator-Eval: step: 600; epoch: 8; Dice: 0.40107933; Dice_AlphaGranule: 0.43582454; Dice_CanalicularVessel: 0.15312457; Dice_Cell: 0.9320569; Dice_GranuleBody: 0.11979302; Dice_GranuleCore: 0.76528555; Dice_Mitochondria: 0.0003915782; loss: 0.07295823;
FastEstimator-Train: step: 675; loss: 0.025439756; steps/sec: 0.42;
FastEstimator-Train: step: 675; epoch: 9; epoch_time(sec): 180.04;
FastEstimator-Eval: step: 675; epoch: 9; Dice: 0.39481708; Dice_AlphaGranule: 0.3910953; Dice_CanalicularVessel: 0.14935504; Dice_Cell: 0.930211; Dice_GranuleBody: 0.118480444; Dice_GranuleCore: 0.7636419; Dice_Mitochondria: 0.016118763; loss: 0.094853036;
FastEstimator-Train: step: 750; loss: 0.016932765; steps/sec: 0.42;
FastEstimator-Train: step: 750; epoch: 10; epoch_time(sec): 180.29;
FastEstimator-Eval: step: 750; epoch: 10; Dice: 0.3541808; Dice_AlphaGranule: 0.17112847; Dice_CanalicularVessel: 0.16391844; Dice_Cell: 0.8922416; Dice_GranuleBody: 0.14522946; Dice_GranuleCore: 0.7323538; Dice_Mitochondria: 0.020212961; loss: 0.15120216;
FastEstimator-Finish: step: 750; model_lr: 1e-04; total_time(sec): 2069.42;
Note that your steps/sec as reported by logging will be much lower when using slicers than you might normally expect, since each 'step' is now actually 24 mini-steps.
Example Usecase 2: Using a sliding window to work around GPU memory limits¶
Let's imagine a world where you wanted to work with this cell data, but in which FE had not implemented it's convenient dataset tiling feature for you. Instead you are confronted with some gigantic images:
from fastestimator.dataset.data.em_3d import load_data
train_data, eval_data = load_data(tile=False)
print(f"Training Samples: {len(train_data)}")
print(f"Eval Samples: {len(eval_data)}")
print(f"Train Image Shape: {train_data[0]['image'].shape}")
print(f"Train Label Shape: {train_data[0]['label'].shape}")
print(f"Eval Sample Shape: {eval_data[0]['image'].shape}")
print(f"Eval Label Shape: {eval_data[0]['label'].shape}")
Training Samples: 1 Eval Samples: 1 Train Image Shape: (800, 800, 50) Train Label Shape: (800, 800, 50, 6) Eval Sample Shape: (800, 800, 24) Eval Label Shape: (800, 800, 24, 6)
We'll use the same pipeline processing from our first usecase:
import fastestimator as fe
from fastestimator.op.numpyop.meta import Sometimes
from fastestimator.op.numpyop.multivariate import HorizontalFlip, VerticalFlip
from fastestimator.op.numpyop.univariate import Minmax
from fastestimator.op.numpyop.univariate.expand_dims import ExpandDims
pipeline = fe.Pipeline(
train_data=train_data,
eval_data=eval_data,
batch_size=1,
ops=[
Sometimes(numpy_op=HorizontalFlip(image_in="image", mask_in="label", mode='train')),
Sometimes(numpy_op=VerticalFlip(image_in="image", mask_in="label", mode='train')),
Minmax(inputs="image", outputs="image"),
ExpandDims(inputs="image", outputs="image"), # We'll add a channel dimension to the images
],
num_process=0) # Since we only have 1 image to work with we'll turn off multi-processing
sample_batch = pipeline.get_results()
print(f"Image Batch Shape: {sample_batch['image'].shape}")
print(f"Label Batch Shape: {sample_batch['label'].shape}")
Image Batch Shape: torch.Size([1, 800, 800, 50, 1]) Label Batch Shape: torch.Size([1, 800, 800, 50, 6])
The Network¶
Rather than cutting the data beforehand, we can use a SlidingSlicer
to move over it in a sliding-window fashion. Each chunk of data is run through the network separately and then re-combined before being passed on to subsequent Network
post-processing Ops
or Traces
. You'll also need to specify how to un-slice (re-combine) the data which is generated during the repeated forward passes over the Network
. In this case we will use a MeanUnslicer
in order to merge our loss terms together for log printing, and re-use our SlidingSlicer
to paste the network predictions back into a 3D volume.
Let's cut our image into windows of size (256, 256, 1) and then run them through a 2D network. In this case we don't want to cut along the batch or channel dimensions, so we will set their values to -1 in the SlidingSlicer
window_size argument. Note that 256 does not evenly divide into 800. SlidingSlicer gives you several ways to handle this. By default it will simply drop any leftover data along each window axis. You can instead keep partial slices by switching the 'pad_mode' to 'partial', or using padding to fill out the final slices by switching 'pad_mode' to 'constant'. You could also consider customizing the 'strides' such that each window partially overlaps with the previous in such a way that your overall output properly tiles. In our example, you could stride by 136 to achieve this (136*4+256=800). When overlapping strides are re-combined later, any values where there is overlap will be averaged together. If you prefer a simple sum you can change this by modifying the 'unslice_mode' option.
Let's use the stride method here for the sake of example. Since we are feeding our data into a 2D network, we'll also use the 'squeeze_window' ability of our SlidingSlicer to automatically remove our z axis during forward passes:
import tensorflow as tf
from fastestimator.architecture.tensorflow import UNet
from fastestimator.op.tensorop.loss import CrossEntropy
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
from fastestimator.slicer import SlidingSlicer, MeanUnslicer
model = fe.build(model_fn=lambda: UNet(input_size=(256, 256, 1), output_channel=6),
optimizer_fn=lambda: tf.optimizers.legacy.Adam(learning_rate=0.0001),)
network = fe.Network(
ops=[
ModelOp(inputs="image", model=model, outputs="pred"),
CrossEntropy(inputs=("pred", "label"), outputs="loss", form="binary"),
UpdateOp(model=model, loss_name="loss")
],
slicers=[
SlidingSlicer(slice=["image", "label"],
unslice=["pred"],
window_size=(-1, 256, 256, 1, -1),
strides=(0, 136, 136, 1, 0),
squeeze_window=True,
),
MeanUnslicer(unslice="loss")
])
Let's take a look at a sample prediction after running through the network. This is going to take a while, since a single step is now responsible for processing 5x5x50=1250 mini-batches.
# sample_prediction = network.transform(data=sample_batch, mode='eval') # Uncomment this if you want to inspect the shapes yourself
# print(f"Image Batch Shape: {sample_prediction['image'].shape}")
# print(f"Label Batch Shape: {sample_prediction['label'].shape}")
# print(f"Prediction Batch Shape: {sample_prediction['pred'].shape}")
# print(f"Mean Loss Value: {sample_prediction['loss']}")
Image Batch Shape: (1, 800, 800, 50, 1) Label Batch Shape: (1, 800, 800, 50, 6) Prediction Batch Shape: (1, 800, 800, 50, 6) Mean Loss Value: 0.9557278156280518
Finally, let's add in a metric trace and see the training:
from fastestimator.trace.metric import Dice
channel_mapping={0: 'Cell', 1: 'Mitochondria', 2: 'AlphaGranule',
3: 'CanalicularVessel', 4: 'GranuleBody', 5: 'GranuleCore'}
traces=[Dice(true_key="label", pred_key="pred", channel_mapping=channel_mapping)]
estimator = fe.Estimator(
pipeline=pipeline,
network=network,
traces=traces,
log_steps=1,
eval_log_steps=[-1],
epochs=10)
# estimator.fit() # Uncomment this if you want to re-run the training yourself
______ __ ______ __ _ __
/ ____/___ ______/ /_/ ____/____/ /_(_)___ ___ ____ _/ /_____ _____
/ /_ / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/
/ __/ / /_/ (__ ) /_/ /___(__ ) /_/ / / / / / / /_/ / /_/ /_/ / /
/_/ \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/
FastEstimator-Warn: No ModelSaver Trace detected. Models will not be saved.
FastEstimator-Start: step: 1; logging_interval: 1; num_device: 1;
FastEstimator-Train: step: 1; loss: 0.09241561;
FastEstimator-Train: step: 1; epoch: 1; epoch_time(sec): 187.88;
FastEstimator-Eval: step: 1; epoch: 1; Dice: 0.19790895; Dice_AlphaGranule: 0.0; Dice_CanalicularVessel: 0.0; Dice_Cell: 0.76499873; Dice_GranuleBody: 0.0; Dice_GranuleCore: 0.42245498; Dice_Mitochondria: 0.0; loss: 0.17678283;
FastEstimator-Train: step: 2; loss: 0.05950647; steps/sec: 0.0;
FastEstimator-Train: step: 2; epoch: 2; epoch_time(sec): 207.26;
FastEstimator-Eval: step: 2; epoch: 2; Dice: 0.31357098; Dice_AlphaGranule: 0.27818844; Dice_CanalicularVessel: 0.0; Dice_Cell: 0.91047233; Dice_GranuleBody: 0.0; Dice_GranuleCore: 0.68593067; Dice_Mitochondria: 0.006834461; loss: 0.09680609;
FastEstimator-Train: step: 3; loss: 0.047980543; steps/sec: 0.0;
FastEstimator-Train: step: 3; epoch: 3; epoch_time(sec): 225.46;
FastEstimator-Eval: step: 3; epoch: 3; Dice: 0.2589626; Dice_AlphaGranule: 0.00012370278; Dice_CanalicularVessel: 0.0; Dice_Cell: 0.87941015; Dice_GranuleBody: 0.0; Dice_GranuleCore: 0.6742417; Dice_Mitochondria: 0.0; loss: 0.13281296;
FastEstimator-Train: step: 4; loss: 0.03970469; steps/sec: 0.0;
FastEstimator-Train: step: 4; epoch: 4; epoch_time(sec): 303.56;
FastEstimator-Eval: step: 4; epoch: 4; Dice: 0.2699795; Dice_AlphaGranule: 0.0039450238; Dice_CanalicularVessel: 0.0; Dice_Cell: 0.87593764; Dice_GranuleBody: 0.0; Dice_GranuleCore: 0.738118; Dice_Mitochondria: 0.0018762958; loss: 0.15984353;
FastEstimator-Train: step: 5; loss: 0.035081092; steps/sec: 0.0;
FastEstimator-Train: step: 5; epoch: 5; epoch_time(sec): 343.57;
FastEstimator-Eval: step: 5; epoch: 5; Dice: 0.28861472; Dice_AlphaGranule: 0.10057701; Dice_CanalicularVessel: 0.0; Dice_Cell: 0.878825; Dice_GranuleBody: 0.0; Dice_GranuleCore: 0.73525584; Dice_Mitochondria: 0.017030539; loss: 0.15729764;
FastEstimator-Train: step: 6; loss: 0.030691482; steps/sec: 0.0;
FastEstimator-Train: step: 6; epoch: 6; epoch_time(sec): 373.16;
FastEstimator-Eval: step: 6; epoch: 6; Dice: 0.2974657; Dice_AlphaGranule: 0.10909026; Dice_CanalicularVessel: 0.0; Dice_Cell: 0.91101515; Dice_GranuleBody: 0.0; Dice_GranuleCore: 0.74959046; Dice_Mitochondria: 0.0150984535; loss: 0.134259;
FastEstimator-Train: step: 7; loss: 0.040506963; steps/sec: 0.0;
FastEstimator-Train: step: 7; epoch: 7; epoch_time(sec): 384.6;
FastEstimator-Eval: step: 7; epoch: 7; Dice: 0.3132342; Dice_AlphaGranule: 0.36752912; Dice_CanalicularVessel: 0.10559372; Dice_Cell: 0.8301301; Dice_GranuleBody: 0.0; Dice_GranuleCore: 0.55129784; Dice_Mitochondria: 0.02485453; loss: 0.1310871;
FastEstimator-Train: step: 8; loss: 0.03810389; steps/sec: 0.0;
FastEstimator-Train: step: 8; epoch: 8; epoch_time(sec): 488.05;
FastEstimator-Eval: step: 8; epoch: 8; Dice: 0.3384758; Dice_AlphaGranule: 0.10329177; Dice_CanalicularVessel: 0.3118691; Dice_Cell: 0.8723342; Dice_GranuleBody: 0.0; Dice_GranuleCore: 0.69752693; Dice_Mitochondria: 0.045832805; loss: 0.1552314;
FastEstimator-Train: step: 9; loss: 0.030884113; steps/sec: 0.0;
FastEstimator-Train: step: 9; epoch: 9; epoch_time(sec): 550.86;
FastEstimator-Eval: step: 9; epoch: 9; Dice: 0.34353375; Dice_AlphaGranule: 0.12995201; Dice_CanalicularVessel: 0.033380825; Dice_Cell: 0.87134874; Dice_GranuleBody: 0.25112656; Dice_GranuleCore: 0.74927706; Dice_Mitochondria: 0.026117384; loss: 0.15764695;
FastEstimator-Train: step: 10; loss: 0.027712101; steps/sec: 0.0;
FastEstimator-Train: step: 10; epoch: 10; epoch_time(sec): 572.82;
FastEstimator-Eval: step: 10; epoch: 10; Dice: 0.4659363; Dice_AlphaGranule: 0.08518501; Dice_CanalicularVessel: 0.6543881; Dice_Cell: 0.8772722; Dice_GranuleBody: 0.37757555; Dice_GranuleCore: 0.7437467; Dice_Mitochondria: 0.057450105; loss: 0.14906605;
FastEstimator-Finish: step: 10; model1_lr: 1e-04; total_time(sec): 5205.04;
Note that your steps/sec as reported by logging will be much lower when using slicers than you might normally expect, since each 'step' is now actually 1250 mini-steps.
Building Your Own Slicer¶
If the existing Slicers don't meet your needs, you can always customize your own. To do so, simply inherit the base class and implement one or both of the abstract methods defined there:
from fastestimator.slicer.slicer import Slicer
class DoubleBatchSlicer(Slicer):
def _slice_batch(self, batch):
# Implement this method if you want your slicer to be able to cut keys apart
# In this toy example we just duplicate the batch tensor twice
return [batch, batch]
def _unslice_batch(self, slices, key):
# Implement this method if you want your slicer to be able to put keys back together
# In our toy example we just ignore the second mini-batch that we created
return slices[0]
As a slightly more practical example, suppose that when recording loss values while using a slicer, you want to get the maximum loss rather than the mean. You could implement a MaxUnslicer as follows:
class MaxUnslicer(Slicer):
def __init__(self, unslice, mode=None, ds_id=None):
super().__init__(slice=None, unslice=unslice, mode=mode, ds_id=ds_id)
def _unslice_batch(self, slices, key):
maks = slices[0]
for minibatch in slices[1:]:
maks = max(maks, minibatch)
return maks
Let's test out our new creation:
from fastestimator.slicer.slicer import forward_slicers, reverse_slicers
loss = tf.random.uniform((5,1))
print(f"loss values: {loss}")
batch = {"loss": loss}
axis_slicer = AxisSlicer(slice="loss", unslice=None, axis=0)
max_unslicer = MaxUnslicer(unslice="loss")
mini_batches = forward_slicers(slicers=[axis_slicer, max_unslicer], data=batch)
print(f"number of slices: {len(mini_batches)}")
output = reverse_slicers(slicers=[axis_slicer, max_unslicer], data=mini_batches, original_data=batch)
print(f"max loss: {output['loss']}")
loss values: [[0.7506162 ] [0.07465518] [0.4861052 ] [0.9194704 ] [0.28890288]] number of slices: 5 max loss: [0.9194704]