The advent of foundation models has ushered in a transformative era in computational pathology, enabling the extraction of rich, transferable image features for a broad range of downstream pathology tasks. However, site-specific signatures and demographic biases persist in these features, leading to short-cut learning and unfair predictions, ultimately compromising model generalizability and fairness across diverse clinical sites and demographic groups.
This repository implements FLEX, a novel framework that enhances cross-domain generalization and demographic fairness of pathology foundation models, thus facilitating accurate diagnosis across diverse pathology tasks. FLEX employs a task-specific information bottleneck, informed by visual and textual domain knowledge, to promote:
- Generalizability across clinical settings
- Fairness across demographic groups
- Adaptability to specific pathology tasks
- Cross-domain generalization: Significantly improves diagnostic performance on data from unseen sites
- Demographic fairness: Reduces performance gaps between demographic groups
- Versatility: Compatible with various vision-language models
- Scalability: Adaptable to varying training data sizes
- Seamless integration: Works with multiple instance learning frameworks
-
Clone the repository:
git clone https://github.com/HKU-MedAI/FLEX cd FLEX
-
Create and activate a virtual environment, and install the dependencies:
conda env create -f environment.yml conda activate flex
- Fresh installation with conda environment creation: ~10-15 minutes on a normal desktop computer
- Installation of additional dependencies: ~5 minutes
- Download of pretrained models (optional): ~10-20 minutes depending on internet speed
Prepare your data in the following structure:
Dataset/
├── TCGA-BRCA/
│ ├── features/
│ │ ├── ...
│ ├── tcga-brca_label.csv
│ ├── tcga-brca_label_her2.csv
│ └── ...
├── TCGA-NSCLC/
└── ...
Organize visual prompts in the following structure:
prompts/
├── BRCA/
│ ├── 0/
│ │ ├── image1.png
│ │ └── ...
│ └── 1/
│ ├── image1.png
│ └── ...
├── BRCA_HER2/
└── ...
-
Generate site-preserved Monte Carlo Cross-Validation (SP-MCCV) splits for your dataset:
python generate_sitepreserved_splits.py python generate_sp_mccv_splits.py
-
Extract features (if not already done) using CLAM or TRIDENT.
-
Train the FLEX model and evaluate the performance:
bash ./scripts/train_flex.sh
To reproduce the results in our paper:
-
Download the datasets mentioned in the paper (TCGA-BRCA, TCGA-NSCLC, TCGA-STAD, TCGA-CRC)
-
Run the following commands:
# Generate splits python generate_sitepreserved_splits.py python generate_sp_mccv_splits.py # Train model bash ./scripts/train_flex.sh # For each task, modify the task parameter in train_flex.sh and run the script bash ./scripts/train_flex.sh
-
For specific tasks or customizations, refer to the key parameters section below.
--task
: Task name (e.g., BRCA, NSCLC, STAD_LAUREN)--data_root_dir
: Path to the data directory--split_suffix
: Split suffix (e.g., sitepre5_fold3)--exp_code
: Experiment code for logging and saving results--model_type
: Model type (default: flex)--base_mil
: Base MIL framework (default:abmil)--slide_align
: Whether to align in slide level (default: 1)--w_infonce
: Weight for InfoNCE loss (default: 14)--w_kl
: Weight for KL loss (default: 14)--len_prompt
: Number of learnable textual prompt tokens
FLEX has been evaluated on 16 clinically relevant tasks and demonstrates:
- Improved performance on unseen clinical sites
- Reduced performance gap between seen and unseen sites
- Enhanced fairness across demographic groups
For detailed results, refer to our paper.
This project is licensed under the Apache-2.0 license.
This project was built on the top of amazing works, including CLAM, CONCH, QuiltNet, PathGen-CLIP, and PreservedSiteCV. We thank the authors for their great works.