Utilities
The signxai.utils
module provides utility functions used by both the PyTorch and TensorFlow implementations.
General Utilities
- signxai.utils.load_image(path, target_size=(224, 224), preprocess=True)
Loads and preprocesses an image for use with models.
- signxai.utils.normalize_heatmap(heatmap, percentile=99)
Normalizes a heatmap to the specified percentile for visualization.
- Parameters:
heatmap (numpy.ndarray) – The heatmap to normalize
percentile (float, optional) – The percentile value for normalization
- Returns:
Normalized heatmap
- Return type:
- signxai.utils.download_image(path, url=None)
Downloads an example image if it doesn’t exist.
- signxai.utils.download_model(path, url=None)
Downloads an example model if it doesn’t exist.
- signxai.utils.ensure_dir(file_path)
Ensures that a directory exists.
- Parameters:
file_path (str) – Path to check/create
- Returns:
None
- signxai.utils.remove_softmax(model)
Removes the softmax activation from a model.
- Parameters:
model – Model to process
- Returns:
Model with softmax removed
- Raises:
NotImplementedError: If the model framework is not supported
Visualization Utilities
- signxai.utils.plot_relevancemap(relevance_map, ax=None, colorbar_ax=None, colorbar_kw=None, **kwargs)
Plots a relevance map as a heatmap.
- Parameters:
relevance_map (numpy.ndarray) – The relevance map to visualize
ax (matplotlib.axes.Axes, optional) – Matplotlib axes to plot on
colorbar_ax (matplotlib.axes.Axes, optional) – Axes for the colorbar
colorbar_kw (dict, optional) – Additional keyword arguments for the colorbar
kwargs – Additional keyword arguments for imshow
- Returns:
Matplotlib image object
- Return type:
matplotlib.image.AxesImage
- signxai.utils.plot_comparison(original_image, explanations, method_names, figsize=(15, 6), cmap='seismic')
Plots multiple explanation methods side by side for comparison.
- Parameters:
original_image (numpy.ndarray) – The original input image
explanations (list) – List of explanations to compare
method_names (list) – Names of the methods for the explanations
figsize (tuple, optional) – Figure size
cmap (str, optional) – Colormap for the heatmaps
- Returns:
Matplotlib figure
- Return type:
matplotlib.figure.Figure
Data Handling
- signxai.utils.batch_to_numpy(batch)
Converts a batch of tensors to numpy arrays.
- Parameters:
batch (torch.Tensor or tf.Tensor or numpy.ndarray) – Batch of tensors
- Returns:
Batch as numpy array
- Return type:
- signxai.utils.ensure_batch_dimension(x)
Ensures input has a batch dimension.
- Parameters:
x (numpy.ndarray) – Input tensor
- Returns:
Input with batch dimension
- Return type:
Framework-Specific Utilities
TensorFlow Utilities
- signxai.utils.calculate_explanation_innvestigate(model, x, method, **kwargs)
Interface to iNNvestigate for explanation generation.
- Parameters:
model (tf.keras.Model) – TensorFlow model
x (numpy.ndarray) – Input tensor
method (str) – iNNvestigate method name
kwargs – Additional parameters for the method
- Returns:
Explanation
- Return type:
PyTorch Utilities
- signxai.utils.numpy_to_torch(array, requires_grad=True)
Converts a numpy array to a PyTorch tensor.
- Parameters:
array (numpy.ndarray) – Numpy array
requires_grad (bool, optional) – Whether the tensor requires gradients
- Returns:
PyTorch tensor
- Return type:
- signxai.utils.torch_to_numpy(tensor)
Converts a PyTorch tensor to a numpy array.
- Parameters:
tensor (torch.Tensor) – PyTorch tensor
- Returns:
Numpy array
- Return type: