Microscopy Cell Segmentation Using Unet 3D 3plusĀ¶
[Paper] [Notebook] [TF Implementation] [Torch Implementation]
UNet, which is one of deep learning networks with an encoder-decoder architecture, is widely used in medical image segmentation. Combining multi-scale features is one of important factors for accurate segmentation. UNet++ was developed as a modified Unet by designing an architecture with nested and dense skip connections. However, it does not explore sufficient information from full scales and there is still a large room for improvement.
UNET 3plus full-scale skip connections convert the inter-connection between the encoder and decoder as well as intra-connection between the decoder sub-networks. Both UNet with plain connections and UNet++ with nested and dense connections are short of exploring sufficient information from full scales, failing to explicitly learn position and boundary of an organ. To remedy the defect in UNet and UNet++, each decoder layer in UNet 3+ incorporates both smaller- and same-scale feature maps from encoder and larger-scale feature maps from decoder, which capturing fine-grained details and coarse-grained semantics in full scales.
This example of UNET 3D 3plus to modification of UNET 3plus to segmentation 3D dataset. We are showcasing UNET 3D 3plus to segment electronic microscopy 3D cell dataset.
Getting things readyĀ¶
Lets import the necessary packages:
import tempfile
import fastestimator as fe
from fastestimator.dataset.data.em_3d import load_data
from fastestimator.op.numpyop.meta import Sometimes
from fastestimator.op.numpyop.multivariate import HorizontalFlip, VerticalFlip
from fastestimator.op.numpyop.univariate import ChannelTranspose, Minmax
from fastestimator.op.numpyop.univariate.expand_dims import ExpandDims
from fastestimator.op.tensorop import TensorOp
from fastestimator.op.tensorop.loss import CrossEntropy
from fastestimator.op.tensorop.model import ModelOp, UpdateOp
from fastestimator.op.tensorop.resize3d import Resize3D
from fastestimator.op.tensorop.argmax import Argmax
from fastestimator.trace.adapt import EarlyStopping, ReduceLROnPlateau
from fastestimator.trace.io import BestModelSaver
from fastestimator.trace.metric import Dice
from fastestimator.util import ImageDisplay, GridDisplay
Letās define several uitlitiy functions:
import copy
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import animation, rc
rc('animation', html='jshtml')
def apply_mask(image, mask, color, alpha=0.1):
"""
Apply the given mask to the image.
image: the input image(H, W, D, C)
mask: the mask to overlay(H, W, D)
color_mapping: list of color value [C]
"""
channels = image.shape[-1]
for c in range(channels):
image[:, :, :, c] = np.where(mask == 1,
image[:, :, :, c] *
(1 - alpha) + alpha * color[c],
image[:, :, :, c])
return image
def generate_mask_overlay_image(image, mask, color_mapping):
"""
Generate a image overlaying mask over it.
image: the input image(H, W, D, C)
mask: the mask to overlay(H, W, D, class)
color_mapping: a dictionary mapping input classes to colors(eg: {0: [0, 0, 0], 1: [0, 40, 255]})
The dictionary should have length "class" and each element should contain list of length "C"
image and color_mapping are expected to be in same format(eg 0-1 or 0-255)
"""
combined_image = copy.deepcopy(image)
classes = mask.shape[-1]
for i in range(classes):
combined_image = apply_mask(combined_image, mask[:,:,:,i], color_mapping[i])
return combined_image
def create_animation(images, labels):
"""
create animation combining input image and mask
image: the input image(H, W, D, 3)
labels: the image to compare with(H, W, D, 3)
"""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
ax1.axis('off')
ax2.axis('off')
im1 = ax1.imshow(images[:, :, 0,:])
im2 = ax2.imshow(labels[:, :, 0,:])
fig.show()
im = [im1, im2]
def animate_func(i):
im[0].set_array(images[:, :, i,:])
im[1].set_array(labels[:, :, i,:])
return [im]
return animation.FuncAnimation(fig,
animate_func,
frames=images.shape[-2],
interval=images.shape[-2]*4)
color_mapping = {
0: [0, 0, 0],
1: [0.0, 0.16, 1.0],
2: [0.0, 0.83, 1.0],
3: [0.49, 1.0, 0.48],
4: [0.50, 0.0, 0.0],
5: [1.0, 0.28, 0.0],
6: [1.0, 0.90, 0.0]}
def save_gif(image, label, color_mapping, file_name):
overlay_image = generate_mask_overlay_image(image, label, color_mapping)
create_animation(image, overlay_image).save(file_name)
Next, let's set up some hyperparameters related to the task:
batch_size = 1
epochs = 40
log_steps = 20
height = 256
width = 256
depth = 24
channels = 1
num_classes = 6
filters = 64
learning_rate = 1e-3
train_steps_per_epoch = None
eval_steps_per_epoch = None
save_dir = tempfile.mkdtemp()
data_dir = None
Importing DatasetĀ¶
Electronic Microscopy 3D cell dataset, consists of 2 3D images, one 800x800x50 and the other 800x800x24. The 800x800x50 is used as training dataset and 800x800x24 is used for validation. Instead of using the entire 800x800 images, the 800x800x50 is tiled into 256x256x24 tiles with an overlap of 128 producing around 75 training images and similarly the 800x800x24 image is tiled to produce 25 validation images.
train_data, eval_data = load_data(data_dir)
print("training dataset length is {}".format(len(train_data)))
print("evaluation dataset length is {}".format(len(eval_data)))
print("dataset sample:")
print("Image Shape: ", train_data[0]['image'].shape)
print("Label Shape: ", train_data[0]['label'].shape)
training dataset length is 75 evaluation dataset length is 25 dataset sample: Image Shape: (256, 256, 24) Label Shape: (256, 256, 24, 6)
The image
is a 256x256x24 numpy array of uint16.
The label
is a 256x256x24*6 encoded(6 classes) numpy array. Semantic label files classify each image voxel into one of six classes, indexed from 0-5:
Index | Color | Class name |
0 | Dark Blue | Cell |
1 | Cyan | Mitochondria |
2 | Green | Alpha granule |
3 | Yellow | Canalicular vessel |
4 | Red Dense | granule body |
5 | Purple | granule core |
Creating PipelineĀ¶
Now that both training and validation datasets are created, we use Pipeline
to define the preprocessing operations:
We are using HorizontalFlip and VerticalFlip as our applied data agumentations.
ExpandDims is used to expand the last channel(256x256x24)-> (256x256x24x1).
pipeline = fe.Pipeline(
train_data=train_data,
eval_data=eval_data,
batch_size=batch_size,
ops=[
Sometimes(numpy_op=HorizontalFlip(image_in="image", mask_in="label", mode='train')),
Sometimes(numpy_op=VerticalFlip(image_in="image", mask_in="label", mode='train')),
Minmax(inputs="image", outputs="image"),
ExpandDims(inputs="image", outputs="image"),
ChannelTranspose(inputs=("image", "label"), outputs=("image", "label"), axes=(3, 0, 1, 2))
])
data = pipeline.get_results(mode='train')
image = data['image'].numpy()
label = data['label'].numpy()
print("Image shape: ", image.shape," Mask shape:", label.shape)
Image shape: (1, 1, 256, 256, 24) Mask shape: (1, 6, 256, 256, 24)
Visualizing Sample Data:Ā¶
image = np.squeeze(image[0])
label = np.transpose(label[0], (1, 2, 3, 0))
Now lets visualize pipeline output, a single slice of image with the label overlaying on it using ImageDisplay.
ImageDisplay(image=image[:,:,0], masks=label[:,:,0], title='Ground Truth').show()
If you want you visualize the all the slices of pipeline output, please use below code to save multi slice visualization as gif.
# adding ground truth layer
input_gt = np.ones(label.shape[:-1] + (1,), dtype=label.dtype)
input_gt[np.sum(label, axis=-1)==1] = 0
label_out = np.concatenate((input_gt, label), axis=-1)
image = np.tile(np.expand_dims(image, -1), 3)
save_gif(image, label_out, color_mapping, 'training_data.gif')