Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ transforms:
match_eager_attention:
stage: pattern_matcher
requires_shape_prop: true
match_grouped_attention_with_repeat_kv:
match_sdpa_to_torch_attention:
stage: pattern_matcher
match_grouped_attention_without_repeat_kv:
match_grouped_attention:
stage: pattern_matcher
match_attention_layout:
stage: pattern_matcher
Expand Down
256 changes: 166 additions & 90 deletions tensorrt_llm/_torch/auto_deploy/transform/library/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from ...models.factory import ModelFactory
from ...shim.interface import CachedSequenceInterface
from ...utils.logger import ad_logger
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
from ..interface import (
BaseTransform,
Expand Down Expand Up @@ -374,28 +373,137 @@ def _call_attn(q, k, v, *, is_causal: bool, attn_mask=None, dropout_p=None, scal
return torch.ops.auto_deploy.torch_attention.default(q, k, v, **kwargs)


def make_grouped_attn_pair(
def make_sdpa_to_torch_attn_pair(
*,
repeat_kv: bool,
is_causal: bool,
has_scale: bool,
enable_gqa: bool,
has_attn_mask: bool,
has_dropout: bool,
) -> Tuple[Callable, Callable, List[str]]:
"""
Returns (pattern_fn, replacement_fn, argnames) with exact positional parity.

Arg order rules:
Base: (q, k, v)
+repeat_kv -> insert n_rep after (q, k, v)
+attn_mask -> include attn_mask after n_rep if repeat_kv else after (q, k, v)
+dropout -> include dropout_p after attn_mask or after n_rep/base if no attn_mask
+scale -> include scale last
Returns (pattern_fn, replacement_fn, argnames) for matching SDPA to torch_attention.

Pattern: torch_attention_sdpa --> torch_attention
"""
argnames: List[str] = ["q", "k", "v"]
if repeat_kv:
argnames.append("n_rep")
if has_attn_mask:
argnames.append("attn_mask")
if has_dropout:
argnames.append("dropout_p")
if has_scale:
argnames.append("scale")

def pattern_fn(*args):
if len(args) != len(argnames):
raise TypeError(f"Expected {len(argnames)} args {tuple(argnames)}, got {len(args)}")
m = dict(zip(argnames, args))
return _call_sdpa(
m["q"],
m["k"],
m["v"],
is_causal=is_causal,
enable_gqa=enable_gqa,
attn_mask=m.get("attn_mask"),
dropout_p=m.get("dropout_p"),
scale=m.get("scale"),
)

def replacement_fn(*args):
if len(args) != len(argnames):
raise TypeError(f"Expected {len(argnames)} args {tuple(argnames)}, got {len(args)}")
m = dict(zip(argnames, args))
return _call_attn(
m["q"],
m["k"],
m["v"],
is_causal=is_causal,
attn_mask=m.get("attn_mask"),
dropout_p=m.get("dropout_p"),
scale=m.get("scale"),
)

_attach_signature(pattern_fn, argnames)
_attach_signature(replacement_fn, argnames)
return pattern_fn, replacement_fn, argnames


def generate_and_register_sdpa_to_torch_attn_patterns(patterns, register_ad_pattern: Callable):
"""
Generate patterns for matching SDPA to torch_attention.
Enumerates combinations across:
- is_causal: [False, True]
- has_scale: [False, True]
- enable_gqa: [False, True]
- has_attn_mask: [False, True]
- has_dropout: [False, True]
"""
q = torch.randn(8, 8, 16, 64, device="cuda", dtype=torch.float16)
k = torch.randn(8, 8, 16, 64, device="cuda", dtype=torch.float16)
v = torch.randn(8, 8, 16, 64, device="cuda", dtype=torch.float16)
attn_mask_tensor = torch.randn(8, 1, 1, 16, device="cuda", dtype=torch.float16)

dropout_val = 0.12345
scale_val = 0.56789

total = 0
axes = ((False, True),) * 5
for is_causal, has_scale, enable_gqa, has_attn_mask, has_dropout in product(*axes):
pat_fn, rep_fn, argnames = make_sdpa_to_torch_attn_pair(
is_causal=is_causal,
has_scale=has_scale,
enable_gqa=enable_gqa,
has_attn_mask=has_attn_mask,
has_dropout=has_dropout,
)

value_map = {
"q": q,
"k": k,
"v": v,
"attn_mask": attn_mask_tensor,
"dropout_p": dropout_val,
"scale": scale_val,
}
dummy_args: List[object] = []
for name in argnames:
try:
dummy_args.append(value_map[name])
except KeyError:
raise RuntimeError(f"Unexpected arg name: {name}")

scalar_names = {"dropout_p", "scale"}
scalar_workaround: Dict[str, object] = {
n: value_map[n] for n in argnames if n in scalar_names
}
if not scalar_workaround:
scalar_workaround = None

register_ad_pattern(
search_fn=pat_fn,
replace_fn=rep_fn,
patterns=patterns,
dummy_args=dummy_args,
scalar_workaround=scalar_workaround,
)
total += 1
return total


def make_repeat_kv_torch_attn_pair(
*,
is_causal: bool,
has_scale: bool,
has_attn_mask: bool,
has_dropout: bool,
) -> Tuple[Callable, Callable, List[str]]:
"""
Returns (pattern_fn, replacement_fn, argnames) for matching repeat_kv + torch_attention.

Pattern: repeat_kv(k, n_rep), repeat_kv(v, n_rep), torch_attention --> torch_attention
This handles GQA patterns where repeat_kv is explicitly applied before torch_attention.
"""
argnames: List[str] = ["q", "k", "v", "n_rep"]
if has_attn_mask:
argnames.append("attn_mask")
if has_dropout:
Expand All @@ -411,30 +519,27 @@ def pattern_fn(*args):
q = m["q"]
k = m["k"]
v = m["v"]
n_rep = m["n_rep"]

if repeat_kv:
n_rep = m["n_rep"]
k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep)
v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep)
# Apply repeat_kv to k and v
k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep)
v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep)

