Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added my_changes.txt
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ class LayoutLMv3Config(PreTrainedConfig):

model_type = "layoutlmv3"

# Support flags for attention implementations
_supports_sdpa = True
_supports_flash_attn_2 = True
Comment on lines +109 to +111
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These flags go under the pretrained model class, e.g. in llama

_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_can_compile_fullgraph = True
_supports_attention_backend = True


def __init__(
self,
vocab_size=50265,
Expand Down Expand Up @@ -138,6 +142,7 @@ def __init__(
num_channels=3,
patch_size=16,
classifier_dropout=None,
_attn_implementation="eager",
**kwargs,
):
super().__init__(
Expand Down Expand Up @@ -173,6 +178,7 @@ def __init__(
self.num_channels = num_channels
self.patch_size = patch_size
self.classifier_dropout = classifier_dropout
self._attn_implementation = _attn_implementation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We handle this already, no need to modify anything in the config



__all__ = ["LayoutLMv3Config"]
182 changes: 170 additions & 12 deletions src/transformers/models/layoutlmv3/modeling_layoutlmv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...activations import ACT2FN

# SDPA and Flash Attention support
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import (
BaseModelOutput,
Expand All @@ -35,12 +37,25 @@
from ...pytorch_utils import apply_chunking_to_forward
from ...utils import (
auto_docstring,
is_flash_attn_2_available,
logging,
torch_int,
)
from .configuration_layoutlmv3 import LayoutLMv3Config


if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import pad_input, unpad_input
else:
# Define dummy functions for when flash_attn is not available
def unpad_input(*args, **kwargs):
raise ImportError("flash_attn is not available")

def pad_input(*args, **kwargs):
raise ImportError("flash_attn is not available")


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -285,10 +300,11 @@ def forward(
# Changing the computational order into QT(K/√d) alleviates the problem. (https://huggingface.co/papers/2105.13290)
attention_scores = torch.matmul(query_layer / math.sqrt(self.attention_head_size), key_layer.transpose(-1, -2))

if self.has_relative_attention_bias and self.has_spatial_attention_bias:
attention_scores += (rel_pos + rel_2d_pos) / math.sqrt(self.attention_head_size)
elif self.has_relative_attention_bias:
attention_scores += rel_pos / math.sqrt(self.attention_head_size)
if self.has_relative_attention_bias and rel_pos is not None:
if self.has_spatial_attention_bias and rel_2d_pos is not None:
attention_scores += (rel_pos + rel_2d_pos) / math.sqrt(self.attention_head_size)
else:
attention_scores += rel_pos / math.sqrt(self.attention_head_size)

if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
Expand Down Expand Up @@ -328,32 +344,174 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to
return hidden_states


# Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3
# Adapted from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Attention with LayoutLMv2->LayoutLMv3
# Enhanced with unified attention implementation supporting eager, SDPA, and FlashAttention-2
Comment on lines +347 to +348
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better if we could change the upstream model instead. Or only copy relevant parts.

BUT, we also shouldn't modify this class by much. You should follow the structure of Bert for example:

  • Wrapper before the actual attention classes (without the additional cross attn logic for this model here)
    class BertAttention(nn.Module):
    def __init__(self, config, is_causal=False, layer_idx=None, is_cross_attention=False):
    super().__init__()
    self.is_cross_attention = is_cross_attention
    attention_class = BertCrossAttention if is_cross_attention else BertSelfAttention
    self.self = attention_class(config, is_causal=is_causal, layer_idx=layer_idx)
    self.output = BertSelfOutput(config)
    def forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.FloatTensor] = None,
    encoder_hidden_states: Optional[torch.FloatTensor] = None,
    encoder_attention_mask: Optional[torch.FloatTensor] = None,
    past_key_values: Optional[Cache] = None,
    cache_position: Optional[torch.Tensor] = None,
    **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[torch.Tensor]:
    attention_mask = attention_mask if not self.is_cross_attention else encoder_attention_mask
    attention_output, attn_weights = self.self(
    hidden_states,
    encoder_hidden_states=encoder_hidden_states,
    attention_mask=attention_mask,
    past_key_values=past_key_values,
    cache_position=cache_position,
    **kwargs,
    )
    attention_output = self.output(attention_output, hidden_states)
    return attention_output, attn_weights
  • The actual attention class
    def eager_attention_forward(
    module: nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    scaling: Optional[float] = None,
    dropout: float = 0.0,
    use_cache: Optional[bool] = None,
    **kwargs: Unpack[TransformersKwargs],
    ):
    if scaling is None:
    scaling = query.size(-1) ** -0.5
    # Take the dot product between "query" and "key" to get the raw attention scores.
    attn_weights = torch.matmul(query, key.transpose(2, 3))
    # Relative positional embeddings
    if module.position_embedding_type == "relative_key" or module.position_embedding_type == "relative_key_query":
    query_length, key_length = query.shape[2], key.shape[2]
    if use_cache:
    position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=query.device).view(-1, 1)
    else:
    position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1)
    position_ids_r = torch.arange(key_length, dtype=torch.long, device=query.device).view(1, -1)
    distance = position_ids_l - position_ids_r
    positional_embedding = module.distance_embedding(distance + module.max_position_embeddings - 1)
    positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility
    if module.position_embedding_type == "relative_key":
    relative_position_scores = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
    attn_weights = attn_weights + relative_position_scores
    elif module.position_embedding_type == "relative_key_query":
    relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query, positional_embedding)
    relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key, positional_embedding)
    attn_weights = attn_weights + relative_position_scores_query + relative_position_scores_key
    # Scaling is shifted in case of embeddings being relative
    attn_weights = attn_weights * scaling
    if attention_mask is not None and attention_mask.ndim == 4:
    attention_mask = attention_mask[:, :, :, : key.shape[-2]]
    attn_weights = attn_weights + attention_mask
    attn_weights = nn.functional.softmax(attn_weights, dim=-1)
    attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
    attn_output = torch.matmul(attn_weights, value)
    attn_output = attn_output.transpose(1, 2).contiguous()
    return attn_output, attn_weights
    class BertSelfAttention(nn.Module):
    def __init__(self, config, position_embedding_type=None, is_causal=False, layer_idx=None):
    super().__init__()
    if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
    raise ValueError(
    f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
    f"heads ({config.num_attention_heads})"
    )
    self.config = config
    self.num_attention_heads = config.num_attention_heads
    self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
    self.all_head_size = self.num_attention_heads * self.attention_head_size
    self.scaling = self.attention_head_size**-0.5
    self.query = nn.Linear(config.hidden_size, self.all_head_size)
    self.key = nn.Linear(config.hidden_size, self.all_head_size)
    self.value = nn.Linear(config.hidden_size, self.all_head_size)
    self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
    self.position_embedding_type = position_embedding_type or getattr(
    config, "position_embedding_type", "absolute"
    )
    if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
    self.max_position_embeddings = config.max_position_embeddings
    self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
    self.is_decoder = config.is_decoder
    self.is_causal = is_causal
    self.layer_idx = layer_idx
    def forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.FloatTensor] = None,
    past_key_value: Optional[Cache] = None,
    cache_position: Optional[torch.Tensor] = None,
    **kwargs: Unpack[TransformersKwargs],
    ) -> tuple[torch.Tensor]:
    input_shape = hidden_states.shape[:-1]
    hidden_shape = (*input_shape, -1, self.attention_head_size)
    # get all proj
    query_layer = self.query(hidden_states).view(*hidden_shape).transpose(1, 2)
    key_layer = self.key(hidden_states).view(*hidden_shape).transpose(1, 2)
    value_layer = self.value(hidden_states).view(*hidden_shape).transpose(1, 2)
    if past_key_value is not None:
    # decoder-only bert can have a simple dynamic cache for example
    current_past_key_value = past_key_value
    if isinstance(past_key_value, EncoderDecoderCache):
    current_past_key_value = past_key_value.self_attention_cache
    # save all key/value_layer to cache to be re-used for fast auto-regressive generation
    key_layer, value_layer = current_past_key_value.update(
    key_layer,
    value_layer,
    self.layer_idx,
    {"cache_position": cache_position},
    )
    attention_interface: Callable = eager_attention_forward
    if self.config._attn_implementation != "eager":
    if self.position_embedding_type != "absolute":
    raise ValueError(
    f"You are using {self.config._attn_implementation} as attention type. However, non-absolute "
    'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.'
    )
    attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
    attn_output, attn_weights = attention_interface(
    self,
    query_layer,
    key_layer,
    value_layer,
    attention_mask,
    dropout=0.0 if not self.training else self.dropout.p,
    scaling=self.scaling,
    # only for relevant for non-absolute positional embeddings
    use_cache=past_key_value is not None,
    **kwargs,
    )
    attn_output = attn_output.reshape(*input_shape, -1).contiguous()
    return attn_output, attn_weights

