Source code for signxai.tf_signxai

import numpy as np

# Attempt to import common validation, handle if not found for robustness
try:
    from ..common.validation import validate_input
    # If validate_input is successfully imported, you might want to log or know.
    # print("Successfully imported 'validate_input' from signxai.common.validation")
except ImportError:
    # This block executes if 'validate_input' cannot be imported from signxai.common.validation
    # Define a dummy validate_input if the real one isn't found or needed immediately
    def validate_input(*args, **kwargs):  # pragma: no cover
        """Dummy validate_input function. Does nothing."""
        # You can add a print here to confirm the dummy is being used:
        # print("Warning: 'validate_input' not found in signxai.common.validation. Using dummy placeholder.")
        pass
    # print("Defined dummy 'validate_input' due to ImportError.")

# Define supported methods (ensure this list is comprehensive for your package)
SUPPORTED_METHODS = [
    "gradient", "smoothgrad", "integrated_gradients", "guided_backprop",
    "lrp.epsilon", "lrp.alpha_beta", "lrp.alpha_2_beta_1", "lrp.z", "lrp.w_square", "lrp.flat",  # Common LRP rules
    "deep_taylor", "input_t_gradient", "deconvnet",
    "occlusion", "grad_cam"
]


[docs] def calculate_relevancemap(method: str, x: np.ndarray, model, # This is the Keras model neuron_selection: int, **kwargs): """ Calculates the relevance map for a given input and model using the specified TensorFlow-based method. Args: method (str): The XAI method to use (e.g., "gradient", "lrp.epsilon"). x (np.ndarray): The input data (e.g., image) as a NumPy array. model: The TensorFlow/Keras model (must be a model without softmax for many methods). neuron_selection (int): The index of the output neuron for which to generate the explanation. **kwargs: Additional arguments specific to the chosen XAI method. Returns: np.ndarray: The calculated relevance map. Raises: ValueError: If the method is not supported or if inputs are invalid. """ # If you intend to use validate_input, ensure it's called appropriately. # For now, its call is commented out as its definition might be a dummy. # validate_input(x, model) # Check if method string is recognized, but allow to proceed if not strictly in list for flexibility if not isinstance(method, str): # pragma: no cover raise ValueError("Method argument must be a string.") # if method not in SUPPORTED_METHODS: # print(f"Warning: Method '{method}' not in explicitly defined SUPPORTED_METHODS. Attempting to proceed.") if not isinstance(x, np.ndarray): # pragma: no cover raise ValueError("Input x must be a NumPy array.") try: from .methods import wrappers as tf_method_wrappers except ImportError: # pragma: no cover tf_method_wrappers = None print( "Warning: Could not import signxai.tf_signxai.methods.wrappers. TF-specific wrapped methods may not be available.") relevancemap = None # Method dispatch using specific wrappers if available specific_wrapper_used = False if tf_method_wrappers: if method == 'gradient': # This wrapper likely expects model_no_softmax directly, as it was working relevancemap = tf_method_wrappers.gradient(x=x, model_no_softmax=model, neuron_selection=neuron_selection, **kwargs) specific_wrapper_used = True elif method == 'smoothgrad': # Corrected: pass 'model' as 'model_no_softmax' relevancemap = tf_method_wrappers.smoothgrad(x=x, model_no_softmax=model, neuron_selection=neuron_selection, **kwargs) specific_wrapper_used = True elif method == 'integrated_gradients': # Corrected: pass 'model' as 'model_no_softmax' relevancemap = tf_method_wrappers.integrated_gradients(x=x, model_no_softmax=model, neuron_selection=neuron_selection, **kwargs) specific_wrapper_used = True elif method == 'guided_backprop': # Corrected: pass 'model' as 'model_no_softmax' relevancemap = tf_method_wrappers.guided_backprop(x=x, model_no_softmax=model, neuron_selection=neuron_selection, **kwargs) specific_wrapper_used = True elif method == 'grad_cam': # Corrected: pass 'model' as 'model_no_softmax' relevancemap = tf_method_wrappers.grad_cam(x=x, model_no_softmax=model, neuron_selection=neuron_selection, **kwargs) specific_wrapper_used = True elif method == 'occlusion': # Corrected: pass 'model' as 'model_no_softmax' (assuming it follows the same pattern) relevancemap = tf_method_wrappers.occlusion(x=x, model_no_softmax=model, neuron_selection=neuron_selection, **kwargs) specific_wrapper_used = True # Add other specific tf_method_wrappers calls here if you have them, ensuring to use model_no_softmax=model # If not handled by a specific TF wrapper, or if wrappers failed to import, # try the generic iNNvestigate handler for methods it supports. if not specific_wrapper_used: # This list can be cross-referenced with iNNvestigate's own supported method strings # Note: 'gradient', 'smoothgrad', 'integrated_gradients', 'guided_backprop' might also be routed here # if their specific wrappers above are commented out or tf_method_wrappers is None. # The current logic prioritizes specific wrappers if tf_method_wrappers exists. innvestigate_methods = [ "gradient", "smoothgrad", "integrated_gradients", "guided_backprop", "lrp.epsilon", "lrp.alpha_beta", "lrp.alpha_2_beta_1", "lrp.z", "lrp.w_square", "lrp.flat", "deep_taylor", "input_t_gradient", "deconvnet" ] # Occlusion and Grad-CAM are typically custom wrappers not directly in iNNvestigate's generic analyzer call. if method in innvestigate_methods or method.startswith('lrp'): try: from ..utils.utils import calculate_explanation_innvestigate # The 'model' passed to calculate_explanation_innvestigate is expected to be # the model without softmax by that utility's internal logic for iNNvestigate. relevancemap = calculate_explanation_innvestigate(method=method, x=x, model=model, neuron_selection=neuron_selection, **kwargs) except ImportError: # pragma: no cover raise ImportError( "Failed to import 'calculate_explanation_innvestigate' from signxai.utils.utils. Cannot proceed with iNNvestigate-based method.") except Exception as e: # pragma: no cover raise ValueError(f"Method '{method}' failed in generic iNNvestigate handler. Error: {e}") else: # pragma: no cover # If method was not handled by specific wrappers AND not in innvestigate_methods list. if method not in SUPPORTED_METHODS: # Check against the package's declared supported methods. raise ValueError( f"Unsupported method: {method}. Supported methods are: {SUPPORTED_METHODS} or check iNNvestigate specific methods if applicable.") else: # Method in SUPPORTED_METHODS but no handler was found/triggered. # This case implies a logic error in the dispatch (e.g. tf_method_wrappers was None but method expected it) raise ValueError( f"Method '{method}' is listed as supported but could not be processed. Check tf_method_wrappers import and dispatch logic for this method.") if relevancemap is None: # pragma: no cover # This should ideally not be reached if the dispatch logic is complete and error handling within methods is robust. raise ValueError(f"Method '{method}' could not be processed and did not produce a result. Check specific handlers.") return relevancemap