Skip to content

Conversation

@jackiehimel
Copy link

@jackiehimel jackiehimel commented Oct 22, 2025

What does this PR do?

Adds SDPA and FlashAttention-2 support to LayoutLMv3 following the same pattern as other models. Fully backward compatible.

SDPA converts masks to boolean format. FA2 uses _upad_input for variable-length sequences and avoids redundant unpads. Both fall back gracefully when needed. FA2 is O(N) memory vs O(N²).

Fixes #35467

Changes

  • Added LayoutLMv3SdpaAttention using torch.nn.functional.scaled_dot_product_attention
  • Added LayoutLMv3FlashAttention2 with flash_attn_func / flash_attn_varlen_func
  • Both inherit from LayoutLMv3Attention
  • Fallback to standard attention when backends unavailable or output_attentions=True / relative position bias is used

Testing

  • 121 tests passed in test_modeling_layoutlmv3.py
  • Manually verified forward passes with/without attention masks

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@vasqu @ArthurZucker @Cyrilvallez

jackiehimel added 9 commits October 22, 2025 16:39
- Removed duplicate imports
- Fixed attention component access patterns
- Added missing _upad_input method for Flash Attention 2
- Corrected relative position bias handling
- Added proper fallbacks for unsupported features
- Remove unsupported head_mask parameter from super().forward() calls
- Fix attention mask shape handling for SDPA (convert 4D mask to boolean format)
- Maintain proper mask application in relative position bias fallback path
…ad_mask from super calls, and fixed attention mask format for SDPA.
- Fix attention mask conversion for SDPA (use >= 0 instead of > 0)
- Fix _upad_input to use actual unpadded tensors from unpad_input
- Optimize _upad_input to call unpad_input only once for self-attention
- Remove unused head_mask parameter from forward methods
Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach is sadly outdated, I made a comment on what to do instead. I suspect that the tests also passed because the model didn't update their _supports_xxx flag. Without them these tests for FA, SDPA won't be run.

return outputs


class LayoutLMv3SdpaAttention(LayoutLMv3Attention):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry but this approach is outdated and we changed it to a unified class instead. An older Bert version should give a good idea over here

def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: Optional[float] = None,
dropout: float = 0.0,
use_cache: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
):
if scaling is None:
scaling = query.size(-1) ** -0.5
# Take the dot product between "query" and "key" to get the raw attention scores.
attn_weights = torch.matmul(query, key.transpose(2, 3))
# Relative positional embeddings
if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query":
query_length, key_length = query.shape[2], key.shape[2]
if use_cache:
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1)
else:
position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1)
position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1)
distance = position_ids_l - position_ids_r
positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1)
positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility
if module.position_embedding_type == "relative_key":
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
attn_weights = attn_weights + relative_position_scores
elif module.position_embedding_type == "relative_key_query":
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding)
attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key
# Scaling is shifted in case of embeddings being relative
attn_weights = attn_weights * scaling
if attention_mask is not None and attention_mask.ndim == 4:
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class BertSelfAttention(nn.Module):
def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None):
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
f"heads ({config.num_attention_heads})"
)
self.config = config
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.scaling = self.attention_head_size**-0.5
self.query = nn.Linear(config.hidden_size, self.all_head_size)
self.key = nn.Linear(config.hidden_size, self.all_head_size)
self.value = nn.Linear(config.hidden_size, self.all_head_size)
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.position_embedding_type = position_embedding_type or getattr(
config, "position_embedding_type", "absolute"
)
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
self.max_position_embeddings = config.max_position_embeddings
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
self.is_decoder = config.is_decoder
self.is_causal = is_causal
self.layer_idx = layer_idx
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.attention_head_size)
# get all proj
query_layer = self.query(hidden_states).view(*hidden_shape).transpose(1, 2)
key_layer = self.key(hidden_states).view(*hidden_shape).transpose(1, 2)
value_layer = self.value(hidden_states).view(*hidden_shape).transpose(1, 2)
if past_key_value is not None:
# decoder-only bert can have a simple dynamic cache for example
current_past_key_value = past_key_value
if isinstance(past_key_value, EncoderDecoderCache):
current_past_key_value = past_key_value.self_attention_cache
# save all key/value_layer to cache to be re-used for fast auto-regressive generation
key_layer, value_layer = current_past_key_value.update(
key_layer,
value_layer,
self.layer_idx,
{"cache_position": cache_position},
)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.position_embedding_type != "absolute":
raise ValueError(
f"You are using {self.config._attn_implementation} as attention type. However, non-absolute "
'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.'
)
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_layer,
key_layer,
value_layer,
attention_mask,
dropout=0.0 if not self.training else self.dropout.p,
scaling=self.scaling,
# only for relevant for non-absolute positional embeddings
use_cache=past_key_value is not None,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
return attn_output, attn_weights

the relative positions are no longer in main (deprecated) but since they are used here, it's a better pointer

We will also need to update the mask creations to something along

attention_mask = create_bidirectional_mask(
config=self.config,
input_embeds=embedding_output,
attention_mask=attention_mask,
)

And you will need _supports_xxx flags in order to indicate that the model really supports these attention flavors, e.g.

_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback @vasqu! You're right - I've refactored to the unified class approach following the BERT pattern. I added the _supports_sdpa and _supports_flash_attn_2 flags to the config and updated the mask creation to use create_bidirectional_mask.

Just pushed the changes - let me know if there's anything else that needs adjusting. Thanks again!

batch_size, seq_len, self.self.num_attention_heads, self.self.attention_head_size
).transpose(1, 2)

if self.self.has_relative_attention_bias and rel_pos is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a fallback to eager attention, no?

jackiehimel and others added 7 commits October 24, 2025 20:20
- Replace separate SDPA and FlashAttention classes with unified LayoutLMv3Attention
- Add _supports_sdpa and _supports_flash_attn_2 config flags
- Add _attn_implementation parameter to config
- Implement runtime dispatch based on attention implementation
- Add proper fallback logic for unsupported features
- Fix linting and formatting issues
- All tests passing
- Add proper None checks for relative position bias
- Fix attention mask shape handling for SDPA
- Add fallback imports for flash_attn functions
- Ensure proper mask broadcasting for scaled_dot_product_attention
- All attention implementations now work correctly
@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: layoutlmv3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support SDPA & Flash Attention 2 for LayoutLMv3

2 participants