Source code for signxai.torch_signxai.utils

# signxai/torch_signxai/utils.py

import torch
import torch.nn as nn
import json
import os
from typing import Union, Type, List, Optional, Dict, Any, Tuple


# 1. Corrected top-level remove_softmax
[docs] def remove_softmax(model: nn.Module) -> nn.Module: """ Removes the softmax layer from a PyTorch model if it's the last one in model.classifier or a common sequential structure. This version modifies the model in-place if a Softmax layer is found. For models like the custom VGG16_PyTorch that don't have an explicit Softmax layer but output logits, this function will effectively be a no-op regarding layer removal, which is correct. """ classifier_attr_names = ['classifier', 'fc', 'output_layer'] modified = False for attr_name in classifier_attr_names: if hasattr(model, attr_name): classifier_block = getattr(model, attr_name) if isinstance(classifier_block, nn.Sequential): if len(classifier_block) > 0 and isinstance(classifier_block[-1], (nn.Softmax, nn.LogSoftmax)): new_classifier_block = nn.Sequential(*list(classifier_block.children())[:-1]) setattr(model, attr_name, new_classifier_block) modified = True break if not modified: if isinstance(model, nn.Sequential) and len(model) > 0 and isinstance(model[-1], (nn.Softmax, nn.LogSoftmax)): pass # In-place modification of top-level Sequential is more complex and usually handled by reassignment return model
# 2. NoSoftmaxWrapper class
[docs] class NoSoftmaxWrapper(nn.Module): """Wrapper class that removes softmax from a PyTorch model.""" def __init__(self, model: nn.Module): super().__init__() self.model = model self.model.eval()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: output = self.model(x) if isinstance(output, tuple): return output[0] return output
# 3. Corrected top-level decode_predictions def decode_predictions(preds: torch.Tensor, top: int = 5, class_list_path: Optional[str] = None ) -> List[List[Tuple[str, str, float]]]: """Decodes the prediction of an ImageNet model.""" labels_loaded = False labels: Dict[str, Tuple[str, str]] = {} if class_list_path is None: try: current_script_dir = os.path.dirname(os.path.abspath(__file__)) default_path = os.path.join(current_script_dir, "..", "..", "data", "imagenet_class_index.json") if os.path.exists(default_path): class_list_path = default_path else: try: import signxai package_root = os.path.dirname(signxai.__file__) class_list_path = os.path.join(package_root, "data", "imagenet_class_index.json") except ImportError: pass if not class_list_path or not os.path.exists(class_list_path): raise FileNotFoundError(f"Default class index not found. Checked: {default_path}" + (f" and package path." if 'package_root' in locals() else ".")) except Exception as e: print(f"Warning: Could not determine default ImageNet class index path. {e}") if class_list_path and os.path.exists(class_list_path): try: with open(class_list_path) as f: idx2label_data = json.load(f) labels = {idx_str: (data_tuple[0], data_tuple[1]) for idx_str, data_tuple in idx2label_data.items()} labels_loaded = True except Exception as e: print(f"Error loading ImageNet class index from {class_list_path}: {e}.") if not labels_loaded: print(f"Warning: ImageNet class index not loaded. Predictions will be generic.") num_classes = preds.shape[-1] labels = {str(i): (f"idx_{i}", f"Class_{i}") for i in range(num_classes)} if preds.ndim == 1: preds = preds.unsqueeze(0) sum_preds = preds.sum(dim=-1) is_logits = not ( (preds.min() >= 0.0) and (preds.max() <= 1.0) and torch.all(torch.isclose(sum_preds, torch.ones_like(sum_preds))) ) if is_logits: preds = torch.softmax(preds, dim=-1) top_probs, top_idxs = torch.topk(preds, top, dim=-1) final_output: List[List[Tuple[str, str, float]]] = [] for i in range(preds.shape[0]): sample_predictions: List[Tuple[str, str, float]] = [] for j in range(top): class_idx_tensor = top_idxs[i, j] probability_tensor = top_probs[i, j] class_idx_str = str(class_idx_tensor.item()) probability_float = probability_tensor.item() imagenet_id_str, class_name_str = labels.get(class_idx_str, (class_idx_str, "UnknownClassName")) sample_predictions.append((imagenet_id_str, class_name_str, probability_float)) final_output.append(sample_predictions) return final_output