Unified API
The SignXAI unified API provides a framework-agnostic interface for explainable AI methods.
Overview
The unified API automatically detects whether you’re using PyTorch or TensorFlow models and routes your requests to the appropriate backend implementation.
Main Interface
Main Function
explain(model, input_data, method, **kwargs)
Get explanations for a model’s predictions using any supported XAI method.
- param model:
The neural network model (PyTorch or TensorFlow)
- param input_data:
Input data for which to generate explanations
- param method:
The explanation method to use
- param kwargs:
Additional method-specific parameters
- return:
Explanation array with same shape as input
Framework Detection
get_framework(model)
Automatically detect the framework of a given model.
- param model:
The model to check
- return:
‘pytorch’ or ‘tensorflow’
Common Parameters
All methods support these common parameters:
neuron_selection: Target neuron/class for explanation
batchsize: Batch size for processing (PyTorch only)
postprocess: Post-processing function to apply
Method-Specific Parameters
Different methods accept additional parameters:
Gradient-based methods:
postprocess: ‘abs’, ‘square’, or custom function
mu: SIGN threshold parameter (for SIGN variants)
LRP methods:
epsilon: Stabilization parameter for LRP-ε
alpha/beta: Parameters for LRP-α/β rule
layer_rule: Custom rules for specific layers
Integrated Gradients:
reference: Reference/baseline input
steps: Number of integration steps
SmoothGrad:
noise: Noise level for sampling
samples: Number of samples
Grad-CAM:
last_conv: Name/index of last convolutional layer
Usage Examples
Basic usage with automatic framework detection:
from signxai import explain, list_methods
# List all available methods
available_methods = list_methods()
print(f"Available methods: {available_methods}")
# TensorFlow model
explanation_tf = explain(
tf_model,
input_data,
method='gradient'
)
# PyTorch model
explanation_pt = explain(
torch_model,
input_data,
method='gradient'
)
Using method-specific parameters:
# LRP with epsilon
explanation = explain(
model,
input_data,
method='lrp_epsilon',
epsilon=0.01
)
# Input × Gradient method
explanation = explain(
model,
input_data,
method='input_t_gradient'
)
# SIGN variant with custom mu
explanation = explain(
model,
input_data,
method='gradient_x_sign_mu',
mu=0.5
)
Error Handling
The unified API provides consistent error messages across frameworks:
UnsupportedMethodError: Method not available for the detected framework
InvalidParameterError: Invalid parameter for the chosen method
FrameworkDetectionError: Unable to determine model framework