15
15
from torch .nn .attention import sdpa_kernel , SDPBackend
16
16
from torch .nn .attention .flex_attention import (
17
17
_mask_mod_signature ,
18
+ AuxOutput ,
18
19
BlockMask ,
19
20
create_block_mask ,
20
21
flex_attention ,
28
29
FLEX_ATTN_MASK_T = tuple [str , int | None ]
29
30
30
31
32
+ class FlexAttentionWrapper (torch .nn .Module ):
33
+ _flex_attn : ClassVar [Callable ] = torch .compile (
34
+ flex_attention , mode = "max-autotune-no-cudagraphs"
35
+ )
36
+
37
+ def __init__ (self ) -> None :
38
+ super ().__init__ ()
39
+
40
+ def forward (self , * args : object , ** kwargs : object ) -> [
41
+ torch .Tensor | tuple [torch .Tensor , torch .Tensor ],
42
+ tuple [torch .Tensor , AuxOutput ],
43
+ ]:
44
+ # 1. _flex_attn has to be a class variable, otherwise there will
45
+ # be multiple complied flex_attention, which can be slow.
46
+ # 2. `self._flex_attn` is not correct, `self` will be passed in
47
+ # as the first argument, which will cause an error.
48
+ # `FlexAttentionWrapper._flex_attn` is correct.
49
+ return FlexAttentionWrapper ._flex_attn (* args , ** kwargs )
50
+
51
+
31
52
class FlexAttention (torch .nn .Module ):
32
53
"""FlexAttention module that uses torch.nn.attention.flex_attention.
33
54
@@ -46,11 +67,6 @@ class FlexAttention(torch.nn.Module):
46
67
to the keys within the same block.
47
68
"""
48
69
49
- # We registered flex_attention related attributes as class variables as we
50
- # need to amortize the cost of compilation.
51
- flex_attn : ClassVar [Callable ] = torch .compile (
52
- flex_attention , mode = "max-autotune-no-cudagraphs"
53
- )
54
70
compiled_create_block_mask : ClassVar [Callable ] = torch .compile (create_block_mask )
55
71
used_attn_mask_types : ClassVar [set [FLEX_ATTN_MASK_T ]] = set ()
56
72
# Attention mask type to the created BlockMask.
@@ -71,6 +87,7 @@ def __init__(
71
87
raise ValueError (f"Unrecognized attn_mask_type { attn_mask_type } ." )
72
88
self .attn_mask_type = attn_mask_type
73
89
self .fixed_block_size = fixed_block_size
90
+ self .attention_fn_wrapper = FlexAttentionWrapper ()
74
91
75
92
FlexAttention .used_attn_mask_types .add (self .mask_key )
76
93
@@ -86,7 +103,7 @@ def forward(
86
103
scale : float | None = None ,
87
104
) -> torch .Tensor :
88
105
block_mask = FlexAttention .block_masks [self .mask_key ]
89
- return FlexAttention . flex_attn (q , k , v , block_mask = block_mask , scale = scale )
106
+ return self . attention_fn_wrapper (q , k , v , block_mask = block_mask , scale = scale )
90
107
91
108
@staticmethod
92
109
def _get_causal_mask_mod () -> _mask_mod_signature :
@@ -251,6 +268,11 @@ def init_attention_mask(
251
268
# while we continue debugging accuracy issues. However, we want to evaluate
252
269
# the user experience with CP enabled.
253
270
if cp_mesh is not None :
271
+ from torch .distributed .tensor .experimental ._attention import _DispatchMode
272
+
273
+ torch .distributed .tensor .experimental ._attention ._dispatch_mode = (
274
+ _DispatchMode .MODULE_WRAPPER
275
+ )
254
276
FlexAttention .compiled_create_block_mask = functools .partial (
255
277
create_cp_block_mask , device_mesh = cp_mesh
256
278
)
0 commit comments