Skip to content

eigen_cam

EigenCAM

Bases: Trace

A trace which draws EigenCAM heatmaps on top of images.

These are useful for visualizing the outputs of the feature extractor component of a model. They are relatively insensitive to adversarial attacks, so don't use them to try and detect those. See https://arxiv.org/abs/2008.00299 for more details.

Parameters:

Name Type Description Default
images str

The key corresponding to images onto which to draw the CAM outputs.

required
activations str

The key corresponding to outputs from a convolution layer from which to draw the CAM outputs. You can easily extract these from any model by using the 'intermediate_layers' variable in a ModelOp.

required
n_components int

How many principal components to visualize.

3
n_samples Optional[int]

How many images in total to display every epoch (or None to display all available images).

5
labels Optional[str]

The key corresponding to the true labels of the images to be visualized.

None
preds Optional[str]

The key corresponding to the model prediction for each image.

None
label_mapping Optional[Dict[str, Any]]

{class_string: model_output_value}.

None
outputs str

The key into which to write the eigencam images.

'eigencam'
mode Union[None, str, Iterable[str]]

What mode(s) to execute this Op in. For example, "train", "eval", "test", or "infer". To execute regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument like "!infer" or "!train".

'!train'
ds_id Union[None, str, Iterable[str]]

What dataset id(s) to execute this Trace in. To execute regardless of ds_id, pass None. To execute in all ds_ids except for a particular one, you can pass an argument like "!ds1".

None
Source code in fastestimator\fastestimator\trace\xai\eigen_cam.py
@traceable()
class EigenCAM(Trace):
    """A trace which draws EigenCAM heatmaps on top of images.

    These are useful for visualizing the outputs of the feature extractor component of a model. They are relatively
    insensitive to adversarial attacks, so don't use them to try and detect those. See https://arxiv.org/abs/2008.00299
    for more details.

    Args:
        images: The key corresponding to images onto which to draw the CAM outputs.
        activations: The key corresponding to outputs from a convolution layer from which to draw the CAM outputs. You
            can easily extract these from any model by using the 'intermediate_layers' variable in a ModelOp.
        n_components: How many principal components to visualize.
        n_samples: How many images in total to display every epoch (or None to display all available images).
        labels: The key corresponding to the true labels of the images to be visualized.
        preds: The key corresponding to the model prediction for each image.
        label_mapping: {class_string: model_output_value}.
        outputs: The key into which to write the eigencam images.
        mode: What mode(s) to execute this Op in. For example, "train", "eval", "test", or "infer". To execute
            regardless of mode, pass None. To execute in all modes except for a particular one, you can pass an argument
            like "!infer" or "!train".
        ds_id: What dataset id(s) to execute this Trace in. To execute regardless of ds_id, pass None. To execute in all
            ds_ids except for a particular one, you can pass an argument like "!ds1".
    """
    def __init__(self,
                 images: str,
                 activations: str,
                 n_components: int = 3,
                 n_samples: Optional[int] = 5,
                 labels: Optional[str] = None,
                 preds: Optional[str] = None,
                 label_mapping: Optional[Dict[str, Any]] = None,
                 outputs: str = "eigencam",
                 mode: Union[None, str, Iterable[str]] = "!train",
                 ds_id: Union[None, str, Iterable[str]] = None):
        self.image_key = images
        self.activation_key = activations
        self.true_label_key = labels
        self.pred_label_key = preds
        inputs = [x for x in (images, activations, labels, preds) if x is not None]
        self.n_components = n_components
        self.n_samples = n_samples
        # TODO - handle non-hashable labels
        self.label_mapping = {val: key for key, val in label_mapping.items()} if label_mapping else None
        super().__init__(inputs=inputs, outputs=outputs, mode=mode, ds_id=ds_id)
        self.images = []
        self.activations = []
        self.labels = []
        self.preds = []
        self.n_found = 0

    def _reset(self) -> None:
        """Clear memory for next epoch.
        """
        self.images = []
        self.activations = []
        self.labels = []
        self.preds = []
        self.n_found = 0

    def _project_2d(self, activations: np.ndarray) -> List[np.ndarray]:
        """Project 2D convolution activations maps into 2D principal component maps.

        Args:
            activations: A tensor of shape (batch, channels, height, width) to be transformed.

        Returns:
            Principal component projections of the `activations`.
        """
        projections = [[] for _ in range(self.n_components)]
        for activation in activations:
            flat = activation.reshape(activation.shape[0], -1).transpose()
            flat = flat - flat.mean(axis=0)
            U, S, VT = np.linalg.svd(flat, full_matrices=True)
            for i in range(self.n_components):
                component_i = flat @ VT[i, :]
                component_i = component_i.reshape(activation.shape[1:])
                projections[i].append(component_i)
        return [np.array(elem, dtype=np.float32) for elem in projections]

    def on_batch_end(self, data: Data) -> None:
        if self.n_samples is None or self.n_found < self.n_samples:
            self.images.append(data[self.image_key])
            self.activations.append(data[self.activation_key])
            if self.true_label_key:
                self.labels.append(data[self.true_label_key])
            if self.pred_label_key:
                self.preds.append(data[self.pred_label_key])
            self.n_found += len(data[self.image_key])

    def on_epoch_end(self, data: Data) -> None:
        # Keep only the user-specified number of samples
        images = concat(self.images)[:self.n_samples or self.n_found]
        _, height, width = get_image_dims(images)
        activations = to_number(concat(self.activations)[:self.n_samples or self.n_found])
        if tf.is_tensor(images):
            activations = np.moveaxis(activations, source=-1, destination=1)  # Activations should be channel first
        args = {}
        labels = None if not self.labels else concat(self.labels)[:self.n_samples or self.n_found]
        if labels is not None:
            if len(labels.shape) > 1:
                labels = argmax(labels, axis=-1)
            if self.label_mapping:
                labels = np.array([self.label_mapping[clazz] for clazz in to_number(squeeze(labels))])
            args[self.true_label_key] = labels
        preds = None if not self.preds else concat(self.preds)[:self.n_samples or self.n_found]
        if preds is not None:
            if len(preds.shape) > 1:
                preds = argmax(preds, axis=-1)
            if self.label_mapping:
                preds = np.array([self.label_mapping[clazz] for clazz in to_number(squeeze(preds))])
            args[self.pred_label_key] = preds
        args[self.image_key] = images
        # Clear memory
        self._reset()
        # Make the image
        components = self._project_2d(activations)
        components = [np.maximum(component, 0) for component in components]
        masks = []
        for component_batch in components:
            img_batch = []
            for img in component_batch:
                img = cv2.resize(img, (height, width))
                img = img - np.min(img)
                img = img / np.max(img)
                img = cv2.cvtColor(cv2.applyColorMap(np.uint8(255 * img), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB)
                img = np.float32(img) / 255
                img_batch.append(img)
            img_batch = np.array(img_batch, dtype=np.float32)
            # Switch to channel first for pytorch
            if isinstance(images, torch.Tensor):
                img_batch = np.moveaxis(img_batch, source=-1, destination=1)
            masks.append(img_batch)

        components = [images + mask for mask in masks]  # This seems to work even if the image is 1 channel instead of 3
        components = [image / reduce_max(image) for image in components]

        for idx, elem in enumerate(components):
            args[f"Component {idx}"] = elem

        result = ImgData(**args)
        data.write_without_log(self.outputs[0], result)