From bc18d27864e7bcfc3a12913f0a73d097da84d7a3 Mon Sep 17 00:00:00 2001 From: chengzeyi Date: Mon, 23 Dec 2024 17:14:56 +0800 Subject: [PATCH] make hunyuan video work with more resolutions and update the performance table --- docs/performance/hunyuanvideo.md | 3 +- examples/hunyuan_video_usp_example.py | 42 ++++++++++++++++-------- xfuser/core/distributed/runtime_state.py | 9 ++++- 3 files changed, 39 insertions(+), 15 deletions(-) diff --git a/docs/performance/hunyuanvideo.md b/docs/performance/hunyuanvideo.md index e2149238..0e04d275 100644 --- a/docs/performance/hunyuanvideo.md +++ b/docs/performance/hunyuanvideo.md @@ -10,7 +10,7 @@ xDiT is [HunyuanVideo](https://github.com/Tencent/HunyuanVideo?tab=readme-ov-fil |----------|--------|---------|---------|---------| | H100 | 1,904.08 | 925.04 | 514.08 | 337.58 | | H20 | 6,639.17 | 3,400.55 | 1,762.86 | 940.97 | -| L20 | 6,043.88 | | | | +| L20 | 6,043.88 | 3,271.44 | 2,080.05 | | @@ -22,5 +22,6 @@ xDiT is [HunyuanVideo](https://github.com/Tencent/HunyuanVideo?tab=readme-ov-fil |----------|--------|---------|---------|---------| | H100 | 1,735.01 | 934.09 | 645.45 | 367.02 | | H20 | 6,621.46 | 3,400.55 | 2,310.48 | 1,214.67 | +| L20 | 6,039.08 | 3,260.62 | 2,070.96 | | diff --git a/examples/hunyuan_video_usp_example.py b/examples/hunyuan_video_usp_example.py index a276a388..0360bf22 100644 --- a/examples/hunyuan_video_usp_example.py +++ b/examples/hunyuan_video_usp_example.py @@ -1,6 +1,6 @@ # from https://github.com/chengzeyi/ParaAttention/blob/main/examples/run_hunyuan_video.py import functools -from typing import Any, Dict, Union +from typing import Any, Dict, Union, Optional import logging import time @@ -8,6 +8,7 @@ from diffusers import DiffusionPipeline, HunyuanVideoPipeline, HunyuanVideoTransformer3DModel from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.utils import scale_lora_layers, unscale_lora_layers, USE_PEFT_BACKEND from diffusers.utils import export_to_video from xfuser import xFuserArgs @@ -45,8 +46,22 @@ def new_forward( encoder_attention_mask: torch.Tensor, pooled_projections: torch.Tensor, guidance: torch.Tensor = None, + attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logging.warning("Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective.") + batch_size, num_channels, num_frames, height, width = hidden_states.shape assert batch_size % get_classifier_free_guidance_world_size( @@ -68,13 +83,14 @@ def new_forward( encoder_attention_mask) hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1) + hidden_states = hidden_states.flatten(1, 3) + hidden_states = torch.chunk(hidden_states, get_classifier_free_guidance_world_size(), dim=0)[get_classifier_free_guidance_rank()] hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(), - dim=2)[get_sequence_parallel_rank()] - hidden_states = hidden_states.flatten(1, 3) + dim=-2)[get_sequence_parallel_rank()] encoder_attention_mask = encoder_attention_mask[0].to(torch.bool) encoder_hidden_states_indices = torch.arange( @@ -103,11 +119,7 @@ def new_forward( freqs_cos, freqs_sin = image_rotary_emb def get_rotary_emb_chunk(freqs): - dim_thw = freqs.shape[-1] - freqs = freqs.reshape(num_frames, -1, dim_thw) - freqs = freqs.chunk(get_sequence_parallel_world_size(), dim=-2)[ - get_sequence_parallel_rank()] - freqs = freqs.reshape(-1, dim_thw) + freqs = torch.chunk(freqs, get_sequence_parallel_world_size(), dim=0)[get_sequence_parallel_rank()] return freqs freqs_cos = get_rotary_emb_chunk(freqs_cos) @@ -166,17 +178,21 @@ def custom_forward(*inputs): hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.reshape(batch_size // get_classifier_free_guidance_world_size(), + hidden_states = get_sp_group().all_gather(hidden_states, dim=-2) + hidden_states = get_cfg_group().all_gather(hidden_states, dim=0) + + hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, - post_patch_height // get_sequence_parallel_world_size(), + post_patch_height, post_patch_width, -1, p_t, p, p) - hidden_states = get_sp_group().all_gather(hidden_states, dim=2) - hidden_states = get_cfg_group().all_gather(hidden_states, dim=0) - hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7) hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + if not return_dict: return (hidden_states, ) diff --git a/xfuser/core/distributed/runtime_state.py b/xfuser/core/distributed/runtime_state.py index 7de0606f..dd8ec07d 100644 --- a/xfuser/core/distributed/runtime_state.py +++ b/xfuser/core/distributed/runtime_state.py @@ -103,7 +103,12 @@ def __init__(self, pipeline: DiffusionPipeline, config: EngineConfig): pipeline=pipeline, parallel_config=config.parallel_config ) self.cogvideox = False + self.hunyuan_video = False if pipeline.__class__.__name__.startswith(("CogVideoX", "HunyuanVideo")): + if pipeline.__class__.__name__.startswith("CogVideoX"): + self.cogvideox = True + else: + self.hunyuan_video = True self._set_cogvideox_parameters( vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial, vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal, @@ -194,7 +199,6 @@ def _set_cogvideox_parameters( self.backbone_patch_size = backbone_patch_size self.backbone_inner_dim = backbone_inner_dim self.backbone_in_channel = backbone_in_channel - self.cogvideox = True def set_patched_mode(self, patch_mode: bool): self.patch_mode = patch_mode @@ -259,6 +263,9 @@ def _video_input_size_change( self.input_config.batch_size = batch_size or self.input_config.batch_size if self.cogvideox: self._calc_cogvideox_patches_metadata() + elif self.hunyuan_video: + # TODO: implement the hunyuan video patches metadata + pass else: self._calc_patches_metadata() self._reset_recv_buffer()