Currently, the wrapper class is modified and a lot of custom paths are introduced. We use our own integrations with the interface here (linking the files as examples below as well)

attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.position_embedding_type != "absolute":
raise ValueError(
f"You are using {self.config._attn_implementation} as attention type. However, non-absolute "
'positional embeddings can not work with them. Please load the model with `attn_implementation="eager"`.'
)
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

What needs to be handled:

  • the mask creation
  • adapting the new attention class (as per Bert for example)

class LayoutLMv3Attention(nn.Module):
"""
Unified LayoutLMv3 attention module with support for eager, SDPA, and FlashAttention-2.
"""

def __init__(self, config):
super().__init__()
self.self = LayoutLMv3SelfAttention(config)
self.output = LayoutLMv3SelfOutput(config)

# Store attention implementation config
self.is_causal = False
self._attn_implementation = config._attn_implementation if hasattr(config, "_attn_implementation") else "eager"

def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Head mask is deprecated

output_attentions=False,
rel_pos=None,
rel_2d_pos=None,
):
self_outputs = self.self(
hidden_states,
attention_mask,
output_attentions,
rel_pos=rel_pos,
rel_2d_pos=rel_2d_pos,
)
# Dispatch to appropriate attention implementation
if self._attn_implementation == "flash_attention_2":
# Check for unsupported features
if output_attentions:
logger.warning_once(
"FlashAttention-2 does not support output_attentions, falling back to eager attention."
)
self_outputs = self.self(hidden_states, attention_mask, output_attentions, rel_pos, rel_2d_pos)
elif self.self.has_relative_attention_bias and rel_pos is not None:
logger.warning_once(
"FlashAttention-2 does not support relative position bias, falling back to eager attention."
)
self_outputs = self.self(hidden_states, attention_mask, output_attentions, rel_pos, rel_2d_pos)
else:
self_outputs = self._flash_attention_forward(hidden_states, attention_mask)

