Skip to content

sliding_slicer

SlidingSlicer

Bases: Slicer

A slicer which cuts inputs using a sliding window.

This slicer cuts inputs using a sliding window, which should have the same rank as the input array. For example, if you want to feed a giant batched channel-first 2D image [B, C, W, H] into a model which can only handle smaller images of size [B, C, W2, H2] then you could set window_size=[-1, -1, W2, H2].

Parameters:

Name Type Description Default
window_size Sequence[int]

The size of your sliding window. Use -1 for any axis when you want to copy the size from the input.

required
strides Union[int, Sequence[int]]

How far to step in each direction each time you move a sliding window. Use 0 for each axis which you set as -1 in window_size. If you set strides < window_size you will get overlapping windows. If you set strides > window_size you will drop data from your input. If a single integer is provided if will be used for every axis where sliding occurs. If an empty sequence (default) is provided, then strides will be set equal to window_size such that there is no overlap between windows.

()
pad_mode str

One of 'drop', 'partial', or 'constant'. If 'drop' then the final slice along various axes will be dropped if the window_size does not tile perfectly with the input shape. If 'partial' then the final slice along various axes may be of a different (smaller) shape than other slices (if window_size does not perfectly tile the input shape). If 'constant' then the final slice along various axis will be filled out by pad_val.

'drop'
pad_val Union[int, float]

The value to fill out the input with when pad_mode is 'constant'.

0.0
slice Union[None, str, Sequence[str]]

The input key(s) which this Slicer slices. Data which this slicer does not cut will be replicated across the resulting minibatches so that the network ops always have access to all of the batch keys. If multiple keys are provided here it is expected that they all have the same shape.

None
unslice Union[None, str, Sequence[str]]

The input key(s) which this Slicer un-slices. By default (empty tuple) the Slicer will un-slice whatever keys were specified in slice. If you do not want to un-slice those keys, then pass None or manually specify the specific key(s) which you would like this slicer to un-slice.

()
unslice_mode str

When re-combining your slices, how to you want to handle overlapping regions? Options are 'sum' or 'avg'. Regardless of option, any gaps between slices will be filled in by the pad_val.

'avg'
unslice_shape Optional[Tuple[int, ...]]

The shape of your final tensor(s) after you unslice it. This argument is only required if you did not provide slice keys. If slice keys were provided, then the unslice_shape is inferred at every step based on the shape of the inputs before slicing.

None
squeeze_window bool

If enabled, then all axes which are set to 1 in window_size will be squeezed out during the network forward pass, and re-introduced later during un-slicing. This makes it more convenient if you want to, for example, use SlidingSlicer to feed 3D image data into a 2D network (though you may want to instead consider AxisSlicer as a simpler way to achieve that purpose).

False
mode Union[None, str, Iterable[str]]

What mode(s) to invoke this Slicer in. For example, "train", "eval", "test", or "infer". To invoke regardless of mode, pass None. To invoke in all modes except for a particular one, you can pass an argument like "!infer" or "!train".

None
ds_id Union[None, str, Iterable[str]]

What dataset id(s) to invoke this Slicer in. To invoke regardless of ds_id, pass None. To invoke in all ds_ids except for a particular one, you can pass an argument like "!ds1".

