-
Notifications
You must be signed in to change notification settings - Fork 31k
SDPA and FlashAttention-2 support for LayoutLMv3 #41801
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
SDPA and FlashAttention-2 support for LayoutLMv3 #41801
Conversation
- Removed duplicate imports - Fixed attention component access patterns - Added missing _upad_input method for Flash Attention 2 - Corrected relative position bias handling - Added proper fallbacks for unsupported features
- Remove unsupported head_mask parameter from super().forward() calls - Fix attention mask shape handling for SDPA (convert 4D mask to boolean format) - Maintain proper mask application in relative position bias fallback path
…ad_mask from super calls, and fixed attention mask format for SDPA.
- Fix attention mask conversion for SDPA (use >= 0 instead of > 0) - Fix _upad_input to use actual unpadded tensors from unpad_input
- Optimize _upad_input to call unpad_input only once for self-attention - Remove unused head_mask parameter from forward methods
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This approach is sadly outdated, I made a comment on what to do instead. I suspect that the tests also passed because the model didn't update their _supports_xxx flag. Without them these tests for FA, SDPA won't be run.
| return outputs | ||
|
|
||
|
|
||
| class LayoutLMv3SdpaAttention(LayoutLMv3Attention): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry but this approach is outdated and we changed it to a unified class instead. An older Bert version should give a good idea over here
transformers/src/transformers/models/bert/modeling_bert.py
Lines 121 to 258 in 9db58ab
| def eager_attention_forward( | |
| module: nn.Module, | |
| query: torch.Tensor, | |
| key: torch.Tensor, | |
| value: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor], | |
| scaling: Optional[float] = None, | |
| dropout: float = 0.0, | |
| use_cache: Optional[bool] = None, | |
| **kwargs: Unpack[TransformersKwargs], | |
| ): | |
| if scaling is None: | |
| scaling = query.size(-1) ** -0.5 | |
| # Take the dot product between "query" and "key" to get the raw attention scores. | |
| attn_weights = torch.matmul(query, key.transpose(2, 3)) | |
| # Relative positional embeddings | |
| if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query": | |
| query_length, key_length = query.shape[2], key.shape[2] | |
| if use_cache: | |
| position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1) | |
| else: | |
| position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1) | |
| position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1) | |
| distance = position_ids_l - position_ids_r | |
| positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1) | |
| positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility | |
| if module.position_embedding_type == "relative_key": | |
| relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) | |
| attn_weights = attn_weights + relative_position_scores | |
| elif module.position_embedding_type == "relative_key_query": | |
| relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding) | |
| relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding) | |
| attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key | |
| # Scaling is shifted in case of embeddings being relative | |
| attn_weights = attn_weights * scaling | |
| if attention_mask is not None and attention_mask.ndim == 4: | |
| attention_mask = attention_mask[:, :, :, : key.shape[-2]] | |
| attn_weights = attn_weights + attention_mask | |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1) | |
| attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) | |
| attn_output = torch.matmul(attn_weights, value) | |
| attn_output = attn_output.transpose(1, 2).contiguous() | |
| return attn_output, attn_weights | |
| class BertSelfAttention(nn.Module): | |
| def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None): | |
| super().__init__() | |
| if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): | |
| raise ValueError( | |
| f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " | |
| f"heads ({config.num_attention_heads})" | |
| ) | |
| self.config = config | |
| self.num_attention_heads = config.num_attention_heads | |
| self.attention_head_size = int(config.hidden_size / config.num_attention_heads) | |
| self.all_head_size = self.num_attention_heads * self.attention_head_size | |
| self.scaling = self.attention_head_size**-0.5 | |
| self.query = nn.Linear(config.hidden_size, self.all_head_size) | |
| self.key = nn.Linear(config.hidden_size, self.all_head_size) | |
| self.value = nn.Linear(config.hidden_size, self.all_head_size) | |
| self.dropout = nn.Dropout(config.attention_probs_dropout_prob) | |
| self.position_embedding_type = position_embedding_type or getattr( | |
| config, "position_embedding_type", "absolute" | |
| ) | |
| if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": | |
| self.max_position_embeddings = config.max_position_embeddings | |
| self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) | |
| self.is_decoder = config.is_decoder | |
| self.is_causal = is_causal | |
| self.layer_idx = layer_idx | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.FloatTensor] = None, | |
| past_key_value: Optional[Cache] = None, | |
| cache_position: Optional[torch.Tensor] = None, | |
| **kwargs: Unpack[TransformersKwargs], | |
| ) -> tuple[torch.Tensor]: | |
| input_shape = hidden_states.shape[:-1] | |
| hidden_shape = (*input_shape, -1, self.attention_head_size) | |
| # get all proj | |
| query_layer = self.query(hidden_states).view(*hidden_shape).transpose(1, 2) | |
| key_layer = self.key(hidden_states).view(*hidden_shape).transpose(1, 2) | |
| value_layer = self.value(hidden_states).view(*hidden_shape).transpose(1, 2) | |
| if past_key_value is not None: | |
| # decoder-only bert can have a simple dynamic cache for example | |
| current_past_key_value = past_key_value | |
| if isinstance(past_key_value, EncoderDecoderCache): | |
| current_past_key_value = past_key_value.self_attention_cache | |
| # save all key/value_layer to cache to be re-used for fast auto-regressive generation | |
| key_layer, value_layer = current_past_key_value.update( | |
| key_layer, | |
| value_layer, | |
| self.layer_idx, | |
| {"cache_position": cache_position}, | |
| ) | |
| attention_interface: Callable = eager_attention_forward | |
| if self.config._attn_implementation != "eager": | |
| if self.position_embedding_type != "absolute": | |
| raise ValueError( | |
| f"You are using {self.config._attn_implementation} as attention type. However, non-absolute " | |
| 'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.' | |
| ) | |
| attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] | |
| attn_output, attn_weights = attention_interface( | |
| self, | |
| query_layer, | |
| key_layer, | |
| value_layer, | |
| attention_mask, | |
| dropout=0.0 if not self.training else self.dropout.p, | |
| scaling=self.scaling, | |
| # only for relevant for non-absolute positional embeddings | |
| use_cache=past_key_value is not None, | |
| **kwargs, | |
| ) | |
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() | |
| return attn_output, attn_weights |
the relative positions are no longer in main (deprecated) but since they are used here, it's a better pointer
We will also need to update the mask creations to something along
transformers/src/transformers/models/bert/modeling_bert.py
Lines 749 to 753 in 91b5a68
| attention_mask = create_bidirectional_mask( | |
| config=self.config, | |
| input_embeds=embedding_output, | |
| attention_mask=attention_mask, | |
| ) |
And you will need _supports_xxx flags in order to indicate that the model really supports these attention flavors, e.g.
transformers/src/transformers/models/bert/modeling_bert.py
Lines 562 to 564 in 91b5a68
| _supports_flash_attn = True | |
| _supports_sdpa = True | |
| _supports_flex_attn = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the feedback @vasqu! You're right - I've refactored to the unified class approach following the BERT pattern. I added the _supports_sdpa and _supports_flash_attn_2 flags to the config and updated the mask creation to use create_bidirectional_mask.
Just pushed the changes - let me know if there's anything else that needs adjusting. Thanks again!
| batch_size, seq_len, self.self.num_attention_heads, self.self.attention_head_size | ||
| ).transpose(1, 2) | ||
|
|
||
| if self.self.has_relative_attention_bias and rel_pos is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a fallback to eager attention, no?
- Replace separate SDPA and FlashAttention classes with unified LayoutLMv3Attention - Add _supports_sdpa and _supports_flash_attn_2 config flags - Add _attn_implementation parameter to config - Implement runtime dispatch based on attention implementation - Add proper fallback logic for unsupported features - Fix linting and formatting issues - All tests passing
- Add proper None checks for relative position bias - Fix attention mask shape handling for SDPA - Add fallback imports for flash_attn functions - Ensure proper mask broadcasting for scaled_dot_product_attention - All attention implementations now work correctly
|
[For maintainers] Suggested jobs to run (before merge) run-slow: layoutlmv3 |
What does this PR do?
Adds SDPA and FlashAttention-2 support to LayoutLMv3 following the same pattern as other models. Fully backward compatible.
SDPA converts masks to boolean format. FA2 uses
_upad_inputfor variable-length sequences and avoids redundant unpads. Both fall back gracefully when needed. FA2 is O(N) memory vs O(N²).Fixes #35467
Changes
LayoutLMv3SdpaAttentionusingtorch.nn.functional.scaled_dot_product_attentionLayoutLMv3FlashAttention2withflash_attn_func/flash_attn_varlen_funcLayoutLMv3Attentionoutput_attentions=True/ relative position bias is usedTesting
test_modeling_layoutlmv3.pyBefore submitting
Who can review?
@vasqu @ArthurZucker @Cyrilvallez