Skip to content

Commit

Permalink
ConsisID for xdit (#405)
Browse files Browse the repository at this point in the history
  • Loading branch information
pkuhxy authored Dec 24, 2024
1 parent 1b589b7 commit 970d5ce
Show file tree
Hide file tree
Showing 18 changed files with 1,491 additions and 19 deletions.
33 changes: 20 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<div align="center">
<!-- <h1>KTransformers</h1> -->
<p align="center">

<picture>
<img alt="xDiT" src="https://raw.githubusercontent.com/xdit-project/xdit_assets/main/XDiTlogo.png" width="50%">

Expand All @@ -22,6 +22,7 @@
- [📈 Performance](#perf)
- [HunyuanVideo](#perf_hunyuanvideo)
- [Mochi-1](#perf_mochi1)
- [ConsisID](#perf_consisid)
- [CogVideoX](#perf_cogvideox)
- [Flux.1](#perf_flux)
- [HunyuanDiT](#perf_hunyuandit)
Expand Down Expand Up @@ -94,6 +95,7 @@ Furthermore, xDiT incorporates optimization techniques from [DiTFastAttn](https:

<h2 id="updates">📢 Updates</h2>

* 🎉**December 24, 2024**: xDiT supports [ConsisID-Preview](https://github.com/PKU-YuanGroup/ConsisID) and achieved 3.21x speedup compare to the official implementation! The inference scripts are [examples/consisid_example.py](examples/consisid_example.py) and [examples/consisid_usp_example.py](examples/consisid_usp_example.py).
* 🎉**December 7, 2024**: xDiT is the official parallel inference engine for [HunyuanVideo](https://github.com/Tencent/HunyuanVideo), reducing the 5-sec video generation latency from 31 minutes to 5 minutes on 8xH100!
* 🎉**November 28, 2024**: xDiT achieves 1.6 sec end-to-end latency for 28-step [Flux.1-Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) inference on 4xH100!
* 🎉**November 20, 2024**: xDiT supports [CogVideoX-1.5](https://huggingface.co/THUDM/CogVideoX1.5-5B) and achieved 6.12x speedup compare to the implementation in diffusers!
Expand All @@ -117,11 +119,12 @@ Furthermore, xDiT incorporates optimization techniques from [DiTFastAttn](https:

| Model Name | CFG | SP | PipeFusion |
| --- | --- | --- | --- |
| [🎬 HunyuanVideo](https://github.com/Tencent/HunyuanVideo) | NA | ✔️ ||
| [🎬 CogVideoX1.5](https://huggingface.co/THUDM/CogVideoX1.5-5B) | ✔️ | ✔️ ||
| [🎬 Mochi-1](https://github.com/xdit-project/mochi-xdit) | ✔️ | ✔️ ||
| [🎬 CogVideoX](https://huggingface.co/THUDM/CogVideoX-2b) | ✔️ | ✔️ ||
| [🎬 Latte](https://huggingface.co/maxin-cn/Latte-1) || ✔️ ||
| [🎬 HunyuanVideo](https://github.com/Tencent/HunyuanVideo) | NA | ✔️ ||
| [🎬 ConsisID-Preview](https://github.com/PKU-YuanGroup/ConsisID) | ✔️ | ✔️ ||
| [🎬 CogVideoX1.5](https://huggingface.co/THUDM/CogVideoX1.5-5B) | ✔️ | ✔️ ||
| [🎬 Mochi-1](https://github.com/xdit-project/mochi-xdit) | ✔️ | ✔️ ||
| [🎬 CogVideoX](https://huggingface.co/THUDM/CogVideoX-2b) | ✔️ | ✔️ ||
| [🎬 Latte](https://huggingface.co/maxin-cn/Latte-1) || ✔️ ||
| [🔵 HunyuanDiT-v1.2-Diffusers](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers) | ✔️ | ✔️ | ✔️ |
| [🟠 Flux](https://huggingface.co/black-forest-labs/FLUX.1-schnell) | NA | ✔️ | ✔️ |
| [🔴 PixArt-Sigma](https://huggingface.co/PixArt-alpha/PixArt-Sigma-XL-2-1024-MS) | ✔️ | ✔️ | ✔️ |
Expand Down Expand Up @@ -163,33 +166,37 @@ Currently, if you need the parallel version of ComfyUI, please fill in this [app

1. [HunyuanVideo Performance Report](./docs/performance/hunyuanvideo.md)

<h3 id="perf_consisid">ConsisID-Preview</h3>

2. [ConsisID Performance Report](./docs/performance/consisid.md)

<h3 id="perf_cogvideox">Mochi1</h3>

2. [mochi1-xdit: Reducing the Inference Latency by 3.54x Compare to the Official Open Souce Implementation!](https://github.com/xdit-project/mochi-xdit)
3. [mochi1-xdit: Reducing the Inference Latency by 3.54x Compare to the Official Open Souce Implementation!](https://github.com/xdit-project/mochi-xdit)

<h3 id="perf_cogvideox">CogVideo</h3>

3. [CogVideo Performance Report](./docs/performance/cogvideo.md)
4. [CogVideo Performance Report](./docs/performance/cogvideo.md)

<h3 id="perf_flux">Flux.1</h3>

4. [Flux Performance Report](./docs/performance/flux.md)
5. [Flux Performance Report](./docs/performance/flux.md)

<h3 id="perf_latte">Latte</h3>

5. [Latte Performance Report](./docs/performance/latte.md)
6. [Latte Performance Report](./docs/performance/latte.md)

<h3 id="perf_hunyuandit">HunyuanDiT</h3>

6. [HunyuanDiT Performance Report](./docs/performance/hunyuandit.md)
7. [HunyuanDiT Performance Report](./docs/performance/hunyuandit.md)

<h3 id="perf_sd3">SD3</h3>

7. [Stable Diffusion 3 Performance Report](./docs/performance/sd3.md)
8. [Stable Diffusion 3 Performance Report](./docs/performance/sd3.md)

<h3 id="perf_pixart">Pixart</h3>

8. [Pixart-Alpha Performance Report (legacy)](./docs/performance/pixart_alpha_legacy.md)
9. [Pixart-Alpha Performance Report (legacy)](./docs/performance/pixart_alpha_legacy.md)


<h2 id="QuickStart">🚀 QuickStart</h2>
Expand Down
29 changes: 29 additions & 0 deletions docs/performance/ConsisID.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
## ConsisID Performance Report

[ConsisID](https://github.com/PKU-YuanGroup/ConsisID) is an identity-preserving text-to-video generation model that keeps the face consistent in the generated video by frequency decomposition.xDiT currently integrates USP techniques, including Ulysses Attention, Ring Attention, and CFG parallelization, to enhance inference speed, while work on PipeFusion is ongoing. We conducted an in-depth analysis comparing single-GPU ConsisID inference, based on the diffusers library, with our proposed parallelized version for generating 49 frames (6 seconds) of 720x480 resolution video. By flexibly combining different parallelization methods, we achieved varying performance outcomes. In this study, we systematically evaluate xDiT's acceleration performance across 1 to 6 Nvidia H100 GPUs.

As shown in the table, the ConsisID model achieves a significant reduction in inference latency with Ulysses Attention, Ring Attention, or Classifier-Free Guidance (CFG) parallelization. Notably, CFG parallelization outperforms the other two techniques due to its lower communication overhead. By combining sequence parallelization and CFG parallelization, inference efficiency was further improved. With increased parallelism, inference latency continued to decrease. Under the optimal configuration, xDiT achieved a 3.21× speedup over single-GPU inference, reducing iteration time to just 0.72 seconds. For the default 50 iterations of ConsisID, this enables end-to-end generation of 49 frames in 35 seconds, with a GPU memory usage of 40 GB.

### 720x480 Resolution (49 frames, 50 steps)


| N-GPUs | Ulysses Degree | Ring Degree | Cfg Parallel | Times |
| :----: | :------------: | :---------: | :----------: | :-----: |
| 6 | 2 | 3 | 1 | 44.89s |
| 6 | 3 | 2 | 1 | 44.24s |
| 6 | 1 | 3 | 2 | 35.78s |
| 6 | 3 | 1 | 2 | 38.35s |
| 4 | 2 | 1 | 2 | 41.37s |
| 4 | 1 | 2 | 2 | 40.68s |
| 3 | 3 | 1 | 1 | 53.57s |
| 3 | 1 | 3 | 1 | 55.51s |
| 2 | 1 | 2 | 1 | 70.19s |
| 2 | 2 | 1 | 1 | 76.56s |
| 2 | 1 | 1 | 2 | 59.72s |
| 1 | 1 | 1 | 1 | 114.87s |

## Resources

Learn more about ConsisID with the following resources.
- A [video](https://www.youtube.com/watch?v=PhlgC-bI5SQ) demonstrating ConsisID's main features.
- The research paper, [Identity-Preserving Text-to-Video Generation by Frequency Decomposition](https://hf.co/papers/2411.17440) for more details.
23 changes: 23 additions & 0 deletions docs/performance/ConsisID_zh.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
## ConsisID Performance Report

[ConsisID](https://github.com/PKU-YuanGroup/ConsisID) 是一种身份保持的文本到视频生成模型,其通过频率分解在生成的视频中保持面部一致性。xDiT 目前整合了 USP 技术(包括 Ulysses 注意力和 Ring 注意力)和 CFG 并行来提高推理速度,同时 PipeFusion 的工作正在进行中。我们对基于 diffusers 库的单 GPU ConsisID 推理与我们提出的并行化版本在生成 49帧(6秒)720x480 分辨率视频时的性能差异进行了深入分析。由于我们可以任意组合不同的并行方式以获得不同的性能。在本文中,我们对xDiT在1-6张H100(Nvidia)GPU上的加速性能进行了系统测试。

如表所示,对于模型ConsisID,无论是采用 Ulysses Attention、Ring Attention 还是 Classifier-Free Guidance(CFG)并行,均观察到推理延迟的显著降低。值得注意的是,由于其较低的通信开销,CFG 并行方法在性能上优于其他两种技术。通过结合序列并行和 CFG 并行,我们成功提升了推理效率。随着并行度的增加,推理延迟持续下降。在最优配置下,xDiT 相对于单GPU推理实现了 3.21 倍的加速,使得每次迭代仅需 0.72 秒。鉴于 ConsisID 默认的 50 次迭代,总计 35 秒即可完成 49帧 视频的端到端生成,并且运行过程中占用GPU显存40G。

### 720x480 Resolution (49 frames, 50 steps)


| N-GPUs | ulysses_degree | ring_degree | cfg-parallel | times |
|:------:|:--------------:|:-----------:|:------------:|:---------:|
| 6 | 2 | 3 | 1 | 44.89s |
| 6 | 3 | 2 | 1 | 44.24s |
| 6 | 1 | 3 | 2 | 35.78s |
| 6 | 3 | 1 | 2 | 38.35s |
| 4 | 2 | 1 | 2 | 41.37s |
| 4 | 1 | 2 | 2 | 40.68s |
| 3 | 3 | 1 | 1 | 53.57s |
| 3 | 1 | 3 | 1 | 55.51s |
| 2 | 1 | 2 | 1 | 70.19s |
| 2 | 2 | 1 | 1 | 76.56s |
| 2 | 1 | 1 | 2 | 59.72s |
| 1 | 1 | 1 | 1 | 114.87s |
3 changes: 2 additions & 1 deletion examples/cogvideox_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def main():

engine_config, input_config = engine_args.create_config()
local_rank = get_world_group().local_rank


assert engine_args.pipefusion_parallel_degree == 1, "This script does not support PipeFusion."
assert engine_args.use_parallel_vae is False, "parallel VAE not implemented for CogVideo"

pipe = xFuserCogVideoXPipeline.from_pretrained(
Expand Down
119 changes: 119 additions & 0 deletions examples/consisid_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import logging
import os
import time
import torch
import torch.distributed

from diffusers.pipelines.consisid.consisid_utils import prepare_face_models, process_face_embeddings_infer
from diffusers.utils import export_to_video
from huggingface_hub import snapshot_download

from xfuser import xFuserConsisIDPipeline, xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
get_world_group,
get_runtime_state,
is_dp_last_group,
)


def main():
parser = FlexibleArgumentParser(description="xFuser Arguments")
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)

engine_config, input_config = engine_args.create_config()
local_rank = get_world_group().local_rank

assert engine_args.pipefusion_parallel_degree == 1, "This script does not support PipeFusion."
assert engine_args.use_parallel_vae is False, "parallel VAE not implemented for ConsisID"

# 1. Prepare all the Checkpoints
if not os.path.exists(engine_config.model_config.model):
print("Base Model not found, downloading from Hugging Face...")
snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir=engine_config.model_config.model)
else:
print(f"Base Model already exists in {engine_config.model_config.model}, skipping download.")

# 2. Load Pipeline
device = torch.device(f"cuda:{local_rank}")
pipe = xFuserConsisIDPipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
torch_dtype=torch.bfloat16,
)
if args.enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload(gpu_id=local_rank)
logging.info(f"rank {local_rank} sequential CPU offload enabled")
elif args.enable_model_cpu_offload:
pipe.enable_model_cpu_offload(gpu_id=local_rank)
logging.info(f"rank {local_rank} model CPU offload enabled")
else:
pipe = pipe.to(device)

face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = (
prepare_face_models(engine_config.model_config.model, device=device, dtype=torch.bfloat16)
)

if args.enable_tiling:
pipe.vae.enable_tiling()

if args.enable_slicing:
pipe.vae.enable_slicing()

# 3. Prepare Model Input
id_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer(
face_helper_1,
face_clip_model,
face_helper_2,
eva_transform_mean,
eva_transform_std,
face_main_model,
device,
torch.bfloat16,
input_config.img_file_path,
is_align_face=True,
)

# 4. Generate Identity-Preserving Video
torch.cuda.reset_peak_memory_stats()
start_time = time.time()

output = pipe(
image=image,
prompt=input_config.prompt[0],
id_vit_hidden=id_vit_hidden,
id_cond=id_cond,
kps_cond=face_kps,
height=input_config.height,
width=input_config.width,
num_frames=input_config.num_frames,
num_inference_steps=input_config.num_inference_steps,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
guidance_scale=6.0,
use_dynamic_cfg=False,
).frames[0]

end_time = time.time()
elapsed_time = end_time - start_time
peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

parallel_info = (
f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_"
f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_"
f"tp{engine_args.tensor_parallel_degree}_"
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}"
)
if is_dp_last_group():
resolution = f"{input_config.width}x{input_config.height}"
output_filename = f"results/consisid_{parallel_info}_{resolution}.mp4"
export_to_video(output, output_filename, fps=8)
print(f"output saved to {output_filename}")

if get_world_group().rank == get_world_group().world_size - 1:
print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB")
get_runtime_state().destory_distributed_env()


if __name__ == "__main__":
main()
Loading

0 comments on commit 970d5ce

Please sign in to comment.