Skip to content
/ FLEX Public

Knowledge-Guided Adaptation of Pathology Foundation Models Improves Cross-domain Generalization and Demographic Fairness

License

Notifications You must be signed in to change notification settings

HKU-MedAI/FLEX

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

10 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

header

Overview

The advent of foundation models has ushered in a transformative era in computational pathology, enabling the extraction of rich, transferable image features for a broad range of downstream pathology tasks. However, site-specific signatures and demographic biases persist in these features, leading to short-cut learning and unfair predictions, ultimately compromising model generalizability and fairness across diverse clinical sites and demographic groups.

This repository implements FLEX, a novel framework that enhances cross-domain generalization and demographic fairness of pathology foundation models, thus facilitating accurate diagnosis across diverse pathology tasks. FLEX employs a task-specific information bottleneck, informed by visual and textual domain knowledge, to promote:

  • Generalizability across clinical settings
  • Fairness across demographic groups
  • Adaptability to specific pathology tasks

FLEX Framework

Features

  • Cross-domain generalization: Significantly improves diagnostic performance on data from unseen sites
  • Demographic fairness: Reduces performance gaps between demographic groups
  • Versatility: Compatible with various vision-language models
  • Scalability: Adaptable to varying training data sizes
  • Seamless integration: Works with multiple instance learning frameworks

Installation

Setup

  1. Clone the repository:

    git clone https://github.com/HKU-MedAI/FLEX
    cd FLEX
  2. Create and activate a virtual environment, and install the dependencies:

    conda env create -f environment.yml
    conda activate flex

Install Time

  • Fresh installation with conda environment creation: ~10-15 minutes on a normal desktop computer
  • Installation of additional dependencies: ~5 minutes
  • Download of pretrained models (optional): ~10-20 minutes depending on internet speed

Instructions for Use

Data Preparation

Prepare your data in the following structure:

Dataset/
├── TCGA-BRCA/
│   ├── features/
│   │   ├── ...
│   ├── tcga-brca_label.csv
│   ├── tcga-brca_label_her2.csv
│   └── ...
├── TCGA-NSCLC/
└── ...

Visual Prompts

Organize visual prompts in the following structure:

prompts/
├── BRCA/
│   ├── 0/
│   │   ├── image1.png
│   │   └── ...
│   └── 1/
│       ├── image1.png
│       └── ...
├── BRCA_HER2/
└── ...

Running on Your Data

  1. Generate site-preserved Monte Carlo Cross-Validation (SP-MCCV) splits for your dataset:

    python generate_sitepreserved_splits.py
    python generate_sp_mccv_splits.py
  2. Extract features (if not already done) using CLAM or TRIDENT.

  3. Train the FLEX model and evaluate the performance:

    bash ./scripts/train_flex.sh

Reproduction Instructions

To reproduce the results in our paper:

  1. Download the datasets mentioned in the paper (TCGA-BRCA, TCGA-NSCLC, TCGA-STAD, TCGA-CRC)

  2. Extract features using CLAM or TRIDENT.

  3. Run the following commands:

    # Generate splits
    python generate_sitepreserved_splits.py
    python generate_sp_mccv_splits.py
    
    # Train model
    bash ./scripts/train_flex.sh
    
    # For each task, modify the task parameter in train_flex.sh and run the script
    bash ./scripts/train_flex.sh
  4. For specific tasks or customizations, refer to the key parameters section below.

Key Parameters

  • --task: Task name (e.g., BRCA, NSCLC, STAD_LAUREN)
  • --data_root_dir: Path to the data directory
  • --split_suffix: Split suffix (e.g., sitepre5_fold3)
  • --exp_code: Experiment code for logging and saving results
  • --model_type: Model type (default: flex)
  • --base_mil: Base MIL framework (default:abmil)
  • --slide_align: Whether to align in slide level (default: 1)
  • --w_infonce: Weight for InfoNCE loss (default: 14)
  • --w_kl: Weight for KL loss (default: 14)
  • --len_prompt: Number of learnable textual prompt tokens

Evaluation Results

FLEX has been evaluated on 16 clinically relevant tasks and demonstrates:

  • Improved performance on unseen clinical sites
  • Reduced performance gap between seen and unseen sites
  • Enhanced fairness across demographic groups

For detailed results, refer to our paper.

License

This project is licensed under the Apache-2.0 license.

Acknowledgments

This project was built on the top of amazing works, including CLAM, CONCH, QuiltNet, PathGen-CLIP, and PreservedSiteCV. We thank the authors for their great works.

About

Knowledge-Guided Adaptation of Pathology Foundation Models Improves Cross-domain Generalization and Demographic Fairness

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published