Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/diffusers/models/attention_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,14 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
# buf25 = torch.ops._c10d_functional.all_to_all_single.default(buf24, [1, 1], [1, 1], '3')
# ValueError: Tensors must be contiguous
x = x.flatten()
x = funcol.all_to_all_single(x, None, None, group)

# verify all_to_all
# x = funcol.all_to_all_single(x, None, None, group)
x_out = torch.empty_like(x, device='npu')
handler = torch.distributed.all_to_all_single(x_out, x, None, None, group, async_op=True)
handler.wait()
x = x_out

x = x.reshape(shape)
x = _wait_tensor(x)
return x
Expand Down
49 changes: 35 additions & 14 deletions src/diffusers/models/transformers/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,12 @@ def _wait_tensor(tensor):
def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
shape = x.shape
x = x.flatten()
x = funcol.all_to_all_single(x, None, None, group)
# x = funcol.all_to_all_single(x, None, None, group)
x_out = torch.empty_like(x, device="npu")
handler = torch.distributed.all_to_all_single(x_out, x, None, None, group, async_op=True)
handler.wait()
x = x_out

x = x.reshape(shape)
x = _wait_tensor(x)
return x
Expand All @@ -106,8 +111,12 @@ def ulysses_preforward(
):
x = x.reshape(B, S_LOCAL, world_size, H_LOCAL, D).permute(2, 1, 0, 3, 4).contiguous()
x = x.flatten()
x = funcol.all_to_all_single(x, None, None, group)
return x
# x = funcol.all_to_all_single(x, None, None, group)
x_out = torch.empty_like(x, device='npu')
handler = torch.distributed.all_to_all_single(x_out, x, None, None, group, async_op=True)
x = x_out

return x, handler

class FluxAttnProcessor:
_attention_backend = None
Expand Down Expand Up @@ -216,7 +225,7 @@ def _context_parallel_forward_la(

B, S_KV_LOCAL, H, D = value.shape
H_LOCAL = H // world_size
value_all = ulysses_preforward(value, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL)
value_all, value_handler = ulysses_preforward(value, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL)

query = attn.to_q(hidden_states)
query = query.unflatten(-1, (attn.heads, -1))
Expand All @@ -229,7 +238,7 @@ def _context_parallel_forward_la(
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
_, S_Q_LOCAL, _, _ = query.shape
query_all = ulysses_preforward(query, group, world_size, B, S_Q_LOCAL, H, D, H_LOCAL)
query_all, query_handler = ulysses_preforward(query, group, world_size, B, S_Q_LOCAL, H, D, H_LOCAL)

key = attn.to_k(hidden_states)
key = key.unflatten(-1, (attn.heads, -1))
Expand All @@ -241,18 +250,21 @@ def _context_parallel_forward_la(
key = torch.cat([encoder_key, key], dim=1)
if image_rotary_emb is not None:
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
key_all = ulysses_preforward(key, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL)
key_all, key_handler = ulysses_preforward(key, group, world_size, B, S_KV_LOCAL, H, D, H_LOCAL)

B, S, N, D = B, world_size * S_KV_LOCAL, H_LOCAL, D
value_all = _wait_tensor(value_all)
# value_all = _wait_tensor(value_all)
value_handler.wait()
value_all = value_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous()
# value_all = value_all.view(B * S, N * D)

query_all = _wait_tensor(query_all)
# query_all = _wait_tensor(query_all)
query_handler.wait()
query_all = query_all.reshape(world_size, S_Q_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous()
# query_all = query_all.view(B * S, N * D)

key_all = _wait_tensor(key_all)
# key_all = _wait_tensor(key_all)
key_handler.wait()
key_all = key_all.reshape(world_size, S_KV_LOCAL, B, H_LOCAL, D).flatten(0, 1).permute(1, 0, 2, 3).contiguous()
# key_all = key_all.view(B * S, N * D)

Expand Down Expand Up @@ -292,9 +304,13 @@ def _context_parallel_forward_la(
return hidden_states, encoder_hidden_states
else:
out = out.flatten()
out = funcol.all_to_all_single(out, None, None, group)
# out = funcol.all_to_all_single(out, None, None, group)
x_out = torch.empty_like(out, device='npu')
handler = torch.distributed.all_to_all_single(x_out, out, None, None, group, async_op=True)
out = x_out

hidden_states = out.reshape(world_size, H_LOCAL, B, S_Q_LOCAL, D).flatten(0, 1).permute(1, 2, 0, 3)
return hidden_states
return hidden_states, handler

def _context_parallel_forward_atb_fa(
self,
Expand Down Expand Up @@ -385,7 +401,11 @@ def _context_parallel_forward_atb_fa(
return hidden_states, encoder_hidden_states
else:
out = out.flatten()
out = funcol.all_to_all_single(out, None, None, group)
# out = funcol.all_to_all_single(out, None, None, group)
x_out = torch.empty_like(out, device='npu')
torch.distributed.all_to_all_single(x_out, out, None, None, group, async_op=True)
out = x_out

hidden_states = out.reshape(world_size, H_LOCAL, B, S_Q_LOCAL, D).flatten(0, 1).permute(1, 2, 0, 3)
return hidden_states

Expand Down Expand Up @@ -640,13 +660,14 @@ def forward(
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
# mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
joint_attention_kwargs = joint_attention_kwargs or {}
attn_output = self.attn(
attn_output, attn_handler = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
mlp_hidden_states = self.proj_mlp(norm_hidden_states)
attn_output = _wait_tensor(attn_output)
# attn_output = _wait_tensor(attn_output)
attn_handler.wait()
attn_output = attn_output.contiguous()
if attn_output.ndim == 4:
attn_output = attn_output.flatten(2, 3)
Expand Down