diff --git a/my_changes.txt b/my_changes.txt new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py b/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py index d67d4a446422..0eae2af49829 100644 --- a/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/configuration_layoutlmv3.py @@ -106,6 +106,10 @@ class LayoutLMv3Config(PreTrainedConfig): model_type = "layoutlmv3" + # Support flags for attention implementations + _supports_sdpa = True + _supports_flash_attn_2 = True + def __init__( self, vocab_size=50265, @@ -138,6 +142,7 @@ def __init__( num_channels=3, patch_size=16, classifier_dropout=None, + _attn_implementation="eager", **kwargs, ): super().__init__( @@ -173,6 +178,7 @@ def __init__( self.num_channels = num_channels self.patch_size = patch_size self.classifier_dropout = classifier_dropout + self._attn_implementation = _attn_implementation __all__ = ["LayoutLMv3Config"] diff --git a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py index 3aa97051f855..6fc88a7de0ce 100644 --- a/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py +++ b/src/transformers/models/layoutlmv3/modeling_layoutlmv3.py @@ -24,6 +24,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN + +# SDPA and Flash Attention support from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import ( BaseModelOutput, @@ -35,12 +37,25 @@ from ...pytorch_utils import apply_chunking_to_forward from ...utils import ( auto_docstring, + is_flash_attn_2_available, logging, torch_int, ) from .configuration_layoutlmv3 import LayoutLMv3Config +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import pad_input, unpad_input +else: + # Define dummy functions for when flash_attn is not available + def unpad_input(*args, **kwargs): + raise ImportError("flash_attn is not available") + + def pad_input(*args, **kwargs): + raise ImportError("flash_attn is not available") + + logger = logging.get_logger(__name__) @@ -285,10 +300,11 @@ def forward( # Changing the computational order into QT(K/√d) alleviates the problem. (https://huggingface.co/papers/2105.13290) attention_scores = torch.matmul(query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2)) - if self.has_relative_attention_bias and self.has_spatial_attention_bias: - attention_scores += (rel_pos + rel_2d_pos) / math.sqrt(self.attention_head_size) - elif self.has_relative_attention_bias: - attention_scores += rel_pos / math.sqrt(self.attention_head_size) + if self.has_relative_attention_bias and rel_pos is not None: + if self.has_spatial_attention_bias and rel_2d_pos is not None: + attention_scores += (rel_pos + rel_2d_pos) / math.sqrt(self.attention_head_size) + else: + attention_scores += rel_pos / math.sqrt(self.attention_head_size) if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function) @@ -328,32 +344,174 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to return hidden_states -# Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3 +# Adapted from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3 +# Enhanced with unified attention implementation supporting eager, SDPA, and FlashAttention-2 class LayoutLMv3Attention(nn.Module): + """ + Unified LayoutLMv3 attention module with support for eager, SDPA, and FlashAttention-2. + """ + def __init__(self, config): super().__init__() self.self = LayoutLMv3SelfAttention(config) self.output = LayoutLMv3SelfOutput(config) + # Store attention implementation config + self.is_causal = False + self._attn_implementation = config._attn_implementation if hasattr(config, "_attn_implementation") else "eager" + def forward( self, hidden_states, attention_mask=None, + head_mask=None, output_attentions=False, rel_pos=None, rel_2d_pos=None, ): - self_outputs = self.self( - hidden_states, - attention_mask, - output_attentions, - rel_pos=rel_pos, - rel_2d_pos=rel_2d_pos, - ) + # Dispatch to appropriate attention implementation + if self._attn_implementation == "flash_attention_2": + # Check for unsupported features + if output_attentions: + logger.warning_once( + "FlashAttention-2 does not support output_attentions, falling back to eager attention." + ) + self_outputs = self.self(hidden_states, attention_mask, output_attentions, rel_pos, rel_2d_pos) + elif self.self.has_relative_attention_bias and rel_pos is not None: + logger.warning_once( + "FlashAttention-2 does not support relative position bias, falling back to eager attention." + ) + self_outputs = self.self(hidden_states, attention_mask, output_attentions, rel_pos, rel_2d_pos) + else: + self_outputs = self._flash_attention_forward(hidden_states, attention_mask) + + elif self._attn_implementation == "sdpa": + # Check for unsupported features + if output_attentions: + logger.warning_once("SDPA does not support output_attentions, falling back to eager attention.") + self_outputs = self.self(hidden_states, attention_mask, output_attentions, rel_pos, rel_2d_pos) + elif self.self.has_relative_attention_bias and rel_pos is not None: + # SDPA doesn't support relative position bias + self_outputs = self.self(hidden_states, attention_mask, output_attentions, rel_pos, rel_2d_pos) + else: + self_outputs = self._sdpa_attention_forward(hidden_states, attention_mask, head_mask) + + else: # eager + self_outputs = self.self(hidden_states, attention_mask, output_attentions, rel_pos, rel_2d_pos) + attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs + def _sdpa_attention_forward(self, hidden_states, attention_mask=None, head_mask=None): + batch_size, seq_len, _ = hidden_states.size() + + # Get Q, K, V + query_layer = self.self.query(hidden_states) + key_layer = self.self.key(hidden_states) + value_layer = self.self.value(hidden_states) + + query_layer = query_layer.view( + batch_size, seq_len, self.self.num_attention_heads, self.self.attention_head_size + ).transpose(1, 2) + key_layer = key_layer.view( + batch_size, seq_len, self.self.num_attention_heads, self.self.attention_head_size + ).transpose(1, 2) + value_layer = value_layer.view( + batch_size, seq_len, self.self.num_attention_heads, self.self.attention_head_size + ).transpose(1, 2) + + # Convert attention mask to boolean format for SDPA + attn_mask = None + if attention_mask is not None: + # SDPA expects 2D mask, but we might have 4D extended mask + if attention_mask.dim() == 4: + # Convert 4D extended mask to 2D: (batch_size, 1, 1, seq_len) -> (batch_size, seq_len) + attn_mask = attention_mask.squeeze(1).squeeze(1) >= 0 + elif attention_mask.dim() == 2: + attn_mask = attention_mask >= 0 + else: + # For other dimensions, try to squeeze to 2D + attn_mask = attention_mask.squeeze() >= 0 + + # Expand mask to be broadcastable with attention heads: (batch_size, seq_len) -> (batch_size, 1, seq_len, seq_len) + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) + + # SDPA doesn't support head_mask, fallback if needed + if head_mask is not None: + return self.self(hidden_states, attention_mask, output_attentions=False, rel_pos=None, rel_2d_pos=None) + + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_layer, + key_layer, + value_layer, + attn_mask=attn_mask, + dropout_p=self.self.dropout.p if self.training else 0.0, + scale=1.0 / math.sqrt(self.self.attention_head_size), + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, seq_len, self.self.all_head_size) + + return (attn_output,) + + def _flash_attention_forward(self, hidden_states, attention_mask=None): + batch_size, seq_length, _ = hidden_states.size() + + # Get Q, K, V + query_states = self.self.query(hidden_states) + key_states = self.self.key(hidden_states) + value_states = self.self.value(hidden_states) + + query_states = query_states.view( + batch_size, seq_length, self.self.num_attention_heads, self.self.attention_head_size + ) + key_states = key_states.view( + batch_size, seq_length, self.self.num_attention_heads, self.self.attention_head_size + ) + value_states = value_states.view( + batch_size, seq_length, self.self.num_attention_heads, self.self.attention_head_size + ) + + if attention_mask is not None: + # Unpad for variable length sequences + query_states = query_states.view(-1, self.self.num_attention_heads, self.self.attention_head_size) + key_states = key_states.view(-1, self.self.num_attention_heads, self.self.attention_head_size) + value_states = value_states.view(-1, self.self.num_attention_heads, self.self.attention_head_size) + + query_states, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask) + key_states = key_states[indices_q] + value_states = value_states[indices_q] + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_q, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_q, + dropout_p=self.self.dropout.p if self.training else 0.0, + softmax_scale=1.0 / math.sqrt(self.self.attention_head_size), + causal=self.is_causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, seq_length) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout_p=self.self.dropout.p if self.training else 0.0, + softmax_scale=1.0 / math.sqrt(self.self.attention_head_size), + causal=self.is_causal, + ) + + attn_output = attn_output.reshape(batch_size, seq_length, self.self.all_head_size) + + return (attn_output,) + # Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Layer with LayoutLMv2->LayoutLMv3 class LayoutLMv3Layer(GradientCheckpointingLayer): diff --git a/src/transformers/models/layoutlmv3/my_changes.txt b/src/transformers/models/layoutlmv3/my_changes.txt new file mode 100644 index 000000000000..e69de29bb2d1