Official Implementation of Diffusion Bridge Implicit Models (ICLR 2025) and Consistency Diffusion Bridge Models (NeurIPS 2024).
To install all packages in this codebase along with their dependencies, run
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
pip install blobfile piq matplotlib opencv-python joblib lmdb scipy clean-fid easydict torchmetrics rich ipdb wandb
Please put (or link) the datasets under assets/datasets/
.
- For Edges2Handbags, please follow instructions from here. The resulting folder structure should be
assets/datasets/edges2handbags/train
andassets/datasets/edges2handbags/val
. - For DIODE, please download the training dataset and the data list from here. The resulting folder structure should be
assets/datasets/DIODE/train
andassets/datasets/DIODE/data_list
. - For ImageNet, please download the dataset from here. The resulting folder structure should be
assets/datasets/ImageNet/train
andassets/datasets/ImageNet/val
.
We also provide automatic downloading scripts.
cd assets/datasets
bash download_extract_edges2handbags.sh
bash download_extract_DIODE.sh
bash download_extract_ImageNet.sh
After downloading, the DIODE dataset requires preprocessing by running python preprocess_depth.py
.
DBIM offers a suite of fast samplers tailored for Denoising Diffusion Bridge Models (DDBMs). We clean the codebase to support a broad range of diffusion bridges, facilitating unified training and sampling workflows. We also streamline the deployment process by replacing the cumbersome MPI-based distributed launcher with the more efficient and engineer-friendly torchrun
.
Please put the downloaded checkpoints under assets/ckpts/
.
For image translation, we directly adopt the pretrained checkpoints from DDBM:
- Edges2Handbags: e2h_ema_0.9999_420000.pt
- DIODE: diode_ema_0.9999_440000.pt
We remove the dependency on external packages such as flash_attn
in this codebase, which is already supported natively by PyTorch. After downloading the two checkpoints above, please run python preprocess_ckpt.py
to complete the conversion.
For image restoration:
- Center 128x128 Inpainting on ImageNet 256x256: imagenet256_inpaint_ema_0.9999_400000.pt
bash scripts/sample.sh $DATASET_NAME $NFE $SAMPLER ($AUX)
-
$DATASET_NAME
can be chosen frome2h
/diode
/imagenet_inpaint_center
. -
$NFE
is the Number of Function Evaluations, which is proportional to the sampling time. -
$SAMPLER
can be chosen fromheun
/dbim
/dbim_high_order
.-
heun
is the vanilla sampler of DDBM, which simulates the SDE/ODE step alternatively. In this case,$AUX
is not required. -
dbim
anddbim_high_order
are our proposed samplers. When usingdbim
,$AUX
corresponds to$\eta$ which controls the stochasticity level (floating-point value in$[0,1]$ ). When usingdbim_high_order
,$AUX
corresponds to the order (2 or 3).
-
The samples will be saved to workdir/
.
Before evaluating the image translation results, please download the reference statistics from DDBM and put them under assets/stats/
:
- Reference stats for Edge2Handbags: edges2handbags_ref_64_data.npz.
- Reference stats for DIODE: diode_ref_256_data.npz.
The evaluation can automatically proceed by specifying the same dataset and sampler arguments as sampling:
bash scripts/evaluate.sh $DATASET_NAME $NFE $SAMPLER ($AUX)
We provide the script for training diffusion bridge models on the ImageNet 256x256 inpainting (center 128x128) task. As ImageNet 256x256 is a challenging dataset, we follow I2SB and initialize the network with a pretrained diffusion model ADM. To start training, run
python download_diffusion.py
bash scripts/train_bridge.sh
Training for 400k iterations takes around 2 weeks on 8xA100. We recommend using multiple nodes and modifying scripts/train_bridge.sh
accordingly.
TODO
This codebase is built upon DDBM and I2SB.
If you find this method and/or code useful, please consider citing
@article{zheng2024diffusion,
title={Diffusion Bridge Implicit Models},
author={Zheng, Kaiwen and He, Guande and Chen, Jianfei and Bao, Fan and Zhu, Jun},
journal={arXiv preprint arXiv:2405.15885},
year={2024}
}
and
@article{he2024consistency,
title={Consistency Diffusion Bridge Models},
author={He, Guande and Zheng, Kaiwen and Chen, Jianfei and Bao, Fan and Zhu, Jun},
journal={arXiv preprint arXiv:2410.22637},
year={2024}
}