diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index eaa3b7799a4e..8670caf9e545 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -980,7 +980,14 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor: # buf25 = torch.ops._c10d_functional.all_to_all_single.default(buf24, [1, 1], [1, 1], '3') # ValueError: Tensors must be contiguous x = x.flatten() - x = funcol.all_to_all_single(x, None, None, group) + + # verify all_to_all + # x = funcol.all_to_all_single(x, None, None, group) + x_out = torch.empty_like(x, device='npu') + handler = torch.distributed.all_to_all_single(x_out, x, None, None, group, async_op=True) + handler.wait() + x = x_out + x = x.reshape(shape) x = _wait_tensor(x) return x diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 0c68e9917371..c934067fea20 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -89,7 +89,12 @@ def _wait_tensor(tensor): def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor: shape = x.shape x = x.flatten() - x = funcol.all_to_all_single(x, None, None, group) + # x = funcol.all_to_all_single(x, None, None, group) + x_out = torch.empty_like(x, device="npu") + handler = torch.distributed.all_to_all_single(x_out, x, None, None, group, async_op=True) + handler.wait() + x = x_out + x = x.reshape(shape) x = _wait_tensor(x) return x @@ -106,8 +111,12 @@ def ulysses_preforward( ): x = x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous() x = x.flatten() - x = funcol.all_to_all_single(x, None, None, group) - return x + # x = funcol.all_to_all_single(x, None, None, group) + x_out = torch.empty_like(x, device='npu') + handler = torch.distributed.all_to_all_single(x_out, x, None, None, group, async_op=True) + x = x_out + + return x, handler class FluxAttnProcessor: _attention_backend = None @@ -216,7 +225,7 @@ def _context_parallel_forward_la( B, S_KV_LOCAL, H, D = value.shape H_LOCAL = H // world_size - value_all = ulysses_preforward(value, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) + value_all, value_handler = ulysses_preforward(value, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) query = attn.to_q(hidden_states) query = query.unflatten(-1, (attn.heads, -1)) @@ -229,7 +238,7 @@ def _context_parallel_forward_la( if image_rotary_emb is not None: query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1) _, S_Q_LOCAL, _, _ = query.shape - query_all = ulysses_preforward(query, group, world_size, B, S_Q_LOCAL, H, D, H_LOCAL) + query_all, query_handler = ulysses_preforward(query, group, world_size, B, S_Q_LOCAL, H, D, H_LOCAL) key = attn.to_k(hidden_states) key = key.unflatten(-1, (attn.heads, -1)) @@ -241,18 +250,21 @@ def _context_parallel_forward_la( key = torch.cat([encoder_key, key], dim=1) if image_rotary_emb is not None: key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1) - key_all = ulysses_preforward(key, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) + key_all, key_handler = ulysses_preforward(key, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL) B, S, N, D = B, world_size * S_KV_LOCAL, H_LOCAL, D - value_all = _wait_tensor(value_all) + # value_all = _wait_tensor(value_all) + value_handler.wait() value_all = value_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous() # value_all = value_all.view(B * S, N * D) - query_all = _wait_tensor(query_all) + # query_all = _wait_tensor(query_all) + query_handler.wait() query_all = query_all.reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous() # query_all = query_all.view(B * S, N * D) - key_all = _wait_tensor(key_all) + # key_all = _wait_tensor(key_all) + key_handler.wait() key_all = key_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous() # key_all = key_all.view(B * S, N * D) @@ -292,9 +304,13 @@ def _context_parallel_forward_la( return hidden_states, encoder_hidden_states else: out = out.flatten() - out = funcol.all_to_all_single(out, None, None, group) + # out = funcol.all_to_all_single(out, None, None, group) + x_out = torch.empty_like(out, device='npu') + handler = torch.distributed.all_to_all_single(x_out, out, None, None, group, async_op=True) + out = x_out + hidden_states = out.reshape(world_size, H_LOCAL, B, S_Q_LOCAL, D).flatten(0, 1).permute(1, 2, 0, 3) - return hidden_states + return hidden_states, handler def _context_parallel_forward_atb_fa( self, @@ -385,7 +401,11 @@ def _context_parallel_forward_atb_fa( return hidden_states, encoder_hidden_states else: out = out.flatten() - out = funcol.all_to_all_single(out, None, None, group) + # out = funcol.all_to_all_single(out, None, None, group) + x_out = torch.empty_like(out, device='npu') + torch.distributed.all_to_all_single(x_out, out, None, None, group, async_op=True) + out = x_out + hidden_states = out.reshape(world_size, H_LOCAL, B, S_Q_LOCAL, D).flatten(0, 1).permute(1, 2, 0, 3) return hidden_states @@ -640,13 +660,14 @@ def forward( norm_hidden_states, gate = self.norm(hidden_states, emb=temb) # mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) joint_attention_kwargs = joint_attention_kwargs or {} - attn_output = self.attn( + attn_output, attn_handler = self.attn( hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, **joint_attention_kwargs, ) mlp_hidden_states = self.proj_mlp(norm_hidden_states) - attn_output = _wait_tensor(attn_output) + # attn_output = _wait_tensor(attn_output) + attn_handler.wait() attn_output = attn_output.contiguous() if attn_output.ndim == 4: attn_output = attn_output.flatten(2, 3)