Skip to content

Conversation

@nvxuanyuc
Copy link
Collaborator

@nvxuanyuc nvxuanyuc commented Jul 22, 2025

[feat] Implement pytorch sampler for MTP

Description

  • Adds support for advanced sampling in the PyTorch path for MTP with speculative decoding
    • Previously, only greedy mode was supported.
    • Implements temperature, top-p, top-k, and min-p sampling parameters in Python when using MTP speculative decoding (for DeepSeek) @pathorn [#5627]
    • Adds support for returning log-probs from the Pytorch sampler related to [#5620]

The default behavior of the MTP pytorch decoder remains greedy sampling. Advanced sampling can be enabled via the enable_mixed_sampler flag in TorchLlmArgs.

Test Coverage

  • Added test for greedy mode temperature <= 1e-2 using the new PyTorch sampler.
  • Tests for advanced sampling modes are not yet included.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@nvxuanyuc nvxuanyuc requested review from a team as code owners July 22, 2025 03:37
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jul 22, 2025

📝 Walkthrough

Walkthrough

This change introduces advanced multi-token prediction (MTP) sampling support in the PyTorch execution engine, enabling per-token sampling parameter control for speculative decoding. It adds new fields and methods to propagate and utilize sampling parameters (temperature, top-k, top-p, min-p) throughout the model engine, speculative metadata, and MTP worker. A new batch sampling function and corresponding tests are also included.

Changes

Cohort / File(s) Change Summary
Advanced MTP Sampler Integration
tensorrt_llm/_torch/pyexecutor/model_engine.py
Adds advanced MTP sampler support: tracks per-request sampling params, propagates them for speculative decoding, and updates control flow to handle advanced sampler mode.
PyTorch-native Sampling Utilities
tensorrt_llm/_torch/pyexecutor/sampler.py
Introduces modular PyTorch-native sampling functions (temperature, top-k, top-p, min-p, batch sampling) and a unified batch sampling interface.
Speculative Metadata Extensions
tensorrt_llm/_torch/speculative/interface.py
Adds optional tensor fields to SpecMetadata for storing per-request sampling parameters.
MTP Worker and Metadata Updates
tensorrt_llm/_torch/speculative/mtp.py
Adds CUDA tensor fields and setup/update methods to MTPSpecMetadata, enables advanced sampling in MTPWorker when configured, and integrates the new batch sampler.
Model Config Flag
tensorrt_llm/_torch/model_config.py
Adds enable_mixed_sampler boolean field to ModelConfig dataclass, defaulting to False.
Advanced Sampler Unit Test
tests/unittest/_torch/speculative/test_mtp.py
Adds a parameterized unit test for sample_and_accept_draft_tokens with the advanced PyTorch sampler in greedy mode, verifying correct behavior with deterministic settings.

Sequence Diagram(s)

sequenceDiagram
    participant User
    participant PyTorchModelEngine
    participant MTPWorker
    participant SpecMetadata
    participant Sampler

    User->>PyTorchModelEngine: forward(requests, ...)
    PyTorchModelEngine->>PyTorchModelEngine: _prepare_tp_inputs()
    PyTorchModelEngine->>SpecMetadata: update_advanced_mtp_sampling_params(...)
    PyTorchModelEngine->>SpecMetadata: _set_up_advanced_mtp_sampling(...)
    PyTorchModelEngine->>MTPWorker: sample_and_accept_draft_tokens(input_ids, logits, spec_metadata, ...)
    alt enable_mixed_sampler
        MTPWorker->>Sampler: sampling_batch(logits, temperatures, top_k, top_p, min_p)
        Sampler-->>MTPWorker: sampled_tokens, log_probs
    else
        MTPWorker->>MTPWorker: greedy_sample(logits)
    end
    MTPWorker-->>PyTorchModelEngine: accepted_tokens
    PyTorchModelEngine-->>User: output
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Suggested reviewers

  • HuiGao-NV
  • Funatiq

Note

⚡️ Unit Test Generation is now available in beta!

Learn more here, or try it out under "Finishing Touches" below.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai generate unit tests to generate unit tests for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
tests/unittest/_torch/speculative/test_mtp.py (1)

370-370: Fix line length violation.

The line exceeds the 120-character limit as flagged by static analysis.

-            # sampling default config vals set in [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
+            # sampling default config vals set in 
+            # [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)

264-267: Verify the global impact of setting torch.manual_seed(0).

Setting a global PyTorch seed in the constructor could have unintended side effects on other operations. Consider:

  1. This affects all PyTorch random operations, not just sampling
  2. It might interfere with user-controlled randomness
  3. Consider using a local generator instead of global seed

Consider using a dedicated random generator for sampling operations:

-        # Set deterministic seed for consistent multi-GPU sampling using PyTorch RNG
-        # operations that avoid torch.multinomial's CPU-GPU sync overhead
-        torch.manual_seed(0)
+        # Create dedicated generator for consistent multi-GPU sampling
+        # to avoid torch.multinomial's CPU-GPU sync overhead
+        self.sampling_generator = torch.Generator(device='cuda')
+        self.sampling_generator.manual_seed(0)

Then pass this generator to sampling operations that need deterministic behavior.


1163-1195: Consider moving helper functions to class level for better organization.

These helper functions are defined inside _prepare_tp_inputs but could be reused elsewhere. Consider making them class methods or static methods.

Move these functions to class level:

-        def get_request_temperature(request: LlmRequest) -> float:
-            if not request.sampling_config.temperature:
-                return 0.7
-            temperature = request.sampling_config.temperature[0]
-            if 0 < temperature < 1e-2:
-                # temperature less than 0.01 may cause numerical errors
-                temperature = 0.01
-            return temperature
+    @staticmethod
+    def _get_request_temperature(request: LlmRequest) -> float:
+        if not request.sampling_config.temperature:
+            return 0.7
+        temperature = request.sampling_config.temperature[0]
+        if 0 < temperature < 1e-2:
+            # temperature less than 0.01 may cause numerical errors
+            temperature = 0.01
+        return temperature

Apply similar changes to the other helper functions and update the call sites accordingly.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between fddb7f1 and cec9318.

📒 Files selected for processing (8)
  • examples/llm-api/quickstart_advanced.py (2 hunks)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py (11 hunks)
  • tensorrt_llm/_torch/pyexecutor/sampler.py (2 hunks)
  • tensorrt_llm/_torch/speculative/interface.py (1 hunks)
  • tensorrt_llm/_torch/speculative/mtp.py (2 hunks)
  • tensorrt_llm/llmapi/llm_args.py (1 hunks)
  • tensorrt_llm/llmapi/tokenizer.py (1 hunks)
  • tests/unittest/_torch/speculative/test_mtp.py (1 hunks)
🧠 Learnings (1)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)

Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.374Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache() and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.

🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py

370-370: Line too long (123 > 120)

(E501)

🧰 Additional context used
🧠 Learnings (1)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)

Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.374Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache() and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.

🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py

370-370: Line too long (123 > 120)

(E501)

🔇 Additional comments (20)
tensorrt_llm/_torch/speculative/interface.py (1)

135-142: LGTM! Clean addition of sampling parameter fields.

The new optional tensor fields for sampling parameters (temperatures, top_k, top_p, min_p) are well-structured and follow the existing pattern in the SpecMetadata dataclass. The type annotations and comments are clear and appropriate.

examples/llm-api/quickstart_advanced.py (2)

115-117: LGTM! Clean addition of command-line argument.

The new --use_advanced_mtp_sampler flag follows the established pattern for boolean command-line arguments with an appropriate default value.


169-170: LGTM! Proper integration of the new flag.

The use_advanced_mtp_sampler parameter is correctly passed to the MTPDecodingConfig constructor, maintaining consistency with the command-line argument.

tensorrt_llm/_torch/speculative/mtp.py (2)

11-11: LGTM! Appropriate import addition.

The import of sampling_batch function is correctly added to support the advanced MTP sampler functionality.


825-833: LGTM! Well-structured conditional sampling logic.

The implementation demonstrates good practices:

  1. Backward compatibility: Maintains the existing greedy sampling as the default behavior
  2. Clear conditional logic: The flag-based switching is easy to understand and maintain
  3. Future-proofing: Acknowledges the unused target_log_probs for future log probability support
  4. Clean integration: The advanced sampler integrates seamlessly with the existing acceptance algorithm

The approach minimizes risk while enabling the new advanced sampling functionality.

tensorrt_llm/llmapi/llm_args.py (1)

417-422: LGTM! Configuration improvements enhance usability.

The changes improve the MTPDecodingConfig class by:

  1. Making several fields optional with sensible conservative defaults
  2. Adding the new use_advanced_mtp_sampler flag to enable the advanced sampling feature
  3. Following consistent patterns with other boolean configuration flags

The default values are appropriate:

  • num_nextn_predict_layers=1 maintains backward compatibility
  • Boolean flags default to False for conservative behavior
  • relaxed_topk=1 and relaxed_delta=0. provide safe starting points

This provides a clean API where users can enable advanced sampling by simply setting use_advanced_mtp_sampler=True without having to specify all the other parameters.

tests/unittest/_torch/speculative/test_mtp.py (1)

333-401: LGTM! Good test coverage for advanced MTP sampler in greedy mode.

The test implementation correctly validates the advanced PyTorch sampler functionality with proper setup of sampling parameters to enforce greedy behavior. The deterministic seeding and reuse of existing test cases ensures consistency and reproducibility.

However, note that this test only covers greedy mode (temperature ≤ 0.01). Consider adding future tests for actual advanced sampling modes (temperature > 0.01) to validate the full functionality of the advanced sampler.

tensorrt_llm/_torch/pyexecutor/model_engine.py (5)

20-20: LGTM!

The import is necessary for accessing sampling configurations from request objects.


284-286: LGTM!

Clear and logical detection of advanced MTP sampler mode.


382-398: LGTM!

Appropriate CUDA tensor allocations for sampling parameters with correct sizes and data types.


1229-1234: LGTM!

Sampling parameters are correctly collected and replicated for each token position across different request types.

Also applies to: 1317-1326, 1356-1365, 1398-1407


1511-1526: LGTM!

Efficient non-blocking CUDA tensor copies and proper assignment to spec_metadata for advanced MTP sampling.

Also applies to: 1601-1607

tensorrt_llm/_torch/pyexecutor/sampler.py (8)

4-4: LGTM: Clean import addition

The additional typing imports are necessary for the new type annotations in the sampling functions.


154-167: LGTM: Well-implemented sampling function

The function correctly implements top-k and top-p filtering with efficient in-place operations. The use of custom random sampling to avoid CPU-GPU synchronization is a good performance optimization.


