diff --git a/examples/run.sh b/examples/run.sh index ca96f394..e71744bd 100644 --- a/examples/run.sh +++ b/examples/run.sh @@ -28,7 +28,7 @@ declare -A MODEL_CONFIGS=( ["Sd3"]="sd3_example.py /cfs/dit/stable-diffusion-3-medium-diffusers 20" ["Flux"]="flux_example.py /cfs/dit/FLUX.1-schnell 4" ["HunyuanDiT"]="hunyuandit_example.py /mnt/models/SD/HunyuanDiT-v1.2-Diffusers 50" - ["CogVideoX"]="cogvideox_example.py /cfs/dit/CogVideoX-2b 1" + ["CogVideoX"]="cogvideox_example.py /cfs/dit/CogVideoX-2b 20" ) if [[ -v MODEL_CONFIGS[$MODEL_TYPE] ]]; then diff --git a/xfuser/model_executor/layers/embeddings.py b/xfuser/model_executor/layers/embeddings.py index 6f6d67dc..5d8e05ac 100644 --- a/xfuser/model_executor/layers/embeddings.py +++ b/xfuser/model_executor/layers/embeddings.py @@ -115,6 +115,7 @@ def __init__( super().__init__( module=patch_embedding, ) + self.module: CogVideoXPatchEmbed def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): r""" @@ -124,6 +125,10 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): image_embeds (`torch.Tensor`): Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width). """ + sum_height = ( + get_runtime_state().input_config.height + // get_runtime_state().vae_scale_factor_spatial + ) text_embeds = self.text_proj(text_embeds) batch, num_frames, channels, height, width = image_embeds.shape @@ -133,28 +138,47 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:]) image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] - - if get_runtime_state().patch_mode: - start, end = get_runtime_state().pp_patches_token_start_end_idx_global[ - get_runtime_state().pipeline_patch_idx - ] - image_embeds = image_embeds[ - :, - start:end, - :, - ] - else: - image_embeds_list = [ - image_embeds[ + + if self.use_positional_embeddings or self.use_learned_positional_embeddings: + if self.use_learned_positional_embeddings and (self.sample_width != width or self.sample_height != sum_height): + raise ValueError( + "It is currently not possible to generate videos at a different resolution that the defaults. This should only be the case with 'THUDM/CogVideoX-5b-I2V'." + "If you think this is incorrect, please open an issue at https://github.com/huggingface/diffusers/issues." + ) + + pre_time_compression_frames = (num_frames - 1) * self.temporal_compression_ratio + 1 + + if ( + self.sample_height != sum_height + or self.sample_width != width + or self.sample_frames != pre_time_compression_frames + ): + pos_embedding = self._get_positional_embeddings(sum_height, width, pre_time_compression_frames) + pos_embedding = pos_embedding.to(image_embeds.device, dtype=image_embeds.dtype) + else: + pos_embedding = self.pos_embedding + + # extract the image part of the positional embedding + pos_embedding = pos_embedding[:, self.max_text_seq_length :] + + # slice the positional embedding + post_patch_height = sum_height // self.patch_size + post_patch_width = width // self.patch_size + post_time_compression_frames = (pre_time_compression_frames - 1) // self.temporal_compression_ratio + 1 + + pos_embed_list = [ + pos_embedding[ :, - get_runtime_state() - .pp_patches_token_start_end_idx_global[i][0] : get_runtime_state() - .pp_patches_token_start_end_idx_global[i][1], + post_patch_height * post_patch_width * i + get_runtime_state().pp_patches_token_start_end_idx_global[0][0]: + post_patch_height * post_patch_width * i + get_runtime_state().pp_patches_token_start_end_idx_global[0][1], :, ] - for i in range(get_runtime_state().num_pipeline_patch) + for i in range(post_time_compression_frames) ] - image_embeds = torch.cat(image_embeds_list, dim=1) + pos_embedding = torch.cat(pos_embed_list, dim=1) + + image_embeds = image_embeds + pos_embedding + embeds = torch.cat( [text_embeds, image_embeds], dim=1 ).contiguous() # [batch, seq_length + num_frames x height x width, channels] diff --git a/xfuser/model_executor/models/transformers/cogvideox_transformer_3d.py b/xfuser/model_executor/models/transformers/cogvideox_transformer_3d.py index 48793152..51a74a66 100644 --- a/xfuser/model_executor/models/transformers/cogvideox_transformer_3d.py +++ b/xfuser/model_executor/models/transformers/cogvideox_transformer_3d.py @@ -41,7 +41,7 @@ def __init__( ): super().__init__( transformer=transformer, - submodule_classes_to_wrap=[nn.Conv2d], + submodule_classes_to_wrap=[nn.Conv2d, CogVideoXPatchEmbed], submodule_name_to_wrap=["attn1"] ) diff --git a/xfuser/model_executor/pipelines/pipeline_cogvideox.py b/xfuser/model_executor/pipelines/pipeline_cogvideox.py index 4ca79c07..4cb2be6f 100644 --- a/xfuser/model_executor/pipelines/pipeline_cogvideox.py +++ b/xfuser/model_executor/pipelines/pipeline_cogvideox.py @@ -226,8 +226,7 @@ def __call__( max_sequence_length=max_sequence_length, device=device, ) - if do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + prompt_embeds = self._process_cfg_split_batch_latte(prompt_embeds, negative_prompt_embeds) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( @@ -272,9 +271,11 @@ def __call__( if self.interrupt: continue - latent_model_input = ( - torch.cat([latents] * 2) if do_classifier_free_guidance else latents - ) + if do_classifier_free_guidance: + latent_model_input = torch.cat( + [latents] * (2 // get_classifier_free_guidance_world_size()) + ) + latent_model_input = self.scheduler.scale_model_input( latent_model_input, t ) @@ -295,21 +296,15 @@ def __call__( # perform guidance if use_dynamic_cfg: self._guidance_scale = 1 + guidance_scale * ( - ( - 1 - - math.cos( - math.pi - * ( - (num_inference_steps - t.item()) - / num_inference_steps - ) - ** 5.0 - ) - ) - / 2 + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 ) if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + if get_classifier_free_guidance_world_size() == 1: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + elif get_classifier_free_guidance_world_size() == 2: + noise_pred_uncond, noise_pred_text = get_cfg_group().all_gather( + noise_pred, separate_tensors=True + ) noise_pred = noise_pred_uncond + self.guidance_scale * ( noise_pred_text - noise_pred_uncond ) @@ -344,9 +339,7 @@ def __call__( "negative_prompt_embeds", negative_prompt_embeds ) - if i == len(timesteps) - 1 or ( - (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 - ): + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if get_sequence_parallel_world_size() > 1: