Skip to content

Commit

Permalink
feat: complete sequence parallelism for CogVideo (xdit-project#285)
Browse files Browse the repository at this point in the history
  • Loading branch information
xibosun authored and feifeibear committed Oct 25, 2024
1 parent c680956 commit 8f968c0
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 41 deletions.
2 changes: 1 addition & 1 deletion examples/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ declare -A MODEL_CONFIGS=(
["Sd3"]="sd3_example.py /cfs/dit/stable-diffusion-3-medium-diffusers 20"
["Flux"]="flux_example.py /cfs/dit/FLUX.1-schnell 4"
["HunyuanDiT"]="hunyuandit_example.py /mnt/models/SD/HunyuanDiT-v1.2-Diffusers 50"
["CogVideoX"]="cogvideox_example.py /cfs/dit/CogVideoX-2b 1"
["CogVideoX"]="cogvideox_example.py /cfs/dit/CogVideoX-2b 20"
)

if [[ -v MODEL_CONFIGS[$MODEL_TYPE] ]]; then
Expand Down
60 changes: 42 additions & 18 deletions xfuser/model_executor/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(
super().__init__(
module=patch_embedding,
)
self.module: CogVideoXPatchEmbed

def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
r"""
Expand All @@ -124,6 +125,10 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
image_embeds (`torch.Tensor`):
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
"""
sum_height = (
get_runtime_state().input_config.height
// get_runtime_state().vae_scale_factor_spatial
)
text_embeds = self.text_proj(text_embeds)
batch, num_frames, channels, height, width = image_embeds.shape

Expand All @@ -133,28 +138,47 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]

if get_runtime_state().patch_mode:
start, end = get_runtime_state().pp_patches_token_start_end_idx_global[
get_runtime_state().pipeline_patch_idx
]
image_embeds = image_embeds[
:,
start:end,
:,
]
else:
image_embeds_list = [
image_embeds[

if self.use_positional_embeddings or self.use_learned_positional_embeddings:
if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != sum_height):
raise ValueError(
"It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'."
"If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues."
)

pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1

if (
self.sample_height != sum_height
or self.sample_width != width
or self.sample_frames != pre_time_compression_frames
):
pos_embedding = self._get_positional_embeddings(sum_height, width, pre_time_compression_frames)
pos_embedding = pos_embedding.to(image_embeds.device, dtype=image_embeds.dtype)
else:
pos_embedding = self.pos_embedding

# extract the image part of the positional embedding
pos_embedding = pos_embedding[:, self.max_text_seq_length :]

# slice the positional embedding
post_patch_height = sum_height // self.patch_size
post_patch_width = width // self.patch_size
post_time_compression_frames = (pre_time_compression_frames - 1) // self.temporal_compression_ratio + 1

pos_embed_list = [
pos_embedding[
:,
get_runtime_state()
.pp_patches_token_start_end_idx_global[i][0] : get_runtime_state()
.pp_patches_token_start_end_idx_global[i][1],
post_patch_height * post_patch_width * i + get_runtime_state().pp_patches_token_start_end_idx_global[0][0]:
post_patch_height * post_patch_width * i + get_runtime_state().pp_patches_token_start_end_idx_global[0][1],
:,
]
for i in range(get_runtime_state().num_pipeline_patch)
for i in range(post_time_compression_frames)
]
image_embeds = torch.cat(image_embeds_list, dim=1)
pos_embedding = torch.cat(pos_embed_list, dim=1)

image_embeds = image_embeds + pos_embedding

embeds = torch.cat(
[text_embeds, image_embeds], dim=1
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
):
super().__init__(
transformer=transformer,
submodule_classes_to_wrap=[nn.Conv2d],
submodule_classes_to_wrap=[nn.Conv2d, CogVideoXPatchEmbed],
submodule_name_to_wrap=["attn1"]
)

Expand Down
35 changes: 14 additions & 21 deletions xfuser/model_executor/pipelines/pipeline_cogvideox.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,7 @@ def __call__(
max_sequence_length=max_sequence_length,
device=device,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_embeds = self._process_cfg_split_batch_latte(prompt_embeds, negative_prompt_embeds)

# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
Expand Down Expand Up @@ -272,9 +271,11 @@ def __call__(
if self.interrupt:
continue

latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
if do_classifier_free_guidance:
latent_model_input = torch.cat(
[latents] * (2 // get_classifier_free_guidance_world_size())
)

latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
Expand All @@ -295,21 +296,15 @@ def __call__(
# perform guidance
if use_dynamic_cfg:
self._guidance_scale = 1 + guidance_scale * (
(
1
- math.cos(
math.pi
* (
(num_inference_steps - t.item())
/ num_inference_steps
)
** 5.0
)
)
/ 2
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
if get_classifier_free_guidance_world_size() == 1:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
elif get_classifier_free_guidance_world_size() == 2:
noise_pred_uncond, noise_pred_text = get_cfg_group().all_gather(
noise_pred, separate_tensors=True
)
noise_pred = noise_pred_uncond + self.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
Expand Down Expand Up @@ -344,9 +339,7 @@ def __call__(
"negative_prompt_embeds", negative_prompt_embeds
)

if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()

if get_sequence_parallel_world_size() > 1:
Expand Down

0 comments on commit 8f968c0

Please sign in to comment.