Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions .github/workflows/check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@ name: Check file size
on:
pull_request:
branches: [main]

# to run this workflow manually from the Actions tab
workflow_dispatch:

permissions:
contents: read
pull-requests: write

jobs:
sync-to-hub:
runs-on: ubuntu-latest
steps:
- name: Check large files
uses: ActionsDesk/lfs-warning@v2.0
with:
filesizelimit: 10485760 # this is 10MB
filesizelimit: 10485760 # this is 10MB
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -182,4 +182,7 @@ EDA/
model_outputs/

# VS CODE
.vscode/
.vscode/

# Test images
test_images/
17 changes: 16 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ This repository contains a Gradio web application for eye disease detection usin
- Support for **multiple model architectures** (MobileNetV4, LeViT, EfficientViT, GENet, RegNetX)
- **Custom model loading** from saved model checkpoints
- **Visualization** of prediction probabilities
- **Attention heatmap visualization** using GradCAM to show which regions the model focuses on
- **Dockerized deployment** option

## Supported Eye Conditions
Expand Down Expand Up @@ -85,7 +86,21 @@ The system can detect the following eye conditions:
2. (Optional) Specify the path to your trained model file (.pth)
3. Select the model architecture (MobileNetV4, LeViT, EfficientViT, GENet, RegNetX)
4. Click "Analyze Image" to get the prediction
5. View the results and probability distribution
5. View the results including:
- Probability distribution across all disease classes
- Attention heatmap showing which regions the model focused on for its prediction

### Understanding the Attention Heatmap

The attention heatmap is generated using GradCAM (Gradient-weighted Class Activation Mapping), which visualizes the regions of the fundus image that the model considers most important for making its prediction:

- **Red/Yellow areas**: Regions the model focuses on most strongly
- **Blue/Green areas**: Regions with less influence on the prediction

This visualization helps in:
- Understanding the model's decision-making process
- Validating that the model is looking at clinically relevant features
- Building trust in the AI's predictions by making them interpretable

## Model Training

Expand Down
174 changes: 150 additions & 24 deletions gradio-inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@

import os
import numpy as np
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F

import gradio as gr
from PIL import Image
import logging

from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image

from main import get_transform

logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -80,49 +85,161 @@ def load_model(model_type: str = "efficientvit") -> nn.Module:
return model


def predict_image(image: np.ndarray, model_type: str) -> dict:
def get_target_layers(model, model_type):
"""
Get the target layers for GradCAM based on model type.

Args:
model: The model
model_type: Type of model

Returns:
target_layers: List of layers to use for GradCAM
"""
Predict eye disease from an uploaded image.
try:
if model_type == "mobilenetv4":
# For MobileNetV4, use the last convolutional layer in features
return [model.features[-1]]
elif model_type == "levit":
# For LeViT (transformer), use the last block
return [model.blocks[-1]]
elif model_type == "efficientvit":
# For EfficientViT, use the last stage
return [model.stages[-1]]
elif model_type == "gernet":
# For GENet, use the last stage
return [model.stages[-1]]
elif model_type == "regnetx":
# For RegNetX, use the last trunk layer
return [model.trunk[-1]]
else:
# Default: try to get the last feature layer
if hasattr(model, 'features'):
return [model.features[-1]]
elif hasattr(model, 'stages'):
return [model.stages[-1]]
elif hasattr(model, 'blocks'):
return [model.blocks[-1]]
else:
raise ValueError(f"Cannot determine target layer for model type: {model_type}")
except Exception as e:
logging.warning(f"Error getting target layer: {e}. Using fallback.")
# Fallback: try to get any reasonable last conv layer
for module in reversed(list(model.modules())):
if isinstance(module, nn.Conv2d):
return [module]
raise ValueError("Could not find suitable target layer for GradCAM")


def apply_heatmap_on_image(img, cam, alpha=0.4):
"""
Apply CAM heatmap overlay on the original image.

Args:
img: Original image (PIL Image or numpy array)
cam: Class activation map (grayscale, values 0-1)
alpha: Overlay transparency (not used with show_cam_on_image, kept for compatibility)

Returns:
Heatmap overlay image as numpy array
"""
# Convert PIL to numpy if needed
if isinstance(img, Image.Image):
img = np.array(img)

# Normalize image to 0-1 range for show_cam_on_image
img_float = img.astype(np.float32) / 255.0

# Resize CAM to match image size
h, w = img.shape[:2]
cam_resized = cv2.resize(cam, (w, h))

# Use pytorch_grad_cam utility to overlay
# This function expects img in 0-1 range and cam in 0-1 range
overlay = show_cam_on_image(img_float, cam_resized, use_rgb=True)

return overlay


def predict_image(image: np.ndarray, model_type: str) -> tuple[dict, np.ndarray]:
"""
Predict eye disease from an uploaded image and generate attention heatmap.

Args:
image: Input image from Gradio
model_path: Path to the model state dict
model_type: Type of model architecture

Returns:
Dictionary of class probabilities
Tuple of (Dictionary of class probabilities, Heatmap overlay image)
"""
try:

logging.info("Starting prediction...")

# Handle None image
if image is None:
logging.warning("No image provided.")
return {cls: 0.0 for cls in CLASSES}, None

# Load model
model = load_model(model_type)
model.to(device)