169-178: LGTM: Clever sampling implementation

This function uses the Gumbel-max trick effectively to avoid CPU-GPU synchronization. The mathematical approach is sound and the performance justification is clear.


180-198: LGTM: Correct min-p implementation

The adaptive probability thresholding logic is correctly implemented, using the standard approach of scaling min_p by the maximum probability per sequence.


200-232: LGTM: Comprehensive top-k/top-p implementation

The function correctly implements both top-k and top-p filtering with proper handling of edge cases like the "at least one" guarantee. The sorting and scatter approach ensures correctness.


234-236: LGTM: Simple and correct greedy sampling

Clean implementation using argmax with proper tensor reshaping.


238-244: LGTM: Efficient temperature scaling

Correct implementation with efficient in-place operations and proper broadcasting.


246-264: LGTM: Well-designed batch sampling function

This function effectively combines all sampling techniques with proper handling of greedy vs. random sampling. The temperature threshold logic and log-probability calculation are correctly implemented.

@nvxuanyuc nvxuanyuc requested review from amukkara and jhaotingc July 22, 2025 16:23
@nvxuanyuc nvxuanyuc force-pushed the deepinfra_mtp_torch_sampler branch 2 times, most recently from 4a68f67 to 84d09a0 Compare July 22, 2025 16:34
@nvxuanyuc
Copy link
Collaborator Author

/bot run --disable-fail-fast

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (3)
tests/unittest/_torch/speculative/test_mtp.py (1)

370-370: Fix line length violation.

The comment line exceeds the 120-character limit flagged by static analysis.

-            # sampling default config vals set in [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
+            # sampling default config vals set in 
+            # [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)

264-266: Consider making the deterministic seed configurable.

The hardcoded seed value of 0 ensures consistent multi-GPU sampling, but consider making this configurable through the PyTorchConfig to provide flexibility for different use cases while maintaining the performance benefits of avoiding CPU-GPU synchronization.

-        # Set deterministic seed for consistent multi-GPU sampling using PyTorch RNG
-        # operations that avoid torch.multinomial's CPU-GPU sync overhead
-        torch.manual_seed(0)
+        # Set deterministic seed for consistent multi-GPU sampling using PyTorch RNG
+        # operations that avoid torch.multinomial's CPU-GPU sync overhead
+        seed = getattr(pytorch_backend_config, 'sampling_seed', 0)
+        torch.manual_seed(seed)

1163-1194: LGTM: Well-designed helper functions with proper edge case handling.

The helper functions correctly extract sampling parameters with appropriate defaults and constraints. The temperature clamping to avoid numerical errors and top_k max value handling are particularly well thought out.

Consider extracting the magic numbers to constants:

+TEMPERATURE_MIN_THRESHOLD = 1e-2
+TEMPERATURE_MIN_VALUE = 0.01
+TOP_K_DISABLED_VALUE = 2147483647  # Max int32

 def get_request_temperature(request: LlmRequest) -> float:
     if not request.sampling_config.temperature:
         return 0.7
     temperature = request.sampling_config.temperature[0]
-    if 0 < temperature < 1e-2:
-        temperature = 0.01
+    if 0 < temperature < TEMPERATURE_MIN_THRESHOLD:
+        temperature = TEMPERATURE_MIN_VALUE
     return temperature
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between cec9318 and 84d09a0.

📒 Files selected for processing (8)
  • examples/llm-api/quickstart_advanced.py (2 hunks)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py (11 hunks)
  • tensorrt_llm/_torch/pyexecutor/sampler.py (2 hunks)
  • tensorrt_llm/_torch/speculative/interface.py (1 hunks)
  • tensorrt_llm/_torch/speculative/mtp.py (2 hunks)
  • tensorrt_llm/llmapi/llm_args.py (1 hunks)
  • tensorrt_llm/llmapi/tokenizer.py (1 hunks)
  • tests/unittest/_torch/speculative/test_mtp.py (1 hunks)
🧠 Learnings (1)
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)

Learnt from: yechank-nvidia
PR: #6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.703Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using from_shared_tensor() is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call strip_for_generation() to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.

Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.374Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache() and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.

🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py

370-370: Line too long (123 > 120)

(E501)

🚧 Files skipped from review as they are similar to previous changes (6)
  • tensorrt_llm/llmapi/tokenizer.py
  • tensorrt_llm/_torch/speculative/interface.py
  • tensorrt_llm/_torch/speculative/mtp.py
  • examples/llm-api/quickstart_advanced.py
  • tensorrt_llm/llmapi/llm_args.py
  • tensorrt_llm/_torch/pyexecutor/sampler.py
🧰 Additional context used
🧠 Learnings (1)
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)

Learnt from: yechank-nvidia
PR: #6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.703Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using from_shared_tensor() is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call strip_for_generation() to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.

Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.374Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache() and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.

🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py

