FewMedDiff是一个基于扩散模型的少样本医学图像生成框架,旨在解决医学图像领域数据稀缺的问题。通过创新的多模态融合、对比学习和生成模型技术,从有限的医学图像中生成高质量、多样化的图像,用于医学图像处理任务的数据增强。
-
多模态融合模块:整合不同医学成像模态(如MRI中的T1、T2、T1ce、FLAIR)的信息,生成包含丰富特征的融合表示。
-
双对比学习策略:
- 多视角对比学习:对同一图像的不同增强版本进行对比学习
- 多模态对比学习:学习不同模态之间的语义关系
-
RAP损失函数:综合考虑重建损失(Reconstruction)、对抗损失(Adversarial)和感知损失(Perceptual),优化生成图像质量。
-
多尺度判别器:在多个分辨率上同时进行判别,平衡全局结构一致性和局部细节真实性。
-
无条件扩散模型:基于去噪扩散概率模型(DDPM)的图像生成技术。
models/multimodal_fusion.py: 多模态融合模块models/contrastive_learning.py: 双对比学习模块models/diffusion.py: 扩散模型实现models/discriminator.py: 多尺度判别器datasets/dataset.py: 数据加载和预处理train.py: 训练脚本generate.py: 图像生成脚本config.py: 配置管理
pip install torch torchvision tqdm numpy matplotlib scikit-image
pip install nibabel SimpleITK # 用于医学图像处理训练过程分为两个阶段:对比学习预训练和扩散模型训练。
# 仅进行对比学习预训练
python train.py --stage contrastive
# 仅训练扩散模型
python train.py --stage diffusion
# 完整训练流程
python train.py --stage all可选参数:
--device: 指定计算设备 (默认: 可用的CUDA设备或CPU)--batch_size: 批处理大小--num_epochs: 训练轮数--resume: 从检查点恢复训练
使用训练好的模型生成医学图像:
# 生成图像
python generate.py --checkpoint checkpoints/diffusion_best.pth --num_samples 100 --mode generate
# 生成真实图像与生成图像的对比
python generate.py --checkpoint checkpoints/diffusion_best.pth --mode compare
# 评估生成图像质量
python generate.py --checkpoint checkpoints/diffusion_best.pth --mode evaluate可选参数:
--output_dir: 输出目录--batch_size: 批处理大小--seed: 随机种子--device: 计算设备
FewMedDiff支持以下医学图像数据集:
-
BraTS:脑肿瘤分割多模态MRI数据集,包含T1、T1ce、T2和FLAIR四种模态。
-
LUNA16:肺结节CT图像数据集,我们将同一CT图像转换为不同窗口设置(原始、肺窗、纵隔窗)作为多模态输入。
使用以下指标评估生成图像质量:
- FID (Fréchet Inception Distance):测量生成图像与真实图像的特征分布差异
- KID (Kernel Inception Distance):无偏估计的分布相似度度量
- IS (Inception Score):评估生成图像的质量和多样性
- SSIM (Structural Similarity Index):结构相似性指标
- PSNR (Peak Signal-to-Noise Ratio):峰值信噪比
生成的医学图像可用于各种医学图像处理任务的数据增强,如目标检测、分割等。我们的实验表明,使用生成图像增强训练集可以显著提高下游任务的性能,特别是在数据有限的情况下。
如果您在研究中使用了FewMedDiff,请引用我们的论文:
@article{fewmeddiff2024,
title={FewMedDiff: Diffusion Model for Medical Image Generation and Application in Data Augmentation for Object Detection Tasks},
author={Anonymous},
journal={Under Review},
year={2024}
}