-
Notifications
You must be signed in to change notification settings - Fork 0
Add GradCAM attention heatmap visualization to Gradio interface using pytorch-grad-cam #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
da65c71
9fb382a
403da8d
d3c47c8
616a9cb
f344728
b4cda7f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -182,4 +182,7 @@ EDA/ | |
| model_outputs/ | ||
|
|
||
| # VS CODE | ||
| .vscode/ | ||
| .vscode/ | ||
|
|
||
| # Test images | ||
| test_images/ | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -4,14 +4,19 @@ | |||||
|
|
||||||
| import os | ||||||
| import numpy as np | ||||||
| import cv2 | ||||||
|
|
||||||
| import torch | ||||||
| import torch.nn as nn | ||||||
| import torch.nn.functional as F | ||||||
|
|
||||||
| import gradio as gr | ||||||
| from PIL import Image | ||||||
| import logging | ||||||
|
|
||||||
| from pytorch_grad_cam import GradCAM | ||||||
| from pytorch_grad_cam.utils.image import show_cam_on_image | ||||||
|
|
||||||
| from main import get_transform | ||||||
|
|
||||||
| logging.basicConfig(level=logging.INFO) | ||||||
|
|
@@ -80,49 +85,161 @@ def load_model(model_type: str = "efficientvit") -> nn.Module: | |||||
| return model | ||||||
|
|
||||||
|
|
||||||
| def predict_image(image: np.ndarray, model_type: str) -> dict: | ||||||
| def get_target_layers(model, model_type): | ||||||
| """ | ||||||
| Get the target layers for GradCAM based on model type. | ||||||
|
|
||||||
| Args: | ||||||
| model: The model | ||||||
| model_type: Type of model | ||||||
|
|
||||||
| Returns: | ||||||
| target_layers: List of layers to use for GradCAM | ||||||
| """ | ||||||
| Predict eye disease from an uploaded image. | ||||||
| try: | ||||||
| if model_type == "mobilenetv4": | ||||||
| # For MobileNetV4, use the last convolutional layer in features | ||||||
| return [model.features[-1]] | ||||||
| elif model_type == "levit": | ||||||
| # For LeViT (transformer), use the last block | ||||||
| return [model.blocks[-1]] | ||||||
| elif model_type == "efficientvit": | ||||||
| # For EfficientViT, use the last stage | ||||||
| return [model.stages[-1]] | ||||||
| elif model_type == "gernet": | ||||||
| # For GENet, use the last stage | ||||||
| return [model.stages[-1]] | ||||||
| elif model_type == "regnetx": | ||||||
| # For RegNetX, use the last trunk layer | ||||||
| return [model.trunk[-1]] | ||||||
| else: | ||||||
| # Default: try to get the last feature layer | ||||||
| if hasattr(model, 'features'): | ||||||
| return [model.features[-1]] | ||||||
| elif hasattr(model, 'stages'): | ||||||
| return [model.stages[-1]] | ||||||
| elif hasattr(model, 'blocks'): | ||||||
| return [model.blocks[-1]] | ||||||
| else: | ||||||
| raise ValueError(f"Cannot determine target layer for model type: {model_type}") | ||||||
| except Exception as e: | ||||||
| logging.warning(f"Error getting target layer: {e}. Using fallback.") | ||||||
| # Fallback: try to get any reasonable last conv layer | ||||||
| for module in reversed(list(model.modules())): | ||||||
| if isinstance(module, nn.Conv2d): | ||||||
| return [module] | ||||||
| raise ValueError("Could not find suitable target layer for GradCAM") | ||||||
|
|
||||||
|
|
||||||
| def apply_heatmap_on_image(img, cam, alpha=0.4): | ||||||
| """ | ||||||
| Apply CAM heatmap overlay on the original image. | ||||||
|
|
||||||
| Args: | ||||||
| img: Original image (PIL Image or numpy array) | ||||||
| cam: Class activation map (grayscale, values 0-1) | ||||||
| alpha: Overlay transparency (not used with show_cam_on_image, kept for compatibility) | ||||||
|
|
||||||
| Returns: | ||||||
| Heatmap overlay image as numpy array | ||||||
| """ | ||||||
| # Convert PIL to numpy if needed | ||||||
| if isinstance(img, Image.Image): | ||||||
| img = np.array(img) | ||||||
|
|
||||||
| # Normalize image to 0-1 range for show_cam_on_image | ||||||
| img_float = img.astype(np.float32) / 255.0 | ||||||
|
|
||||||
| # Resize CAM to match image size | ||||||
| h, w = img.shape[:2] | ||||||
| cam_resized = cv2.resize(cam, (w, h)) | ||||||
|
|
||||||
| # Use pytorch_grad_cam utility to overlay | ||||||
| # This function expects img in 0-1 range and cam in 0-1 range | ||||||
| overlay = show_cam_on_image(img_float, cam_resized, use_rgb=True) | ||||||
|
|
||||||
| return overlay | ||||||
|
|
||||||
|
|
||||||
| def predict_image(image: np.ndarray, model_type: str) -> tuple[dict, np.ndarray]: | ||||||
| """ | ||||||
| Predict eye disease from an uploaded image and generate attention heatmap. | ||||||
|
|
||||||
| Args: | ||||||
| image: Input image from Gradio | ||||||
| model_path: Path to the model state dict | ||||||
| model_type: Type of model architecture | ||||||
|
|
||||||
| Returns: | ||||||
| Dictionary of class probabilities | ||||||
| Tuple of (Dictionary of class probabilities, Heatmap overlay image) | ||||||
| """ | ||||||
| try: | ||||||
|
|
||||||
| logging.info("Starting prediction...") | ||||||
|
|
||||||
| # Handle None image | ||||||
| if image is None: | ||||||
| logging.warning("No image provided.") | ||||||
| return {cls: 0.0 for cls in CLASSES}, None | ||||||
|
|
||||||
| # Load model | ||||||
| model = load_model(model_type) | ||||||
| model.to(device) | ||||||
|
|
||||||
| # Preprocess image | ||||||
| logging.info("Preprocessing image...") | ||||||
| if image is None: | ||||||
| logging.warning("No image provided.") | ||||||
| return {cls: 0.0 for cls in CLASSES} | ||||||
| transform = get_transform() | ||||||
| if image is None: | ||||||
| return {cls: 0.0 for cls in CLASSES} | ||||||
|
|
||||||
| # Convert numpy array to PIL Image | ||||||
| img = Image.fromarray(image).convert("RGB") | ||||||
| img_tensor = transform(img).unsqueeze(0).to(device) | ||||||
| # Convert numpy array to PIL Image and keep original for heatmap | ||||||
| img_pil = Image.fromarray(image).convert("RGB") | ||||||
| img_tensor = transform(img_pil).unsqueeze(0).to(device) | ||||||
| logging.info("Image preprocessed successfully.") | ||||||
|
|
||||||
| # Make prediction | ||||||
| with torch.no_grad(): | ||||||
| outputs = model(img_tensor) | ||||||
| probabilities = torch.nn.functional.softmax(outputs, dim=1)[0].cpu().numpy() | ||||||
|
|
||||||
| # Return probabilities for each class | ||||||
| return {cls: float(prob) for cls, prob in zip(CLASSES, probabilities)} | ||||||
| # Get target layers for GradCAM | ||||||
| try: | ||||||
| target_layers = get_target_layers(model, model_type) | ||||||
| logging.info(f"Using target layers: {target_layers}") | ||||||
|
|
||||||
| # Initialize GradCAM from pytorch_grad_cam library | ||||||
| cam_extractor = GradCAM(model=model, target_layers=target_layers) | ||||||
|
|
||||||
| # Generate CAM - the library handles forward and backward passes | ||||||
| grayscale_cam = cam_extractor(input_tensor=img_tensor, targets=None) | ||||||
|
|
||||||
| # Get the CAM for the first image in batch | ||||||
| cam = grayscale_cam[0, :] | ||||||
|
|
||||||
| # Get model prediction | ||||||
| with torch.no_grad(): | ||||||
| outputs = model(img_tensor) | ||||||
|
|
||||||
| # Generate heatmap overlay | ||||||
| heatmap_overlay = apply_heatmap_on_image(img_pil, cam) | ||||||
|
|
||||||
| # Clean up | ||||||
| del cam_extractor | ||||||
|
||||||
| del cam_extractor | |
Copilot
AI
Oct 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The traceback module is imported inside the exception handler. This import should be moved to the top of the file with other imports for better code organization and to avoid repeated imports during runtime.
Copilot
AI
Oct 23, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The traceback module is imported inside the exception handler. This import should be moved to the top of the file with other imports for better code organization and to avoid repeated imports during runtime.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moving the model to device on every prediction call is inefficient. The model should be moved to device once during loading in the
load_modelfunction, not repeatedly inpredict_image.