370-370: Line too long (123 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (8)
tests/unittest/_torch/speculative/test_mtp.py (2)

333-401: LGTM! Well-structured test for the advanced MTP sampler.

The new test method effectively validates that the advanced PyTorch sampler produces identical results to the standard sampler when configured for greedy mode. The test design is solid:

  • Proper parameterization reusing existing test cases
  • Deterministic seeding for reproducible results
  • Correct configuration of sampling parameters to enforce greedy mode (temperature ≤ 0.01)
  • Appropriate assertions matching the reference implementation

369-374: Greedy sampling parameters verified

Confirmed that in tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_temperature, any temperature below 0.01 is clamped up to 0.01. Therefore, extending temperatures with 0.01 correctly enforces the intended greedy sampling boundary. No changes required.

tensorrt_llm/_torch/pyexecutor/model_engine.py (6)

20-20: LGTM: Import addition is necessary for new functionality.

The LlmRequest import is properly placed and required for accessing sampling configuration in the advanced MTP sampler.


284-285: LGTM: Correct logic for advanced MTP sampler detection.

The boolean flag correctly identifies when the advanced MTP sampler should be active by checking all necessary conditions in the proper sequence.


382-398: LGTM: Proper CUDA tensor allocation for sampling parameters.

The tensor allocation correctly sizes buffers for batch_size × (max_draft_len + 1) elements, uses appropriate data types, and efficiently allocates only when the advanced sampler is enabled.


1157-1161: LGTM: Correct parameter replication for draft tokens.

The sampling parameter lists are properly initialized and populated with the correct replication pattern for each request type, ensuring parameters are available for both the main token and all draft tokens.

Also applies to: 1229-1233, 1318-1326, 1357-1365, 1399-1407


1512-1526: LGTM: Efficient CUDA tensor copying and metadata integration.

The implementation uses pinned memory and non-blocking copies for optimal performance, properly slices tensors to match actual usage, and cleanly integrates with the existing speculative decoding metadata structure.

Also applies to: 1602-1607


264-266: Excellent implementation of advanced MTP sampler support.

The changes successfully add support for advanced sampling parameters in MTP speculative decoding with:

  • Proper memory management through pre-allocated CUDA tensors
  • Efficient parameter extraction and replication logic
  • Clean integration with existing speculative decoding infrastructure
  • Good performance considerations (pinned memory, non-blocking transfers)
  • Appropriate conditional activation based on configuration flags

The implementation follows established patterns in the codebase and maintains backward compatibility while adding the new functionality.

Also applies to: 284-285, 382-398, 1157-1607

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12593 [ run ] triggered by Bot

@jhaotingc jhaotingc requested a review from netanel-haber July 22, 2025 16:57
@jhaotingc
Copy link
Collaborator

Hi @netanel-haber, @nvxuanyuc has addressed the comments in PR 5627, would you mind reviewing this PR again? Thanks.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12593 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9371 completed with status: 'FAILURE'

@nvxuanyuc nvxuanyuc force-pushed the deepinfra_mtp_torch_sampler branch from 84d09a0 to 7f31555 Compare July 23, 2025 03:10
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/unittest/_torch/speculative/test_mtp.py (1)

370-370: Fix line length violation.

The line exceeds the 120-character limit enforced by the linter.

-            # sampling default config vals set in [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
+            # sampling default config vals set in 
+            # [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 84d09a0 and 7f31555.

📒 Files selected for processing (8)
  • examples/llm-api/quickstart_advanced.py (2 hunks)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py (11 hunks)
  • tensorrt_llm/_torch/pyexecutor/sampler.py (2 hunks)
  • tensorrt_llm/_torch/speculative/interface.py (1 hunks)
  • tensorrt_llm/_torch/speculative/mtp.py (2 hunks)
  • tensorrt_llm/llmapi/llm_args.py (1 hunks)
  • tensorrt_llm/llmapi/tokenizer.py (1 hunks)
  • tests/unittest/_torch/speculative/test_mtp.py (1 hunks)
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py

370-370: Line too long (123 > 120)

(E501)

✅ Files skipped from review due to trivial changes (2)
  • tensorrt_llm/_torch/speculative/interface.py
  • tensorrt_llm/llmapi/llm_args.py
🚧 Files skipped from review as they are similar to previous changes (5)
  • examples/llm-api/quickstart_advanced.py
  • tensorrt_llm/_torch/speculative/mtp.py
  • tensorrt_llm/llmapi/tokenizer.py
  • tensorrt_llm/_torch/pyexecutor/sampler.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py

370-370: Line too long (123 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (1)
tests/unittest/_torch/speculative/test_mtp.py (1)

333-401: LGTM! Well-structured test for the advanced MTP sampler.

The test correctly validates the advanced PyTorch sampler functionality by:

  • Using deterministic seeding for reproducible results
  • Properly configuring sampling parameters to enforce greedy mode (temperature ≤ 0.01)
  • Reusing existing test cases to ensure consistent behavior with the default sampler
  • Following the established test patterns in the codebase

The sampling parameter configuration looks correct for greedy mode testing.

@nvxuanyuc nvxuanyuc force-pushed the deepinfra_mtp_torch_sampler branch from 7f31555 to 607dbc5 Compare July 23, 2025 04:10
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/unittest/_torch/speculative/test_mtp.py (1)

370-370: Fix line length violation.

The line exceeds the 120 character limit. Consider breaking it into multiple lines for better readability.

-            # sampling default config vals set in [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
+            # sampling default config vals set in
+            # [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7f31555 and 607dbc5.

📒 Files selected for processing (8)
  • examples/llm-api/quickstart_advanced.py (2 hunks)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py (11 hunks)
  • tensorrt_llm/_torch/pyexecutor/sampler.py (2 hunks)
  • tensorrt_llm/_torch/speculative/interface.py (1 hunks)
  • tensorrt_llm/_torch/speculative/mtp.py (2 hunks)
  • tensorrt_llm/llmapi/llm_args.py (1 hunks)
  • tensorrt_llm/llmapi/tokenizer.py (1 hunks)
  • tests/unittest/_torch/speculative/test_mtp.py (1 hunks)
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py

370-370: Line too long (123 > 120)

(E501)

✅ Files skipped from review due to trivial changes (2)
  • tensorrt_llm/llmapi/llm_args.py
  • tensorrt_llm/_torch/speculative/interface.py
🚧 Files skipped from review as they are similar to previous changes (5)
  • tensorrt_llm/_torch/speculative/mtp.py
  • examples/llm-api/quickstart_advanced.py
  • tensorrt_llm/llmapi/tokenizer.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/pyexecutor/sampler.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py

370-370: Line too long (123 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (2)
tests/unittest/_torch/speculative/test_mtp.py (2)

333-401: Well-structured test for advanced MTP sampler in greedy mode.

The test method is properly implemented with correct parameterization, deterministic seed setting, and appropriate sampling parameter configuration to enforce greedy mode behavior. The test logic follows the established patterns and should effectively validate the advanced sampler functionality.


363-386: Sampling Parameter Threshold Confirmed

The model engine clamps any non-zero temperature below 0.01 up to 0.01 and treats temperatures ≤ 0.01 as greedy mode. Your test’s use of temperature = 0.01 correctly hits that boundary.

No changes required.

@nvxuanyuc
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12650 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12650 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9404 completed with status: 'FAILURE'

@nvxuanyuc nvxuanyuc force-pushed the deepinfra_mtp_torch_sampler branch from 607dbc5 to 2953667 Compare July 23, 2025 16:37
@nvxuanyuc
Copy link
Collaborator Author

/bot run --disable-fail-fast

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (2)
tests/unittest/_torch/speculative/test_mtp.py (2)

370-370: Fix line length violation.

The line exceeds the 120 character limit flagged by the linter.

-            # sampling default config vals set in [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
+            # sampling default config vals set in 
+            # [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]

388-401: Test execution looks correct but consider broader test coverage.

The test execution properly validates that the advanced sampler produces the same results as the original implementation in greedy mode, which is the expected behavior.

However, this test only covers greedy mode. Consider adding tests for the actual advanced sampling modes (temperature > 0.01, top-k < max_int, etc.) to fully validate the new functionality.

Would you like me to help generate additional test cases for non-greedy sampling modes to improve coverage of the advanced sampler functionality?

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 607dbc5 and 2953667.

📒 Files selected for processing (8)
  • examples/llm-api/quickstart_advanced.py (2 hunks)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py (11 hunks)
  • tensorrt_llm/_torch/pyexecutor/sampler.py (2 hunks)
  • tensorrt_llm/_torch/speculative/interface.py (1 hunks)
  • tensorrt_llm/_torch/speculative/mtp.py (2 hunks)
  • tensorrt_llm/llmapi/llm_args.py (1 hunks)
  • tensorrt_llm/llmapi/tokenizer.py (1 hunks)
  • tests/unittest/_torch/speculative/test_mtp.py (1 hunks)
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py

370-370: Line too long (123 > 120)

(E501)

✅ Files skipped from review due to trivial changes (1)
  • tensorrt_llm/_torch/speculative/interface.py
🚧 Files skipped from review as they are similar to previous changes (6)
  • tensorrt_llm/llmapi/llm_args.py
  • examples/llm-api/quickstart_advanced.py
  • tensorrt_llm/_torch/speculative/mtp.py
  • tensorrt_llm/llmapi/tokenizer.py
  • tensorrt_llm/_torch/pyexecutor/sampler.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py

370-370: Line too long (123 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (3)
tests/unittest/_torch/speculative/test_mtp.py (3)

333-340: LGTM! Good test structure and deterministic setup.

The test method signature follows the existing pattern and the deterministic seed ensures consistent behavior across runs, which is important for multi-GPU sampling scenarios.


342-346: Correct configuration for advanced sampler testing.

The test properly enables the advanced MTP sampler feature through the use_advanced_mtp_sampler=True flag, which is the key differentiator from the original test method.


363-387: Well-implemented parameter setup for greedy sampling mode.

The sampling parameters are correctly configured to enforce greedy behavior:

  • Temperature set to 0.01 (at the greedy boundary)
  • top_k set to max int value (no filtering)
  • top_p set to 1.0 (no filtering)
  • min_p set to 0.0 (no filtering)

The logic properly accounts for each batch's draft tokens plus one additional token.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)

200-232: Top-k/top-p implementation is efficient but addresses past review concerns.

The implementation is mathematically correct and handles the complexity of combined top-k and top-p filtering. However, this directly relates to the past review comments about optimization when both k and p have default values.

From the past comments, I can see the author explained that device-side conditional execution isn't feasible with native torch ops and would break CUDA graph capture. The current approach of unconditional filtering is the correct design choice for maintaining performance and graph compatibility.

The sorting-based approach is necessary for top-p but can be expensive for large vocabularies. The current implementation is optimal given the constraints.

🧹 Nitpick comments (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)

154-167: Function implementation is correct but consider parameter validation.

The forward_native function correctly combines top-k/top-p filtering with random sampling. The in-place modification warning is helpful.

However, consider adding basic parameter validation to ensure robustness:

def forward_native(
    logits: torch.Tensor,
    k: Optional[torch.Tensor],
    p: Optional[torch.Tensor],
) -> torch.Tensor:
    """
    PyTorch-native implementation of top-k and top-p sampling.

    The logits tensor may be updated in-place.
    """
+   if logits.dim() != 2:
+       raise ValueError(f"Expected 2D logits tensor, got {logits.dim()}D")
    logits = apply_top_k_top_p(logits, k, p)
    probs = logits.softmax(dim=-1, dtype=torch.float32)
    return random_sample(probs)
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9c4da6b and 0ee79b3.

📒 Files selected for processing (6)
  • tensorrt_llm/_torch/model_config.py (1 hunks)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py (10 hunks)
  • tensorrt_llm/_torch/pyexecutor/sampler.py (2 hunks)
  • tensorrt_llm/_torch/speculative/interface.py (1 hunks)
  • tensorrt_llm/_torch/speculative/mtp.py (5 hunks)
  • tests/unittest/_torch/speculative/test_mtp.py (1 hunks)
✅ Files skipped from review due to trivial changes (2)
  • tensorrt_llm/_torch/model_config.py
  • tensorrt_llm/_torch/speculative/interface.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/unittest/_torch/speculative/test_mtp.py
  • tensorrt_llm/_torch/speculative/mtp.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.
tensorrt_llm/_torch/pyexecutor/model_engine.py (4)

Learnt from: yechank-nvidia
PR: #6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using from_shared_tensor() is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call strip_for_generation() to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.

Learnt from: moraxu
PR: #6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Learnt from: yiqingy0
PR: #5198
File: jenkins/mergeWaiveList.py:0-0
Timestamp: 2025-07-22T08:33:49.109Z
Learning: In the TensorRT-LLM waive list merging system, removed lines are always located at the end of the merge waive lists, which is why the mergeWaiveList.py script uses reverse traversal - it's an optimization for this specific domain constraint.

Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.402Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache() and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.

🔇 Additional comments (16)
tensorrt_llm/_torch/pyexecutor/sampler.py (5)

4-4: Import addition looks good.

The Optional import is correctly added and used throughout the new sampling functions.


169-178: Excellent CPU-GPU synchronization avoidance technique.

The random_sample function cleverly uses the Gumbel-max trick (exponential noise + argmax) to avoid torch.multinomial's CPU-GPU synchronization. This is a well-known technique for maintaining CUDA graph compatibility.

The implementation is mathematically sound and performance-optimized.


180-198: Min-p implementation is mathematically correct.

The adaptive probability thresholding implementation correctly:

  1. Converts logits to probabilities
  2. Finds maximum probability per sequence
  3. Scales min_p threshold by the maximum probability
  4. Masks tokens below the adaptive threshold

The logic aligns with the min-p sampling strategy used in modern language models.


234-236: Greedy sampling implementation is correct and efficient.

Simple argmax implementation that correctly flattens the output to 1D as expected by the sampling interface.


238-244: Temperature scaling is efficient with in-place operation.

The in-place division correctly scales logits by temperature while avoiding unnecessary tensor allocation. This is optimal for performance.

tensorrt_llm/_torch/pyexecutor/model_engine.py (11)

20-20: LGTM!

The import of LlmRequest is necessary for accessing sampling configuration parameters in the new helper function.


282-283: LGTM!

The initialization logic correctly combines the three required conditions for enabling advanced MTP sampling: speculative decoding, MTP mode, and mixed sampler configuration.


301-301: LGTM!

The enable_mixed_sampler parameter is correctly passed to the model loading function, maintaining consistency with other configuration parameters.


1179-1223: LGTM!

The helper function is well-implemented with several good practices:

  • Temperature clamping to 0.01 prevents numerical instability
  • Uses torch.iinfo(torch.int32).max instead of magic numbers for disabled top_k
  • Proper handling of None/empty sampling config values
  • Clear separation of concerns with individual parameter extraction functions
  • Correct extension of parameters for draft tokens + main token

1173-1178: LGTM!

The sampling parameter lists are properly initialized when advanced MTP sampling is enabled.


1257-1258: LGTM!

Sampling parameters are correctly collected for context requests with no draft tokens.


1343-1344: LGTM!

Sampling parameters are correctly collected for extend requests, passing the appropriate draft token length.


1375-1376: LGTM!

Sampling parameters are correctly collected for requests with previous batch data, using the maximum draft length.


1410-1411: LGTM!

Sampling parameters are correctly collected for generation requests, using the maximum draft length for consistency.


1606-1608: LGTM!

The spec_metadata is correctly updated with the collected sampling parameters when advanced MTP sampling is enabled.


1159-1161: LGTM!

The advanced MTP sampling setup is correctly invoked during the forward pass with appropriate batch size and draft length parameters.

Comment on lines 246 to 264
def sampling_batch(logits: torch.Tensor, temperatures: torch.Tensor,
top_k: torch.Tensor, top_p: torch.Tensor,
min_p: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
raw_probs = torch.softmax(logits, dim=-1)
greedy_sampled = greedy_sample(logits)
logits = apply_temperature(logits, temperatures)
logits = apply_min_p(logits, min_p)
random_sampled = forward_native(logits, top_k, top_p)
next_tokens = torch.where(
temperatures <= 1e-2, # Match the clamping threshold
greedy_sampled,
random_sampled,
out=greedy_sampled, # Reuse tensor
)
token_probs = torch.gather(raw_probs, dim=1,
index=next_tokens.unsqueeze(1)).squeeze(-1)
log_probs = torch.log(token_probs)
return next_tokens, log_probs

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Unified sampling function is well-designed but has a potential numerical stability issue.

The sampling_batch function effectively combines all sampling strategies with a temperature threshold for greedy vs. random sampling. The design is sound, but there's a potential numerical issue:

The log probability calculation at Line 262 could produce NaN values if token_probs contains zeros due to precision issues:

    token_probs = torch.gather(raw_probs, dim=1,
                               index=next_tokens.unsqueeze(1)).squeeze(-1)
-   log_probs = torch.log(token_probs)
+   log_probs = torch.log(token_probs.clamp(min=1e-8))

This prevents log(0) = -inf issues that could propagate through the system.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def sampling_batch(logits: torch.Tensor, temperatures: torch.Tensor,
top_k: torch.Tensor, top_p: torch.Tensor,
min_p: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
raw_probs = torch.softmax(logits, dim=-1)
greedy_sampled = greedy_sample(logits)
logits = apply_temperature(logits, temperatures)
logits = apply_min_p(logits, min_p)
random_sampled = forward_native(logits, top_k, top_p)
next_tokens = torch.where(
temperatures <= 1e-2, # Match the clamping threshold
greedy_sampled,
random_sampled,
out=greedy_sampled, # Reuse tensor
)
token_probs = torch.gather(raw_probs, dim=1,
index=next_tokens.unsqueeze(1)).squeeze(-1)
log_probs = torch.log(token_probs)
return next_tokens, log_probs
def sampling_batch(logits: torch.Tensor, temperatures: torch.Tensor,
top_k: torch.Tensor, top_p: torch.Tensor,
min_p: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
raw_probs = torch.softmax(logits, dim=-1)
greedy_sampled = greedy_sample(logits)
logits = apply_temperature(logits, temperatures)
logits = apply_min_p(logits, min_p)
random_sampled = forward_native(logits, top_k, top_p)
next_tokens = torch.where(
temperatures <= 1e-2, # Match the clamping threshold
greedy_sampled,
random_sampled,
out=greedy_sampled, # Reuse tensor
)
token_probs = torch.gather(
raw_probs,
dim=1,
index=next_tokens.unsqueeze(1),
).squeeze(-1)
log_probs = torch.log(token_probs.clamp(min=1e-8))
return next_tokens, log_probs
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/sampler.py around lines 246 to 264, the log
probability calculation uses torch.log on token_probs which may contain zeros,
causing NaN values. To fix this, clamp token_probs to a small positive value
(e.g., 1e-10) before applying torch.log to avoid log(0) and ensure numerical
stability.

@nvxuanyuc
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13420 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13420 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #10039 completed with status: 'FAILURE'

@nvxuanyuc
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13502 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13502 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #10115 completed with status: 'SUCCESS'

# Default to greedy mode. If true, use advanced pytorch sampling strategy.
self.enable_mixed_sampler = False
if self.model_config is not None:
self.enable_mixed_sampler = self.model_config.enable_mixed_sampler
Copy link
Collaborator

@ixlmar ixlmar Jul 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: This could be a @property rather than a copy, to avoid potential consistency issues in the future.

@nvxuanyuc nvxuanyuc removed the Community want to contribute PRs initiated from Community label Oct 16, 2025
@nvxuanyuc nvxuanyuc force-pushed the deepinfra_mtp_torch_sampler branch from 0ee79b3 to bf7910b Compare October 20, 2025 20:28
@nvxuanyuc nvxuanyuc requested review from a team as code owners October 20, 2025 20:28
@nvxuanyuc nvxuanyuc changed the title [TRTLLM-5627] feat: Implement pytorch sampler for MTP [None][feat] Implement advanced sampling for one model path mtp/eagle Oct 20, 2025
@IzzyPutterman
Copy link
Collaborator

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22067 [ run ] triggered by Bot. Commit: 81e29f4

Signed-off-by: Izzy Putterman <[email protected]>
Filters logits using adaptive probability thresholding.
"""
# Convert logits to probability distribution
probability_values = torch.nn.functional.softmax(logits, dim=-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this effectively neutralizes the temperature right? We apply temp then softmax again which undoes the scaling

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or perhaps in sampling_batch_spec_dec_one_model, we should remove the first softmax and put one just before the sort in the apply top_k top_P

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm im wrong here, misread something

The logits tensor may be updated in-place.
"""
logits = apply_top_k_top_p(logits, k, p)
probs = logits.softmax(dim=-1, dtype=torch.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can technically skip this softmax

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I am also wrong here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @IzzyPutterman is right: apply_min_p evaluates the softmax of the temperature-scaled logits and uses that to mask out some of the logits (set to -inf). The probs could be masked in the same way (set to 0). The resulting probs can (mind paragraph below) then be reused in apply_top_k_top_p, which masks out more logits/probs.

Every time logits/probs are masked, it is sufficient to renormalize the probs such that they sum to one, which is much cheaper than computing softmax. This is probably also why https://docs.flashinfer.ai/api/sampling.html uses function names like ..._renorm_probs.

Note that much of this is already worked out in #8581, albeit using flashinfer.sampling.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #22067 [ run ] completed with state SUCCESS. Commit: 81e29f4
/LLM/main/L0_MergeRequest_PR pipeline #16640 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

) -> torch.Tensor:
"""Apply top-k and top-p masks to the logits.
If a top-p is used, this function will sort the logits tensor,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a perf optimization, should we skip the expensive sorting / softmax / cumsum ops for top_p >=1?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If top_p is 1, we can skip the expensive sorting / softmax / cumsum ops.

In the latest trt llm version, it is already implemented. Please refer to https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/pyexecutor/sampling_utils.py#L159-L171.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The skipping is not possible because in regular decoding the sampling is not captured in cuda graph.
This part is captured in cuda graph, so unless there's a kernel that determine whether to skip or not (like cpp/kernels/samplingTopPKernel.cu checkAllTopP) there's no way to check with the cpu flag need_top_p.

Copy link
Collaborator

@ixlmar ixlmar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partially reviewed

return logits


def apply_top_k_top_p(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that we already have

def top_k_top_p_sampling_batch(

and some related functions. We should not duplicate this basic functionality, so let's use the existing functions and extend them as necessary (adding the Tensor-type k, p, temperature, etc.).

Also note that I am working on FlashInfer.sampling based alternatives for those functions. This upcoming PR brings support for Tensor-type k, p, temperature, etc. when FlashInfer is used. If you integrate the improvements made here for the non-FlashInfer case, this could give a quite nice feature set.

Ideally, this PR could (i) improve the existing sampling routines and (ii) use them via

class SimpleGroupedStrategySampler(GroupedStrategySampler[Strategy]):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put up #8581 (work in progress!) to give an idea of what to expect.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See also TRTLLM-7723 (and TRTLLM-7152) for scope of ongoing work.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I tested flashinfer with cuda graphs and it was breaking a bunch. With the generator objects its quite annoying in TRTLLM becuase in warmup we alternate between cuda graph warmup and non-cuda graph warmup, which breaks

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth a double check ofc, perhaps there is an easy way around it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ixlmar I think the current implementation of TopK TopP only allows all the request having the same TopK TopP value instead of individual requests having different values, please correct me if I'm wrong.

The current logic in model_engine.py didn't parse out all the sampling params into GPU tensors for cuda graph, this PR enables that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@IzzyPutterman The idea of #8581 is to allow choosing between the sampling routines we have today in sampling_utils.py and those provided by FlashInfer. Both will be available as implementations of GroupedStrategySampler. SimpleGroupedStrategySampler uses the former sampling routines (non FlashInfer) and is already available in main.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jhaotingc Correct. This was what I meant in the first comment:

We should not duplicate this basic functionality, so let's use the existing functions and extend them as necessary (adding the Tensor-type k, p, temperature, etc.).

Ideally, this PR could extend SimpleGroupedStrategySampler to allow for Tensor-type k, p, temperature, etc., in the same way as FlashInferGroupedStrategySampler does it for FlashInfer.sampling in #8581. If the GroupedStrategySampler abstraction is not viable (e.g. because the host data structures interfere with CUDA graphs), then I think we should extend top_k_top_p_sampling_batch (and the related functions) directly (promote scalar arguments to accept any broadcastable torch.Tensor) and reuse them here.

@jhaotingc
Copy link
Collaborator

@ixlmar @IzzyPutterman would you mind sharing the latest plan regarding this PR? Is the plan to merge this in then add in more performant kernels?

@ixlmar we can give you the access to this branch if you'd like to develop on top of this.

@ixlmar
Copy link
Collaborator

ixlmar commented Oct 23, 2025

@ixlmar @IzzyPutterman would you mind sharing the latest plan regarding this PR? Is the plan to merge this in then add in more performant kernels?

@ixlmar we can give you the access to this branch if you'd like to develop on top of this.

@jhaotingc Thanks. As far as I can tell, the code that I think we should be extending and reusing (top_k_top_p_sampling_batch) is already in main and my ongoing and planned work is so far orthogonal to what is happening here. My main concern is to improve the existing sampling code rather than introducing the separate but functionally overlapping apply_top_k_top_p from this PR, which I fear will make things harder to maintain in the long run.

min_p = []

# advanced mtp sampling's request preprocessing helper functions
def collect_req_spec_dec_sampling_params(request: LlmRequest,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The resolution of request.sampling_config to sampling strategy has been cleaned up in #8132. See PR description for the intended semantics. The relevant function is

def _request_strategy(request: LlmRequest, *, vocab_size: int) -> Strategy:

The existing function covers various corner cases already (e.g. temperature=0, top_p=1, etc.) and has extensive unit tests. Consider reusing this function here (perhaps make it "public", i.e., rename to something that does not start with _).

"""
q = torch.empty_like(probs)
q.exponential_()
return probs.div_(q).argmax(dim=-1).view(-1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to admit that I am not familiar with this sampling scheme. If you happen to have a literature reference at hand, I would be curious to learn more (perhaps also include a comment stating the name of the method).

BTW, TorchSampler is using torch.multinomial and I did not notice any stream syncs. Code:

next_tokens = torch.multinomial(softmax, num_samples=1, generator=generator).squeeze(-1)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Disclaimer: I might have well overlooked that torch.multinomial is syncing so far, so I would be curious to hear more on this.)

pin_memory=True)
self.slot_ids[:num_seqs].copy_(mtp_slot_ids, non_blocking=True)

def _set_up_advanced_sampling(self, batch_size: int, max_draft_len: int):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code changes in mtp.py and eagle3.py looks very similar, perhaps something could be reused.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.