-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[None][feat] AutoDeploy: Precompute the A log for mamba layers #9344
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Chenghao Zhang <[email protected]>
📝 WalkthroughWalkthroughThis PR adds a new Mamba/NemotronH model optimization that fuses the A_log parameter into A within Torch FX graphs. A new transform module is created and registered in the configuration to run during the post_load_fusion stage. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py (2)
58-88: Pattern detection logic is sound.The bottom-up search starting from A_log parameters and walking forward to find the exp→neg pattern is a robust approach.
Optional: Consider using
next(iter(...))for single-element access.At line 84, the linter suggests preferring
next(iter(cursor.users.keys()))overlist(cursor.users.keys())[0]. This is more idiomatic Python for accessing a single element from a set/dict keys view:if len(cursor.users) != 1: break - cursor = list(cursor.users.keys())[0] + cursor = next(iter(cursor.users.keys()))
89-99: Correct validation of the neg operation.The check ensures the pattern terminates with a single neg operation as expected.
Optional: Same style improvement for line 93.
Similar to line 84, consider using
next(iter(...)):if len(exp_node.users) != 1: continue - neg_node = list(exp_node.users.keys())[0] + neg_node = next(iter(exp_node.users.keys()))
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
tensorrt_llm/_torch/auto_deploy/config/default.yaml(1 hunks)tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-10-20T16:54:09.824Z
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py:6-6
Timestamp: 2025-10-20T16:54:09.824Z
Learning: In tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py, the import `from ...modules.mamba.layernorm_gated import _layer_norm_fwd` is correct and should not be changed to modules.fla.layernorm_gated. The _layer_norm_fwd function exists in both modules/mamba/layernorm_gated.py and modules/fla/layernorm_gated.py, but the mamba version is the intended implementation for this use case.
Applied to files:
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py (4)
tensorrt_llm/_torch/auto_deploy/models/factory.py (1)
ModelFactory(94-351)tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
CachedSequenceInterface(11-92)tensorrt_llm/_torch/auto_deploy/transform/interface.py (4)
BaseTransform(217-504)SharedConfig(61-66)TransformInfo(121-178)TransformRegistry(507-535)tensorrt_llm/logger.py (1)
warning(132-133)
🪛 Ruff (0.14.5)
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py
42-42: Unused method argument: cm
(ARG002)
43-43: Unused method argument: factory
(ARG002)
44-44: Unused method argument: shared_config
(ARG002)
84-84: Prefer next(iter(cursor.users.keys())) over single element slice
Replace with next(iter(cursor.users.keys()))
(RUF015)
93-93: Prefer next(iter(exp_node.users.keys())) over single element slice
Replace with next(iter(exp_node.users.keys()))
(RUF015)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (7)
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py (6)
16-26: LGTM! Helper functions for dotted attribute access.The implementation correctly handles nested attribute traversal for both getting and setting attributes by dotted names.
42-44: Note: Unused parameters are part of the BaseTransform interface.Static analysis flags these parameters as unused, but they're required by the
BaseTransform._applysignature. All transforms must implement this interface even if they don't use every parameter.
49-50: LGTM! Comprehensive operation matching.The sets include multiple variants (torch functions, aten ops, and string names) to robustly match different graph representations.
101-124: Excellent fusion logic with proper safeguards.The implementation correctly:
- Handles missing attributes with appropriate logging
- Computes the fused value under
torch.no_grad()- Avoids recreating
A_fusedif it already exists (important for multiple A_log usages)- Creates a non-trainable parameter with
requires_grad=False
127-130: LGTM! Clean graph rewrite.The transformation correctly replaces the entire computation chain with a single reference to the precomputed parameter.
132-140: LGTM! Proper cleanup and metadata return.Dead code elimination removes the obsolete computation chain, and the
TransformInfocorrectly reflects the transform's outcome.tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)
99-100: LGTM! Transform correctly registered in the pipeline.The
fuse_mamba_a_logtransform is appropriately configured to run during thepost_load_fusionstage, which ensures weights are loaded before computing the fused parameter.
| fuse_mamba_a_log: | ||
| stage: post_load_fusion |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of adding it to default.yaml, how about adding this to examples/auto_deploy/nano_v3.yaml ?
Signed-off-by: Chenghao Zhang <[email protected]>
Signed-off-by: Chenghao Zhang <[email protected]>
| @@ -0,0 +1,140 @@ | |||
| """Transform to fuse A_log into A for Mamba/NemotronH models.""" | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like there is a memory leak. Could you also add a unit test that checks for mem usage before after this transformation
Signed-off-by: Chenghao Zhang <[email protected]>
The A log is a per-layer parameter and it is a constant during inference. So precompute the -exp(A_log) during the compile time to have a better perf.
Summary by CodeRabbit
✏️ Tip: You can customize this high-level summary in your review settings.