Skip to content

ImageFlowNet: Forecasting Multiscale Image-Level Trajectories of Disease Progression with Irregularly-Sampled Longitudinal Medical Images

License

Notifications You must be signed in to change notification settings

ChenLiu-1996/ImageFlowNet

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

ImageFlowNet

Krishnaswamy Lab, Yale University

Twitter Github Stars ArXiv

This is the authors' implementation of ImageFlowNet. The official codebase is maintained in the Lab GitHub repo.

A Glimpse into the Methods

Citation

@article{liu2024imageflownet,
  title={ImageFlowNet: Forecasting Multiscale Trajectories of Disease Progression with Irregularly-Sampled Longitudinal Medical Images},
  author={Liu, Chen and Xu, Ke and Shen, Liangbo L and Huguet, Guillaume and Wang, Zilong and Tong, Alexander and Bzdok, Danilo and Stewart, Jay and Wang, Jay C and Del Priore, Lucian V and Krishnaswamy, Smita}
  journal={arXiv preprint arXiv:2406.14794},
  year={2024}
}

Abstract

Advances in medical imaging technologies have facilitated the collection of longitudinal images to track disease progression, yet predictive modeling of such data remains challenging due to high dimensionality, irregular sampling, and data sparsity. To address these issues, we propose ImageFlowNet, a novel model designed to forecast disease trajectories from initial images while preserving spatial details. ImageFlowNet first learns multiscale joint representation spaces across patients and time points, then optimizes deterministic or stochastic flow fields within these spaces using a position-parameterized neural ODE/SDE framework. The model leverages a UNet architecture to create robust multiscale representations and mitigates data scarcity by combining knowledge from all patients. We provide theoretical insights that support our formulation of ODEs, and motivate our regularizations involving high-level visual features, latent space organization, and trajectory smoothness. We validate ImageFlowNet on three longitudinal medical image datasets depicting progression in geographic atrophy, multiple sclerosis, and glioblastoma, demonstrating its ability to effectively forecast disease progression and outperform existing methods. Our contributions include the development of ImageFlowNet, its theoretical underpinnings, and empirical validation on real-world datasets.

Repository Hierarchy

ImageFlowNet
    ├── comparison: some comparisons are in the `src` folder instead.
    |   └── interpolation
    |
    ├── checkpoints: only for segmentor model weights. Other model weights in `results`.
    |
    ├── data: folders containing data files.
    |   ├── brain_LUMIERE: Brain Glioblastoma
    |   ├── brain_MS: Brain Multiple Sclerosis
    |   └── retina_ucsf: Retinal Geographic Atrophy
    |
    ├── external_src: other repositories or code.
    |
    ├── results: generated results, including training log, model weights, and evaluation results.
    |
    └── src
        ├── data_utils
        ├── datasets
        ├── nn
        ├── preprocessing
        ├── utils
        └── *.py: some main scripts

Pre-trained weights

We have uploaded the weights for the retinal images.

  1. The weights for the segmentor can be found in checkpoints/segment_retinaUCSF_seed1.pty
  2. The weights for the ImageFlowNetODE models can be found in Google Drive. You can put them under results/retina_ucsf_ImageFlowNetODE_smoothness-0.100_latent-0.001_contrastive-0.010_invariance-0.000_seed_1/run_1/retina_ucsf_ImageFlowNetODE_smoothness-0.100_latent-0.001_contrastive-0.010_invariance-0.000_seed_1_best_pred_psnr.pty and results/retina_ucsf_ImageFlowNetODE_smoothness-0.100_latent-0.001_contrastive-0.010_invariance-0.000_seed_1/run_1/retina_ucsf_ImageFlowNetODE_smoothness-0.100_latent-0.001_contrastive-0.010_invariance-0.000_seed_1_best_seg_dice.pty.

Reproduce the results

Image registration

cd src/preprocessing
python test_registration.py

Training a segmentation network (only for quantitative evaluation purposes)

cd src/
python train_segmentor.py

Training the main network.

cd src/
# ImageFlowNet_{ODE}
python train_2pt_all.py --model ImageFlowNetODE --random-seed 1
python train_2pt_all.py --model ImageFlowNetODE --random-seed 1 --mode test --run-count 1

# ImageFlowNet_{SDE}
python train_2pt_all.py --model ImageFlowNetSDE --random-seed 1
python train_2pt_all.py --model ImageFlowNetSDE --random-seed 1 --mode test --run-count 1

Some common arguments.

