Skip to content

Commit

Permalink
scatter among height rather than time
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Dec 20, 2024
1 parent dea542e commit aedab06
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions examples/hunyuan_video_usp_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ def new_forward(
timestep,
encoder_attention_mask)

hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1)
hidden_states = torch.chunk(hidden_states,
get_classifier_free_guidance_world_size(),
dim=0)[get_classifier_free_guidance_rank()]
hidden_states = torch.chunk(hidden_states,
get_sequence_parallel_world_size(),
dim=2)[get_sequence_parallel_rank()]
hidden_states = hidden_states.flatten(1, 3)

encoder_attention_mask = encoder_attention_mask[0].to(torch.bool)
encoder_hidden_states_indices = torch.arange(
encoder_hidden_states.shape[1],
Expand All @@ -81,12 +90,6 @@ def new_forward(
else:
get_runtime_state().split_text_embed_in_sp = True

hidden_states = torch.chunk(hidden_states,
get_classifier_free_guidance_world_size(),
dim=0)[get_classifier_free_guidance_rank()]
hidden_states = torch.chunk(hidden_states,
get_sequence_parallel_world_size(),
dim=-2)[get_sequence_parallel_rank()]
encoder_hidden_states = torch.chunk(
encoder_hidden_states,
get_classifier_free_guidance_world_size(),
Expand All @@ -100,9 +103,11 @@ def new_forward(
freqs_cos, freqs_sin = image_rotary_emb

def get_rotary_emb_chunk(freqs):
freqs = torch.chunk(freqs,
get_sequence_parallel_world_size(),
dim=0)[get_sequence_parallel_rank()]
dim_thw = freqs.shape[-1]
freqs = freqs.reshape(num_frames, -1, dim_thw)
freqs = freqs.chunk(get_sequence_parallel_world_size(), dim=-2)[
get_sequence_parallel_rank()]
freqs = freqs.reshape(-1, dim_thw)
return freqs

freqs_cos = get_rotary_emb_chunk(freqs_cos)
Expand Down Expand Up @@ -161,13 +166,14 @@ def custom_forward(*inputs):
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)

hidden_states = get_sp_group().all_gather(hidden_states, dim=-2)
hidden_states = get_cfg_group().all_gather(hidden_states, dim=0)

hidden_states = hidden_states.reshape(batch_size,
hidden_states = hidden_states.reshape(batch_size // get_classifier_free_guidance_world_size(),
post_patch_num_frames,
post_patch_height,
post_patch_height // get_sequence_parallel_world_size(),
post_patch_width, -1, p_t, p, p)

hidden_states = get_sp_group().all_gather(hidden_states, dim=2)
hidden_states = get_cfg_group().all_gather(hidden_states, dim=0)

hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)

Expand Down

0 comments on commit aedab06

Please sign in to comment.