Autoregressive Diffusion Distillation Done Right for High-Quality Real-Time Interactive Video Generation
Causal Forcing significantly outperforms Self Forcing in both visual quality and motion dynamics, while keeping the same training budget and inference efficiency—enabling real-time, streaming video generation on a single RTX 4090.
demo.mp4
The inference environment is identical to Self Forcing, so you can migrate directly using our configs and model.
conda create -n causal_forcing python=3.10 -y
conda activate causal_forcing
pip install -r requirements.txt
pip install git+https://github.com/openai/CLIP.git
pip install flash-attn --no-build-isolation
python setup.py develophf download Wan-AI/Wan2.1-T2V-1.3B --local-dir wan_models/Wan2.1-T2V-1.3B
hf download Wan-AI/Wan2.1-T2V-14B --local-dir wan_models/Wan2.1-T2V-14B
hf download zhuhz22/Causal-Forcing chunkwise/causal_forcing.pt --local-dir checkpoints
hf download zhuhz22/Causal-Forcing framewise/causal_forcing.pt --local-dir checkpointsWe open-source both the frame-wise and chunk-wise models; the former is a setting that Self Forcing has chosen not to release.
Frame-wise model (higher dynamic degree and more expressive):
python inference.py \
--config_path configs/causal_forcing_dmd_framewise.yaml \
--output_folder output/framewise \
--checkpoint_path checkpoints/framewise/causal_forcing.pt \
--data_path prompts/demos.txt \
--use_ema
# Note: this frame-wise config not in Self Forcing; if using its framework, migrate this config too.Chunk-wise model (more stable):
python inference.py \
--config_path configs/causal_forcing_dmd_chunkwise.yaml \
--output_folder output/chunkwise \
--checkpoint_path checkpoints/chunkwise/causal_forcing.pt \
--data_path prompts/demos.txtStage 1: Autoregressive Diffusion Training (Can skip by using our pretrained checkpoints. Click to expand.)
First download the dataset (we provide a 6K toy dataset here):
hf download zhuhz22/Causal-Forcing-data --local-dir dataset
python utils/merge_and_get_clean.pyIf the download gets stuck, Ctrl^C and then resume it.
Then train the AR-diffusion model:
-
Framewise:
torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \ --rdzv_backend=c10d \ --rdzv_endpoint $MASTER_ADDR \ train.py \ --config_path configs/ar_diffusion_tf_framewise.yaml \ --logdir logs/ar_diffusion_framewise -
Chunkwise:
torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \ --rdzv_backend=c10d \ --rdzv_endpoint $MASTER_ADDR \ train.py \ --config_path configs/ar_diffusion_tf_chunkwise.yaml \ --logdir logs/ar_diffusion_chunkwise
We recommend training no less than 2K steps, and more steps (e.g., 5~10K) will lead to better performance.
Stage 2: Causal ODE Initialization (Can skip by using our pretrained checkpoints. Click to expand.)
If you have skipped Stage 1, you need to download the pretrained models:
hf download zhuhz22/Causal-Forcing framewise/ar_diffusion.pt --local-dir checkpoints
hf download zhuhz22/Causal-Forcing chunkwise/ar_diffusion.pt --local-dir checkpointsIn this stage, first generate ODE paired data:
# for the frame-wise model
torchrun --nproc_per_node=8 \
get_causal_ode_data_framewise.py \
--generator_ckpt checkpoints/framewise/ar_diffusion.pt \
--rawdata_path dataset/clean_data \
--output_folder dataset/ODE6KCausal_framewise_latents
python utils/create_lmdb_iterative.py \
--data_path dataset/ODE6KCausal_framewise_latents \
--lmdb_path dataset/ODE6KCausal_framewise
# for the chunk-wise model
torchrun --nproc_per_node=8 \
get_causal_ode_data_chunkwise.py \
--generator_ckpt checkpoints/chunkwise/ar_diffusion.pt \
--rawdata_path dataset/clean_data \
--output_folder dataset/ODE6KCausal_chunkwise_latents
python utils/create_lmdb_iterative.py \
--data_path dataset/ODE6KCausal_chunkwise_latents \
--lmdb_path dataset/ODE6KCausal_chunkwiseOr you can also directly download our prepared dataset (~300G):
hf download zhuhz22/Causal-Forcing-data --local-dir dataset
python utils/merge_lmdb.pyIf the download gets stuck, Ctrl^C and then resume it.
And then train ODE initialization models:
- Frame-wise:
torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \ --rdzv_backend=c10d \ --rdzv_endpoint $MASTER_ADDR \ train.py \ --config_path configs/causal_ode_framewise.yaml \ --logdir logs/causal_ode_framewise - Chunk-wise:
torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \ --rdzv_backend=c10d \ --rdzv_endpoint $MASTER_ADDR \ train.py \ --config_path configs/causal_ode_chunkwise.yaml \ --logdir logs/causal_ode_chunkwise
We recommend training no less than 1K steps, and more steps (e.g., 5~10K) will lead to better performance.
This stage is compatible with Self Forcing training, so you can migrate seamlessly by using our configs and checkpoints.
First download the dataset:
hf download gdhe17/Self-Forcing vidprom_filtered_extended.txt --local-dir promptsIf you have skipped Stage 2, you need to download the pretrained checkpoints:
hf download zhuhz22/Causal-Forcing framewise/causal_ode.pt --local-dir checkpoints
hf download zhuhz22/Causal-Forcing chunkwise/causal_ode.pt --local-dir checkpointsAnd then train DMD models:
-
Frame-wise model:
torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \ --rdzv_backend=c10d \ --rdzv_endpoint $MASTER_ADDR \ train.py \ --config_path configs/causal_forcing_dmd_framewise.yaml \ --logdir logs/causal_forcing_dmd_framewiseWe recommend training 500 steps. More than 1K steps will reduce dynamic degree.
-
Chunk-wise model:
torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \ --rdzv_backend=c10d \ --rdzv_endpoint $MASTER_ADDR \ train.py \ --config_path configs/causal_forcing_dmd_chunkwise.yaml \ --logdir logs/causal_forcing_dmd_chunkwiseWe recommend training 100~200 steps. More than 1K steps will reduce dynamic degree.
Such models are the final models used to generate videos.
This codebase is built on top of the open-source implementation of CausVid, Self Forcing and the Wan2.1 repo.
If you find the method useful, please cite
@article{zhu2026casualforcing,
title={Riflex: A free lunch for length extrapolation in video diffusion transformers},
author={Zhu hongzhou, Zhao Min, He Guande, Chongxuan and Zhu, Jun},
journal={arXiv preprint arXiv:2502.15894},
year={2026}
}