diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 4bc4b74fd5d..9979aa20c7f 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -34,9 +34,9 @@ transforms: match_eager_attention: stage: pattern_matcher requires_shape_prop: true - match_grouped_attention_with_repeat_kv: + match_sdpa_to_torch_attention: stage: pattern_matcher - match_grouped_attention_without_repeat_kv: + match_grouped_attention: stage: pattern_matcher match_attention_layout: stage: pattern_matcher diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py b/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py index 295325982d4..1e69bfaf5b0 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py @@ -11,7 +11,6 @@ from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface -from ...utils.logger import ad_logger from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern from ..interface import ( BaseTransform, @@ -374,9 +373,8 @@ def _call_attn(q, k, v, *, is_causal: bool, attn_mask=None, dropout_p=None, scal return torch.ops.auto_deploy.torch_attention.default(q, k, v, **kwargs) -def make_grouped_attn_pair( +def make_sdpa_to_torch_attn_pair( *, - repeat_kv: bool, is_causal: bool, has_scale: bool, enable_gqa: bool, @@ -384,18 +382,128 @@ def make_grouped_attn_pair( has_dropout: bool, ) -> Tuple[Callable, Callable, List[str]]: """ - Returns (pattern_fn, replacement_fn, argnames) with exact positional parity. - - Arg order rules: - Base: (q, k, v) - +repeat_kv -> insert n_rep after (q, k, v) - +attn_mask -> include attn_mask after n_rep if repeat_kv else after (q, k, v) - +dropout -> include dropout_p after attn_mask or after n_rep/base if no attn_mask - +scale -> include scale last + Returns (pattern_fn, replacement_fn, argnames) for matching SDPA to torch_attention. + + Pattern: torch_attention_sdpa --> torch_attention """ argnames: List[str] = ["q", "k", "v"] - if repeat_kv: - argnames.append("n_rep") + if has_attn_mask: + argnames.append("attn_mask") + if has_dropout: + argnames.append("dropout_p") + if has_scale: + argnames.append("scale") + + def pattern_fn(*args): + if len(args) != len(argnames): + raise TypeError(f"Expected {len(argnames)} args {tuple(argnames)}, got {len(args)}") + m = dict(zip(argnames, args)) + return _call_sdpa( + m["q"], + m["k"], + m["v"], + is_causal=is_causal, + enable_gqa=enable_gqa, + attn_mask=m.get("attn_mask"), + dropout_p=m.get("dropout_p"), + scale=m.get("scale"), + ) + + def replacement_fn(*args): + if len(args) != len(argnames): + raise TypeError(f"Expected {len(argnames)} args {tuple(argnames)}, got {len(args)}") + m = dict(zip(argnames, args)) + return _call_attn( + m["q"], + m["k"], + m["v"], + is_causal=is_causal, + attn_mask=m.get("attn_mask"), + dropout_p=m.get("dropout_p"), + scale=m.get("scale"), + ) + + _attach_signature(pattern_fn, argnames) + _attach_signature(replacement_fn, argnames) + return pattern_fn, replacement_fn, argnames + + +def generate_and_register_sdpa_to_torch_attn_patterns(patterns, register_ad_pattern: Callable): + """ + Generate patterns for matching SDPA to torch_attention. + Enumerates combinations across: + - is_causal: [False, True] + - has_scale: [False, True] + - enable_gqa: [False, True] + - has_attn_mask: [False, True] + - has_dropout: [False, True] + """ + q = torch.randn(8, 8, 16, 64, device="cuda", dtype=torch.float16) + k = torch.randn(8, 8, 16, 64, device="cuda", dtype=torch.float16) + v = torch.randn(8, 8, 16, 64, device="cuda", dtype=torch.float16) + attn_mask_tensor = torch.randn(8, 1, 1, 16, device="cuda", dtype=torch.float16) + + dropout_val = 0.12345 + scale_val = 0.56789 + + total = 0 + axes = ((False, True),) * 5 + for is_causal, has_scale, enable_gqa, has_attn_mask, has_dropout in product(*axes): + pat_fn, rep_fn, argnames = make_sdpa_to_torch_attn_pair( + is_causal=is_causal, + has_scale=has_scale, + enable_gqa=enable_gqa, + has_attn_mask=has_attn_mask, + has_dropout=has_dropout, + ) + + value_map = { + "q": q, + "k": k, + "v": v, + "attn_mask": attn_mask_tensor, + "dropout_p": dropout_val, + "scale": scale_val, + } + dummy_args: List[object] = [] + for name in argnames: + try: + dummy_args.append(value_map[name]) + except KeyError: + raise RuntimeError(f"Unexpected arg name: {name}") + + scalar_names = {"dropout_p", "scale"} + scalar_workaround: Dict[str, object] = { + n: value_map[n] for n in argnames if n in scalar_names + } + if not scalar_workaround: + scalar_workaround = None + + register_ad_pattern( + search_fn=pat_fn, + replace_fn=rep_fn, + patterns=patterns, + dummy_args=dummy_args, + scalar_workaround=scalar_workaround, + ) + total += 1 + return total + + +def make_repeat_kv_torch_attn_pair( + *, + is_causal: bool, + has_scale: bool, + has_attn_mask: bool, + has_dropout: bool, +) -> Tuple[Callable, Callable, List[str]]: + """ + Returns (pattern_fn, replacement_fn, argnames) for matching repeat_kv + torch_attention. + + Pattern: repeat_kv(k, n_rep), repeat_kv(v, n_rep), torch_attention --> torch_attention + This handles GQA patterns where repeat_kv is explicitly applied before torch_attention. + """ + argnames: List[str] = ["q", "k", "v", "n_rep"] if has_attn_mask: argnames.append("attn_mask") if has_dropout: @@ -411,30 +519,27 @@ def pattern_fn(*args): q = m["q"] k = m["k"] v = m["v"] + n_rep = m["n_rep"] - if repeat_kv: - n_rep = m["n_rep"] - k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep) - v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep) + # Apply repeat_kv to k and v + k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep) + v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep) - return _call_sdpa( + return _call_attn( q, k, v, is_causal=is_causal, - enable_gqa=enable_gqa, attn_mask=m.get("attn_mask"), dropout_p=m.get("dropout_p"), scale=m.get("scale"), ) - # Replacement: torch_attention.default mirroring the positional signature exactly. - # We do NOT pass enable_gqa here (it’s SDPA-only). We accept n_rep to mirror signature, - # but we don’t need to use it in the replacement graph. def replacement_fn(*args): if len(args) != len(argnames): raise TypeError(f"Expected {len(argnames)} args {tuple(argnames)}, got {len(args)}") m = dict(zip(argnames, args)) + # Replacement: just call torch_attention directly (no repeat_kv needed) return _call_attn( m["q"], m["k"], @@ -445,37 +550,19 @@ def replacement_fn(*args): scale=m.get("scale"), ) - # Pattern matcher needs to see explicit arg names _attach_signature(pattern_fn, argnames) _attach_signature(replacement_fn, argnames) - return pattern_fn, replacement_fn, argnames -def generate_and_register_grouped_attn_patterns( - patterns, register_ad_pattern: Callable, only_repeat_kv: bool = None -): +def generate_and_register_repeat_kv_torch_attn_patterns(patterns, register_ad_pattern: Callable): """ - Auto-generate all grouped attention patterns across these axes: - 1) repeat_kv: [False, True] - 2) is_causal: [False, True] - 3) has_scale: [False, True] - 4) enable_gqa: [False, True] (only a kwarg to SDPA side) - 5) has_attn_mask: [False, True] - 6) has_dropout: [False, True] - - Args: - patterns: The ADPatternMatcherPass instance to register patterns to - register_ad_pattern: The function to call to register each pattern - only_repeat_kv: If True, only register patterns with repeat_kv=True. - If False, only register patterns with repeat_kv=False. - If None, register all patterns. - - For each valid combo, we: - - build pattern/replacement functions with exact-arg parity - - build dummy args matching the signature (with CUDA fp16 tensors etc.) - - build scalar_workaround dict for any scalars/n_rep present - - call register_ad_pattern(...) + Generate patterns for matching repeat_kv + torch_attention. + Enumerates combinations across: + - is_causal: [False, True] + - has_scale: [False, True] + - has_attn_mask: [False, True] + - has_dropout: [False, True] """ q = torch.randn(8, 8, 16, 64, device="cuda", dtype=torch.float16) k1 = torch.randn(8, 1, 16, 64, device="cuda", dtype=torch.float16) @@ -487,24 +574,15 @@ def generate_and_register_grouped_attn_patterns( n_rep_val = 7 total = 0 - axes = ((False, True),) * 6 - for repeat_kv, is_causal, has_scale, enable_gqa, has_attn_mask, has_dropout in product(*axes): - if only_repeat_kv is not None: - if only_repeat_kv and not repeat_kv: - continue # Skip patterns without repeat_kv - if not only_repeat_kv and repeat_kv: - continue # Skip patterns with repeat_kv - - pat_fn, rep_fn, argnames = make_grouped_attn_pair( - repeat_kv=repeat_kv, + axes = ((False, True),) * 4 + for is_causal, has_scale, has_attn_mask, has_dropout in product(*axes): + pat_fn, rep_fn, argnames = make_repeat_kv_torch_attn_pair( is_causal=is_causal, has_scale=has_scale, - enable_gqa=enable_gqa, has_attn_mask=has_attn_mask, has_dropout=has_dropout, ) - # Build dummy args in the same positional order value_map = { "q": q, "k": k1, @@ -539,12 +617,17 @@ def generate_and_register_grouped_attn_patterns( return total -@TransformRegistry.register("match_grouped_attention_with_repeat_kv") -class MatchGroupedAttentionWithRepeatKV(BaseTransform): +@TransformRegistry.register("match_sdpa_to_torch_attention") +class MatchSDPAToTorchAttention(BaseTransform): """ - Match and replace grouped attention patterns WITH repeat_kv to - torch.ops.auto_deploy.torch_attention. + Match and replace SDPA patterns to torch.ops.auto_deploy.torch_attention. + This handles: + - sdpa --> torch_attention + - repeat_kv + sdpa --> torch_attention + + This transform should run BEFORE match_repeat_kv_with_torch_attention to ensure + SDPA calls are converted first. """ def _apply( @@ -554,32 +637,33 @@ def _apply( factory: ModelFactory, shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: - def register_grouped_attention_with_repeat_kv(patterns: ADPatternMatcherPass): - return generate_and_register_grouped_attn_patterns( - patterns, register_ad_pattern, only_repeat_kv=True - ) + def register_sdpa_to_torch_attention(patterns: ADPatternMatcherPass): + return generate_and_register_sdpa_to_torch_attn_patterns(patterns, register_ad_pattern) - num_grouped_patterns = _apply_pattern( - gm, "Grouped Attention (with repeat_kv)", register_grouped_attention_with_repeat_kv + num_patterns = _apply_pattern( + gm, "SDPA to Torch Attention", register_sdpa_to_torch_attention ) info = TransformInfo( skipped=False, - num_matches=num_grouped_patterns, + num_matches=num_patterns, is_clean=False, has_valid_shapes=False, ) return gm, info -@TransformRegistry.register("match_grouped_attention_without_repeat_kv") -class MatchGroupedAttentionWithoutRepeatKV(BaseTransform): +@TransformRegistry.register("match_grouped_attention") +class MatchRepeatKVWithTorchAttention(BaseTransform): """ - Match and replace grouped attention patterns WITHOUT repeat_kv to - torch.ops.auto_deploy.torch_attention. + Match and replace repeat_kv + torch_attention patterns to torch_attention. - This transform should run AFTER match_grouped_attention_with_repeat_kv - to avoid incorrectly matching patterns that should have repeat_kv. + This handles: + - repeat_kv + torch_attention --> torch_attention (removes redundant repeat_kv) + - torch_attention --> torch_attention (identity, catches any remaining patterns) + + This transform should run AFTER match_sdpa_to_torch_attention to ensure + we match the repeat_kv + torch_attention pattern correctly. """ def _apply( @@ -589,26 +673,18 @@ def _apply( factory: ModelFactory, shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: - def register_grouped_attention_without_repeat_kv(patterns: ADPatternMatcherPass): - return generate_and_register_grouped_attn_patterns( - patterns, register_ad_pattern, only_repeat_kv=False + def register_repeat_kv_with_torch_attention(patterns: ADPatternMatcherPass): + return generate_and_register_repeat_kv_torch_attn_patterns( + patterns, register_ad_pattern ) - num_grouped_patterns = _apply_pattern( - gm, - "Grouped Attention (without repeat_kv)", - register_grouped_attention_without_repeat_kv, + num_patterns = _apply_pattern( + gm, "Repeat KV with Torch Attention", register_repeat_kv_with_torch_attention ) - if num_grouped_patterns == 0: - ad_logger.warning( - "Fail to find any Group Attention Pattern (without repeat_kv), " - "output or performance may be incorrect" - ) - info = TransformInfo( skipped=False, - num_matches=num_grouped_patterns, + num_matches=num_patterns, is_clean=False, has_valid_shapes=False, ) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py index e38b10ca145..c3b18303171 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py @@ -446,10 +446,10 @@ def _get_match_grouped_attention_optimizer() -> Callable: "cleanup_noop_slice": { "stage": "post_export", }, - "match_grouped_attention_with_repeat_kv": { + "match_sdpa_to_torch_attention": { "stage": "pattern_matcher", }, - "match_grouped_attention_without_repeat_kv": { + "match_grouped_attention": { "stage": "pattern_matcher", }, } diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py index 9474d1b2833..3ae7775c6af 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py @@ -41,10 +41,10 @@ def _joint_transform(gm: GraphModule) -> None: "match_eager_attention": { "stage": "pattern_matcher", }, - "match_grouped_attention_with_repeat_kv": { + "match_sdpa_to_torch_attention": { "stage": "pattern_matcher", }, - "match_grouped_attention_without_repeat_kv": { + "match_grouped_attention": { "stage": "pattern_matcher", }, "match_attention_layout": {