Skip to content

Commit 4389efd

Browse files
authored
Remove caching for attention masks (#2117)
We remove the lru_cache for attention masks, because in get_attention_mask() function, `and_masks(*mask_mods)` will return different object id. `create_attention_mask` will use all parameters as cache key, and new object id will always cause cache miss. Before the change: (llama3 debugmodel_flex_attn) <img width="1182" height="275" alt="Screenshot 2025-12-09 at 1 27 45 PM" src="https://github.com/user-attachments/assets/e9af2597-9d94-4478-8136-8b9b8c35d9e6" /> After the change: <img width="1182" height="275" alt="Screenshot 2025-12-09 at 1 29 56 PM" src="https://github.com/user-attachments/assets/756a7d09-b47f-434f-8ff6-40098b265a03" />
1 parent a632855 commit 4389efd

File tree

1 file changed

+8
-17
lines changed

1 file changed

+8
-17
lines changed

torchtitan/models/attention.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
#
77
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
88

9-
import functools
109
from collections.abc import Callable
1110
from typing import ClassVar, NamedTuple
1211

@@ -171,22 +170,19 @@ def forward(
171170
return F.scaled_dot_product_attention(q, k, v, scale=scale, is_causal=True)
172171

173172

174-
# We cannot do inner function/closure because we won't be able to cache it --
175-
# if we an inner function, a new closure will be created every time
176-
# `get_causal_mask_mod` is called.
177-
def _causal_mask(
178-
b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor
179-
) -> torch.Tensor:
180-
"""Causal mask that prevents attention to future tokens."""
181-
return q_idx >= kv_idx
182-
183-
184173
def get_causal_mask_mod() -> _mask_mod_signature:
185174
"""Returns a causal mask modifier for flex attention.
186175
187176
Returns:
188177
A mask modifier function that implements causal masking.
189178
"""
179+
180+
def _causal_mask(
181+
b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor
182+
) -> torch.Tensor:
183+
"""Causal mask that prevents attention to future tokens."""
184+
return q_idx >= kv_idx
185+
190186
return _causal_mask
191187

192188

@@ -275,13 +271,8 @@ def sliding_window_mod(
275271
_compiled_create_block_mask = torch.compile(create_block_mask)
276272

277273

278-
@functools.lru_cache(4)
279274
def create_attention_mask(*args, **kwargs):
280-
"""Create an attention mask using compiled create_block_mask.
281-
282-
This function is cached to avoid recreating BlockMasks for the same
283-
arguments.
284-
"""
275+
"""Create an attention mask using compiled create_block_mask."""
285276
return _compiled_create_block_mask(*args, **kwargs)
286277

287278

0 commit comments

Comments
 (0)