Skip to content
Open
112 changes: 1 addition & 111 deletions src/transformers/integrations/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# specific language governing permissions and limitations under the License.

import logging
from collections.abc import Callable
from typing import Optional

import torch
Expand All @@ -24,13 +23,7 @@
StaticCache,
)
from ..generation.configuration_utils import GenerationConfig
from ..masking_utils import (
ALL_MASK_ATTENTION_FUNCTIONS,
_ignore_causal_mask_sdpa,
_is_torch_greater_or_equal_than_2_5,
prepare_padding_mask,
)
from ..modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ..modeling_utils import PreTrainedModel
from ..pytorch_utils import (
is_torch_greater_or_equal,
is_torch_greater_or_equal_than_2_3,
Expand Down Expand Up @@ -229,10 +222,6 @@ def __init__(
"Using `StaticCache` for export as `layer_types` is not specified or `sliding_window` is `null` in the config."
)
self.model = TorchExportableModuleWithStaticCache(model, batch_size, max_cache_len, device)
# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
self.model.model.config._attn_implementation = "sdpa_without_vmap"

def forward(
self,
Expand Down Expand Up @@ -768,11 +757,6 @@ def convert_and_export_with_cache(

import torch.export._trace

# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
model.config._attn_implementation = "sdpa_without_vmap"

with torch.no_grad():
# TODO: The default inputs only work for text models. We need to add support for vision/audio models.
example_input_ids = (
Expand Down Expand Up @@ -1036,11 +1020,6 @@ def export_with_dynamic_cache(
if not is_torch_greater_or_equal_than_2_3:
raise ImportError("torch >= 2.3 is required.")

# This is the same as sdpa, but mask creation does not use `vmap` which is not exportable
ALL_MASK_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", sdpa_mask_without_vmap)
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
model.config._attn_implementation = "sdpa_without_vmap"

register_dynamic_cache_export_support()

with torch.no_grad():
Expand Down Expand Up @@ -1109,92 +1088,3 @@ def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context):
value = value_list[idx] if idx < len(value_list) else None
cache.update(key, value, idx)
return cache


def sdpa_mask_without_vmap(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No longer needed as vmap was the reason we needed this workaround in the first place

batch_size: int,
cache_position: torch.Tensor,
kv_length: int,
kv_offset: int = 0,
mask_function: Optional[Callable] = None,
attention_mask: Optional[torch.Tensor] = None,
local_size: Optional[int] = None,
allow_is_causal_skip: bool = True,
allow_torch_fix: bool = True,
**kwargs,
) -> Optional[torch.Tensor]:
"""
Create a 4D boolean mask of shape `(batch_size, 1, query_length, kv_length)` where a value of True indicates that
the element should take part in the attention computation, and False that it should not.
This is similar to `masking_utils.sdpa_mask` but does not use `vmap` which is incompatible with export.
Args:
batch_size (`int`):
The batch size of the input sequence.
cache_position (`torch.Tensor`):
A tensor of shape (query_length,) indicating the current indices of the input sequence elements.
kv_length (`int`):
The size that the key and value states will have during the attention computation.
kv_offset (`int`, optional):
An optional offset to indicate at which first position the key and values states will refer to.
mask_function (`Callable`):
The mask factory function describing the mask pattern.
attention_mask (`torch.Tensor`, optional):
The 2D attention mask corresponding to padded tokens of shape (batch_size, number_of_seen_tokens+q_length)
local_size (`int`, optional):
The size of the local attention, if we do not use full attention. This is used only if `allow_is_causal_skip=True`
to try to skip mask creation if possible.
allow_is_causal_skip (`bool`, optional):
Whether to allow to return `None` for the mask under conditions where we can use the `is_causal` argument in
`torch.sdpa` instead. Default to `True`.
allow_torch_fix (`bool`, optional):
Whether to update the mask in case a query is not attending to any tokens, to solve a bug in torch's older
versions. We need an arg to skip it when using eager. By default `True`.
"""

q_length = cache_position.shape[0]
# Potentially pad the 2D mask, and slice it correctly
padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset)

# Under specific conditions, we can avoid materializing the mask, instead relying on the `is_causal` argument
if allow_is_causal_skip and _ignore_causal_mask_sdpa(padding_mask, q_length, kv_length, local_size):
return None

# Similar to `kv_arange = torch.arange(start=kv_offset, end=kv_offset + kv_length, device=cache_position.device)`
# but without data-dependent slicing (i.e. torch.compile friendly)
kv_arange = torch.arange(kv_length, device=cache_position.device)
kv_arange += kv_offset
reshaped_cache_position = cache_position.view(-1, 1)

# This is a bit hacky to know what pattern we are using, but all mask creation function actually forward
# the config through kwargs anyway, so it allows to rely on it
# Usually, the `mask_function` is the only entry-point to define the pattern - we could do for loops over it,
# but this is more efficient
sliding_window = getattr(kwargs["config"], "sliding_window", None)
chunk_size = getattr(kwargs["config"], "attention_chunk_size", None)

if sliding_window is not None and chunk_size is not None:
raise ValueError("Cannot use both `sliding_window` and `attention_chunk_size`")

# Simplest and most efficient way to obtain a causal mask
causal_mask = kv_arange <= reshaped_cache_position
# If using sliding window, add the sliding mask
if sliding_window is not None:
sliding_mask_overlay = kv_arange > reshaped_cache_position - sliding_window
causal_mask *= sliding_mask_overlay
# If using chunk attention, add the chunked mask
elif chunk_size is not None:
chunked_mask_overlay = kv_arange // chunk_size == reshaped_cache_position // chunk_size
causal_mask *= chunked_mask_overlay

causal_mask = causal_mask[None, None, :, :].expand(batch_size, -1, -1, -1)
if padding_mask is not None:
causal_mask = causal_mask * padding_mask[:, None, None, :]

# Due to a bug in some older torch version, we need to update the mask in case a query is not attending to any
# tokens (due to padding). See details in https://github.com/pytorch/pytorch/issues/110213
if not _is_torch_greater_or_equal_than_2_5 and allow_torch_fix:
causal_mask |= torch.all(~causal_mask, dim=-1, keepdim=True)
return causal_mask
Loading