Skip to content

Commit 727cbd2

Browse files
authored
feat: Implement SDPA op converter / lowering pass as extensions (#3534)
1 parent 29649eb commit 727cbd2

File tree

4 files changed

+298
-0
lines changed

4 files changed

+298
-0
lines changed

examples/dynamo/register_sdpa.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import copy
2+
import logging
3+
import operator
4+
from typing import Callable, Sequence, Tuple
5+
6+
import torch
7+
from sdpa_converter import *
8+
from torch_tensorrt.dynamo._settings import CompilationSettings
9+
from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check
10+
from torch_tensorrt.dynamo.lowering import TORCH_TRT_DECOMPOSITIONS
11+
from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import (
12+
_aten_lowering_pass,
13+
)
14+
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
15+
clean_up_graph_after_modifications,
16+
)
17+
18+
logger = logging.getLogger(__name__)
19+
20+
# Remove decompositions for aten.scaled_dot_product_attention, aten._scaled_dot_product_efficient_attention, aten._scaled_dot_product_flash_attention
21+
# This is because we want to have SDPA as a standalone operator in the graph and invoke the custom converter for it.
22+
TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten.scaled_dot_product_attention.default)
23+
TORCH_TRT_DECOMPOSITIONS.pop(
24+
torch.ops.aten._scaled_dot_product_efficient_attention.default
25+
)
26+
TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten._scaled_dot_product_flash_attention.default)
27+
28+
REPLACEABLE_ATEN_OPS = {
29+
torch.ops.aten._scaled_dot_product_efficient_attention.default,
30+
torch.ops.aten._scaled_dot_product_flash_attention.default,
31+
}
32+
33+
34+
@_aten_lowering_pass
35+
def replace_variants_of_sdpa(
36+
gm: torch.fx.GraphModule, settings: CompilationSettings
37+
) -> torch.fx.GraphModule:
38+
"""Replace scaled_dot_product_attention with an equivalent
39+
implementation which can be accurately converted to TRT
40+
"""
41+
attn_mask = None
42+
is_causal = True
43+
for node in gm.graph.nodes:
44+
if node.op == "call_function" and node.target in REPLACEABLE_ATEN_OPS:
45+
if (
46+
node.target
47+
== torch.ops.aten._scaled_dot_product_efficient_attention.default
48+
):
49+
if len(node.args) == 7:
50+
(
51+
query,
52+
key,
53+
value,
54+
attn_bias,
55+
compute_log_sumexp,
56+
dropout_p,
57+
is_causal,
58+
) = node.args
59+
elif len(node.args) == 5:
60+
query, key, value, attn_mask, is_causal = node.args
61+
dropout_p = 0.0
62+
else:
63+
raise ValueError(
64+
f"Unexpected number of arguments for {node.target} in the graph"
65+
)
66+
elif (
67+
node.target
68+
== torch.ops.aten._scaled_dot_product_flash_attention.default
69+
):
70+
if len(node.args) == 6:
71+
query, key, value, dropout_p, is_causal, return_debug_mask = (
72+
node.args
73+
)
74+
elif len(node.args) == 3:
75+
query, key, value = node.args
76+
dropout_p = 0.0
77+
is_causal = True
78+
else:
79+
raise ValueError(
80+
f"Unexpected number of arguments for {node.target} in the graph"
81+
)
82+
if attn_mask is not None:
83+
logger.warning(
84+
f"This current version of SDPA converter does not support attn_mask for {node.target} in the graph. Ignoring it and using is_causal=True configuration."
85+
)
86+
87+
modified_input_args = (query, key, value, None, dropout_p, is_causal)
88+
89+
# Create a new node with torch.nn.functional.scaled_dot_product_attention
90+
# The input args is (query, key, value, is_causal). kwargs has scale
91+
with gm.graph.inserting_after(node):
92+
new_node = gm.graph.call_function(
93+
torch.nn.functional.scaled_dot_product_attention,
94+
args=modified_input_args,
95+
kwargs={"scale": node.kwargs.get("scale", None)},
96+
)
97+
98+
# Deep copy encounters RuntimeError: Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). So we use copy instead.
99+
new_node.meta = copy.copy(node.meta)
100+
# Check if there's a getitem node following this attention node
101+
for user in list(node.users):
102+
if user.op == "call_function" and user.target == operator.getitem:
103+
# If the getitem is extracting the first element (the output tensor)
104+
if user.args[1] == 0:
105+
# Replace all uses of the getitem with the new attention node
106+
user.replace_all_uses_with(new_node)
107+
new_node.meta["val"] = new_node.meta["val"][0]
108+
# Replace all uses of the original node with the new node
109+
node.replace_all_uses_with(new_node)
110+
111+
gm.graph.erase_node(node)
112+
113+
# Clean up the graph
114+
clean_up_graph_after_modifications(gm)
115+
116+
logger.info(
117+
"Replaced variants of scaled_dot_product_attention with torch.nn.functional.scaled_dot_product_attention"
118+
)
119+
return gm

