Skip to content

Commit 4e8f991

Browse files
committed
qwen3 vllm compat
1 parent 91bdcc0 commit 4e8f991

File tree

2 files changed

+426
-1
lines changed

2 files changed

+426
-1
lines changed

torchtitan/models/attention.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
create_block_mask,
2121
flex_attention,
2222
)
23-
23+
from vllm.vllm_flash_attn import flash_attn_varlen_func
2424

2525
__all__ = [
2626
"FlexAttentionWrapper",
@@ -103,6 +103,69 @@ def forward(
103103
with sdpa_kernel(self.sdpa_backends, set_priority=True):
104104
return F.scaled_dot_product_attention(q, k, v, scale=scale, is_causal=True)
105105

106+
class VLLMCompatibleFlashAttention(torch.nn.Module):
107+
"""Wrapper around FlashAttention as used by VLLM"""
108+
def __init__(self) -> None:
109+
super().__init__()
110+
self.flash_attn_varlen_func = flash_attn_varlen_func
111+
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
112+
self.vllm_is_batch_invariant = vllm_is_batch_invariant
113+
114+
def forward(
115+
self,
116+
q: torch.Tensor,
117+
k: torch.Tensor,
118+
v: torch.Tensor,
119+
*,
120+
scale: float | None = None,
121+
) -> torch.Tensor | tuple[torch.Tensor, AuxOutput]:
122+
# Flash Attention varlen expects: (batch, seqlen, nheads, headdim)
123+
# The input from TorchTitan is always (batch, num_heads, seq_len, head_dim)
124+
# We need to transpose to (batch, seq_len, num_heads, head_dim)
125+
126+
# Input is (batch, num_heads, seq_len, head_dim) - need to transpose
127+
q = q.transpose(1, 2) # -> (batch, seq_len, num_heads, head_dim)
128+
k = k.transpose(1, 2)
129+
v = v.transpose(1, 2)
130+
131+
# Get dimensions
132+
batch_size, seq_len, num_heads, head_dim = q.shape
133+
134+
# Convert to varlen format: flatten batch and sequence dimensions
135+
# (batch, seqlen, nheads, headdim) -> (total_tokens, nheads, headdim)
136+
q_varlen = q.reshape(-1, num_heads, head_dim)
137+
k_varlen = k.reshape(-1, k.shape[2], head_dim)
138+
v_varlen = v.reshape(-1, v.shape[2], head_dim)
139+
140+
# Create cumulative sequence lengths
141+
# cu_seqlens: [0, seq_len, 2*seq_len, ..., batch_size*seq_len]
142+
cu_seqlens = torch.arange(
143+
0, (batch_size + 1) * seq_len, seq_len,
144+
dtype=torch.int32, device=q.device
145+
)
146+
147+
# Call Flash Attention varlen (works with both standard flash-attn and vLLM's wrapper)
148+
output_varlen = self.flash_attn_varlen_func(
149+
q_varlen, k_varlen, v_varlen,
150+
cu_seqlens_q=cu_seqlens,
151+
cu_seqlens_k=cu_seqlens,
152+
max_seqlen_q=seq_len,
153+
max_seqlen_k=seq_len,
154+
softmax_scale=scale,
155+
causal=True,
156+
num_splits=1 if self.vllm_is_batch_invariant() else 0,
157+
)
158+
159+
# Convert back to batch format
160+
# (total_tokens, nheads, headdim) -> (batch, seqlen, nheads, headdim)
161+
output = output_varlen.reshape(batch_size, seq_len, num_heads, head_dim)
162+
163+
# Transpose back to (batch, num_heads, seq_len, head_dim) to match input format
164+
output = output.transpose(1, 2)
165+
166+
return output
167+
168+
106169

107170
# We cannot do inner function/closure because we won't be able to cache it --
108171
# if we an inner function, a new closure will be created every time

0 commit comments

Comments
 (0)