elif self._attn_implementation == "sdpa":
# Check for unsupported features
if output_attentions:
logger.warning_once("SDPA does not support output_attentions, falling back to eager attention.")
self_outputs = self.self(hidden_states, attention_mask, output_attentions, rel_pos, rel_2d_pos)
elif self.self.has_relative_attention_bias and rel_pos is not None:
# SDPA doesn't support relative position bias
self_outputs = self.self(hidden_states, attention_mask, output_attentions, rel_pos, rel_2d_pos)
else:
self_outputs = self._sdpa_attention_forward(hidden_states, attention_mask, head_mask)

else: # eager
self_outputs = self.self(hidden_states, attention_mask, output_attentions, rel_pos, rel_2d_pos)

attention_output = self.output(self_outputs[0], hidden_states)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs

def _sdpa_attention_forward(self, hidden_states, attention_mask=None, head_mask=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch_size, seq_len, _ = hidden_states.size()

# Get Q, K, V
query_layer = self.self.query(hidden_states)
key_layer = self.self.key(hidden_states)
value_layer = self.self.value(hidden_states)

query_layer = query_layer.view(
batch_size, seq_len, self.self.num_attention_heads, self.self.attention_head_size
).transpose(1, 2)
key_layer = key_layer.view(
batch_size, seq_len, self.self.num_attention_heads, self.self.attention_head_size
).transpose(1, 2)
value_layer = value_layer.view(
batch_size, seq_len, self.self.num_attention_heads, self.self.attention_head_size
).transpose(1, 2)

# Convert attention mask to boolean format for SDPA
attn_mask = None
if attention_mask is not None:
# SDPA expects 2D mask, but we might have 4D extended mask
if attention_mask.dim() == 4:
# Convert 4D extended mask to 2D: (batch_size, 1, 1, seq_len) -> (batch_size, seq_len)
attn_mask = attention_mask.squeeze(1).squeeze(1) >= 0
elif attention_mask.dim() == 2:
attn_mask = attention_mask >= 0
else:
# For other dimensions, try to squeeze to 2D
attn_mask = attention_mask.squeeze() >= 0

# Expand mask to be broadcastable with attention heads: (batch_size, seq_len) -> (batch_size, 1, seq_len, seq_len)
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)

