From dfcb545d983fce4a34ccffb8f78efe98592844a4 Mon Sep 17 00:00:00 2001 From: vasqu Date: Fri, 24 Oct 2025 20:39:18 +0200 Subject: [PATCH 01/10] atmpt 1 --- src/transformers/integrations/executorch.py | 20 +-- src/transformers/masking_utils.py | 172 +++++++------------- 2 files changed, 69 insertions(+), 123 deletions(-) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 0d4910732528..51d95dd2604b 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -229,10 +229,10 @@ def __init__( "Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config." ) self.model = TorchExportableModuleWithStaticCache(model, batch_size, max_cache_len, device) - # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable + """# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) - self.model.model.config._attn_implementation = "sdpa_without_vmap" + self.model.model.config._attn_implementation = "sdpa_without_vmap""" def forward( self, @@ -768,10 +768,10 @@ def convert_and_export_with_cache( import torch.export._trace - # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable - ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) - ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) - model.config._attn_implementation = "sdpa_without_vmap" + """# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable + ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) + ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) + self.model.model.config._attn_implementation = "sdpa_without_vmap""" with torch.no_grad(): # TODO: The default inputs only work for text models. We need to add support for vision/audio models. @@ -1036,10 +1036,10 @@ def export_with_dynamic_cache( if not is_torch_greater_or_equal_than_2_3: raise ImportError("torch >= 2.3 is required.") - # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable - ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) - ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) - model.config._attn_implementation = "sdpa_without_vmap" + """# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable + ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) + ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) + self.model.model.config._attn_implementation = "sdpa_without_vmap""" register_dynamic_cache_export_support() diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 8f0a8a04fd60..44f13d5bc907 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -224,6 +224,9 @@ def prepare_padding_mask( return local_padding_mask +sdpa_needs_sclicing = not _is_torch_greater_or_equal_than_2_6 + + def _ignore_causal_mask_sdpa( padding_mask: Optional[torch.Tensor], query_length: int, @@ -282,7 +285,16 @@ def _ignore_bidirectional_mask_sdpa(padding_mask: Optional[torch.Tensor]) -> boo return False -def sdpa_mask_recent_torch( +def _non_vmap_expansion_sdpa(cache_position, batch_size, kv_length, kv_offset): + device = cache_position.device + q_indices = cache_position[None, None, :, None] + head_indices = torch.arange(1, dtype=torch.long, device=device)[None, :, None, None] + batch_indices = torch.arange(batch_size, dtype=torch.long, device=device)[:, None, None, None] + kv_indices = torch.arange(kv_length, dtype=torch.long, device=device)[None, None, None, :] + kv_offset + return batch_indices, head_indices, q_indices, kv_indices + + +def sdpa_mask( batch_size: int, cache_position: torch.Tensor, kv_length: int, @@ -292,6 +304,8 @@ def sdpa_mask_recent_torch( local_size: Optional[int] = None, allow_is_causal_skip: bool = True, allow_is_bidirectional_skip: bool = False, + allow_torch_fix: bool = True, + use_vmap: bool = False, **kwargs, ) -> Optional[torch.Tensor]: """ @@ -324,6 +338,9 @@ def sdpa_mask_recent_torch( allow_is_bidirectional_skip (`bool`, optional): Whether to allow to return `None` for the mask under conditions where we do not have to add any bias, i.e. full attention without any padding. Default to `False`. + allow_torch_fix (`bool`, optional): + Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older + versions. We need an arg to skip it when using eager. By default `True`. ## Creating a simple causal mask: @@ -392,7 +409,7 @@ def sdpa_mask_recent_torch( """ q_length = cache_position.shape[0] # Potentially pad the 2D mask, and slice it correctly - padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False) + padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=sdpa_needs_sclicing and not use_vmap) # Under specific conditions, we can avoid materializing the mask # 1. Causal masks can rely on the `is_causal` argument @@ -402,130 +419,47 @@ def sdpa_mask_recent_torch( if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask): return None - # vmap can incur performance issues as reported in #41566 for bidirectional mask as we only need to expand the - # padding mask. Thus, we allow early exit here if we do not detect any modification to the base mask function - if mask_function is bidirectional_mask_function: - if padding_mask is not None: - # used for slicing without data-dependent slicing - mask_indices = torch.arange(kv_length, device=cache_position.device) + kv_offset - return padding_mask[:, None, None, mask_indices].expand(-1, -1, q_length, -1) - else: - return torch.ones(batch_size, 1, q_length, kv_length, dtype=torch.bool, device=cache_position.device) - # Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)` # but without data-dependent slicing (i.e. torch.compile friendly) kv_arange = torch.arange(kv_length, device=cache_position.device) kv_arange += kv_offset + # This creates the 4D mask easily. Note that we do not include vmap over the batch_idx dimension as well, + # as vmap cannot handle slicing a tensor from scalar tensor (it internally calls `.item()` which vmap does not allow + # However, in more recent version of Pytorch, a trick was introduced to handle it + if sdpa_needs_sclicing and use_vmap: + attention_mask = _vmap_for_bhqkv(mask_function, bh_indices=False)(None, None, cache_position, kv_arange) + attention_mask = attention_mask[None, None, :, :].expand(batch_size, -1, -1, -1) + # Potentially add the padding 2D mask if padding_mask is not None: - mask_function = and_masks(mask_function, padding_mask_function(padding_mask)) - - batch_arange = torch.arange(batch_size, device=cache_position.device) - head_arange = torch.arange(1, device=cache_position.device) - # This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from - # scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it - # We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices - with TransformGetItemToIndex(): - causal_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange) - - return causal_mask - - -def sdpa_mask_older_torch( - batch_size: int, - cache_position: torch.Tensor, - kv_length: int, - kv_offset: int = 0, - mask_function: Callable = causal_mask_function, - attention_mask: Optional[torch.Tensor] = None, - local_size: Optional[int] = None, - allow_is_causal_skip: bool = True, - allow_torch_fix: bool = True, - allow_is_bidirectional_skip: bool = False, - **kwargs, -) -> Optional[torch.Tensor]: - """ - NOTE: This function is only used when torch version is torch<2.5 - see `sdpa_mask_recent_torch` otherwise. - - Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that - the element should take part in the attention computation, and False that it should not. - If `allow_torch_fix=True` (the default), rows corresponding to query tokens that do not attend - to any other tokens (due to padding) will be fully attended to instead, in order to avoid `nan` propagation (this does - not change the final result). - - Args: - batch_size (`int`): - The batch size of the input sequence. - cache_position (`torch.Tensor`): - A tensor of shape (query_length,) indicating the current indices of the input sequence elements. - kv_length (`int`): - The size that the key and value states will have during the attention computation. - kv_offset (`int`, optional): - An optional offset to indicate at which first position the key and values states will refer to. - mask_function (`Callable`): - The mask factory function describing the mask pattern. - attention_mask (`torch.Tensor`, optional): - The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length) - local_size (`int`, optional): - The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True` - to try to skip mask creation if possible. - allow_is_causal_skip (`bool`, optional): - Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in - `torch.sdpa` instead. Default to `True`. - allow_torch_fix (`bool`, optional): - Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older - versions. We need an arg to skip it when using eager. By default `True`. - allow_is_bidirectional_skip (`bool`, optional): - Whether to allow to return `None` for the mask under conditions where we do not have to add any bias, - i.e. full attention without any padding. Default to `False`. - """ - q_length = cache_position.shape[0] - # Potentially pad the 2D mask, and slice it correctly - padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset) - - # Under specific conditions, we can avoid materializing the mask - # 1. Causal masks can rely on the `is_causal` argument - # 2. Bidirectional do not need any further processing (no bias) - if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, kv_offset, local_size): - return None - if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask): - return None - - # vmap can incur performance issues as reported in #41566 for bidirectional mask as we only need to expand the - # padding mask. Thus, we allow early exit here if we do not detect any modification to the base mask function - if mask_function is bidirectional_mask_function: - if padding_mask is not None: - return padding_mask[:, None, None, :].expand(-1, -1, q_length, -1) + if sdpa_needs_sclicing and use_vmap: + attention_mask = attention_mask * padding_mask[:, None, None, :] else: - return torch.ones(batch_size, 1, q_length, kv_length, dtype=torch.bool, device=cache_position.device) + mask_function = and_masks(mask_function, padding_mask_function(padding_mask)) - # Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)` - # but without data-dependent slicing (i.e. torch.compile friendly) - kv_arange = torch.arange(kv_length, device=cache_position.device) - kv_arange += kv_offset - - # This creates the 4D mask easily. Note that we do not include vmap over the batch_idx dimension as well, - # as vmap cannot handle slicing a tensor from scalar tensor (it internally calls `.item()` which vmap does not allow - # However, in more recent version of Pytorch, a trick was introduced to handle it - which is the reason we have - # `sdpa_mask_recent_torch`, as it allows more general `mask_function` - causal_mask = _vmap_for_bhqkv(mask_function, bh_indices=False)(None, None, cache_position, kv_arange) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1) - if padding_mask is not None: - causal_mask = causal_mask * padding_mask[:, None, None, :] + if not use_vmap: + # Apply mask function element-wise through broadcasting + causal_mask = mask_function(*_non_vmap_expansion_sdpa(cache_position, batch_size, kv_length, kv_offset)) + # Expand the mask to match batch size and query length if they weren't used in the mask function + causal_mask = causal_mask.expand(batch_size, -1, q_length, kv_length) + else: + # Due to a bug in versions of torch<2.5, we need to update the mask in case a query is not attending to any + # tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213 + if sdpa_needs_sclicing and not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix: + causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True) + else: + batch_arange = torch.arange(batch_size, device=cache_position.device) + head_arange = torch.arange(1, device=cache_position.device) + # This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from + # scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it + # We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices + with TransformGetItemToIndex(): + causal_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange) - # Due to a bug in versions of torch<2.5, we need to update the mask in case a query is not attending to any - # tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213 - if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix: - causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True) return causal_mask -# We use the version with newer torch whenever possible, as it is more general and can handle arbitrary mask functions -# (especially mask_function indexing a tensor, such as the padding mask function) -sdpa_mask = sdpa_mask_recent_torch if _is_torch_greater_or_equal_than_2_6 else sdpa_mask_older_torch - - def eager_mask( batch_size: int, cache_position: torch.Tensor, @@ -534,6 +468,7 @@ def eager_mask( mask_function: Callable = causal_mask_function, attention_mask: Optional[torch.Tensor] = None, dtype: torch.dtype = torch.float32, + uses_vmap: bool = False, **kwargs, ) -> torch.Tensor: """ @@ -570,6 +505,7 @@ def eager_mask( allow_is_causal_skip=False, allow_is_bidirectional_skip=False, allow_torch_fix=False, + uses_vmap=uses_vmap, **kwargs, ) min_dtype = torch.finfo(dtype).min @@ -862,21 +798,25 @@ def create_causal_mask( # Allow slight deviations from causal mask # Note that it is very important to apply this before any other deviations of the mask (such as packed sequence mask, # padding mask, etc) as the resulting mask may otherwise not be correct! + uses_vmap = False if or_mask_function is not None: if not _is_torch_greater_or_equal_than_2_6: raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") mask_factory_function = or_masks(mask_factory_function, or_mask_function) allow_is_causal_skip = False + uses_vmap = True if and_mask_function is not None: if not _is_torch_greater_or_equal_than_2_6: raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") mask_factory_function = and_masks(mask_factory_function, and_mask_function) allow_is_causal_skip = False + uses_vmap = True # If we detected packing format if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6: mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask)) allow_is_causal_skip = False + uses_vmap = True # We now create the mask causal_mask = mask_interface( @@ -889,6 +829,7 @@ def create_causal_mask( allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa dtype=dtype, # Additional kwarg for eager config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface + uses_vmap=uses_vmap, # Short-circuit to non-vmap expansions for the mask ) return causal_mask @@ -946,16 +887,19 @@ def create_bidirectional_mask( # Allow slight deviations from the base mask # Note that it is very important to apply this before any other deviations of the mask (such as packed sequence mask, # padding mask, etc) as the resulting mask may otherwise not be correct! + uses_vmap = False if or_mask_function is not None: if not _is_torch_greater_or_equal_than_2_6: raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") mask_factory_function = or_masks(mask_factory_function, or_mask_function) allow_is_bidirectional_skip = False + uses_vmap = True if and_mask_function is not None: if not _is_torch_greater_or_equal_than_2_6: raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") mask_factory_function = and_masks(mask_factory_function, and_mask_function) allow_is_bidirectional_skip = False + uses_vmap = True # We now create the mask attention_mask = mask_interface( @@ -968,8 +912,10 @@ def create_bidirectional_mask( # Additional kwargs for sdpa allow_is_causal_skip=False, allow_is_bidirectional_skip=allow_is_bidirectional_skip, + allow_torch_fix=True, dtype=dtype, # Additional kwarg for eager config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface + uses_vmap=uses_vmap, # Short-circuit to non-vmap expansions for the mask ) return attention_mask From b87a139f5f13f78c196e56ec6f80c291f4dcdef3 Mon Sep 17 00:00:00 2001 From: vasqu Date: Fri, 24 Oct 2025 21:33:02 +0200 Subject: [PATCH 02/10] fixup masking to work correctly with old torch --- src/transformers/masking_utils.py | 48 +++++++++++++++---------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 44f13d5bc907..169ac899f9fa 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -224,9 +224,6 @@ def prepare_padding_mask( return local_padding_mask -sdpa_needs_sclicing = not _is_torch_greater_or_equal_than_2_6 - - def _ignore_causal_mask_sdpa( padding_mask: Optional[torch.Tensor], query_length: int, @@ -408,8 +405,10 @@ def sdpa_mask( """ q_length = cache_position.shape[0] + # Potentially pad the 2D mask, and slice it correctly - padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=sdpa_needs_sclicing and not use_vmap) + sdpa_needs_sclicing = (not _is_torch_greater_or_equal_than_2_6) and use_vmap + padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=sdpa_needs_sclicing) # Under specific conditions, we can avoid materializing the mask # 1. Causal masks can rely on the `is_causal` argument @@ -427,37 +426,37 @@ def sdpa_mask( # This creates the 4D mask easily. Note that we do not include vmap over the batch_idx dimension as well, # as vmap cannot handle slicing a tensor from scalar tensor (it internally calls `.item()` which vmap does not allow # However, in more recent version of Pytorch, a trick was introduced to handle it - if sdpa_needs_sclicing and use_vmap: + if sdpa_needs_sclicing: attention_mask = _vmap_for_bhqkv(mask_function, bh_indices=False)(None, None, cache_position, kv_arange) attention_mask = attention_mask[None, None, :, :].expand(batch_size, -1, -1, -1) # Potentially add the padding 2D mask if padding_mask is not None: - if sdpa_needs_sclicing and use_vmap: + if sdpa_needs_sclicing: attention_mask = attention_mask * padding_mask[:, None, None, :] else: mask_function = and_masks(mask_function, padding_mask_function(padding_mask)) if not use_vmap: # Apply mask function element-wise through broadcasting - causal_mask = mask_function(*_non_vmap_expansion_sdpa(cache_position, batch_size, kv_length, kv_offset)) + attention_mask = mask_function(*_non_vmap_expansion_sdpa(cache_position, batch_size, kv_length, kv_offset)) # Expand the mask to match batch size and query length if they weren't used in the mask function - causal_mask = causal_mask.expand(batch_size, -1, q_length, kv_length) - else: - # Due to a bug in versions of torch<2.5, we need to update the mask in case a query is not attending to any - # tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213 - if sdpa_needs_sclicing and not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix: - causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True) - else: - batch_arange = torch.arange(batch_size, device=cache_position.device) - head_arange = torch.arange(1, device=cache_position.device) - # This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from - # scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it - # We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices - with TransformGetItemToIndex(): - causal_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange) + attention_mask = attention_mask.expand(batch_size, -1, q_length, kv_length) + elif _is_torch_greater_or_equal_than_2_6: + batch_arange = torch.arange(batch_size, device=cache_position.device) + head_arange = torch.arange(1, device=cache_position.device) + # This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from + # scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it + # We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices + with TransformGetItemToIndex(): + attention_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange) + + # Due to a bug in versions of torch<2.5, we need to update the mask in case a query is not attending to any + # tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213 + if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix: + attention_mask |= torch.all(~attention_mask, dim=-1, keepdim=True) - return causal_mask + return attention_mask def eager_mask( @@ -813,10 +812,11 @@ def create_causal_mask( uses_vmap = True # If we detected packing format - if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6: + if packed_sequence_mask is not None: + if uses_vmap and not _is_torch_greater_or_equal_than_2_6: + raise ValueError("TODO: custom patterns not possible with padding free") mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask)) allow_is_causal_skip = False - uses_vmap = True # We now create the mask causal_mask = mask_interface( From 969dab558ada6ec9f92ff44b01b0bc0d19118b61 Mon Sep 17 00:00:00 2001 From: vasqu Date: Fri, 24 Oct 2025 22:07:07 +0200 Subject: [PATCH 03/10] few changes to make things a bit more cleaner --- src/transformers/masking_utils.py | 42 ++++++++++++++++--------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 169ac899f9fa..5440a79c1aa2 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -282,12 +282,12 @@ def _ignore_bidirectional_mask_sdpa(padding_mask: Optional[torch.Tensor]) -> boo return False -def _non_vmap_expansion_sdpa(cache_position, batch_size, kv_length, kv_offset): +def _non_vmap_expansion_sdpa(batch_size, cache_position, kv_arange): device = cache_position.device - q_indices = cache_position[None, None, :, None] - head_indices = torch.arange(1, dtype=torch.long, device=device)[None, :, None, None] batch_indices = torch.arange(batch_size, dtype=torch.long, device=device)[:, None, None, None] - kv_indices = torch.arange(kv_length, dtype=torch.long, device=device)[None, None, None, :] + kv_offset + head_indices = torch.arange(1, dtype=torch.long, device=device)[None, :, None, None] + q_indices = cache_position[None, None, :, None] + kv_indices = kv_arange.to(dtype=torch.long)[None, None, None, :] return batch_indices, head_indices, q_indices, kv_indices @@ -302,7 +302,7 @@ def sdpa_mask( allow_is_causal_skip: bool = True, allow_is_bidirectional_skip: bool = False, allow_torch_fix: bool = True, - use_vmap: bool = False, + use_vmap: bool = False, # TODO: docstring **kwargs, ) -> Optional[torch.Tensor]: """ @@ -418,28 +418,21 @@ def sdpa_mask( if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask): return None - # Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)` - # but without data-dependent slicing (i.e. torch.compile friendly) - kv_arange = torch.arange(kv_length, device=cache_position.device) - kv_arange += kv_offset - - # This creates the 4D mask easily. Note that we do not include vmap over the batch_idx dimension as well, - # as vmap cannot handle slicing a tensor from scalar tensor (it internally calls `.item()` which vmap does not allow - # However, in more recent version of Pytorch, a trick was introduced to handle it - if sdpa_needs_sclicing: - attention_mask = _vmap_for_bhqkv(mask_function, bh_indices=False)(None, None, cache_position, kv_arange) - attention_mask = attention_mask[None, None, :, :].expand(batch_size, -1, -1, -1) - # Potentially add the padding 2D mask if padding_mask is not None: if sdpa_needs_sclicing: - attention_mask = attention_mask * padding_mask[:, None, None, :] + padding_mask = padding_mask[:, None, None, :] else: mask_function = and_masks(mask_function, padding_mask_function(padding_mask)) + # Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)` + # but without data-dependent slicing (i.e. torch.compile friendly) + kv_arange = torch.arange(kv_length, device=cache_position.device) + kv_arange += kv_offset + if not use_vmap: # Apply mask function element-wise through broadcasting - attention_mask = mask_function(*_non_vmap_expansion_sdpa(cache_position, batch_size, kv_length, kv_offset)) + attention_mask = mask_function(*_non_vmap_expansion_sdpa(batch_size, cache_position, kv_arange)) # Expand the mask to match batch size and query length if they weren't used in the mask function attention_mask = attention_mask.expand(batch_size, -1, q_length, kv_length) elif _is_torch_greater_or_equal_than_2_6: @@ -450,6 +443,15 @@ def sdpa_mask( # We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices with TransformGetItemToIndex(): attention_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange) + else: + # This creates the 4D mask easily. Note that we do not include vmap over the batch_idx dimension as well, + # as vmap cannot handle slicing a tensor from scalar tensor (it internally calls `.item()` which vmap does not allow + # However, in more recent version of Pytorch, a trick was introduced to handle it + attention_mask = _vmap_for_bhqkv(mask_function, bh_indices=False)(None, None, cache_position, kv_arange) + attention_mask = attention_mask[None, None, :, :].expand(batch_size, -1, -1, -1) + + if padding_mask is not None: + attention_mask = attention_mask * padding_mask # Due to a bug in versions of torch<2.5, we need to update the mask in case a query is not attending to any # tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213 @@ -797,7 +799,7 @@ def create_causal_mask( # Allow slight deviations from causal mask # Note that it is very important to apply this before any other deviations of the mask (such as packed sequence mask, # padding mask, etc) as the resulting mask may otherwise not be correct! - uses_vmap = False + uses_vmap = False # TODO --> move a bit, describe, and copy to other fns if or_mask_function is not None: if not _is_torch_greater_or_equal_than_2_6: raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") From 513c8ef7640117348b0fcb9bcc347e2666709f18 Mon Sep 17 00:00:00 2001 From: vasqu Date: Fri, 24 Oct 2025 22:18:42 +0200 Subject: [PATCH 04/10] oopsie --- src/transformers/masking_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 5440a79c1aa2..71638a5a203f 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -496,6 +496,7 @@ def eager_mask( # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf _ = kwargs.pop("allow_is_causal_skip", None) _ = kwargs.pop("allow_is_bidirectional_skip", None) + _ = kwargs.pop("allow_torch_fix", None) mask = sdpa_mask( batch_size=batch_size, cache_position=cache_position, @@ -914,7 +915,6 @@ def create_bidirectional_mask( # Additional kwargs for sdpa allow_is_causal_skip=False, allow_is_bidirectional_skip=allow_is_bidirectional_skip, - allow_torch_fix=True, dtype=dtype, # Additional kwarg for eager config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface uses_vmap=uses_vmap, # Short-circuit to non-vmap expansions for the mask From 466acaba77d285ffd6fde52b22c083b7e45468f5 Mon Sep 17 00:00:00 2001 From: vasqu Date: Fri, 24 Oct 2025 22:34:13 +0200 Subject: [PATCH 05/10] fix integer overflow on bidirectional masks via indexing fn --- src/transformers/masking_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 71638a5a203f..43b9380b7671 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -82,8 +82,10 @@ def causal_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) def bidirectional_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: """ This creates a full bidirectional mask. + + NOTE: It is important to keep and index-based version for non-vmap expansion. """ - return q_idx.new_ones((), dtype=torch.bool) + return q_idx >= 0 def sliding_window_overlay(sliding_window: int) -> Callable: From bbaf41d82dca3a9fba6d84b751dda64eb3ffbcb8 Mon Sep 17 00:00:00 2001 From: vasqu Date: Fri, 24 Oct 2025 22:36:19 +0200 Subject: [PATCH 06/10] rm executorch workarounds --> still need to handle on sliding etc fns properly --- src/transformers/integrations/executorch.py | 112 +------------------- 1 file changed, 1 insertion(+), 111 deletions(-) diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 51d95dd2604b..635ab7abe744 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -11,7 +11,6 @@ # specific language governing permissions and limitations under the License. import logging -from collections.abc import Callable from typing import Optional import torch @@ -24,13 +23,7 @@ StaticCache, ) from ..generation.configuration_utils import GenerationConfig -from ..masking_utils import ( - ALL_MASK_ATTENTION_FUNCTIONS, - _ignore_causal_mask_sdpa, - _is_torch_greater_or_equal_than_2_5, - prepare_padding_mask, -) -from ..modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ..modeling_utils import PreTrainedModel from ..pytorch_utils import ( is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3, @@ -229,10 +222,6 @@ def __init__( "Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config." ) self.model = TorchExportableModuleWithStaticCache(model, batch_size, max_cache_len, device) - """# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable - ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) - ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) - self.model.model.config._attn_implementation = "sdpa_without_vmap""" def forward( self, @@ -768,11 +757,6 @@ def convert_and_export_with_cache( import torch.export._trace - """# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable - ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) - ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) - self.model.model.config._attn_implementation = "sdpa_without_vmap""" - with torch.no_grad(): # TODO: The default inputs only work for text models. We need to add support for vision/audio models. example_input_ids = ( @@ -1036,11 +1020,6 @@ def export_with_dynamic_cache( if not is_torch_greater_or_equal_than_2_3: raise ImportError("torch >= 2.3 is required.") - """# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable - ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) - ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) - self.model.model.config._attn_implementation = "sdpa_without_vmap""" - register_dynamic_cache_export_support() with torch.no_grad(): @@ -1109,92 +1088,3 @@ def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context): value = value_list[idx] if idx < len(value_list) else None cache.update(key, value, idx) return cache - - -def sdpa_mask_without_vmap( - batch_size: int, - cache_position: torch.Tensor, - kv_length: int, - kv_offset: int = 0, - mask_function: Optional[Callable] = None, - attention_mask: Optional[torch.Tensor] = None, - local_size: Optional[int] = None, - allow_is_causal_skip: bool = True, - allow_torch_fix: bool = True, - **kwargs, -) -> Optional[torch.Tensor]: - """ - Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that - the element should take part in the attention computation, and False that it should not. - - This is similar to `masking_utils.sdpa_mask` but does not use `vmap` which is incompatible with export. - - Args: - batch_size (`int`): - The batch size of the input sequence. - cache_position (`torch.Tensor`): - A tensor of shape (query_length,) indicating the current indices of the input sequence elements. - kv_length (`int`): - The size that the key and value states will have during the attention computation. - kv_offset (`int`, optional): - An optional offset to indicate at which first position the key and values states will refer to. - mask_function (`Callable`): - The mask factory function describing the mask pattern. - attention_mask (`torch.Tensor`, optional): - The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length) - local_size (`int`, optional): - The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True` - to try to skip mask creation if possible. - allow_is_causal_skip (`bool`, optional): - Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in - `torch.sdpa` instead. Default to `True`. - allow_torch_fix (`bool`, optional): - Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older - versions. We need an arg to skip it when using eager. By default `True`. - - """ - - q_length = cache_position.shape[0] - # Potentially pad the 2D mask, and slice it correctly - padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset) - - # Under specific conditions, we can avoid materializing the mask, instead relying on the `is_causal` argument - if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, local_size): - return None - - # Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)` - # but without data-dependent slicing (i.e. torch.compile friendly) - kv_arange = torch.arange(kv_length, device=cache_position.device) - kv_arange += kv_offset - reshaped_cache_position = cache_position.view(-1, 1) - - # This is a bit hacky to know what pattern we are using, but all mask creation function actually forward - # the config through kwargs anyway, so it allows to rely on it - # Usually, the `mask_function` is the only entry-point to define the pattern - we could do for loops over it, - # but this is more efficient - sliding_window = getattr(kwargs["config"], "sliding_window", None) - chunk_size = getattr(kwargs["config"], "attention_chunk_size", None) - - if sliding_window is not None and chunk_size is not None: - raise ValueError("Cannot use both `sliding_window` and `attention_chunk_size`") - - # Simplest and most efficient way to obtain a causal mask - causal_mask = kv_arange <= reshaped_cache_position - # If using sliding window, add the sliding mask - if sliding_window is not None: - sliding_mask_overlay = kv_arange > reshaped_cache_position - sliding_window - causal_mask *= sliding_mask_overlay - # If using chunk attention, add the chunked mask - elif chunk_size is not None: - chunked_mask_overlay = kv_arange // chunk_size == reshaped_cache_position // chunk_size - causal_mask *= chunked_mask_overlay - - causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1) - if padding_mask is not None: - causal_mask = causal_mask * padding_mask[:, None, None, :] - - # Due to a bug in some older torch version, we need to update the mask in case a query is not attending to any - # tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213 - if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix: - causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True) - return causal_mask From 65357d9b24dfc065a0488fc1e42ba40a8c1b05b2 Mon Sep 17 00:00:00 2001 From: vasqu Date: Fri, 24 Oct 2025 22:36:50 +0200 Subject: [PATCH 07/10] typo --- src/transformers/masking_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 43b9380b7671..8f54930d42c8 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -83,7 +83,7 @@ def bidirectional_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_id """ This creates a full bidirectional mask. - NOTE: It is important to keep and index-based version for non-vmap expansion. + NOTE: It is important to keep an index-based version for non-vmap expansion. """ return q_idx >= 0 From aaaaec2b31bf24aa8f822f220c08854779f2b875 Mon Sep 17 00:00:00 2001 From: vasqu Date: Tue, 28 Oct 2025 18:12:43 +0100 Subject: [PATCH 08/10] docs, fix older torch inplace issue, proper kwarg handling --- src/transformers/masking_utils.py | 79 +++++++++++++++++++++++++------ 1 file changed, 65 insertions(+), 14 deletions(-) diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 8f54930d42c8..eb7652a65566 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -285,6 +285,13 @@ def _ignore_bidirectional_mask_sdpa(padding_mask: Optional[torch.Tensor]) -> boo def _non_vmap_expansion_sdpa(batch_size, cache_position, kv_arange): + """ + Broadcasts indices along their non-responsible dimensions. + Allows the usage of any index-based mask function. + + Reference: + - https://github.com/huggingface/optimum-onnx/blob/c123e8f4fab61b54a8e0e31ce74462bcacca576e/optimum/exporters/onnx/model_patcher.py#L362-L365 + """ device = cache_position.device batch_indices = torch.arange(batch_size, dtype=torch.long, device=device)[:, None, None, None] head_indices = torch.arange(1, dtype=torch.long, device=device)[None, :, None, None] @@ -304,7 +311,7 @@ def sdpa_mask( allow_is_causal_skip: bool = True, allow_is_bidirectional_skip: bool = False, allow_torch_fix: bool = True, - use_vmap: bool = False, # TODO: docstring + use_vmap: bool = False, **kwargs, ) -> Optional[torch.Tensor]: """ @@ -340,6 +347,9 @@ def sdpa_mask( allow_torch_fix (`bool`, optional): Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older versions. We need an arg to skip it when using eager. By default `True`. + use_vmap (`bool`, optional): + Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be + index-based (for the cost of speed performance). By default `False`. ## Creating a simple causal mask: @@ -432,11 +442,15 @@ def sdpa_mask( kv_arange = torch.arange(kv_length, device=cache_position.device) kv_arange += kv_offset + # Actual mask creation + # Option 1: Fast non-vmap mask creation (default) if not use_vmap: # Apply mask function element-wise through broadcasting attention_mask = mask_function(*_non_vmap_expansion_sdpa(batch_size, cache_position, kv_arange)) # Expand the mask to match batch size and query length if they weren't used in the mask function attention_mask = attention_mask.expand(batch_size, -1, q_length, kv_length) + + # Option 2: Vmap mask creation (torch>=2.6 and custom patterns) elif _is_torch_greater_or_equal_than_2_6: batch_arange = torch.arange(batch_size, device=cache_position.device) head_arange = torch.arange(1, device=cache_position.device) @@ -445,6 +459,8 @@ def sdpa_mask( # We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices with TransformGetItemToIndex(): attention_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange) + + # Option 3: Limited vmap mask creation (torch<2.6 and custom patterns) else: # This creates the 4D mask easily. Note that we do not include vmap over the batch_idx dimension as well, # as vmap cannot handle slicing a tensor from scalar tensor (it internally calls `.item()` which vmap does not allow @@ -458,7 +474,7 @@ def sdpa_mask( # Due to a bug in versions of torch<2.5, we need to update the mask in case a query is not attending to any # tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213 if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix: - attention_mask |= torch.all(~attention_mask, dim=-1, keepdim=True) + attention_mask = attention_mask | torch.all(~attention_mask, dim=-1, keepdim=True) return attention_mask @@ -471,7 +487,7 @@ def eager_mask( mask_function: Callable = causal_mask_function, attention_mask: Optional[torch.Tensor] = None, dtype: torch.dtype = torch.float32, - uses_vmap: bool = False, + use_vmap: bool = False, **kwargs, ) -> torch.Tensor: """ @@ -494,6 +510,9 @@ def eager_mask( The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length) dtype (`torch.dtype`, optional): The dtype to use for the mask. By default, `torch.float32`. + use_vmap (`bool`, optional): + Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be + index-based (for the cost of speed performance). By default `False`. """ # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf _ = kwargs.pop("allow_is_causal_skip", None) @@ -509,7 +528,7 @@ def eager_mask( allow_is_causal_skip=False, allow_is_bidirectional_skip=False, allow_torch_fix=False, - uses_vmap=uses_vmap, + use_vmap=use_vmap, **kwargs, ) min_dtype = torch.finfo(dtype).min @@ -791,6 +810,11 @@ def create_causal_mask( mask_factory_function = causal_mask_function mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] + # Defaulting to using non-vmap based mask creations except when detecting + # users passing custom mask functions (as we cannot guarantee that they + # are properly index-based as required by our implementation). + use_vmap = False + # Do not allow skip if we are compiling (this is to match BC) # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it if _is_torch_xpu_available: @@ -802,24 +826,25 @@ def create_causal_mask( # Allow slight deviations from causal mask # Note that it is very important to apply this before any other deviations of the mask (such as packed sequence mask, # padding mask, etc) as the resulting mask may otherwise not be correct! - uses_vmap = False # TODO --> move a bit, describe, and copy to other fns if or_mask_function is not None: if not _is_torch_greater_or_equal_than_2_6: raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") mask_factory_function = or_masks(mask_factory_function, or_mask_function) allow_is_causal_skip = False - uses_vmap = True + use_vmap = True if and_mask_function is not None: if not _is_torch_greater_or_equal_than_2_6: raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") mask_factory_function = and_masks(mask_factory_function, and_mask_function) allow_is_causal_skip = False - uses_vmap = True + use_vmap = True # If we detected packing format if packed_sequence_mask is not None: - if uses_vmap and not _is_torch_greater_or_equal_than_2_6: - raise ValueError("TODO: custom patterns not possible with padding free") + if use_vmap and not _is_torch_greater_or_equal_than_2_6: + raise ValueError( + "Packed masking along custom patterns (i.e. and/or masks) is only allowed from `torch>=2.6.0`." + ) mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask)) allow_is_causal_skip = False @@ -834,7 +859,7 @@ def create_causal_mask( allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa dtype=dtype, # Additional kwarg for eager config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface - uses_vmap=uses_vmap, # Short-circuit to non-vmap expansions for the mask + use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask ) return causal_mask @@ -888,23 +913,26 @@ def create_bidirectional_mask( # Allow skipping the mask creation except we have additional masking operators (and/or masks) allow_is_bidirectional_skip = True + # Defaulting to using non-vmap based mask creations except when detecting + # users passing custom mask functions (as we cannot guarantee that they + # are properly index-based as required by our implementation). + use_vmap = False # Allow slight deviations from the base mask # Note that it is very important to apply this before any other deviations of the mask (such as packed sequence mask, # padding mask, etc) as the resulting mask may otherwise not be correct! - uses_vmap = False if or_mask_function is not None: if not _is_torch_greater_or_equal_than_2_6: raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") mask_factory_function = or_masks(mask_factory_function, or_mask_function) allow_is_bidirectional_skip = False - uses_vmap = True + use_vmap = True if and_mask_function is not None: if not _is_torch_greater_or_equal_than_2_6: raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") mask_factory_function = and_masks(mask_factory_function, and_mask_function) allow_is_bidirectional_skip = False - uses_vmap = True + use_vmap = True # We now create the mask attention_mask = mask_interface( @@ -919,7 +947,7 @@ def create_bidirectional_mask( allow_is_bidirectional_skip=allow_is_bidirectional_skip, dtype=dtype, # Additional kwarg for eager config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface - uses_vmap=uses_vmap, # Short-circuit to non-vmap expansions for the mask + use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask ) return attention_mask @@ -982,6 +1010,10 @@ def create_sliding_window_causal_mask( mask_factory_function = sliding_window_causal_mask_function(sliding_window) mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] + # Defaulting to using non-vmap based mask creations except when detecting + # users passing custom mask functions (as we cannot guarantee that they + # are properly index-based as required by our implementation). + use_vmap = False # Do not allow skip if we are compiling (this is to match BC) # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it allow_is_causal_skip = not getattr(past_key_values, "is_compileable", False) @@ -994,14 +1026,20 @@ def create_sliding_window_causal_mask( raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") mask_factory_function = or_masks(mask_factory_function, or_mask_function) allow_is_causal_skip = False + use_vmap = True if and_mask_function is not None: if not _is_torch_greater_or_equal_than_2_6: raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") mask_factory_function = and_masks(mask_factory_function, and_mask_function) allow_is_causal_skip = False + use_vmap = True # If we detected packing format if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6: + if use_vmap and not _is_torch_greater_or_equal_than_2_6: + raise ValueError( + "Packed masking along custom patterns (i.e. and/or masks) is only allowed from `torch>=2.6.0`." + ) mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask)) allow_is_causal_skip = False @@ -1017,6 +1055,7 @@ def create_sliding_window_causal_mask( local_size=sliding_window, # Additional kwarg for sdpa dtype=dtype, # Additional kwarg for eager config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface + use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask ) return causal_mask @@ -1083,6 +1122,7 @@ def create_chunked_causal_mask( ) batch_size, dtype = input_embeds.shape[0], input_embeds.dtype + # TODO: check if we can use non-vmap here + if the old torch restriction is still valid then # For chunked attention and batched inputs, we need to take the number of left padding tokens into account # to start the chunk from the actual start of the sequence for the padded sequence if attention_mask is not None: @@ -1104,6 +1144,10 @@ def create_chunked_causal_mask( mask_factory_function = chunked_causal_mask_function(chunk_size, left_padding_tokens) mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] + # Defaulting to using non-vmap based mask creations except when detecting + # users passing custom mask functions (as we cannot guarantee that they + # are properly index-based as required by our implementation). + use_vmap = False # Do not allow skip if we are compiling (this is to match BC) # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it allow_is_causal_skip = not getattr(past_key_values, "is_compileable", False) @@ -1116,14 +1160,20 @@ def create_chunked_causal_mask( raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") mask_factory_function = or_masks(mask_factory_function, or_mask_function) allow_is_causal_skip = False + use_vmap = True if and_mask_function is not None: if not _is_torch_greater_or_equal_than_2_6: raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") mask_factory_function = and_masks(mask_factory_function, and_mask_function) allow_is_causal_skip = False + use_vmap = True # If we detected packing format if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6: + if use_vmap and not _is_torch_greater_or_equal_than_2_6: + raise ValueError( + "Packed masking along custom patterns (i.e. and/or masks) is only allowed from `torch>=2.6.0`." + ) mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask)) allow_is_causal_skip = False @@ -1139,6 +1189,7 @@ def create_chunked_causal_mask( local_size=chunk_size, # Additional kwarg for sdpa dtype=dtype, # Additional kwarg for eager config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface + use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask ) return causal_mask From 539bafad6124fb205d8d2ca3769b6740bcfba06a Mon Sep 17 00:00:00 2001 From: vasqu Date: Tue, 28 Oct 2025 18:35:38 +0100 Subject: [PATCH 09/10] chunked works with non vmap and older torch, add warning on non guaranteed masks --- src/transformers/masking_utils.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index eb7652a65566..5e6752b0e198 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -135,8 +135,6 @@ def chunked_causal_mask_function(chunk_size: int, left_padding: torch.Tensor) -> """ This return the mask_function function to create a chunked attention mask. """ - if not _is_torch_greater_or_equal_than_2_6: - return and_masks(_legacy_chunked_overlay(chunk_size), causal_mask_function) return and_masks(chunked_overlay(chunk_size, left_padding), causal_mask_function) @@ -462,6 +460,11 @@ def sdpa_mask( # Option 3: Limited vmap mask creation (torch<2.6 and custom patterns) else: + logger.warning_once( + "Using vmap mask creation (custom patterns with and/or masks) has limited capabilities under " + "`torch<2.6.0`. We cannot guarantee that the correct mask is constructed." + ) + # This creates the 4D mask easily. Note that we do not include vmap over the batch_idx dimension as well, # as vmap cannot handle slicing a tensor from scalar tensor (it internally calls `.item()` which vmap does not allow # However, in more recent version of Pytorch, a trick was introduced to handle it @@ -1122,7 +1125,6 @@ def create_chunked_causal_mask( ) batch_size, dtype = input_embeds.shape[0], input_embeds.dtype - # TODO: check if we can use non-vmap here + if the old torch restriction is still valid then # For chunked attention and batched inputs, we need to take the number of left padding tokens into account # to start the chunk from the actual start of the sequence for the padded sequence if attention_mask is not None: @@ -1130,17 +1132,6 @@ def create_chunked_causal_mask( left_padding_tokens = (attention_mask.cumsum(dim=-1) == torch.zeros_like(attention_mask)).sum(dim=-1) else: left_padding_tokens = torch.zeros(batch_size, device=cache_position.device, dtype=int) - # Raise a warning for older versions if the problematic left-padding situation arises - if ( - not _is_torch_greater_or_equal_than_2_6 - and kv_length + kv_offset > chunk_size - and (left_padding_tokens > 0).any() - ): - logger.warning_once( - "Due to limitations of your current torch version, we cannot correctly account for the left-padding " - "when computing the chunked attention pattern. This will lead to a wrong attention mask for the padded " - "sequences. Behavior will be undefined. Please upgrade to `torch>=2.6` to solve this issue." - ) mask_factory_function = chunked_causal_mask_function(chunk_size, left_padding_tokens) mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] From 01848e3bdca9b9b984e6e2f376eb85548a8aa3ab Mon Sep 17 00:00:00 2001 From: vasqu Date: Tue, 28 Oct 2025 19:06:22 +0100 Subject: [PATCH 10/10] lift unnecessary restriction on older torch --- src/transformers/masking_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 5e6752b0e198..71b15adf714a 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -1038,7 +1038,7 @@ def create_sliding_window_causal_mask( use_vmap = True # If we detected packing format - if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6: + if packed_sequence_mask is not None: if use_vmap and not _is_torch_greater_or_equal_than_2_6: raise ValueError( "Packed masking along custom patterns (i.e. and/or masks) is only allowed from `torch>=2.6.0`." @@ -1160,7 +1160,7 @@ def create_chunked_causal_mask( use_vmap = True # If we detected packing format - if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6: + if packed_sequence_mask is not None: if use_vmap and not _is_torch_greater_or_equal_than_2_6: raise ValueError( "Packed masking along custom patterns (i.e. and/or masks) is only allowed from `torch>=2.6.0`."