return _call_sdpa(
return _call_attn(
q,
k,
v,
is_causal=is_causal,
enable_gqa=enable_gqa,
attn_mask=m.get("attn_mask"),
dropout_p=m.get("dropout_p"),
scale=m.get("scale"),
)

# Replacement: torch_attention.default mirroring the positional signature exactly.
# We do NOT pass enable_gqa here (it’s SDPA-only). We accept n_rep to mirror signature,
# but we don’t need to use it in the replacement graph.
def replacement_fn(*args):
if len(args) != len(argnames):
raise TypeError(f"Expected {len(argnames)} args {tuple(argnames)}, got {len(args)}")
m = dict(zip(argnames, args))
# Replacement: just call torch_attention directly (no repeat_kv needed)
return _call_attn(
m["q"],
m["k"],
Expand All @@ -445,37 +550,19 @@ def replacement_fn(*args):
scale=m.get("scale"),
)

# Pattern matcher needs to see explicit arg names
_attach_signature(pattern_fn, argnames)
_attach_signature(replacement_fn, argnames)

return pattern_fn, replacement_fn, argnames


def generate_and_register_grouped_attn_patterns(
patterns, register_ad_pattern: Callable, only_repeat_kv: bool = None
):
def generate_and_register_repeat_kv_torch_attn_patterns(patterns, register_ad_pattern: Callable):
"""
Auto-generate all grouped attention patterns across these axes:
1) repeat_kv: [False, True]
2) is_causal: [False, True]
3) has_scale: [False, True]
4) enable_gqa: [False, True] (only a kwarg to SDPA side)
5) has_attn_mask: [False, True]
6) has_dropout: [False, True]

Args:
patterns: The ADPatternMatcherPass instance to register patterns to
register_ad_pattern: The function to call to register each pattern
only_repeat_kv: If True, only register patterns with repeat_kv=True.
If False, only register patterns with repeat_kv=False.
If None, register all patterns.

For each valid combo, we:
- build pattern/replacement functions with exact-arg parity
- build dummy args matching the signature (with CUDA fp16 tensors etc.)
- build scalar_workaround dict for any scalars/n_rep present
- call register_ad_pattern(...)
Generate patterns for matching repeat_kv + torch_attention.
Enumerates combinations across:
- is_causal: [False, True]
- has_scale: [False, True]
- has_attn_mask: [False, True]
- has_dropout: [False, True]
"""
q = torch.randn(8, 8, 16, 64, device="cuda", dtype=torch.float16)
k1 = torch.randn(8, 1, 16, 64, device="cuda", dtype=torch.float16)
Expand All @@ -487,24 +574,15 @@ def generate_and_register_grouped_attn_patterns(
n_rep_val = 7

total = 0
axes = ((False, True),) * 6
for repeat_kv, is_causal, has_scale, enable_gqa, has_attn_mask, has_dropout in product(*axes):
if only_repeat_kv is not None:
if only_repeat_kv and not repeat_kv:
continue # Skip patterns without repeat_kv
if not only_repeat_kv and repeat_kv:
continue # Skip patterns with repeat_kv

pat_fn, rep_fn, argnames = make_grouped_attn_pair(
repeat_kv=repeat_kv,
axes = ((False, True),) * 4
for is_causal, has_scale, has_attn_mask, has_dropout in product(*axes):
pat_fn, rep_fn, argnames = make_repeat_kv_torch_attn_pair(
is_causal=is_causal,
has_scale=has_scale,
enable_gqa=enable_gqa,
has_attn_mask=has_attn_mask,
has_dropout=has_dropout,
)

# Build dummy args in the same positional order
value_map = {
"q": q,
"k": k1,
Expand Down Expand Up @@ -539,12 +617,17 @@ def generate_and_register_grouped_attn_patterns(
return total


@TransformRegistry.register("match_grouped_attention_with_repeat_kv")
class MatchGroupedAttentionWithRepeatKV(BaseTransform):
@TransformRegistry.register("match_sdpa_to_torch_attention")
class MatchSDPAToTorchAttention(BaseTransform):
"""
Match and replace grouped attention patterns WITH repeat_kv to
torch.ops.auto_deploy.torch_attention.
Match and replace SDPA patterns to torch.ops.auto_deploy.torch_attention.

This handles:
- sdpa --> torch_attention
- repeat_kv + sdpa --> torch_attention

This transform should run BEFORE match_repeat_kv_with_torch_attention to ensure
SDPA calls are converted first.
"""

def _apply(
Expand All @@ -554,32 +637,33 @@ def _apply(
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
def register_grouped_attention_with_repeat_kv(patterns: ADPatternMatcherPass):
return generate_and_register_grouped_attn_patterns(
patterns, register_ad_pattern, only_repeat_kv=True
)
def register_sdpa_to_torch_attention(patterns: ADPatternMatcherPass):
return generate_and_register_sdpa_to_torch_attn_patterns(patterns, register_ad_pattern)

num_grouped_patterns = _apply_pattern(
gm, "Grouped Attention (with repeat_kv)", register_grouped_attention_with_repeat_kv
num_patterns = _apply_pattern(
gm, "SDPA to Torch Attention", register_sdpa_to_torch_attention
)

info = TransformInfo(
skipped=False,
num_matches=num_grouped_patterns,
num_matches=num_patterns,
is_clean=False,
has_valid_shapes=False,
)
return gm, info


@TransformRegistry.register("match_grouped_attention_without_repeat_kv")
class MatchGroupedAttentionWithoutRepeatKV(BaseTransform):
@TransformRegistry.register("match_grouped_attention")
class MatchRepeatKVWithTorchAttention(BaseTransform):
"""
Match and replace grouped attention patterns WITHOUT repeat_kv to
torch.ops.auto_deploy.torch_attention.
Match and replace repeat_kv + torch_attention patterns to torch_attention.

This transform should run AFTER match_grouped_attention_with_repeat_kv
to avoid incorrectly matching patterns that should have repeat_kv.
This handles:
- repeat_kv + torch_attention --> torch_attention (removes redundant repeat_kv)
- torch_attention --> torch_attention (identity, catches any remaining patterns)

This transform should run AFTER match_sdpa_to_torch_attention to ensure
we match the repeat_kv + torch_attention pattern correctly.
"""

def _apply(
Expand All @@ -589,26 +673,18 @@ def _apply(
factory: ModelFactory,
shared_config: SharedConfig,
) -> Tuple[GraphModule, TransformInfo]:
def register_grouped_attention_without_repeat_kv(patterns: ADPatternMatcherPass):
return generate_and_register_grouped_attn_patterns(
patterns, register_ad_pattern, only_repeat_kv=False
def register_repeat_kv_with_torch_attention(patterns: ADPatternMatcherPass):
return generate_and_register_repeat_kv_torch_attn_patterns(
patterns, register_ad_pattern
)

num_grouped_patterns = _apply_pattern(
gm,
"Grouped Attention (without repeat_kv)",
register_grouped_attention_without_repeat_kv,
num_patterns = _apply_pattern(
gm, "Repeat KV with Torch Attention", register_repeat_kv_with_torch_attention
)

if num_grouped_patterns == 0:
ad_logger.warning(
"Fail to find any Group Attention Pattern (without repeat_kv), "
"output or performance may be incorrect"
)

info = TransformInfo(
skipped=False,
num_matches=num_grouped_patterns,
num_matches=num_patterns,
is_clean=False,
has_valid_shapes=False,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -446,10 +446,10 @@ def _get_match_grouped_attention_optimizer() -> Callable:
"cleanup_noop_slice": {
"stage": "post_export",
},
"match_grouped_attention_with_repeat_kv": {
"match_sdpa_to_torch_attention": {
"stage": "pattern_matcher",
},
"match_grouped_attention_without_repeat_kv": {
"match_grouped_attention": {
"stage": "pattern_matcher",
},
}
Expand Down
Loading
Loading