None
Source code in fastestimator/fastestimator/slicer/sliding_slicer.py
@traceable()
class SlidingSlicer(Slicer):
    """A slicer which cuts inputs using a sliding window.

    This slicer cuts inputs using a sliding window, which should have the same rank as the input array. For example, if
    you want to feed a giant batched channel-first 2D image [B, C, W, H] into a model which can only handle smaller
    images of size [B, C, W2, H2] then you could set window_size=[-1, -1, W2, H2].

    Args:
        window_size: The size of your sliding window. Use -1 for any axis when you want to copy the size from the input.
        strides: How far to step in each direction each time you move a sliding window. Use 0 for each axis which you
            set as -1 in `window_size`. If you set `strides` < `window_size` you will get overlapping windows. If you
            set `strides` > `window_size` you will drop data from your input. If a single integer is provided if will be
            used for every axis where sliding occurs. If an empty sequence (default) is provided, then `strides` will be
            set equal to `window_size` such that there is no overlap between windows.
        pad_mode: One of 'drop', 'partial', or 'constant'. If 'drop' then the final slice along various axes will be
            dropped if the `window_size` does not tile perfectly with the input shape. If 'partial' then the final slice
            along various axes may be of a different (smaller) shape than other slices (if `window_size` does not
            perfectly tile the input shape). If 'constant' then the final slice along various axis will be filled out by
            `pad_val`.
        pad_val: The value to fill out the input with when `pad_mode` is 'constant'.
        slice: The input key(s) which this Slicer slices. Data which this slicer does not cut will be replicated across
            the resulting minibatches so that the network ops always have access to all of the batch keys. If multiple
            keys are provided here it is expected that they all have the same shape.
        unslice: The input key(s) which this Slicer un-slices. By default (empty tuple) the Slicer will un-slice
            whatever keys were specified in `slice`. If you do not want to un-slice those keys, then pass None or
            manually specify the specific key(s) which you would like this slicer to un-slice.
        unslice_mode: When re-combining your slices, how to you want to handle overlapping regions? Options are 'sum'
            or 'avg'. Regardless of option, any gaps between slices will be filled in by the `pad_val`.
        unslice_shape: The shape of your final tensor(s) after you unslice it. This argument is only required if you did
            not provide `slice` keys. If `slice` keys were provided, then the unslice_shape is inferred at every step
            based on the shape of the inputs before slicing.
        squeeze_window: If enabled, then all axes which are set to 1 in `window_size` will be squeezed out during the
            network forward pass, and re-introduced later during un-slicing. This makes it more convenient if you want
            to, for example, use SlidingSlicer to feed 3D image data into a 2D network (though you may want to instead
            consider AxisSlicer as a simpler way to achieve that purpose).
        mode: What mode(s) to invoke this Slicer in. For example, "train", "eval", "test", or "infer". To invoke
            regardless of mode, pass None. To invoke in all modes except for a particular one, you can pass an argument
            like "!infer" or "!train".
        ds_id: What dataset id(s) to invoke this Slicer in. To invoke regardless of ds_id, pass None. To invoke in all
            ds_ids except for a particular one, you can pass an argument like "!ds1".
    """
    def __init__(self,
                 window_size: Sequence[int],
                 strides: Union[int, Sequence[int]] = (),
                 pad_mode: str = 'drop',
                 pad_val: Union[int, float] = 0.0,
                 slice: Union[None, str, Sequence[str]] = None,
                 unslice: Union[None, str, Sequence[str]] = (),
                 unslice_mode: str = "avg",
                 unslice_shape: Optional[Tuple[int, ...]] = None,
                 squeeze_window: bool = False,
                 mode: Union[None, str, Iterable[str]] = None,
                 ds_id: Union[None, str, Iterable[str]] = None) -> None:
        super().__init__(slice=slice, unslice=unslice, mode=mode, ds_id=ds_id)
        self.window_size = to_list(window_size)
        assert len(self.window_size) > 0, "A window_size must be specified"
        for dim in self.window_size:
            assert isinstance(dim, int), \
                f"All window_size dimensions must be integers, but found {dim} of type {type(dim)}"
            assert dim > 0 or dim == -1, f"All window_size dimensions must be positive or -1, but found {dim}"
            # TODO - allow window_size=0 to delete an axis? or some other way to reduce dimensionality?
        if isinstance(strides, int):
            self.strides = [strides if size > 0 else 0 for size in self.window_size]
        else:
            self.strides = to_list(strides)
            if not strides:
                self.strides = [size if size > 0 else 0 for size in self.window_size]
        assert len(self.strides) == len(self.window_size), "strides and window_size should have the same number of " + \
            f"dimensions, but got {self.strides} vs {self.window_size}"
        for dim in self.strides:
            assert isinstance(dim, int), f"All stride dimensions must be integers, but found {dim} of type {type(dim)}"
            assert dim >= 0, f"All stride dimensions must be non-negative, but found {dim}."
        options = ('drop', 'partial', 'constant')
        assert pad_mode in options, f"pad_mode must be one of {options}, but got {pad_mode}"
        self.pad_mode = pad_mode
        self.pad_val = pad_val
        options = ("sum", "avg")
        assert unslice_mode in options, f"unslice_mode must be one of: {options}, but got {unslice_mode}"
        self.unslice_mode = unslice_mode
        self.static_unslice_shape = True if unslice_shape is not None else False
        self.unslice_shape: Optional[List[int]] = to_list(unslice_shape)
        if self.unslice_inputs and not self.slice_inputs:
            assert self.static_unslice_shape, "If you want to use SlidingSlicer to unslice data you must either " + \
                "use it to perform the initial slicing, or else provide the unslice_shape that you want your data " + \
                "to take after it is reconstructed."
        self.auto_squeeze: List[int] = []
        if squeeze_window:
            for idx, size in enumerate(self.window_size):
                if size == 1:
                    self.auto_squeeze.append(idx)
        self.replica_batch_sizes: Dict[int, int] = {}

    def _slice_batch(self, batch: Tensor) -> List[Tensor]:
        shape = list(get_shape(batch))
        if len(self.window_size) != len(shape):
            raise ValueError(f"Sliding window size: {self.window_size} is incompatible with data shape: {shape}. " + \
                               "They must have the same rank.")
        if self.unslice_inputs and not self.static_unslice_shape:
            # If we have to unslice things later, then need to remember the desired shape if the user didn't give us one
            self.unslice_shape = [int(x) for x in tuple(shape)]
            if get_num_gpus() > 1:
                # Unfortunately in tf multi-gpu the batches are split over multiple replicas, in which case we need to
                # manually correct the desired batch dimension later
                replica_context = tf.distribute.get_replica_context()
                if replica_context is not None:
                    replica_id = int(replica_context.replica_id_in_sync_group)
                    self.replica_batch_sizes[replica_id] = int(shape[0])
        stride_template = [
            slice(None) if stride == 0 or stride >= dim else None for stride, dim in zip(self.strides, shape)
        ]
        cuts = self._get_cuts(shape, stride_template)
        batch = self._solve_padding(batch=batch, batch_shape=shape)
        minibatches = [batch[cut] for cut in cuts]
        for ax in reversed(self.auto_squeeze):
            minibatches = [squeeze(minibatch, axis=ax) for minibatch in minibatches]
        return minibatches

    def _get_cuts(self, data_shape: List[int], stride_template: List[Optional[slice]]) -> List[List[slice]]:
        results = []
        for axis, cut_template in enumerate(stride_template):
            if cut_template is None:
                target_shape = data_shape[axis]
                for start in range(0, target_shape, self.strides[axis]):
                    stop = start + self.window_size[axis]
                    if stop > target_shape:
                        if self.pad_mode == 'drop':
                            continue
                        elif self.pad_mode == 'partial':
                            stop = target_shape
                        else:
                            # Padding the input
                            pass
                    new_cut = slice(start, stop)
                    new_template = list(stride_template)
                    new_template[axis] = new_cut
                    results.extend(self._get_cuts(data_shape, new_template))
                # Each level of recursion should only solve one axis, so break here
                break
        if not results:
            # No slices were generated, so all of the slices are already defined within the stride_template
            return [stride_template]
        return results

    def _solve_padding(self, batch: TensorT, batch_shape: List[int]) -> TensorT:
        if self.pad_mode == 'constant':
            paddings = [[0, int(win - ((tru % (stride or tru)) or win))] for tru,
                        win,
                        stride in zip(batch_shape, self.window_size, self.strides)]
            if isinstance(batch, torch.Tensor):
                paddings.reverse()  # Torch padding reads from right-most dim to left-most dim
                paddings = [elem for x in paddings for elem in x]
                batch = torch.nn.functional.pad(batch, pad=paddings, mode='constant', value=self.pad_val)
            else:
                batch = tf.pad(batch,
                               paddings=tf.constant(paddings),
                               mode="CONSTANT",
                               constant_values=tf.cast(self.pad_val, dtype=batch.dtype))
        return batch

    def _unslice_batch(self, slices: Tuple[Tensor, ...], key: str) -> Tensor:
        target_shape = self.unslice_shape
        assert target_shape is not None, "Unit tests should run forward_slicers before running reverse_slicers"
        if self.replica_batch_sizes:
            target_shape[0] = sum(self.replica_batch_sizes.values())

        stride_template = [
            slice(None) if stride == 0 or stride >= dim else None for stride, dim in zip(self.strides, target_shape)
        ]
        cuts = self._get_cuts(target_shape, stride_template)
        assert len(cuts) == len(slices), f"SlidingSlicer could not unslice key: {key}. It received {len(slices)} " + \
            f"slices, but a target_shape of {target_shape} with strides of {self.strides} could not have produced this."

        # TF doesn't support slice assignment, and np doesn't support indexing using list of slice, so we'll use torch
        # for everything and then cast back to tf later if needed
        merged_dtype = torch.float32 if self.unslice_mode == 'avg' else to_tensor(slices[0], target_type='torch').dtype
        merged = torch.zeros(size=target_shape, dtype=merged_dtype)
        merged = self._solve_padding(batch=merged, batch_shape=target_shape)
        counts = torch.zeros(size=target_shape, dtype=torch.int32)
        counts = self._solve_padding(batch=counts, batch_shape=target_shape)

        for cut, data in zip(cuts, slices):
            for ax in self.auto_squeeze:
                data = expand_dims(data, axis=ax)
            merged[cut] += to_tensor(data=data, target_type='torch')
            counts[cut] += 1

        if self.unslice_mode == 'avg':
            merged /= torch.maximum(counts, torch.tensor(1.0))

        merged[counts == 0] = torch.tensor(self.pad_val, dtype=merged_dtype)

        # Remove any border padding which may have been added
        merged = merged[[slice(target) for target in target_shape]]

        return merged if isinstance(slices[0], torch.Tensor) else to_tensor(data=merged, target_type='tf')