examples/dynamo/sdpa_converter.py

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
import logging
2+
import math
3+
from typing import Any, Dict, Optional, Tuple, Union
4+
5+
import numpy as np
6+
import tensorrt as trt
7+
import torch
8+
import torch_tensorrt
9+
from torch.fx.node import Target
10+
from torch_tensorrt._enums import dtype
11+
from torch_tensorrt.dynamo.conversion import impl
12+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
13+
from torch_tensorrt.dynamo.conversion.converter_utils import (
14+
SourceIR,
15+
cast_trt_tensor,
16+
get_trt_tensor,
17+
)
18+
from torch_tensorrt.fx.types import TRTTensor
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
def tril(
24+
ctx: ConversionContext,
25+
target: Union[Target, str],
26+
source_ir: Optional[SourceIR],
27+
name: str,
28+
row: TRTTensor,
29+
col: TRTTensor,
30+
) -> TRTTensor:
31+
row_arange_tensor = impl.arange.arange(
32+
ctx, target, source_ir, name + "_arange_row", start=0, end=row, step=1
33+
)
34+
row_reshape_tensor = impl.shuffle.reshape(
35+
ctx, target, source_ir, name + "_reshape_row", row_arange_tensor, [row, 1]
36+
)
37+
38+
col_arange_tensor = impl.arange.arange(
39+
ctx, target, source_ir, name + "_arange_col", start=0, end=col, step=1
40+
)
41+
col_reshape_tensor = impl.shuffle.reshape(
42+
ctx, target, source_ir, name + "_reshape_col", col_arange_tensor, [1, col]
43+
)
44+
45+
mask = impl.elementwise.ge(
46+
ctx, target, source_ir, name + "_ge", row_reshape_tensor, col_reshape_tensor
47+
)
48+
return mask
49+
50+
51+
@torch_tensorrt.dynamo.conversion.dynamo_tensorrt_converter(
52+
torch.nn.functional.scaled_dot_product_attention,
53+
enabled=True,
54+
supports_dynamic_shapes=True,
55+
)
56+
def scaled_dot_product_attention(
57+
ctx: torch_tensorrt.dynamo.conversion.ConversionContext,
58+
target: Target,
59+
args: Tuple[Any, ...],
60+
kwargs: Dict[str, Any],
61+
name: str,
62+
) -> TRTTensor:
63+
# TODO: Handle attn_mask and is_causal arguments in the future
64+
query, key, value, attn_mask, dropout_p, is_causal = args
65+
logger.info(
66+
"Ignoring attn_mask and is_causal arguments provided by the original graph. "
67+
"This converter expects is_causal to be an input to the graph. For prefill phase, is_causal=True "
68+
"and for generate phase, is_causal=False since we pass only 1 input token at a time"
69+
)
70+
71+
# TODO: remove this once we have a better way to handle the causal mask
72+
scale = kwargs.get("scale", None)
73+
source_ir = SourceIR.ATEN
74+
# implementation as described here: https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
75+
mm = impl.matmul.matrix_multiply(
76+
ctx,
77+
target,
78+
source_ir,
79+
name + "_mm",
80+
query,
81+
key,
82+
other_matrix_op=trt.MatrixOperation.TRANSPOSE,
83+
)
84+
if scale is None:
85+
scale = query.shape[-1]
86+
if scale < 0:
87+
# dynamic shape
88+
scale = impl.shape.shape(ctx, target, source_ir, name + "_shape", query, -1)
89+
sqrt_scaled = impl.unary.sqrt(ctx, target, source_ir, name + "_sqrt", scale)
90+
else:
91+
# static shape
92+
sqrt_scaled = math.sqrt(scale)
93+
scaled = impl.elementwise.div(
94+
ctx,
95+
target,
96+
source_ir,
97+
name + "_scale",
98+
mm,
99+
sqrt_scaled,
100+
)
101+
else:
102+
scaled = impl.elementwise.mul(
103+
ctx,
104+
target,
105+
source_ir,
106+
name + "_scale",
107+
mm,
108+
scale,
109+
)
110+
111+
# If is_causal is True, we need to generate a causal mask
112+
if is_causal:
113+
L, S = query.shape[-2], key.shape[-2]
114+
if L >= 0 and S >= 0:
115+
# static shape
116+
attn_bias = np.zeros((L, S), dtype=dtype._from(query.dtype).to(np.dtype))
117+
temp_mask = np.logical_not(np.tril(np.ones((L, S), dtype=np.bool_), k=0))
118+
attn_bias = np.ma.array(attn_bias, mask=temp_mask).filled(float("-inf"))
119+
attn_bias = get_trt_tensor(ctx, attn_bias, name + "_attn_bias")
120+
else:
121+
# if any of the L or S is dynamic shape
122+
if L < 0:
123+
L = impl.shape.shape(
124+
ctx, target, source_ir, name + "_shape_0", query, 2
125+
)
126+
if S < 0:
127+
S = impl.shape.shape(ctx, target, source_ir, name + "_shape_1", key, 2)
128+
129+
# generate the mask tensor
130+
tril_tensor = tril(ctx, target, source_ir, name + "_tril", L, S)
131+
132+
temp_mask = impl.unary.logical_not(
133+
ctx, target, source_ir, name + "_logical_not", tril_tensor
134+
)
135+
temp_mask_casted = cast_trt_tensor(
136+
ctx, temp_mask, trt.float32, name + "_casted_bool", target, source_ir
137+
)
138+
one_minus_temp_mask = impl.elementwise.sub(
139+
ctx,
140+
target,
141+
source_ir,
142+
name + "_one_minus_temp_mask",
143+
1.0,
144+
temp_mask_casted,
145+
)
146+
attn_bias = impl.unary.log(
147+
ctx, target, source_ir, name + "_log", one_minus_temp_mask
148+
)
149+
150+
scaled_add_attn_bias = impl.elementwise.add(
151+
ctx, target, source_ir, name + "_attn_bias_add", scaled, attn_bias
152+
)
153+
else:
154+
scaled_add_attn_bias = scaled
155+
156+
# Create a if condition to check if is_causal is True
157+
if isinstance(is_causal, TRTTensor):
158+
if_layer = ctx.net.add_if_conditional()
159+
condition, true_branch, false_branch = is_causal, scaled_add_attn_bias, scaled
160+
if_layer.set_condition(condition)
161+
output_layer = if_layer.add_output(true_branch, false_branch)
162+
scaled_add_attn_bias = output_layer.get_output(0)
163+
164+
softmax = impl.normalization.softmax(
165+
ctx, target, source_ir, name + "_softmax", scaled_add_attn_bias, -1, False
166+
)
167+
out = impl.matmul.matrix_multiply(
168+
ctx,
169+
target,
170+
source_ir,
171+
name + "_out",
172+
softmax,
173+
value,
174+
)
175+
176+
return out

examples/dynamo/torch_export_flux_dev.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
we demonstrate optimizing the ``transformer`` component of the model (which typically consumes >95% of the e2e diffusion latency)
2020
"""
2121

22+
import register_sdpa # Register SDPA as a standalone operator
23+
2224
# %%
2325
# Import the following libraries
2426
# -----------------------------

py/torch_tensorrt/dynamo/lowering/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ._decomposition_groups import (
2+
TORCH_TRT_DECOMPOSITIONS,
23
torch_disabled_decompositions,
34
torch_enabled_decompositions,
45
)

0 commit comments

Comments
 (0)