Skip to content

Conversation

@nvchenghaoz
Copy link
Collaborator

@nvchenghaoz nvchenghaoz commented Nov 21, 2025

The A log is a per-layer parameter and it is a constant during inference. So precompute the -exp(A_log) during the compile time to have a better perf.

Summary by CodeRabbit

  • New Features
    • Added new transform capability to the auto-deploy configuration system for model optimization.
    • Expands available graph transformation patterns during automated model deployment phases.
    • Provides additional optimization flexibility in the deployment workflow for compatible model types.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 21, 2025

📝 Walkthrough

Walkthrough

This PR adds a new Mamba/NemotronH model optimization that fuses the A_log parameter into A within Torch FX graphs. A new transform module is created and registered in the configuration to run during the post_load_fusion stage.

Changes

Cohort / File(s) Summary
Configuration
tensorrt_llm/_torch/auto_deploy/config/default.yaml
Added fuse_mamba_a_log transform entry assigned to the post_load_fusion stage.
Transform Implementation
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py
New module introducing the FuseMambaALog transform class that scans Torch FX graphs for A_log parameter usage patterns and replaces them with a fused parameter A_fused computed as -exp(A_log.float()). Includes helper functions _get_attr_by_name() and _set_attr_by_name() for dotted-name attribute access and mutation.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

  • Pattern matching logic: Verify correctness of bottom-up graph traversal and detection of A_log usage through casts and exp/neg operations.
  • Parameter fusion correctness: Ensure the fused parameter computation and non-trainable parameter creation work as expected.
  • Edge case handling: Review logic for existing A_fused parameters, missing attributes, and dead code elimination validity checks.
  • Graph rewriting: Confirm proper rewriting of graph references and elimination of dead nodes.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description check ❓ Inconclusive The PR description explains the rationale but lacks required sections like test coverage details and PR checklist confirmation as specified in the template. Add test coverage information, confirm PR checklist items, and provide a more detailed description of the implementation approach and validation strategy.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: precomputing A log for mamba layers as an optimization feature in AutoDeploy.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py (2)

58-88: Pattern detection logic is sound.

The bottom-up search starting from A_log parameters and walking forward to find the exp→neg pattern is a robust approach.

Optional: Consider using next(iter(...)) for single-element access.

At line 84, the linter suggests preferring next(iter(cursor.users.keys())) over list(cursor.users.keys())[0]. This is more idiomatic Python for accessing a single element from a set/dict keys view:

                    if len(cursor.users) != 1:
                        break
-                   cursor = list(cursor.users.keys())[0]
+                   cursor = next(iter(cursor.users.keys()))

89-99: Correct validation of the neg operation.

The check ensures the pattern terminates with a single neg operation as expected.

Optional: Same style improvement for line 93.

Similar to line 84, consider using next(iter(...)):

                if len(exp_node.users) != 1:
                    continue

-               neg_node = list(exp_node.users.keys())[0]
+               neg_node = next(iter(exp_node.users.keys()))
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b1c9936 and 51154fb.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py (1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: 2025-10-20T16:54:09.824Z
Learnt from: nvchenghaoz
Repo: NVIDIA/TensorRT-LLM PR: 8469
File: tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py:6-6
Timestamp: 2025-10-20T16:54:09.824Z
Learning: In tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py, the import `from ...modules.mamba.layernorm_gated import _layer_norm_fwd` is correct and should not be changed to modules.fla.layernorm_gated. The _layer_norm_fwd function exists in both modules/mamba/layernorm_gated.py and modules/fla/layernorm_gated.py, but the mamba version is the intended implementation for this use case.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py (4)
tensorrt_llm/_torch/auto_deploy/models/factory.py (1)
  • ModelFactory (94-351)
tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
  • CachedSequenceInterface (11-92)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (4)
  • BaseTransform (217-504)
  • SharedConfig (61-66)
  • TransformInfo (121-178)
  • TransformRegistry (507-535)
tensorrt_llm/logger.py (1)
  • warning (132-133)
🪛 Ruff (0.14.5)
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py

42-42: Unused method argument: cm

(ARG002)


43-43: Unused method argument: factory

(ARG002)


44-44: Unused method argument: shared_config

(ARG002)


84-84: Prefer next(iter(cursor.users.keys())) over single element slice

Replace with next(iter(cursor.users.keys()))

(RUF015)


93-93: Prefer next(iter(exp_node.users.keys())) over single element slice

Replace with next(iter(exp_node.users.keys()))

(RUF015)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (7)
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_mamba_a_log.py (6)

16-26: LGTM! Helper functions for dotted attribute access.

The implementation correctly handles nested attribute traversal for both getting and setting attributes by dotted names.


42-44: Note: Unused parameters are part of the BaseTransform interface.

Static analysis flags these parameters as unused, but they're required by the BaseTransform._apply signature. All transforms must implement this interface even if they don't use every parameter.


49-50: LGTM! Comprehensive operation matching.

The sets include multiple variants (torch functions, aten ops, and string names) to robustly match different graph representations.


101-124: Excellent fusion logic with proper safeguards.

The implementation correctly:

  • Handles missing attributes with appropriate logging
  • Computes the fused value under torch.no_grad()
  • Avoids recreating A_fused if it already exists (important for multiple A_log usages)
  • Creates a non-trainable parameter with requires_grad=False

127-130: LGTM! Clean graph rewrite.

The transformation correctly replaces the entire computation chain with a single reference to the precomputed parameter.


132-140: LGTM! Proper cleanup and metadata return.

Dead code elimination removes the obsolete computation chain, and the TransformInfo correctly reflects the transform's outcome.

tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)

99-100: LGTM! Transform correctly registered in the pipeline.

The fuse_mamba_a_log transform is appropriately configured to run during the post_load_fusion stage, which ensures weights are loaded before computing the fused parameter.

Comment on lines +99 to +100
fuse_mamba_a_log:
stage: post_load_fusion
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of adding it to default.yaml, how about adding this to examples/auto_deploy/nano_v3.yaml ?

@@ -0,0 +1,140 @@
"""Transform to fuse A_log into A for Mamba/NemotronH models."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like there is a memory leak. Could you also add a unit test that checks for mem usage before after this transformation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

Status: Backlog

Development

Successfully merging this pull request may close these issues.

2 participants