-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[None][feat] Implement advanced sampling for one model path mtp/eagle #6245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
📝 WalkthroughWalkthroughThis change introduces advanced multi-token prediction (MTP) sampling support in the PyTorch execution engine, enabling per-token sampling parameter control for speculative decoding. It adds new fields and methods to propagate and utilize sampling parameters (temperature, top-k, top-p, min-p) throughout the model engine, speculative metadata, and MTP worker. A new batch sampling function and corresponding tests are also included. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant PyTorchModelEngine
participant MTPWorker
participant SpecMetadata
participant Sampler
User->>PyTorchModelEngine: forward(requests, ...)
PyTorchModelEngine->>PyTorchModelEngine: _prepare_tp_inputs()
PyTorchModelEngine->>SpecMetadata: update_advanced_mtp_sampling_params(...)
PyTorchModelEngine->>SpecMetadata: _set_up_advanced_mtp_sampling(...)
PyTorchModelEngine->>MTPWorker: sample_and_accept_draft_tokens(input_ids, logits, spec_metadata, ...)
alt enable_mixed_sampler
MTPWorker->>Sampler: sampling_batch(logits, temperatures, top_k, top_p, min_p)
Sampler-->>MTPWorker: sampled_tokens, log_probs
else
MTPWorker->>MTPWorker: greedy_sample(logits)
end
MTPWorker-->>PyTorchModelEngine: accepted_tokens
PyTorchModelEngine-->>User: output
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Suggested reviewers
Note ⚡️ Unit Test Generation is now available in beta!Learn more here, or try it out under "Finishing Touches" below. ✨ Finishing Touches
🧪 Generate unit tests
🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
Documentation and Community
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (3)
tests/unittest/_torch/speculative/test_mtp.py (1)
370-370: Fix line length violation.The line exceeds the 120-character limit as flagged by static analysis.
- # sampling default config vals set in [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]] + # sampling default config vals set in + # [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
264-267: Verify the global impact of setting torch.manual_seed(0).Setting a global PyTorch seed in the constructor could have unintended side effects on other operations. Consider:
- This affects all PyTorch random operations, not just sampling
- It might interfere with user-controlled randomness
- Consider using a local generator instead of global seed
Consider using a dedicated random generator for sampling operations:
- # Set deterministic seed for consistent multi-GPU sampling using PyTorch RNG - # operations that avoid torch.multinomial's CPU-GPU sync overhead - torch.manual_seed(0) + # Create dedicated generator for consistent multi-GPU sampling + # to avoid torch.multinomial's CPU-GPU sync overhead + self.sampling_generator = torch.Generator(device='cuda') + self.sampling_generator.manual_seed(0)Then pass this generator to sampling operations that need deterministic behavior.
1163-1195: Consider moving helper functions to class level for better organization.These helper functions are defined inside
_prepare_tp_inputsbut could be reused elsewhere. Consider making them class methods or static methods.Move these functions to class level:
- def get_request_temperature(request: LlmRequest) -> float: - if not request.sampling_config.temperature: - return 0.7 - temperature = request.sampling_config.temperature[0] - if 0 < temperature < 1e-2: - # temperature less than 0.01 may cause numerical errors - temperature = 0.01 - return temperature + @staticmethod + def _get_request_temperature(request: LlmRequest) -> float: + if not request.sampling_config.temperature: + return 0.7 + temperature = request.sampling_config.temperature[0] + if 0 < temperature < 1e-2: + # temperature less than 0.01 may cause numerical errors + temperature = 0.01 + return temperatureApply similar changes to the other helper functions and update the call sites accordingly.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
examples/llm-api/quickstart_advanced.py(2 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py(11 hunks)tensorrt_llm/_torch/pyexecutor/sampler.py(2 hunks)tensorrt_llm/_torch/speculative/interface.py(1 hunks)tensorrt_llm/_torch/speculative/mtp.py(2 hunks)tensorrt_llm/llmapi/llm_args.py(1 hunks)tensorrt_llm/llmapi/tokenizer.py(1 hunks)tests/unittest/_torch/speculative/test_mtp.py(1 hunks)
🧠 Learnings (1)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.374Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache() and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py
370-370: Line too long (123 > 120)
(E501)
🧰 Additional context used
🧠 Learnings (1)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.374Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache() and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py
370-370: Line too long (123 > 120)
(E501)
🔇 Additional comments (20)
tensorrt_llm/_torch/speculative/interface.py (1)
135-142: LGTM! Clean addition of sampling parameter fields.The new optional tensor fields for sampling parameters (
temperatures,top_k,top_p,min_p) are well-structured and follow the existing pattern in theSpecMetadatadataclass. The type annotations and comments are clear and appropriate.examples/llm-api/quickstart_advanced.py (2)
115-117: LGTM! Clean addition of command-line argument.The new
--use_advanced_mtp_samplerflag follows the established pattern for boolean command-line arguments with an appropriate default value.
169-170: LGTM! Proper integration of the new flag.The
use_advanced_mtp_samplerparameter is correctly passed to theMTPDecodingConfigconstructor, maintaining consistency with the command-line argument.tensorrt_llm/_torch/speculative/mtp.py (2)
11-11: LGTM! Appropriate import addition.The import of
sampling_batchfunction is correctly added to support the advanced MTP sampler functionality.
825-833: LGTM! Well-structured conditional sampling logic.The implementation demonstrates good practices:
- Backward compatibility: Maintains the existing greedy sampling as the default behavior
- Clear conditional logic: The flag-based switching is easy to understand and maintain
- Future-proofing: Acknowledges the unused
target_log_probsfor future log probability support- Clean integration: The advanced sampler integrates seamlessly with the existing acceptance algorithm
The approach minimizes risk while enabling the new advanced sampling functionality.
tensorrt_llm/llmapi/llm_args.py (1)
417-422: LGTM! Configuration improvements enhance usability.The changes improve the
MTPDecodingConfigclass by:
- Making several fields optional with sensible conservative defaults
- Adding the new
use_advanced_mtp_samplerflag to enable the advanced sampling feature- Following consistent patterns with other boolean configuration flags
The default values are appropriate:
num_nextn_predict_layers=1maintains backward compatibility- Boolean flags default to
Falsefor conservative behaviorrelaxed_topk=1andrelaxed_delta=0.provide safe starting pointsThis provides a clean API where users can enable advanced sampling by simply setting
use_advanced_mtp_sampler=Truewithout having to specify all the other parameters.tests/unittest/_torch/speculative/test_mtp.py (1)
333-401: LGTM! Good test coverage for advanced MTP sampler in greedy mode.The test implementation correctly validates the advanced PyTorch sampler functionality with proper setup of sampling parameters to enforce greedy behavior. The deterministic seeding and reuse of existing test cases ensures consistency and reproducibility.
However, note that this test only covers greedy mode (temperature ≤ 0.01). Consider adding future tests for actual advanced sampling modes (temperature > 0.01) to validate the full functionality of the advanced sampler.
tensorrt_llm/_torch/pyexecutor/model_engine.py (5)
20-20: LGTM!The import is necessary for accessing sampling configurations from request objects.
284-286: LGTM!Clear and logical detection of advanced MTP sampler mode.
382-398: LGTM!Appropriate CUDA tensor allocations for sampling parameters with correct sizes and data types.
1229-1234: LGTM!Sampling parameters are correctly collected and replicated for each token position across different request types.
Also applies to: 1317-1326, 1356-1365, 1398-1407
1511-1526: LGTM!Efficient non-blocking CUDA tensor copies and proper assignment to spec_metadata for advanced MTP sampling.
Also applies to: 1601-1607
tensorrt_llm/_torch/pyexecutor/sampler.py (8)
4-4: LGTM: Clean import additionThe additional typing imports are necessary for the new type annotations in the sampling functions.
154-167: LGTM: Well-implemented sampling functionThe function correctly implements top-k and top-p filtering with efficient in-place operations. The use of custom random sampling to avoid CPU-GPU synchronization is a good performance optimization.
169-178: LGTM: Clever sampling implementationThis function uses the Gumbel-max trick effectively to avoid CPU-GPU synchronization. The mathematical approach is sound and the performance justification is clear.
180-198: LGTM: Correct min-p implementationThe adaptive probability thresholding logic is correctly implemented, using the standard approach of scaling min_p by the maximum probability per sequence.
200-232: LGTM: Comprehensive top-k/top-p implementationThe function correctly implements both top-k and top-p filtering with proper handling of edge cases like the "at least one" guarantee. The sorting and scatter approach ensures correctness.
234-236: LGTM: Simple and correct greedy samplingClean implementation using argmax with proper tensor reshaping.
238-244: LGTM: Efficient temperature scalingCorrect implementation with efficient in-place operations and proper broadcasting.
246-264: LGTM: Well-designed batch sampling functionThis function effectively combines all sampling techniques with proper handling of greedy vs. random sampling. The temperature threshold logic and log-probability calculation are correctly implemented.
4a68f67 to
84d09a0
Compare
|
/bot run --disable-fail-fast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
tests/unittest/_torch/speculative/test_mtp.py (1)
370-370: Fix line length violation.The comment line exceeds the 120-character limit flagged by static analysis.
- # sampling default config vals set in [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]] + # sampling default config vals set in + # [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
264-266: Consider making the deterministic seed configurable.The hardcoded seed value of 0 ensures consistent multi-GPU sampling, but consider making this configurable through the PyTorchConfig to provide flexibility for different use cases while maintaining the performance benefits of avoiding CPU-GPU synchronization.
- # Set deterministic seed for consistent multi-GPU sampling using PyTorch RNG - # operations that avoid torch.multinomial's CPU-GPU sync overhead - torch.manual_seed(0) + # Set deterministic seed for consistent multi-GPU sampling using PyTorch RNG + # operations that avoid torch.multinomial's CPU-GPU sync overhead + seed = getattr(pytorch_backend_config, 'sampling_seed', 0) + torch.manual_seed(seed)
1163-1194: LGTM: Well-designed helper functions with proper edge case handling.The helper functions correctly extract sampling parameters with appropriate defaults and constraints. The temperature clamping to avoid numerical errors and top_k max value handling are particularly well thought out.
Consider extracting the magic numbers to constants:
+TEMPERATURE_MIN_THRESHOLD = 1e-2 +TEMPERATURE_MIN_VALUE = 0.01 +TOP_K_DISABLED_VALUE = 2147483647 # Max int32 def get_request_temperature(request: LlmRequest) -> float: if not request.sampling_config.temperature: return 0.7 temperature = request.sampling_config.temperature[0] - if 0 < temperature < 1e-2: - temperature = 0.01 + if 0 < temperature < TEMPERATURE_MIN_THRESHOLD: + temperature = TEMPERATURE_MIN_VALUE return temperature
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
examples/llm-api/quickstart_advanced.py(2 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py(11 hunks)tensorrt_llm/_torch/pyexecutor/sampler.py(2 hunks)tensorrt_llm/_torch/speculative/interface.py(1 hunks)tensorrt_llm/_torch/speculative/mtp.py(2 hunks)tensorrt_llm/llmapi/llm_args.py(1 hunks)tensorrt_llm/llmapi/tokenizer.py(1 hunks)tests/unittest/_torch/speculative/test_mtp.py(1 hunks)
🧠 Learnings (1)
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
Learnt from: yechank-nvidia
PR: #6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.703Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using from_shared_tensor() is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call strip_for_generation() to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.
Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.374Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache() and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py
370-370: Line too long (123 > 120)
(E501)
🚧 Files skipped from review as they are similar to previous changes (6)
- tensorrt_llm/llmapi/tokenizer.py
- tensorrt_llm/_torch/speculative/interface.py
- tensorrt_llm/_torch/speculative/mtp.py
- examples/llm-api/quickstart_advanced.py
- tensorrt_llm/llmapi/llm_args.py
- tensorrt_llm/_torch/pyexecutor/sampler.py
🧰 Additional context used
🧠 Learnings (1)
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
Learnt from: yechank-nvidia
PR: #6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.703Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using from_shared_tensor() is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call strip_for_generation() to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.
Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.374Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache() and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py
370-370: Line too long (123 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (8)
tests/unittest/_torch/speculative/test_mtp.py (2)
333-401: LGTM! Well-structured test for the advanced MTP sampler.The new test method effectively validates that the advanced PyTorch sampler produces identical results to the standard sampler when configured for greedy mode. The test design is solid:
- Proper parameterization reusing existing test cases
- Deterministic seeding for reproducible results
- Correct configuration of sampling parameters to enforce greedy mode (temperature ≤ 0.01)
- Appropriate assertions matching the reference implementation
369-374: Greedy sampling parameters verifiedConfirmed that in
tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_temperature, any temperature below 0.01 is clamped up to 0.01. Therefore, extendingtemperatureswith 0.01 correctly enforces the intended greedy sampling boundary. No changes required.tensorrt_llm/_torch/pyexecutor/model_engine.py (6)
20-20: LGTM: Import addition is necessary for new functionality.The
LlmRequestimport is properly placed and required for accessing sampling configuration in the advanced MTP sampler.
284-285: LGTM: Correct logic for advanced MTP sampler detection.The boolean flag correctly identifies when the advanced MTP sampler should be active by checking all necessary conditions in the proper sequence.
382-398: LGTM: Proper CUDA tensor allocation for sampling parameters.The tensor allocation correctly sizes buffers for batch_size × (max_draft_len + 1) elements, uses appropriate data types, and efficiently allocates only when the advanced sampler is enabled.
1157-1161: LGTM: Correct parameter replication for draft tokens.The sampling parameter lists are properly initialized and populated with the correct replication pattern for each request type, ensuring parameters are available for both the main token and all draft tokens.
Also applies to: 1229-1233, 1318-1326, 1357-1365, 1399-1407
1512-1526: LGTM: Efficient CUDA tensor copying and metadata integration.The implementation uses pinned memory and non-blocking copies for optimal performance, properly slices tensors to match actual usage, and cleanly integrates with the existing speculative decoding metadata structure.
Also applies to: 1602-1607
264-266: Excellent implementation of advanced MTP sampler support.The changes successfully add support for advanced sampling parameters in MTP speculative decoding with:
- Proper memory management through pre-allocated CUDA tensors
- Efficient parameter extraction and replication logic
- Clean integration with existing speculative decoding infrastructure
- Good performance considerations (pinned memory, non-blocking transfers)
- Appropriate conditional activation based on configuration flags
The implementation follows established patterns in the codebase and maintains backward compatibility while adding the new functionality.
Also applies to: 284-285, 382-398, 1157-1607
|
PR_Github #12593 [ run ] triggered by Bot |
|
Hi @netanel-haber, @nvxuanyuc has addressed the comments in PR 5627, would you mind reviewing this PR again? Thanks. |
|
PR_Github #12593 [ run ] completed with state |
84d09a0 to
7f31555
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/unittest/_torch/speculative/test_mtp.py (1)
370-370: Fix line length violation.The line exceeds the 120-character limit enforced by the linter.
- # sampling default config vals set in [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]] + # sampling default config vals set in + # [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
examples/llm-api/quickstart_advanced.py(2 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py(11 hunks)tensorrt_llm/_torch/pyexecutor/sampler.py(2 hunks)tensorrt_llm/_torch/speculative/interface.py(1 hunks)tensorrt_llm/_torch/speculative/mtp.py(2 hunks)tensorrt_llm/llmapi/llm_args.py(1 hunks)tensorrt_llm/llmapi/tokenizer.py(1 hunks)tests/unittest/_torch/speculative/test_mtp.py(1 hunks)
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py
370-370: Line too long (123 > 120)
(E501)
✅ Files skipped from review due to trivial changes (2)
- tensorrt_llm/_torch/speculative/interface.py
- tensorrt_llm/llmapi/llm_args.py
🚧 Files skipped from review as they are similar to previous changes (5)
- examples/llm-api/quickstart_advanced.py
- tensorrt_llm/_torch/speculative/mtp.py
- tensorrt_llm/llmapi/tokenizer.py
- tensorrt_llm/_torch/pyexecutor/sampler.py
- tensorrt_llm/_torch/pyexecutor/model_engine.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py
370-370: Line too long (123 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (1)
tests/unittest/_torch/speculative/test_mtp.py (1)
333-401: LGTM! Well-structured test for the advanced MTP sampler.The test correctly validates the advanced PyTorch sampler functionality by:
- Using deterministic seeding for reproducible results
- Properly configuring sampling parameters to enforce greedy mode (temperature ≤ 0.01)
- Reusing existing test cases to ensure consistent behavior with the default sampler
- Following the established test patterns in the codebase
The sampling parameter configuration looks correct for greedy mode testing.
7f31555 to
607dbc5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/unittest/_torch/speculative/test_mtp.py (1)
370-370: Fix line length violation.The line exceeds the 120 character limit. Consider breaking it into multiple lines for better readability.
- # sampling default config vals set in [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]] + # sampling default config vals set in + # [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
examples/llm-api/quickstart_advanced.py(2 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py(11 hunks)tensorrt_llm/_torch/pyexecutor/sampler.py(2 hunks)tensorrt_llm/_torch/speculative/interface.py(1 hunks)tensorrt_llm/_torch/speculative/mtp.py(2 hunks)tensorrt_llm/llmapi/llm_args.py(1 hunks)tensorrt_llm/llmapi/tokenizer.py(1 hunks)tests/unittest/_torch/speculative/test_mtp.py(1 hunks)
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py
370-370: Line too long (123 > 120)
(E501)
✅ Files skipped from review due to trivial changes (2)
- tensorrt_llm/llmapi/llm_args.py
- tensorrt_llm/_torch/speculative/interface.py
🚧 Files skipped from review as they are similar to previous changes (5)
- tensorrt_llm/_torch/speculative/mtp.py
- examples/llm-api/quickstart_advanced.py
- tensorrt_llm/llmapi/tokenizer.py
- tensorrt_llm/_torch/pyexecutor/model_engine.py
- tensorrt_llm/_torch/pyexecutor/sampler.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py
370-370: Line too long (123 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (2)
tests/unittest/_torch/speculative/test_mtp.py (2)
333-401: Well-structured test for advanced MTP sampler in greedy mode.The test method is properly implemented with correct parameterization, deterministic seed setting, and appropriate sampling parameter configuration to enforce greedy mode behavior. The test logic follows the established patterns and should effectively validate the advanced sampler functionality.
363-386: Sampling Parameter Threshold ConfirmedThe model engine clamps any non-zero temperature below 0.01 up to 0.01 and treats temperatures ≤ 0.01 as greedy mode. Your test’s use of
temperature = 0.01correctly hits that boundary.No changes required.
|
/bot run --disable-fail-fast |
|
PR_Github #12650 [ run ] triggered by Bot |
|
PR_Github #12650 [ run ] completed with state |
607dbc5 to
2953667
Compare
|
/bot run --disable-fail-fast |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tests/unittest/_torch/speculative/test_mtp.py (2)
370-370: Fix line length violation.The line exceeds the 120 character limit flagged by the linter.
- # sampling default config vals set in [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]] + # sampling default config vals set in + # [tensorrt_llm/_torch/pyexecutor/model_engine.py:get_request_[param_name]]
388-401: Test execution looks correct but consider broader test coverage.The test execution properly validates that the advanced sampler produces the same results as the original implementation in greedy mode, which is the expected behavior.
However, this test only covers greedy mode. Consider adding tests for the actual advanced sampling modes (temperature > 0.01, top-k < max_int, etc.) to fully validate the new functionality.
Would you like me to help generate additional test cases for non-greedy sampling modes to improve coverage of the advanced sampler functionality?
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
examples/llm-api/quickstart_advanced.py(2 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py(11 hunks)tensorrt_llm/_torch/pyexecutor/sampler.py(2 hunks)tensorrt_llm/_torch/speculative/interface.py(1 hunks)tensorrt_llm/_torch/speculative/mtp.py(2 hunks)tensorrt_llm/llmapi/llm_args.py(1 hunks)tensorrt_llm/llmapi/tokenizer.py(1 hunks)tests/unittest/_torch/speculative/test_mtp.py(1 hunks)
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py
370-370: Line too long (123 > 120)
(E501)
✅ Files skipped from review due to trivial changes (1)
- tensorrt_llm/_torch/speculative/interface.py
🚧 Files skipped from review as they are similar to previous changes (6)
- tensorrt_llm/llmapi/llm_args.py
- examples/llm-api/quickstart_advanced.py
- tensorrt_llm/_torch/speculative/mtp.py
- tensorrt_llm/llmapi/tokenizer.py
- tensorrt_llm/_torch/pyexecutor/sampler.py
- tensorrt_llm/_torch/pyexecutor/model_engine.py
🧰 Additional context used
🪛 Ruff (0.12.2)
tests/unittest/_torch/speculative/test_mtp.py
370-370: Line too long (123 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (3)
tests/unittest/_torch/speculative/test_mtp.py (3)
333-340: LGTM! Good test structure and deterministic setup.The test method signature follows the existing pattern and the deterministic seed ensures consistent behavior across runs, which is important for multi-GPU sampling scenarios.
342-346: Correct configuration for advanced sampler testing.The test properly enables the advanced MTP sampler feature through the
use_advanced_mtp_sampler=Trueflag, which is the key differentiator from the original test method.
363-387: Well-implemented parameter setup for greedy sampling mode.The sampling parameters are correctly configured to enforce greedy behavior:
- Temperature set to 0.01 (at the greedy boundary)
- top_k set to max int value (no filtering)
- top_p set to 1.0 (no filtering)
- min_p set to 0.0 (no filtering)
The logic properly accounts for each batch's draft tokens plus one additional token.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)
200-232: Top-k/top-p implementation is efficient but addresses past review concerns.The implementation is mathematically correct and handles the complexity of combined top-k and top-p filtering. However, this directly relates to the past review comments about optimization when both k and p have default values.
From the past comments, I can see the author explained that device-side conditional execution isn't feasible with native torch ops and would break CUDA graph capture. The current approach of unconditional filtering is the correct design choice for maintaining performance and graph compatibility.
The sorting-based approach is necessary for top-p but can be expensive for large vocabularies. The current implementation is optimal given the constraints.
🧹 Nitpick comments (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)
154-167: Function implementation is correct but consider parameter validation.The
forward_nativefunction correctly combines top-k/top-p filtering with random sampling. The in-place modification warning is helpful.However, consider adding basic parameter validation to ensure robustness:
def forward_native( logits: torch.Tensor, k: Optional[torch.Tensor], p: Optional[torch.Tensor], ) -> torch.Tensor: """ PyTorch-native implementation of top-k and top-p sampling. The logits tensor may be updated in-place. """ + if logits.dim() != 2: + raise ValueError(f"Expected 2D logits tensor, got {logits.dim()}D") logits = apply_top_k_top_p(logits, k, p) probs = logits.softmax(dim=-1, dtype=torch.float32) return random_sample(probs)
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
tensorrt_llm/_torch/model_config.py(1 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py(10 hunks)tensorrt_llm/_torch/pyexecutor/sampler.py(2 hunks)tensorrt_llm/_torch/speculative/interface.py(1 hunks)tensorrt_llm/_torch/speculative/mtp.py(5 hunks)tests/unittest/_torch/speculative/test_mtp.py(1 hunks)
✅ Files skipped from review due to trivial changes (2)
- tensorrt_llm/_torch/model_config.py
- tensorrt_llm/_torch/speculative/interface.py
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/unittest/_torch/speculative/test_mtp.py
- tensorrt_llm/_torch/speculative/mtp.py
🧰 Additional context used
🧠 Learnings (2)
📓 Common learnings
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.
tensorrt_llm/_torch/pyexecutor/model_engine.py (4)
Learnt from: yechank-nvidia
PR: #6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using from_shared_tensor() is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call strip_for_generation() to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.
Learnt from: moraxu
PR: #6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.
Learnt from: yiqingy0
PR: #5198
File: jenkins/mergeWaiveList.py:0-0
Timestamp: 2025-07-22T08:33:49.109Z
Learning: In the TensorRT-LLM waive list merging system, removed lines are always located at the end of the merge waive lists, which is why the mergeWaiveList.py script uses reverse traversal - it's an optimization for this specific domain constraint.
Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.402Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache() and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.
🔇 Additional comments (16)
tensorrt_llm/_torch/pyexecutor/sampler.py (5)
4-4: Import addition looks good.The Optional import is correctly added and used throughout the new sampling functions.
169-178: Excellent CPU-GPU synchronization avoidance technique.The
random_samplefunction cleverly uses the Gumbel-max trick (exponential noise + argmax) to avoidtorch.multinomial's CPU-GPU synchronization. This is a well-known technique for maintaining CUDA graph compatibility.The implementation is mathematically sound and performance-optimized.
180-198: Min-p implementation is mathematically correct.The adaptive probability thresholding implementation correctly:
- Converts logits to probabilities
- Finds maximum probability per sequence
- Scales min_p threshold by the maximum probability
- Masks tokens below the adaptive threshold
The logic aligns with the min-p sampling strategy used in modern language models.
234-236: Greedy sampling implementation is correct and efficient.Simple argmax implementation that correctly flattens the output to 1D as expected by the sampling interface.
238-244: Temperature scaling is efficient with in-place operation.The in-place division correctly scales logits by temperature while avoiding unnecessary tensor allocation. This is optimal for performance.
tensorrt_llm/_torch/pyexecutor/model_engine.py (11)
20-20: LGTM!The import of
LlmRequestis necessary for accessing sampling configuration parameters in the new helper function.
282-283: LGTM!The initialization logic correctly combines the three required conditions for enabling advanced MTP sampling: speculative decoding, MTP mode, and mixed sampler configuration.
301-301: LGTM!The
enable_mixed_samplerparameter is correctly passed to the model loading function, maintaining consistency with other configuration parameters.
1179-1223: LGTM!The helper function is well-implemented with several good practices:
- Temperature clamping to 0.01 prevents numerical instability
- Uses
torch.iinfo(torch.int32).maxinstead of magic numbers for disabled top_k- Proper handling of None/empty sampling config values
- Clear separation of concerns with individual parameter extraction functions
- Correct extension of parameters for draft tokens + main token
1173-1178: LGTM!The sampling parameter lists are properly initialized when advanced MTP sampling is enabled.
1257-1258: LGTM!Sampling parameters are correctly collected for context requests with no draft tokens.
1343-1344: LGTM!Sampling parameters are correctly collected for extend requests, passing the appropriate draft token length.
1375-1376: LGTM!Sampling parameters are correctly collected for requests with previous batch data, using the maximum draft length.
1410-1411: LGTM!Sampling parameters are correctly collected for generation requests, using the maximum draft length for consistency.
1606-1608: LGTM!The spec_metadata is correctly updated with the collected sampling parameters when advanced MTP sampling is enabled.
1159-1161: LGTM!The advanced MTP sampling setup is correctly invoked during the forward pass with appropriate batch size and draft length parameters.
| def sampling_batch(logits: torch.Tensor, temperatures: torch.Tensor, | ||
| top_k: torch.Tensor, top_p: torch.Tensor, | ||
| min_p: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | ||
| raw_probs = torch.softmax(logits, dim=-1) | ||
| greedy_sampled = greedy_sample(logits) | ||
| logits = apply_temperature(logits, temperatures) | ||
| logits = apply_min_p(logits, min_p) | ||
| random_sampled = forward_native(logits, top_k, top_p) | ||
| next_tokens = torch.where( | ||
| temperatures <= 1e-2, # Match the clamping threshold | ||
| greedy_sampled, | ||
| random_sampled, | ||
| out=greedy_sampled, # Reuse tensor | ||
| ) | ||
| token_probs = torch.gather(raw_probs, dim=1, | ||
| index=next_tokens.unsqueeze(1)).squeeze(-1) | ||
| log_probs = torch.log(token_probs) | ||
| return next_tokens, log_probs | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Unified sampling function is well-designed but has a potential numerical stability issue.
The sampling_batch function effectively combines all sampling strategies with a temperature threshold for greedy vs. random sampling. The design is sound, but there's a potential numerical issue:
The log probability calculation at Line 262 could produce NaN values if token_probs contains zeros due to precision issues:
token_probs = torch.gather(raw_probs, dim=1,
index=next_tokens.unsqueeze(1)).squeeze(-1)
- log_probs = torch.log(token_probs)
+ log_probs = torch.log(token_probs.clamp(min=1e-8))This prevents log(0) = -inf issues that could propagate through the system.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def sampling_batch(logits: torch.Tensor, temperatures: torch.Tensor, | |
| top_k: torch.Tensor, top_p: torch.Tensor, | |
| min_p: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| raw_probs = torch.softmax(logits, dim=-1) | |
| greedy_sampled = greedy_sample(logits) | |
| logits = apply_temperature(logits, temperatures) | |
| logits = apply_min_p(logits, min_p) | |
| random_sampled = forward_native(logits, top_k, top_p) | |
| next_tokens = torch.where( | |
| temperatures <= 1e-2, # Match the clamping threshold | |
| greedy_sampled, | |
| random_sampled, | |
| out=greedy_sampled, # Reuse tensor | |
| ) | |
| token_probs = torch.gather(raw_probs, dim=1, | |
| index=next_tokens.unsqueeze(1)).squeeze(-1) | |
| log_probs = torch.log(token_probs) | |
| return next_tokens, log_probs | |
| def sampling_batch(logits: torch.Tensor, temperatures: torch.Tensor, | |
| top_k: torch.Tensor, top_p: torch.Tensor, | |
| min_p: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| raw_probs = torch.softmax(logits, dim=-1) | |
| greedy_sampled = greedy_sample(logits) | |
| logits = apply_temperature(logits, temperatures) | |
| logits = apply_min_p(logits, min_p) | |
| random_sampled = forward_native(logits, top_k, top_p) | |
| next_tokens = torch.where( | |
| temperatures <= 1e-2, # Match the clamping threshold | |
| greedy_sampled, | |
| random_sampled, | |
| out=greedy_sampled, # Reuse tensor | |
| ) | |
| token_probs = torch.gather( | |
| raw_probs, | |
| dim=1, | |
| index=next_tokens.unsqueeze(1), | |
| ).squeeze(-1) | |
| log_probs = torch.log(token_probs.clamp(min=1e-8)) | |
| return next_tokens, log_probs |
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/sampler.py around lines 246 to 264, the log
probability calculation uses torch.log on token_probs which may contain zeros,
causing NaN values. To fix this, clamp token_probs to a small positive value
(e.g., 1e-10) before applying torch.log to avoid log(0) and ensure numerical
stability.
|
/bot run --disable-fail-fast |
|
PR_Github #13420 [ run ] triggered by Bot |
|
PR_Github #13420 [ run ] completed with state |
|
/bot run |
|
PR_Github #13502 [ run ] triggered by Bot |
|
PR_Github #13502 [ run ] completed with state |
| # Default to greedy mode. If true, use advanced pytorch sampling strategy. | ||
| self.enable_mixed_sampler = False | ||
| if self.model_config is not None: | ||
| self.enable_mixed_sampler = self.model_config.enable_mixed_sampler |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick: This could be a @property rather than a copy, to avoid potential consistency issues in the future.
…agle one model path Signed-off-by: Xuanyu Chen <[email protected]>
0ee79b3 to
bf7910b
Compare
…er logic Signed-off-by: Xuanyu Chen <[email protected]>
|
/bot run |
|
PR_Github #22067 [ run ] triggered by Bot. Commit: |
Signed-off-by: Izzy Putterman <[email protected]>
| Filters logits using adaptive probability thresholding. | ||
| """ | ||
| # Convert logits to probability distribution | ||
| probability_values = torch.nn.functional.softmax(logits, dim=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this effectively neutralizes the temperature right? We apply temp then softmax again which undoes the scaling
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or perhaps in sampling_batch_spec_dec_one_model, we should remove the first softmax and put one just before the sort in the apply top_k top_P
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvm im wrong here, misread something
| The logits tensor may be updated in-place. | ||
| """ | ||
| logits = apply_top_k_top_p(logits, k, p) | ||
| probs = logits.softmax(dim=-1, dtype=torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can technically skip this softmax
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I am also wrong here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think @IzzyPutterman is right: apply_min_p evaluates the softmax of the temperature-scaled logits and uses that to mask out some of the logits (set to -inf). The probs could be masked in the same way (set to 0). The resulting probs can (mind paragraph below) then be reused in apply_top_k_top_p, which masks out more logits/probs.
Every time logits/probs are masked, it is sufficient to renormalize the probs such that they sum to one, which is much cheaper than computing softmax. This is probably also why https://docs.flashinfer.ai/api/sampling.html uses function names like ..._renorm_probs.
Note that much of this is already worked out in #8581, albeit using flashinfer.sampling.
|
PR_Github #22067 [ run ] completed with state |
| ) -> torch.Tensor: | ||
| """Apply top-k and top-p masks to the logits. | ||
| If a top-p is used, this function will sort the logits tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As a perf optimization, should we skip the expensive sorting / softmax / cumsum ops for top_p >=1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If top_p is 1, we can skip the expensive sorting / softmax / cumsum ops.
In the latest trt llm version, it is already implemented. Please refer to https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/pyexecutor/sampling_utils.py#L159-L171.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The skipping is not possible because in regular decoding the sampling is not captured in cuda graph.
This part is captured in cuda graph, so unless there's a kernel that determine whether to skip or not (like cpp/kernels/samplingTopPKernel.cu checkAllTopP) there's no way to check with the cpu flag need_top_p.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Partially reviewed
| return logits | ||
|
|
||
|
|
||
| def apply_top_k_top_p( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that we already have
| def top_k_top_p_sampling_batch( |
and some related functions. We should not duplicate this basic functionality, so let's use the existing functions and extend them as necessary (adding the Tensor-type k, p, temperature, etc.).
Also note that I am working on FlashInfer.sampling based alternatives for those functions. This upcoming PR brings support for Tensor-type k, p, temperature, etc. when FlashInfer is used. If you integrate the improvements made here for the non-FlashInfer case, this could give a quite nice feature set.
Ideally, this PR could (i) improve the existing sampling routines and (ii) use them via
| class SimpleGroupedStrategySampler(GroupedStrategySampler[Strategy]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cf. #6245 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Put up #8581 (work in progress!) to give an idea of what to expect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See also TRTLLM-7723 (and TRTLLM-7152) for scope of ongoing work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I tested flashinfer with cuda graphs and it was breaking a bunch. With the generator objects its quite annoying in TRTLLM becuase in warmup we alternate between cuda graph warmup and non-cuda graph warmup, which breaks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Worth a double check ofc, perhaps there is an easy way around it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ixlmar I think the current implementation of TopK TopP only allows all the request having the same TopK TopP value instead of individual requests having different values, please correct me if I'm wrong.
The current logic in model_engine.py didn't parse out all the sampling params into GPU tensors for cuda graph, this PR enables that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@IzzyPutterman The idea of #8581 is to allow choosing between the sampling routines we have today in sampling_utils.py and those provided by FlashInfer. Both will be available as implementations of GroupedStrategySampler. SimpleGroupedStrategySampler uses the former sampling routines (non FlashInfer) and is already available in main.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jhaotingc Correct. This was what I meant in the first comment:
We should not duplicate this basic functionality, so let's use the existing functions and extend them as necessary (adding the Tensor-type k, p, temperature, etc.).
Ideally, this PR could extend SimpleGroupedStrategySampler to allow for Tensor-type k, p, temperature, etc., in the same way as FlashInferGroupedStrategySampler does it for FlashInfer.sampling in #8581. If the GroupedStrategySampler abstraction is not viable (e.g. because the host data structures interfere with CUDA graphs), then I think we should extend top_k_top_p_sampling_batch (and the related functions) directly (promote scalar arguments to accept any broadcastable torch.Tensor) and reuse them here.
|
@ixlmar @IzzyPutterman would you mind sharing the latest plan regarding this PR? Is the plan to merge this in then add in more performant kernels? @ixlmar we can give you the access to this branch if you'd like to develop on top of this. |
@jhaotingc Thanks. As far as I can tell, the code that I think we should be extending and reusing ( |
| min_p = [] | ||
|
|
||
| # advanced mtp sampling's request preprocessing helper functions | ||
| def collect_req_spec_dec_sampling_params(request: LlmRequest, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The resolution of request.sampling_config to sampling strategy has been cleaned up in #8132. See PR description for the intended semantics. The relevant function is
| def _request_strategy(request: LlmRequest, *, vocab_size: int) -> Strategy: |
The existing function covers various corner cases already (e.g. temperature=0, top_p=1, etc.) and has extensive unit tests. Consider reusing this function here (perhaps make it "public", i.e., rename to something that does not start with _).
| """ | ||
| q = torch.empty_like(probs) | ||
| q.exponential_() | ||
| return probs.div_(q).argmax(dim=-1).view(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have to admit that I am not familiar with this sampling scheme. If you happen to have a literature reference at hand, I would be curious to learn more (perhaps also include a comment stating the name of the method).
BTW, TorchSampler is using torch.multinomial and I did not notice any stream syncs. Code:
| next_tokens = torch.multinomial(softmax, num_samples=1, generator=generator).squeeze(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(Disclaimer: I might have well overlooked that torch.multinomial is syncing so far, so I would be curious to hear more on this.)
| pin_memory=True) | ||
| self.slot_ids[:num_seqs].copy_(mtp_slot_ids, non_blocking=True) | ||
|
|
||
| def _set_up_advanced_sampling(self, batch_size: int, max_draft_len: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code changes in mtp.py and eagle3.py looks very similar, perhaps something could be reused.
[feat] Implement pytorch sampler for MTP
Description
The default behavior of the MTP pytorch decoder remains greedy sampling. Advanced sampling can be enabled via the
enable_mixed_samplerflag inTorchLlmArgs.Test Coverage
temperature <= 1e-2using the new PyTorch sampler.GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.