diff --git a/wan/distributed/xdit_context_parallel.py b/wan/distributed/xdit_context_parallel.py index 1c56b2b..7a4aedc 100644 --- a/wan/distributed/xdit_context_parallel.py +++ b/wan/distributed/xdit_context_parallel.py @@ -27,21 +27,14 @@ def pad_freqs(original_tensor, target_len): return padded_tensor -@amp.autocast(enabled=False) def rope_apply(x, grid_sizes, freqs_list): """ x: [B, L, N, C]. grid_sizes: [B, 3]. freqs: [M, C // 2]. """ - s, n, c = x.size(1), x.size(2), x.size(3) - output = [] - for i, (f, h, w) in enumerate(grid_sizes.tolist()): - x_i = x[i, :s].reshape(1, s, n, c) - cos, sin = freqs_list[i] - x_i = rotary_position_embedding(x_i, cos, sin, rotated_mode="rotated_interleaved", fused=True) - output.append(x_i) - return torch.cat(output).float() + cos, sin = freqs_list[0] + return rotary_position_embedding(x, cos, sin, rotated_mode="rotated_interleaved", fused=True) def usp_dit_forward( diff --git a/wan/modules/attention.py b/wan/modules/attention.py index 9c7cbfd..1e1a659 100644 --- a/wan/modules/attention.py +++ b/wan/modules/attention.py @@ -161,7 +161,7 @@ def attention( opt_mode="manual", op_type="ascend_laser_attention", layout="BNSD") else: out = attention_forward(q, k, v, - opt_mode="manual", op_type="fused_attn_score", layout="BNSD") + opt_mode="manual", op_type="fused_attn_score", layout="BSND") return out.to(qtype) elif FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE: return flash_attention( diff --git a/wan/modules/model.py b/wan/modules/model.py index 93f6607..6e270b8 100644 --- a/wan/modules/model.py +++ b/wan/modules/model.py @@ -39,21 +39,14 @@ def rope_params(max_seq_len, dim, theta=10000): return freqs -@amp.autocast(enabled=False) def rope_apply(x, grid_sizes, freqs_list): """ x: [B, L, N, C]. grid_sizes: [B, 3]. freqs: [M, C // 2]. """ - s, n, c = x.size(1), x.size(2), x.size(3) - output = [] - for i, (f, h, w) in enumerate(grid_sizes.tolist()): - x_i = x[i, :s].reshape(1, s, n, c) - cos, sin = freqs_list[i] - x_i = rotary_position_embedding(x_i, cos, sin, rotated_mode="rotated_interleaved", fused=True) - output.append(x_i) - return torch.cat(output).float() + cos, sin = freqs_list[0] + return rotary_position_embedding(x, cos, sin, rotated_mode="rotated_interleaved", fused=True) class WanRMSNorm(nn.Module):