--dataset-name: name of the dataset (`retina_ucsf`, `brain_ms`, `brain_gbm`)
--segmentor-ckpt: the location of the segmentor model. Both for training and using the segmentor.

Ablations.

  1. Flow field formulation.
python train_2pt_all.py --model ODEUNet
python train_2pt_all.py --model ImageFlowNetODE
  1. Single-scale vs multiscale ODEs.
python train_2pt_all.py --model ImageFlowNetODE --ode-location 'bottleneck'
python train_2pt_all.py --model ImageFlowNetODE --ode-location 'all_resolutions'
python train_2pt_all.py --model ImageFlowNetODE --ode-location 'all_connections' # default
  1. Visual feature regularization.
python train_2pt_all.py --model ImageFlowNetODE --coeff-latent 0.1
  1. Contrastive learning regularization.
python train_2pt_all.py --model ImageFlowNetODE --coeff-contrastive 0.1
  1. Trajectory smoothness regularization.
python train_2pt_all.py --model ImageFlowNetODE --coeff-smoothness 0.1

Comparisons

Image interpolation/extrapolation methods.

cd comparison/interpolation
python run_baseline_interp.py --method linear
python run_baseline_interp.py --method cubic_spline

Time-conditional UNet.

cd src
python train_2pt_all.py --model T_UNet --random-seed 1 --mode train
python train_2pt_all.py --model T_UNet --random-seed 1 --mode test --run-count 1

Time-aware diffusion model (Image-to-Image Schrodinger Bridge)

cd src
python train_2pt_all.py --model I2SBUNet --random-seed 1
python train_2pt_all.py --model I2SBUNet --random-seed 1 --mode test --run-count 1

Style-based Manifold Extrapolation (Nat. Mach. Int. 2022).

conda deactivate
conda activate stylegan

cd src/preprocessing
python 04_unpack_retina_UCSF.py

cd ../../comparison/style_manifold_extrapolation/stylegan2-ada-pytorch
python train.py --outdir=../training-runs --data='../../../data/retina_ucsf/UCSF_images_final_unpacked_256x256/' --gpus=1

Datasets

  1. Retinal Geographic Atrophy dataset from METforMIN study (UCSF).
  2. Brain Multiple Sclerosis dataset.
  3. Brain Glioblastoma dataset.

Data preparation and preprocessing

  1. Retinal Geographic Atrophy dataset.
  • Put data under: data/retina_ucsf/Images/
cd src/preprocessing
python 01_preprocess_retina_UCSF.py
python 02_register_retina_UCSF.py
python 03_crop_retina_UCSF.py
  1. Brain Multiple Sclerosis dataset.
  • Put data under: data/brain_MS/brain_MS_images/trainX/ after unzipping.
cd src/preprocessing
python 01_preprocess_brain_MS.py
  1. Brain Glioblastoma dataset.
  • Put data under: data/brain_LUMIERE/ after unzipping.
cd src/preprocessing
python 01_preprocess_brain_GBM.py

Segment Anything Model (SAM)

This is only used for test_registration.py to facilitate visualization. Not used anywhere else.

cd `external_src/`
mkdir SAM && cd SAM
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

Dependencies

We developed the codebase in a miniconda environment. How we created the conda environment:

# Optional: Update to libmamba solver.
conda update -n base conda
conda install -n base conda-libmamba-solver
conda config --set solver libmamba

conda create --name imageflownet pytorch==1.12.1 torchvision==0.13.1 torchaudio==0.12.1 cudatoolkit=11.3 -c pytorch -c nvidia -c anaconda -c conda-forge
conda activate imageflownet
conda install scikit-learn scikit-image pillow matplotlib seaborn tqdm -c pytorch -c anaconda -c conda-forge
conda install read-roi -c conda-forge
python -m pip install -U albumentations
python -m pip install timm
python -m pip install opencv-python
python -m pip install git+https://github.com/facebookresearch/segment-anything.git
python -m pip install monai
python -m pip install torchdiffeq
python -m pip install torch-ema
python -m pip install torchcde
python -m pip install torchsde
python -m pip install phate
python -m pip install psutil
python -m pip install ninja

# For 3D registration
python -m pip install antspyx

Acknowledgements

We adapted some of the code from

  1. I^2SB: Image-to-Image Schrodinger Bridge

About

ImageFlowNet: Forecasting Multiscale Image-Level Trajectories of Disease Progression with Irregularly-Sampled Longitudinal Medical Images

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages