Skip to content

Commit 953dbe1

Browse files
committed
qwen3 vllm compat
1 parent 91bdcc0 commit 953dbe1

File tree

2 files changed

+426
-0
lines changed

2 files changed

+426
-0
lines changed

torchtitan/models/attention.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
create_block_mask,
2121
flex_attention,
2222
)
23+
from vllm.vllm_flash_attn import flash_attn_varlen_func
2324

2425

2526
__all__ = [
@@ -103,6 +104,69 @@ def forward(
103104
with sdpa_kernel(self.sdpa_backends, set_priority=True):
104105
return F.scaled_dot_product_attention(q, k, v, scale=scale, is_causal=True)
105106

107+
class VLLMCompatibleFlashAttention(torch.nn.Module):
108+
"""Wrapper around FlashAttention as used by VLLM"""
109+
def __init__(self) -> None:
110+
super().__init__()
111+
self.flash_attn_varlen_func = flash_attn_varlen_func
112+
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
113+
self.vllm_is_batch_invariant = vllm_is_batch_invariant
114+
115+
def forward(
116+
self,
117+
q: torch.Tensor,
118+
k: torch.Tensor,
119+
v: torch.Tensor,
120+
*,
121+
scale: float | None = None,
122+
) -> torch.Tensor | tuple[torch.Tensor, AuxOutput]:
123+
# Flash Attention varlen expects: (batch, seqlen, nheads, headdim)
124+
# The input from TorchTitan is always (batch, num_heads, seq_len, head_dim)
125+
# We need to transpose to (batch, seq_len, num_heads, head_dim)
126+
127+
# Input is (batch, num_heads, seq_len, head_dim) - need to transpose
128+
q = q.transpose(1, 2) # -> (batch, seq_len, num_heads, head_dim)
129+
k = k.transpose(1, 2)
130+
v = v.transpose(1, 2)
131+
132+
# Get dimensions
133+
batch_size, seq_len, num_heads, head_dim = q.shape
134+
135+
# Convert to varlen format: flatten batch and sequence dimensions
136+
# (batch, seqlen, nheads, headdim) -> (total_tokens, nheads, headdim)
137+
q_varlen = q.reshape(-1, num_heads, head_dim)
138+
k_varlen = k.reshape(-1, k.shape[2], head_dim)
139+
v_varlen = v.reshape(-1, v.shape[2], head_dim)
140+
141+
# Create cumulative sequence lengths
142+
# cu_seqlens: [0, seq_len, 2*seq_len, ..., batch_size*seq_len]
143+
cu_seqlens = torch.arange(
144+
0, (batch_size + 1) * seq_len, seq_len,
145+
dtype=torch.int32, device=q.device
146+
)
147+
148+
# Call Flash Attention varlen (works with both standard flash-attn and vLLM's wrapper)
149+
output_varlen = self.flash_attn_varlen_func(
150+
q_varlen, k_varlen, v_varlen,
151+
cu_seqlens_q=cu_seqlens,
152+
cu_seqlens_k=cu_seqlens,
153+
max_seqlen_q=seq_len,
154+
max_seqlen_k=seq_len,
155+
softmax_scale=scale,
156+
causal=True,
157+
num_splits=1 if self.vllm_is_batch_invariant() else 0,
158+
)
159+
160+
# Convert back to batch format
161+
# (total_tokens, nheads, headdim) -> (batch, seqlen, nheads, headdim)
162+
output = output_varlen.reshape(batch_size, seq_len, num_heads, head_dim)
163+
164+
# Transpose back to (batch, num_heads, seq_len, head_dim) to match input format
165+
output = output.transpose(1, 2)
166+
167+
return output
168+
169+
106170

107171
# We cannot do inner function/closure because we won't be able to cache it --
108172
# if we an inner function, a new closure will be created every time

0 commit comments

Comments
 (0)