Comment on lines +185 to 186
Copy link

Copilot AI Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving the model to device on every prediction call is inefficient. The model should be moved to device once during loading in the load_model function, not repeatedly in predict_image.

Suggested change
model.to(device)

Copilot uses AI. Check for mistakes.
# Preprocess image
logging.info("Preprocessing image...")
if image is None:
logging.warning("No image provided.")
return {cls: 0.0 for cls in CLASSES}
transform = get_transform()
if image is None:
return {cls: 0.0 for cls in CLASSES}

# Convert numpy array to PIL Image
img = Image.fromarray(image).convert("RGB")
img_tensor = transform(img).unsqueeze(0).to(device)
# Convert numpy array to PIL Image and keep original for heatmap
img_pil = Image.fromarray(image).convert("RGB")
img_tensor = transform(img_pil).unsqueeze(0).to(device)
logging.info("Image preprocessed successfully.")

# Make prediction
with torch.no_grad():
outputs = model(img_tensor)
probabilities = torch.nn.functional.softmax(outputs, dim=1)[0].cpu().numpy()

# Return probabilities for each class
return {cls: float(prob) for cls, prob in zip(CLASSES, probabilities)}
# Get target layers for GradCAM
try:
target_layers = get_target_layers(model, model_type)
logging.info(f"Using target layers: {target_layers}")

# Initialize GradCAM from pytorch_grad_cam library
cam_extractor = GradCAM(model=model, target_layers=target_layers)

# Generate CAM - the library handles forward and backward passes
grayscale_cam = cam_extractor(input_tensor=img_tensor, targets=None)

# Get the CAM for the first image in batch
cam = grayscale_cam[0, :]

# Get model prediction
with torch.no_grad():
outputs = model(img_tensor)

# Generate heatmap overlay
heatmap_overlay = apply_heatmap_on_image(img_pil, cam)

# Clean up
del cam_extractor
Copy link

Copilot AI Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Explicitly deleting cam_extractor is unnecessary in Python as it will be automatically garbage collected when it goes out of scope. This deletion doesn't provide significant benefit and adds clutter.

Suggested change
del cam_extractor

Copilot uses AI. Check for mistakes.

except Exception as e:
logging.error(f"Error generating heatmap: {e}")
import traceback
traceback.print_exc()
Comment on lines +222 to +223
Copy link

Copilot AI Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The traceback module is imported inside the exception handler. This import should be moved to the top of the file with other imports for better code organization and to avoid repeated imports during runtime.

Copilot uses AI. Check for mistakes.
# Fallback: just do prediction without heatmap
with torch.no_grad():
outputs = model(img_tensor)
heatmap_overlay = np.array(img_pil) # Return original image

# Get probabilities
probabilities = F.softmax(outputs, dim=1)[0].cpu().detach().numpy()

# Return probabilities and heatmap
result_dict = {cls: float(prob) for cls, prob in zip(CLASSES, probabilities)}

logging.info("Prediction completed successfully.")
return result_dict, heatmap_overlay

except Exception as e:
logging.error(f"Error during prediction: {e}")
return {cls: 0.0 for cls in CLASSES}
import traceback
traceback.print_exc()
Comment on lines +240 to +241
Copy link

Copilot AI Oct 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The traceback module is imported inside the exception handler. This import should be moved to the top of the file with other imports for better code organization and to avoid repeated imports during runtime.

Copilot uses AI. Check for mistakes.
return {cls: 0.0 for cls in CLASSES}, None


def main():
Expand Down Expand Up @@ -158,20 +275,21 @@ def main():

with gr.Column():
output_chart = gr.Label(label="Prediction")
output_heatmap = gr.Image(label="Attention Heatmap")

# Process the image when the button is clicked
submit_btn.click(
fn=predict_image,
inputs=[input_image, model_type],
outputs=output_chart,
outputs=[output_chart, output_heatmap],
)

# Examples section
gr.Markdown("### Examples (Please add your own example images)")
gr.Examples(
examples=[], # Add example paths here
inputs=input_image,
outputs=[output_chart],
outputs=[output_chart, output_heatmap],
fn=predict_image,
cache_examples=True,
)
Expand All @@ -187,7 +305,15 @@ def main():
- Enter the path to your trained model file (.pth)
- Select the model architecture that was used for training
3. **Analyze**: Click the "Analyze Image" button to get results
4. **Interpret results**: The system will show the detected condition and probability distribution
4. **Interpret results**: The system will show the detected condition, probability distribution, and an attention heatmap

## Attention Heatmap:

The attention heatmap visualizes which regions of the fundus image the model is focusing on when making its prediction.
- **Red/Yellow areas**: Regions the model considers most important for the diagnosis
- **Blue/Green areas**: Regions with less influence on the prediction

This helps in understanding and validating the model's decision-making process.

## Model Information:

Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@ readme = "README.md"
requires-python = ">=3.12.9"
dependencies = [
"gradio>=5.29.0",
"grad-cam>=1.5.0",
"matplotlib>=3.10.3",
"opencv-python>=4.8.0",
"pandas>=2.2.3",
"pillow>=10.0.0",
"scikit-learn>=1.6.1",
"seaborn>=0.13.2",
"timm>=1.0.15",
Expand Down
Loading