diff --git a/src/transformers/integrations/executorch.py b/src/transformers/integrations/executorch.py index 0d4910732528..635ab7abe744 100644 --- a/src/transformers/integrations/executorch.py +++ b/src/transformers/integrations/executorch.py @@ -11,7 +11,6 @@ # specific language governing permissions and limitations under the License. import logging -from collections.abc import Callable from typing import Optional import torch @@ -24,13 +23,7 @@ StaticCache, ) from ..generation.configuration_utils import GenerationConfig -from ..masking_utils import ( - ALL_MASK_ATTENTION_FUNCTIONS, - _ignore_causal_mask_sdpa, - _is_torch_greater_or_equal_than_2_5, - prepare_padding_mask, -) -from ..modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ..modeling_utils import PreTrainedModel from ..pytorch_utils import ( is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3, @@ -229,10 +222,6 @@ def __init__( "Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config." ) self.model = TorchExportableModuleWithStaticCache(model, batch_size, max_cache_len, device) - # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable - ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) - ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) - self.model.model.config._attn_implementation = "sdpa_without_vmap" def forward( self, @@ -768,11 +757,6 @@ def convert_and_export_with_cache( import torch.export._trace - # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable - ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) - ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) - model.config._attn_implementation = "sdpa_without_vmap" - with torch.no_grad(): # TODO: The default inputs only work for text models. We need to add support for vision/audio models. example_input_ids = ( @@ -1036,11 +1020,6 @@ def export_with_dynamic_cache( if not is_torch_greater_or_equal_than_2_3: raise ImportError("torch >= 2.3 is required.") - # This is the same as sdpa, but mask creation does not use `vmap` which is not exportable - ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap) - ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"]) - model.config._attn_implementation = "sdpa_without_vmap" - register_dynamic_cache_export_support() with torch.no_grad(): @@ -1109,92 +1088,3 @@ def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context): value = value_list[idx] if idx < len(value_list) else None cache.update(key, value, idx) return cache - - -def sdpa_mask_without_vmap( - batch_size: int, - cache_position: torch.Tensor, - kv_length: int, - kv_offset: int = 0, - mask_function: Optional[Callable] = None, - attention_mask: Optional[torch.Tensor] = None, - local_size: Optional[int] = None, - allow_is_causal_skip: bool = True, - allow_torch_fix: bool = True, - **kwargs, -) -> Optional[torch.Tensor]: - """ - Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that - the element should take part in the attention computation, and False that it should not. - - This is similar to `masking_utils.sdpa_mask` but does not use `vmap` which is incompatible with export. - - Args: - batch_size (`int`): - The batch size of the input sequence. - cache_position (`torch.Tensor`): - A tensor of shape (query_length,) indicating the current indices of the input sequence elements. - kv_length (`int`): - The size that the key and value states will have during the attention computation. - kv_offset (`int`, optional): - An optional offset to indicate at which first position the key and values states will refer to. - mask_function (`Callable`): - The mask factory function describing the mask pattern. - attention_mask (`torch.Tensor`, optional): - The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length) - local_size (`int`, optional): - The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True` - to try to skip mask creation if possible. - allow_is_causal_skip (`bool`, optional): - Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in - `torch.sdpa` instead. Default to `True`. - allow_torch_fix (`bool`, optional): - Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older - versions. We need an arg to skip it when using eager. By default `True`. - - """ - - q_length = cache_position.shape[0] - # Potentially pad the 2D mask, and slice it correctly - padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset) - - # Under specific conditions, we can avoid materializing the mask, instead relying on the `is_causal` argument - if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, local_size): - return None - - # Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)` - # but without data-dependent slicing (i.e. torch.compile friendly) - kv_arange = torch.arange(kv_length, device=cache_position.device) - kv_arange += kv_offset - reshaped_cache_position = cache_position.view(-1, 1) - - # This is a bit hacky to know what pattern we are using, but all mask creation function actually forward - # the config through kwargs anyway, so it allows to rely on it - # Usually, the `mask_function` is the only entry-point to define the pattern - we could do for loops over it, - # but this is more efficient - sliding_window = getattr(kwargs["config"], "sliding_window", None) - chunk_size = getattr(kwargs["config"], "attention_chunk_size", None) - - if sliding_window is not None and chunk_size is not None: - raise ValueError("Cannot use both `sliding_window` and `attention_chunk_size`") - - # Simplest and most efficient way to obtain a causal mask - causal_mask = kv_arange <= reshaped_cache_position - # If using sliding window, add the sliding mask - if sliding_window is not None: - sliding_mask_overlay = kv_arange > reshaped_cache_position - sliding_window - causal_mask *= sliding_mask_overlay - # If using chunk attention, add the chunked mask - elif chunk_size is not None: - chunked_mask_overlay = kv_arange // chunk_size == reshaped_cache_position // chunk_size - causal_mask *= chunked_mask_overlay - - causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1) - if padding_mask is not None: - causal_mask = causal_mask * padding_mask[:, None, None, :] - - # Due to a bug in some older torch version, we need to update the mask in case a query is not attending to any - # tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213 - if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix: - causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True) - return causal_mask diff --git a/src/transformers/masking_utils.py b/src/transformers/masking_utils.py index 8f0a8a04fd60..d2e7d08016ca 100644 --- a/src/transformers/masking_utils.py +++ b/src/transformers/masking_utils.py @@ -82,8 +82,10 @@ def causal_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) def bidirectional_mask_function(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: """ This creates a full bidirectional mask. + + NOTE: It is important to keep an index-based version for non-vmap expansion. """ - return q_idx.new_ones((), dtype=torch.bool) + return q_idx >= 0 def sliding_window_overlay(sliding_window: int) -> Callable: @@ -133,8 +135,6 @@ def chunked_causal_mask_function(chunk_size: int, left_padding: torch.Tensor) -> """ This return the mask_function function to create a chunked attention mask. """ - if not _is_torch_greater_or_equal_than_2_6: - return and_masks(_legacy_chunked_overlay(chunk_size), causal_mask_function) return and_masks(chunked_overlay(chunk_size, left_padding), causal_mask_function) @@ -175,52 +175,17 @@ def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: return inner_mask -def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable: - """ - Used to vmap our mask_functions over the q_idx and kv_idx dimensions of the inputs. Optionally, vmap over - the batch and head indices as well if `bh_indices=True`. - Using vmap here allows us to keep the performance of vectorized ops, while having a single set of primitive - functions between attention interfaces (i.e. between flex and sdpa/eager, FA2 being a bit different). - - Args: - mask_function (`Callable`): - The mask_function to vmap. - bh_indices (`bool`, optional): - Whether to vmap over the batch and head indices as well, or only q and kv indices. - - Returns: - Callable: The vmapped function. - """ - # We vmap the function 2 times, broadcasting the [q_idx, kv_idx] dimensions - dimensions = [(None, None, None, 0), (None, None, 0, None)] - if bh_indices: - # We extend broadcasting over the [batch_idx, head_idx] dimensions - dimensions.extend([(None, 0, None, None), (0, None, None, None)]) - - for dims in dimensions: - mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0) - return mask_function - - def prepare_padding_mask( - attention_mask: Optional[torch.Tensor], kv_length: int, kv_offset: int, _slice: bool = True + attention_mask: Optional[torch.Tensor], kv_length: int, kv_offset: int ) -> Optional[torch.Tensor]: """ - From the 2D attention mask, prepare the correct padding mask to use by potentially padding it, and slicing - according to the `kv_offset` if `_slice` is `True`. + From the 2D attention mask, prepare the correct padding mask to use by potentially padding it. """ local_padding_mask = attention_mask if attention_mask is not None: # Pad it if necessary if (padding_length := kv_length + kv_offset - attention_mask.shape[-1]) > 0: local_padding_mask = torch.nn.functional.pad(attention_mask, (0, padding_length)) - # For flex, we should not slice them, only use an offset - if _slice: - # Equivalent to: `local_padding_mask = attention_mask[:, kv_offset : kv_offset + kv_length]`, - # but without data-dependent slicing (i.e. torch.compile friendly) - mask_indices = torch.arange(kv_length, device=local_padding_mask.device) - mask_indices += kv_offset - local_padding_mask = local_padding_mask[:, mask_indices] return local_padding_mask @@ -282,7 +247,38 @@ def _ignore_bidirectional_mask_sdpa(padding_mask: Optional[torch.Tensor]) -> boo return False -def sdpa_mask_recent_torch( +def _vmap_expansion_sdpa(mask_function: Callable) -> Callable: + """ + Used to vmap our mask_functions over the all 4 dimensions (b_idx, h_idx, q_idx, kv_idx) of the inputs. + Using vmap here allows us to keep the performance of vectorized ops, while having a single set of primitive + functions between attention interfaces (i.e. between flex and sdpa/eager, FA2 being a bit different). + """ + # We vmap the function over all 4 dimensions, broadcasting [b_idx, h_idx, q_idx, kv_idx] + dimensions = [(None, None, None, 0), (None, None, 0, None), (None, 0, None, None), (0, None, None, None)] + for dims in dimensions: + mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0) + return mask_function + + +def _non_vmap_expansion_sdpa(batch_size, cache_position, kv_arange): + """ + Used to broadcast our mask_functions over the all 4 dimensions (b_idx, h_idx, q_idx, kv_idx) of the inputs. + Allows the usage of any index-based mask function without relying on vmap. + + NOTE: This is limited to index based functions only and is not guaranteed to work otherwise. + + Reference: + - https://github.com/huggingface/optimum-onnx/blob/c123e8f4fab61b54a8e0e31ce74462bcacca576e/optimum/exporters/onnx/model_patcher.py#L362-L365 + """ + device = cache_position.device + batch_indices = torch.arange(batch_size, dtype=torch.long, device=device)[:, None, None, None] + head_indices = torch.arange(1, dtype=torch.long, device=device)[None, :, None, None] + q_indices = cache_position[None, None, :, None] + kv_indices = kv_arange.to(dtype=torch.long)[None, None, None, :] + return batch_indices, head_indices, q_indices, kv_indices + + +def sdpa_mask( batch_size: int, cache_position: torch.Tensor, kv_length: int, @@ -292,6 +288,8 @@ def sdpa_mask_recent_torch( local_size: Optional[int] = None, allow_is_causal_skip: bool = True, allow_is_bidirectional_skip: bool = False, + allow_torch_fix: bool = True, + use_vmap: bool = False, **kwargs, ) -> Optional[torch.Tensor]: """ @@ -324,6 +322,12 @@ def sdpa_mask_recent_torch( allow_is_bidirectional_skip (`bool`, optional): Whether to allow to return `None` for the mask under conditions where we do not have to add any bias, i.e. full attention without any padding. Default to `False`. + allow_torch_fix (`bool`, optional): + Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older + versions. We need an arg to skip it when using eager. By default `True`. + use_vmap (`bool`, optional): + Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be + index-based (for the cost of speed performance). By default `False`. ## Creating a simple causal mask: @@ -391,8 +395,9 @@ def sdpa_mask_recent_torch( """ q_length = cache_position.shape[0] - # Potentially pad the 2D mask, and slice it correctly - padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False) + + # Potentially pad the 2D mask + padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset) # Under specific conditions, we can avoid materializing the mask # 1. Causal masks can rely on the `is_causal` argument @@ -402,128 +407,46 @@ def sdpa_mask_recent_torch( if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask): return None - # vmap can incur performance issues as reported in #41566 for bidirectional mask as we only need to expand the - # padding mask. Thus, we allow early exit here if we do not detect any modification to the base mask function - if mask_function is bidirectional_mask_function: - if padding_mask is not None: - # used for slicing without data-dependent slicing - mask_indices = torch.arange(kv_length, device=cache_position.device) + kv_offset - return padding_mask[:, None, None, mask_indices].expand(-1, -1, q_length, -1) - else: - return torch.ones(batch_size, 1, q_length, kv_length, dtype=torch.bool, device=cache_position.device) - - # Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)` - # but without data-dependent slicing (i.e. torch.compile friendly) - kv_arange = torch.arange(kv_length, device=cache_position.device) - kv_arange += kv_offset - # Potentially add the padding 2D mask if padding_mask is not None: mask_function = and_masks(mask_function, padding_mask_function(padding_mask)) - batch_arange = torch.arange(batch_size, device=cache_position.device) - head_arange = torch.arange(1, device=cache_position.device) - # This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from - # scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it - # We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices - with TransformGetItemToIndex(): - causal_mask = _vmap_for_bhqkv(mask_function)(batch_arange, head_arange, cache_position, kv_arange) - - return causal_mask - - -def sdpa_mask_older_torch( - batch_size: int, - cache_position: torch.Tensor, - kv_length: int, - kv_offset: int = 0, - mask_function: Callable = causal_mask_function, - attention_mask: Optional[torch.Tensor] = None, - local_size: Optional[int] = None, - allow_is_causal_skip: bool = True, - allow_torch_fix: bool = True, - allow_is_bidirectional_skip: bool = False, - **kwargs, -) -> Optional[torch.Tensor]: - """ - NOTE: This function is only used when torch version is torch<2.5 - see `sdpa_mask_recent_torch` otherwise. - - Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that - the element should take part in the attention computation, and False that it should not. - If `allow_torch_fix=True` (the default), rows corresponding to query tokens that do not attend - to any other tokens (due to padding) will be fully attended to instead, in order to avoid `nan` propagation (this does - not change the final result). - - Args: - batch_size (`int`): - The batch size of the input sequence. - cache_position (`torch.Tensor`): - A tensor of shape (query_length,) indicating the current indices of the input sequence elements. - kv_length (`int`): - The size that the key and value states will have during the attention computation. - kv_offset (`int`, optional): - An optional offset to indicate at which first position the key and values states will refer to. - mask_function (`Callable`): - The mask factory function describing the mask pattern. - attention_mask (`torch.Tensor`, optional): - The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length) - local_size (`int`, optional): - The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True` - to try to skip mask creation if possible. - allow_is_causal_skip (`bool`, optional): - Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in - `torch.sdpa` instead. Default to `True`. - allow_torch_fix (`bool`, optional): - Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older - versions. We need an arg to skip it when using eager. By default `True`. - allow_is_bidirectional_skip (`bool`, optional): - Whether to allow to return `None` for the mask under conditions where we do not have to add any bias, - i.e. full attention without any padding. Default to `False`. - """ - q_length = cache_position.shape[0] - # Potentially pad the 2D mask, and slice it correctly - padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset) - - # Under specific conditions, we can avoid materializing the mask - # 1. Causal masks can rely on the `is_causal` argument - # 2. Bidirectional do not need any further processing (no bias) - if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, kv_offset, local_size): - return None - if allow_is_bidirectional_skip and _ignore_bidirectional_mask_sdpa(padding_mask): - return None - - # vmap can incur performance issues as reported in #41566 for bidirectional mask as we only need to expand the - # padding mask. Thus, we allow early exit here if we do not detect any modification to the base mask function - if mask_function is bidirectional_mask_function: - if padding_mask is not None: - return padding_mask[:, None, None, :].expand(-1, -1, q_length, -1) - else: - return torch.ones(batch_size, 1, q_length, kv_length, dtype=torch.bool, device=cache_position.device) - # Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)` # but without data-dependent slicing (i.e. torch.compile friendly) kv_arange = torch.arange(kv_length, device=cache_position.device) kv_arange += kv_offset - # This creates the 4D mask easily. Note that we do not include vmap over the batch_idx dimension as well, - # as vmap cannot handle slicing a tensor from scalar tensor (it internally calls `.item()` which vmap does not allow - # However, in more recent version of Pytorch, a trick was introduced to handle it - which is the reason we have - # `sdpa_mask_recent_torch`, as it allows more general `mask_function` - causal_mask = _vmap_for_bhqkv(mask_function, bh_indices=False)(None, None, cache_position, kv_arange) - causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1) - if padding_mask is not None: - causal_mask = causal_mask * padding_mask[:, None, None, :] + # Actual mask creation + # Option 1: Fast non-vmap mask creation (default) + if not use_vmap: + # Apply mask function element-wise through broadcasting + attention_mask = mask_function(*_non_vmap_expansion_sdpa(batch_size, cache_position, kv_arange)) + # Expand the mask to match batch size and query length if they weren't used in the mask function + attention_mask = attention_mask.expand(batch_size, -1, q_length, kv_length) + + # Option 2: Vmap mask creation (torch>=2.6 and custom patterns) + elif _is_torch_greater_or_equal_than_2_6: + batch_arange = torch.arange(batch_size, device=cache_position.device) + head_arange = torch.arange(1, device=cache_position.device) + # This creates the 4D mask easily. Note that we need this context manager as vmap cannot handle slicing a tensor from + # scalar tensor (it internally calls `.item()` which vmap does not allow, but this context works around it + # We don't need to add an offset to the mask_function either, as we vmap directly the correct indices for k and kv indices + with TransformGetItemToIndex(): + attention_mask = _vmap_expansion_sdpa(mask_function)(batch_arange, head_arange, cache_position, kv_arange) + + # Option 3: Error out since it indicates that the user did something custom, which they shouldn't have (torch<2.6) + else: + raise ValueError( + "The vmap functionality for mask creation is only supported from torch>=2.6. " + "Please update your torch version or use `use_vmap=False` with index-based masks." + ) # Due to a bug in versions of torch<2.5, we need to update the mask in case a query is not attending to any # tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213 if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix: - causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True) - return causal_mask - + attention_mask = attention_mask | torch.all(~attention_mask, dim=-1, keepdim=True) -# We use the version with newer torch whenever possible, as it is more general and can handle arbitrary mask functions -# (especially mask_function indexing a tensor, such as the padding mask function) -sdpa_mask = sdpa_mask_recent_torch if _is_torch_greater_or_equal_than_2_6 else sdpa_mask_older_torch + return attention_mask def eager_mask( @@ -534,6 +457,7 @@ def eager_mask( mask_function: Callable = causal_mask_function, attention_mask: Optional[torch.Tensor] = None, dtype: torch.dtype = torch.float32, + use_vmap: bool = False, **kwargs, ) -> torch.Tensor: """ @@ -556,10 +480,14 @@ def eager_mask( The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length) dtype (`torch.dtype`, optional): The dtype to use for the mask. By default, `torch.float32`. + use_vmap (`bool`, optional): + Whether to use `vmap` during the mask construction or not. Allows powerful custom patterns that may not be + index-based (for the cost of speed performance). By default `False`. """ # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf _ = kwargs.pop("allow_is_causal_skip", None) _ = kwargs.pop("allow_is_bidirectional_skip", None) + _ = kwargs.pop("allow_torch_fix", None) mask = sdpa_mask( batch_size=batch_size, cache_position=cache_position, @@ -570,6 +498,7 @@ def eager_mask( allow_is_causal_skip=False, allow_is_bidirectional_skip=False, allow_torch_fix=False, + use_vmap=use_vmap, **kwargs, ) min_dtype = torch.finfo(dtype).min @@ -655,7 +584,7 @@ def flex_attention_mask( if not _is_torch_greater_or_equal_than_2_6 and pad_len > 0: attention_mask = torch.nn.functional.pad(attention_mask, value=0, pad=(0, pad_len)) - padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False) + padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset) mask_function = and_masks(mask_function, padding_mask_function(padding_mask)) # Add the offsets on top (because flex interface only allows length, not start and end indices) @@ -851,6 +780,11 @@ def create_causal_mask( mask_factory_function = causal_mask_function mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] + # Defaulting to using non-vmap based mask creations except when detecting + # users passing custom mask functions (as we cannot guarantee that they + # are properly index-based as required by our implementation). + use_vmap = False + # Do not allow skip if we are compiling (this is to match BC) # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it if _is_torch_xpu_available: @@ -867,14 +801,16 @@ def create_causal_mask( raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") mask_factory_function = or_masks(mask_factory_function, or_mask_function) allow_is_causal_skip = False + use_vmap = True if and_mask_function is not None: if not _is_torch_greater_or_equal_than_2_6: raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") mask_factory_function = and_masks(mask_factory_function, and_mask_function) allow_is_causal_skip = False + use_vmap = True # If we detected packing format - if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6: + if packed_sequence_mask is not None: mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask)) allow_is_causal_skip = False @@ -889,6 +825,7 @@ def create_causal_mask( allow_is_causal_skip=allow_is_causal_skip, # additional kwarg for sdpa dtype=dtype, # Additional kwarg for eager config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface + use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask ) return causal_mask @@ -942,6 +879,10 @@ def create_bidirectional_mask( # Allow skipping the mask creation except we have additional masking operators (and/or masks) allow_is_bidirectional_skip = True + # Defaulting to using non-vmap based mask creations except when detecting + # users passing custom mask functions (as we cannot guarantee that they + # are properly index-based as required by our implementation). + use_vmap = False # Allow slight deviations from the base mask # Note that it is very important to apply this before any other deviations of the mask (such as packed sequence mask, @@ -951,11 +892,13 @@ def create_bidirectional_mask( raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") mask_factory_function = or_masks(mask_factory_function, or_mask_function) allow_is_bidirectional_skip = False + use_vmap = True if and_mask_function is not None: if not _is_torch_greater_or_equal_than_2_6: raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") mask_factory_function = and_masks(mask_factory_function, and_mask_function) allow_is_bidirectional_skip = False + use_vmap = True # We now create the mask attention_mask = mask_interface( @@ -970,6 +913,7 @@ def create_bidirectional_mask( allow_is_bidirectional_skip=allow_is_bidirectional_skip, dtype=dtype, # Additional kwarg for eager config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface + use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask ) return attention_mask @@ -1032,6 +976,10 @@ def create_sliding_window_causal_mask( mask_factory_function = sliding_window_causal_mask_function(sliding_window) mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] + # Defaulting to using non-vmap based mask creations except when detecting + # users passing custom mask functions (as we cannot guarantee that they + # are properly index-based as required by our implementation). + use_vmap = False # Do not allow skip if we are compiling (this is to match BC) # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it allow_is_causal_skip = not getattr(past_key_values, "is_compileable", False) @@ -1044,14 +992,16 @@ def create_sliding_window_causal_mask( raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") mask_factory_function = or_masks(mask_factory_function, or_mask_function) allow_is_causal_skip = False + use_vmap = True if and_mask_function is not None: if not _is_torch_greater_or_equal_than_2_6: raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") mask_factory_function = and_masks(mask_factory_function, and_mask_function) allow_is_causal_skip = False + use_vmap = True # If we detected packing format - if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6: + if packed_sequence_mask is not None: mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask)) allow_is_causal_skip = False @@ -1067,6 +1017,7 @@ def create_sliding_window_causal_mask( local_size=sliding_window, # Additional kwarg for sdpa dtype=dtype, # Additional kwarg for eager config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface + use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask ) return causal_mask @@ -1140,20 +1091,13 @@ def create_chunked_causal_mask( left_padding_tokens = (attention_mask.cumsum(dim=-1) == torch.zeros_like(attention_mask)).sum(dim=-1) else: left_padding_tokens = torch.zeros(batch_size, device=cache_position.device, dtype=int) - # Raise a warning for older versions if the problematic left-padding situation arises - if ( - not _is_torch_greater_or_equal_than_2_6 - and kv_length + kv_offset > chunk_size - and (left_padding_tokens > 0).any() - ): - logger.warning_once( - "Due to limitations of your current torch version, we cannot correctly account for the left-padding " - "when computing the chunked attention pattern. This will lead to a wrong attention mask for the padded " - "sequences. Behavior will be undefined. Please upgrade to `torch>=2.6` to solve this issue." - ) mask_factory_function = chunked_causal_mask_function(chunk_size, left_padding_tokens) mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] + # Defaulting to using non-vmap based mask creations except when detecting + # users passing custom mask functions (as we cannot guarantee that they + # are properly index-based as required by our implementation). + use_vmap = False # Do not allow skip if we are compiling (this is to match BC) # TODO: cyril -> probably revisit and remove this, but a lot of tests rely on it allow_is_causal_skip = not getattr(past_key_values, "is_compileable", False) @@ -1166,14 +1110,16 @@ def create_chunked_causal_mask( raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") mask_factory_function = or_masks(mask_factory_function, or_mask_function) allow_is_causal_skip = False + use_vmap = True if and_mask_function is not None: if not _is_torch_greater_or_equal_than_2_6: raise ValueError("Using `or_mask_function` or `and_mask_function` arguments require torch>=2.6") mask_factory_function = and_masks(mask_factory_function, and_mask_function) allow_is_causal_skip = False + use_vmap = True # If we detected packing format - if packed_sequence_mask is not None and _is_torch_greater_or_equal_than_2_6: + if packed_sequence_mask is not None: mask_factory_function = and_masks(mask_factory_function, packed_sequence_mask_function(packed_sequence_mask)) allow_is_causal_skip = False @@ -1189,6 +1135,7 @@ def create_chunked_causal_mask( local_size=chunk_size, # Additional kwarg for sdpa dtype=dtype, # Additional kwarg for eager config=config, # Pass the config as well, in case someone wants to easily have their own mask_interface + use_vmap=use_vmap, # Short-circuit to non-vmap expansions for the mask ) return causal_mask