Source code for signxai.torch_signxai.methods.signed

"""Implementation of SIGN thresholding methods for PyTorch."""
import torch
import numpy as np


[docs] def calculate_sign_mu(relevance_map, mu=0.0, vlow=-1, vhigh=1): """Calculate binary sign-based relevance map to match TensorFlow behavior. Args: relevance_map: Relevance map tensor or numpy array mu: Threshold for considering a value positive/negative (default 0.0) vlow: Value for elements below threshold (default -1) vhigh: Value for elements at or above threshold (default 1) Returns: Sign-based relevance map with TensorFlow-compatible behavior """ if isinstance(relevance_map, torch.Tensor): # PyTorch tensor case - match TensorFlow behavior exactly sign_map = torch.full_like(relevance_map, float(vlow)) sign_map[relevance_map >= mu] = float(vhigh) return sign_map else: # Numpy array case - match TensorFlow behavior exactly sign_map = np.full_like(relevance_map, vlow) sign_map[relevance_map >= mu] = vhigh return sign_map