Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions wan/distributed/xdit_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,14 @@ def pad_freqs(original_tensor, target_len):
return padded_tensor


@amp.autocast(enabled=False)
def rope_apply(x, grid_sizes, freqs_list):
"""
x: [B, L, N, C].
grid_sizes: [B, 3].
freqs: [M, C // 2].
"""
s, n, c = x.size(1), x.size(2), x.size(3)
output = []
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
x_i = x[i, :s].reshape(1, s, n, c)
cos, sin = freqs_list[i]
x_i = rotary_position_embedding(x_i, cos, sin, rotated_mode="rotated_interleaved", fused=True)
output.append(x_i)
return torch.cat(output).float()
cos, sin = freqs_list[0]
return rotary_position_embedding(x, cos, sin, rotated_mode="rotated_interleaved", fused=True)


def usp_dit_forward(
Expand Down
2 changes: 1 addition & 1 deletion wan/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def attention(
opt_mode="manual", op_type="ascend_laser_attention", layout="BNSD")
else:
out = attention_forward(q, k, v,
opt_mode="manual", op_type="fused_attn_score", layout="BNSD")
opt_mode="manual", op_type="fused_attn_score", layout="BSND")
return out.to(qtype)
elif FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
return flash_attention(
Expand Down
11 changes: 2 additions & 9 deletions wan/modules/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,14 @@ def rope_params(max_seq_len, dim, theta=10000):
return freqs


@amp.autocast(enabled=False)
def rope_apply(x, grid_sizes, freqs_list):
"""
x: [B, L, N, C].
grid_sizes: [B, 3].
freqs: [M, C // 2].
"""
s, n, c = x.size(1), x.size(2), x.size(3)
output = []
for i, (f, h, w) in enumerate(grid_sizes.tolist()):
x_i = x[i, :s].reshape(1, s, n, c)
cos, sin = freqs_list[i]
x_i = rotary_position_embedding(x_i, cos, sin, rotated_mode="rotated_interleaved", fused=True)
output.append(x_i)
return torch.cat(output).float()
cos, sin = freqs_list[0]
return rotary_position_embedding(x, cos, sin, rotated_mode="rotated_interleaved", fused=True)


class WanRMSNorm(nn.Module):
Expand Down