# SDPA doesn't support head_mask, fallback if needed
if head_mask is not None:
return self.self(hidden_states, attention_mask, output_attentions=False, rel_pos=None, rel_2d_pos=None)

attn_output = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
attn_mask=attn_mask,
dropout_p=self.self.dropout.p if self.training else 0.0,
scale=1.0 / math.sqrt(self.self.attention_head_size),
)

attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.self.all_head_size)

return (attn_output,)

def _flash_attention_forward(self, hidden_states, attention_mask=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

batch_size, seq_length, _ = hidden_states.size()

# Get Q, K, V
query_states = self.self.query(hidden_states)
key_states = self.self.key(hidden_states)
value_states = self.self.value(hidden_states)

query_states = query_states.view(
batch_size, seq_length, self.self.num_attention_heads, self.self.attention_head_size
)
key_states = key_states.view(
batch_size, seq_length, self.self.num_attention_heads, self.self.attention_head_size
)
value_states = value_states.view(
batch_size, seq_length, self.self.num_attention_heads, self.self.attention_head_size
)

if attention_mask is not None:
# Unpad for variable length sequences
query_states = query_states.view(-1, self.self.num_attention_heads, self.self.attention_head_size)
key_states = key_states.view(-1, self.self.num_attention_heads, self.self.attention_head_size)
value_states = value_states.view(-1, self.self.num_attention_heads, self.self.attention_head_size)

query_states, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(query_states, attention_mask)
key_states = key_states[indices_q]
value_states = value_states[indices_q]

attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_q,
dropout_p=self.self.dropout.p if self.training else 0.0,
softmax_scale=1.0 / math.sqrt(self.self.attention_head_size),
causal=self.is_causal,
)

attn_output = pad_input(attn_output_unpad, indices_q, batch_size, seq_length)
else:
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
dropout_p=self.self.dropout.p if self.training else 0.0,
softmax_scale=1.0 / math.sqrt(self.self.attention_head_size),
causal=self.is_causal,
)

attn_output = attn_output.reshape(batch_size, seq_length, self.self.all_head_size)

return (attn_output,)


# Copied from transformers.models.layoutlmv2.modeling_layoutlmv2.LayoutLMv2Layer with LayoutLMv2->LayoutLMv3
class LayoutLMv3Layer(GradientCheckpointingLayer):
Expand Down
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be removed?

Empty file.