|  | 
| 20 | 20 |     create_block_mask, | 
| 21 | 21 |     flex_attention, | 
| 22 | 22 | ) | 
| 23 |  | - | 
|  | 23 | +from vllm.vllm_flash_attn import flash_attn_varlen_func | 
| 24 | 24 | 
 | 
| 25 | 25 | __all__ = [ | 
| 26 | 26 |     "FlexAttentionWrapper", | 
| @@ -103,6 +103,69 @@ def forward( | 
| 103 | 103 |         with sdpa_kernel(self.sdpa_backends, set_priority=True): | 
| 104 | 104 |             return F.scaled_dot_product_attention(q, k, v, scale=scale, is_causal=True) | 
| 105 | 105 | 
 | 
|  | 106 | +class VLLMCompatibleFlashAttention(torch.nn.Module): | 
|  | 107 | +    """Wrapper around FlashAttention as used by VLLM""" | 
|  | 108 | +    def __init__(self) -> None: | 
|  | 109 | +        super().__init__() | 
|  | 110 | +        self.flash_attn_varlen_func = flash_attn_varlen_func | 
|  | 111 | +        from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant | 
|  | 112 | +        self.vllm_is_batch_invariant = vllm_is_batch_invariant | 
|  | 113 | + | 
|  | 114 | +    def forward( | 
|  | 115 | +        self, | 
|  | 116 | +        q: torch.Tensor, | 
|  | 117 | +        k: torch.Tensor, | 
|  | 118 | +        v: torch.Tensor, | 
|  | 119 | +        *, | 
|  | 120 | +        scale: float | None = None, | 
|  | 121 | +    ) -> torch.Tensor | tuple[torch.Tensor, AuxOutput]: | 
|  | 122 | +        # Flash Attention varlen expects: (batch, seqlen, nheads, headdim) | 
|  | 123 | +        # The input from TorchTitan is always (batch, num_heads, seq_len, head_dim) | 
|  | 124 | +        # We need to transpose to (batch, seq_len, num_heads, head_dim) | 
|  | 125 | + | 
|  | 126 | +        # Input is (batch, num_heads, seq_len, head_dim) - need to transpose | 
|  | 127 | +        q = q.transpose(1, 2)  # -> (batch, seq_len, num_heads, head_dim) | 
|  | 128 | +        k = k.transpose(1, 2) | 
|  | 129 | +        v = v.transpose(1, 2) | 
|  | 130 | + | 
|  | 131 | +        # Get dimensions | 
|  | 132 | +        batch_size, seq_len, num_heads, head_dim = q.shape | 
|  | 133 | + | 
|  | 134 | +        # Convert to varlen format: flatten batch and sequence dimensions | 
|  | 135 | +        # (batch, seqlen, nheads, headdim) -> (total_tokens, nheads, headdim) | 
|  | 136 | +        q_varlen = q.reshape(-1, num_heads, head_dim) | 
|  | 137 | +        k_varlen = k.reshape(-1, k.shape[2], head_dim) | 
|  | 138 | +        v_varlen = v.reshape(-1, v.shape[2], head_dim) | 
|  | 139 | + | 
|  | 140 | +        # Create cumulative sequence lengths | 
|  | 141 | +        # cu_seqlens: [0, seq_len, 2*seq_len, ..., batch_size*seq_len] | 
|  | 142 | +        cu_seqlens = torch.arange( | 
|  | 143 | +            0, (batch_size + 1) * seq_len, seq_len, | 
|  | 144 | +            dtype=torch.int32, device=q.device | 
|  | 145 | +        ) | 
|  | 146 | + | 
|  | 147 | +        # Call Flash Attention varlen (works with both standard flash-attn and vLLM's wrapper) | 
|  | 148 | +        output_varlen = self.flash_attn_varlen_func( | 
|  | 149 | +            q_varlen, k_varlen, v_varlen, | 
|  | 150 | +            cu_seqlens_q=cu_seqlens, | 
|  | 151 | +            cu_seqlens_k=cu_seqlens, | 
|  | 152 | +            max_seqlen_q=seq_len, | 
|  | 153 | +            max_seqlen_k=seq_len, | 
|  | 154 | +            softmax_scale=scale, | 
|  | 155 | +            causal=True, | 
|  | 156 | +            num_splits=1 if self.vllm_is_batch_invariant() else 0, | 
|  | 157 | +        ) | 
|  | 158 | + | 
|  | 159 | +        # Convert back to batch format | 
|  | 160 | +        # (total_tokens, nheads, headdim) -> (batch, seqlen, nheads, headdim) | 
|  | 161 | +        output = output_varlen.reshape(batch_size, seq_len, num_heads, head_dim) | 
|  | 162 | + | 
|  | 163 | +        # Transpose back to (batch, num_heads, seq_len, head_dim) to match input format | 
|  | 164 | +        output = output.transpose(1, 2) | 
|  | 165 | + | 
|  | 166 | +        return output | 
|  | 167 | + | 
|  | 168 | + | 
| 106 | 169 | 
 | 
| 107 | 170 | # We cannot do inner function/closure because we won't be able to cache it -- | 
| 108 | 171 | # if we an inner function, a new closure will be created every time | 
|  | 
0 commit comments