From 4488ab58320a59a0ceaa49d2537899eaacf0e1af Mon Sep 17 00:00:00 2001 From: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> Date: Fri, 7 Nov 2025 05:25:48 -0800 Subject: [PATCH 01/11] Add four logprobs mode in llm sampler torch backend. Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/sampler.py | 128 ++++++++++++++- .../_torch/pyexecutor/sampling_utils.py | 100 ++++++++++-- tensorrt_llm/executor/result.py | 36 ++++- tensorrt_llm/sampling_params.py | 20 ++- tests/unittest/llmapi/test_llm_pytorch.py | 149 ++++++++++++++++++ 5 files changed, 408 insertions(+), 25 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 5757f8efbc7..6bb8cfc7d5d 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -320,6 +320,8 @@ class _BatchedSamplingResult: batch_req_indices: torch.Tensor # Next tokens for all requests: batch_next_tokens_cuda_int: torch.Tensor + # Processed logits after temperature/penalties/top-k/top-p (for processed logprobs modes): + processed_logits_cuda: Optional[torch.Tensor] = None # Helper class for _PackedStepIndexer and _UnpackedStepIndexer, facilitating the @@ -1284,6 +1286,7 @@ def _sample_batched_by_strategy( req_num_steps: torch.Tensor, req_offsets: torch.Tensor, token_dtype: torch.dtype, + return_processed_logits: bool = False, ) -> _BatchedSamplingResult: grouped_requests = _group_requests_by_strategy_key( requests, @@ -1307,6 +1310,8 @@ def _sample_batched_by_strategy( batch_next_tokens_cuda_int = torch.empty( (logits_cuda.size(0),), device=cuda_device, dtype=token_dtype ) + # For processed logprobs: collect processed logits after temperature/top-k/top-p + processed_logits_list = [] if return_processed_logits else None batch_req_idx_offset_start = 0 batch_next_tokens_offset_start = 0 for (strategy_key, speculation_needs_probs), ( @@ -1353,6 +1358,7 @@ def _sample_batched_by_strategy( generator=generator_cuda, return_probs=speculation_needs_probs, group_logit_indices=logit_indices_for_sampler, + return_processed_logits=return_processed_logits, ) ) batch_next_tokens_offset_end = ( @@ -1362,6 +1368,10 @@ def _sample_batched_by_strategy( batch_next_tokens_offset_start:batch_next_tokens_offset_end ].copy_(group_next_tokens_cuda, non_blocking=True) + # Collect processed logits if requested + if return_processed_logits and group_processed_logits_cuda is not None: + processed_logits_list.append(group_processed_logits_cuda) + # Set LlmRequest.py_target_probs if speculation_needs_probs: assert group_softmax_cuda is not None @@ -1387,9 +1397,15 @@ def _sample_batched_by_strategy( if needs_d2t: self._apply_d2t(batch_next_tokens_cuda_int, model_outputs) + # Concatenate processed logits if collected + processed_logits_cuda = None + if return_processed_logits and processed_logits_list: + processed_logits_cuda = torch.cat(processed_logits_list, dim=0) + return _BatchedSamplingResult( batch_req_indices=batch_req_indices, batch_next_tokens_cuda_int=batch_next_tokens_cuda_int, + processed_logits_cuda=processed_logits_cuda, ) def _unbatch_sampling_results( @@ -1704,6 +1720,69 @@ def request_stop_words(request: LlmRequest, new_tokens: torch.Tensor): per_step[step][request_idx] = True return per_step + def _compute_processed_logprobs( + self, + batched_sampling_result: _BatchedSamplingResult, + requests: list[LlmRequest], + logits_cuda_indexer: _PackedStepIndexer, + req_num_steps: torch.Tensor, + logprobs_mode: str, + cuda_device: torch.device, + ) -> None: + """ + Compute processed logprobs from logits after temperature/penalties/top-k/top-p. + Updates request objects with the computed logprobs. + """ + processed_logits_cuda = batched_sampling_result.processed_logits_cuda + if processed_logits_cuda is None: + return + + # Get requests that need logprobs + logprobs_req_indices = [ + req_id for req_id, req in enumerate(requests) if req.py_num_logprobs + ] + logprobs_logit_indices = logits_cuda_indexer[logprobs_req_indices] + logprobs_logit_indices_cuda = logprobs_logit_indices.to( + device=cuda_device, non_blocking=True + ) + + # Apply log_softmax if mode is processed_logprobs + if logprobs_mode == "processed_logprobs": + processed_logprobs_cuda = F.log_softmax( + processed_logits_cuda[logprobs_logit_indices_cuda].to( + dtype=torch.float32, non_blocking=True + ), + dim=-1, + ) + elif logprobs_mode == "processed_logits": + processed_logprobs_cuda = processed_logits_cuda[logprobs_logit_indices_cuda].to( + dtype=torch.float32, non_blocking=True + ) + + # Compute top-k + topk_vals_cuda, topk_indices_cuda = torch.topk( + processed_logprobs_cuda, k=max(req.py_num_logprobs for req in requests), dim=-1 + ) + + # Transfer to CPU + topk_vals = torch.empty_like(topk_vals_cuda, device="cpu", pin_memory=True) + topk_indices = torch.empty_like(topk_indices_cuda, device="cpu", pin_memory=True) + topk_vals.copy_(topk_vals_cuda, non_blocking=True) + topk_indices.copy_(topk_indices_cuda, non_blocking=True) + + # Store in request objects + current_offset = 0 + for req_id, steps in zip( + logprobs_req_indices, req_num_steps[logprobs_req_indices].tolist() + ): + req = requests[req_id] + next_offset = current_offset + steps + req.py_topk_logprobs_vals = topk_vals[current_offset:next_offset, : req.py_num_logprobs] + req.py_topk_logprobs_indices = topk_indices[ + current_offset:next_offset, : req.py_num_logprobs + ] + current_offset = next_offset + @nvtx_range("_process_requests") def _process_requests( self, @@ -1776,7 +1855,19 @@ def _process_requests( # Handle top-k logprobs. This is done outside the sampling loop, # because the returned logprobs are specified to not reflect temperature scaling, # top-k/top-p masking, etc. + # Determine logprobs mode from first request that needs logprobs + logprobs_mode = "raw_logprobs" # default if return_log_probs: + for req in requests: + if req.py_num_logprobs: + logprob_params = getattr(req, "_logprob_params", None) + if logprob_params and hasattr(logprob_params, "logprobs_mode"): + logprobs_mode = logprob_params.logprobs_mode + break + + # Handle raw logprobs modes: compute BEFORE sampling (temperature/penalties/top-k/top-p) + # Raw modes return logprobs/logits before any sampling modifications + if return_log_probs and logprobs_mode.startswith("raw"): assert logits_cuda.dim() == 2, "logits should be 2D" logprobs_req_indices = [ req_id for req_id, req in enumerate(requests) if req.py_num_logprobs @@ -1785,10 +1876,21 @@ def _process_requests( logprobs_logit_indices_cuda = logprobs_logit_indices.to( device=logits_cuda.device, non_blocking=True ) - logprobs_cuda = F.log_softmax( - logits_cuda[logprobs_logit_indices_cuda].to(dtype=torch.float32, non_blocking=True), - dim=-1, - ) + + # Compute raw logprobs or logits based on mode + if logprobs_mode == "raw_logprobs": + logprobs_cuda = F.log_softmax( + logits_cuda[logprobs_logit_indices_cuda].to( + dtype=torch.float32, non_blocking=True + ), + dim=-1, + ) + elif logprobs_mode == "raw_logits": + # Return unnormalized logits + logprobs_cuda = logits_cuda[logprobs_logit_indices_cuda].to( + dtype=torch.float32, non_blocking=True + ) + topk_vals_cuda, topk_indices_cuda = torch.topk( logprobs_cuda, k=max(req.py_num_logprobs for req in requests), dim=-1 ) @@ -1813,6 +1915,11 @@ def _process_requests( current_offset = next_offset # Perform sampling in batches + # For processed modes, we need to capture logits after temperature/penalties/top-k/top-p + return_processed_logits_for_logprobs = return_log_probs and logprobs_mode.startswith( + "processed" + ) + batched_sampling_result = self._sample_batched_by_strategy( logits_cuda, requests, @@ -1822,8 +1929,21 @@ def _process_requests( req_offsets=req_offsets, req_num_steps=req_num_steps, token_dtype=new_tokens_cuda.dtype, + return_processed_logits=return_processed_logits_for_logprobs, ) + # Handle processed logprobs modes: compute AFTER sampling (temperature/penalties/top-k/top-p) + # Processed modes return logprobs/logits after all sampling modifications + if return_processed_logits_for_logprobs: + self._compute_processed_logprobs( + batched_sampling_result, + requests, + logits_cuda_indexer, + req_num_steps, + logprobs_mode, + cuda_device, + ) + # Fill results into output buffers new_tokens_host = self._unbatch_sampling_results( batched_sampling_result, diff --git a/tensorrt_llm/_torch/pyexecutor/sampling_utils.py b/tensorrt_llm/_torch/pyexecutor/sampling_utils.py index 35e64afe4c2..334947f252c 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampling_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/sampling_utils.py @@ -95,7 +95,8 @@ def top_k_sampling_batch( top_k: int, temperature: float, generator: Optional[torch.Generator] = None, -) -> tuple[torch.Tensor, torch.Tensor]: + return_processed_logits: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: # NB: To be replaced by a more efficient implementation. return top_k_top_p_sampling_batch( logits, @@ -103,6 +104,7 @@ def top_k_sampling_batch( temperature=temperature, generator=generator, top_p=1, + return_processed_logits=return_processed_logits, ) @@ -112,7 +114,8 @@ def top_p_sampling_batch( top_p: float, temperature: float, generator: Optional[torch.Generator] = None, -) -> tuple[torch.Tensor, torch.Tensor]: + return_processed_logits: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: # NB: To be replaced by a more efficient implementation. return top_k_top_p_sampling_batch( logits, @@ -120,6 +123,7 @@ def top_p_sampling_batch( top_k=logits.size(1), temperature=temperature, generator=generator, + return_processed_logits=return_processed_logits, ) @@ -128,7 +132,8 @@ def temperature_sampling_batch( *, temperature: float, generator: Optional[torch.Generator] = None, -) -> tuple[torch.Tensor, torch.Tensor]: + return_processed_logits: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: # NB: To be replaced by a more efficient implementation. return top_k_top_p_sampling_batch( logits, @@ -136,6 +141,7 @@ def temperature_sampling_batch( top_k=logits.size(1), temperature=temperature, generator=generator, + return_processed_logits=return_processed_logits, ) @@ -146,10 +152,28 @@ def top_k_top_p_sampling_batch( top_p: float, temperature: float, generator: Optional[torch.Generator] = None, -) -> tuple[torch.Tensor, torch.Tensor]: + return_processed_logits: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Perform top-k and top-p sampling. + + Args: + logits: Input logits tensor [batch_size, vocab_size] + top_k: Top-k value + top_p: Top-p (nucleus sampling) value + temperature: Temperature for sampling + generator: Optional torch random generator + return_processed_logits: If True, return processed logits after temperature/top-k/top-p + + Returns: + Tuple of (sampled_tokens, softmax_probs, processed_logits) + processed_logits is None if return_processed_logits is False + """ logits_dim = logits.dim() assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]" assert temperature > 0, "non-greedy sampling requires valid temperature" + # Store processed logits if requested (clone before in-place modifications) + processed_logits_to_return = logits.clone() if return_processed_logits else None logits = logits / max(temperature, 1e-5) batch_size, vocab_size = logits.size() @@ -186,24 +210,51 @@ def top_k_top_p_sampling_batch( ) logits = logits.masked_fill(indices_to_remove, float("-inf")) + # Update processed logits if requested (apply same temperature/top-k/top-p) + if return_processed_logits: + processed_logits_to_return = processed_logits_to_return / max(temperature, 1e-5) + if need_top_k: + processed_logits_to_return = torch.where( + processed_logits_to_return < min_values, + torch.full_like(processed_logits_to_return, float("-inf")), + processed_logits_to_return, + ) + if need_top_p: + processed_logits_to_return = processed_logits_to_return.masked_fill( + indices_to_remove, float("-inf") + ) + # compute probability distribution softmax = torch.softmax(logits, dim=-1) # sample from the distribution and generate result of [batch_size, 1] next_tokens = torch.multinomial(softmax, num_samples=1, generator=generator).squeeze(-1) - return next_tokens, softmax + return next_tokens, softmax, processed_logits_to_return def greedy_search_sampling_batch( logits, *, return_probs: bool = True, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return_processed_logits: bool = False, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Perform greedy sampling. + + Args: + logits: Input logits tensor + return_probs: If True, return softmax probabilities + return_processed_logits: If True, return processed logits + + Returns: + Tuple of (sampled_tokens, softmax_probs, processed_logits) + """ next_tokens = torch.argmax(logits, dim=-1) softmax: Optional[torch.Tensor] = None if return_probs: softmax = torch.softmax(logits, dim=-1) - return next_tokens, softmax + processed_logits = logits.clone() if return_processed_logits else None + return next_tokens, softmax, processed_logits def get_rejected_indices( @@ -254,7 +305,21 @@ def sample( *, generator: Optional[torch.Generator] = None, return_probs: bool = True, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return_processed_logits: bool = False, +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Sample from logits using the specified strategy. + + Args: + strategy: Sampling strategy tuple (strategy_name, *params) + logits: Input logits tensor + generator: Optional random generator + return_probs: If True, return softmax probabilities + return_processed_logits: If True, return processed logits after temperature/penalties/top-k/top-p + + Returns: + Tuple of (sampled_tokens, softmax_probs, processed_logits) + """ match strategy: case ("top_k", top_k, temperature): tokens, softmax = top_k_sampling_batch( @@ -264,29 +329,34 @@ def sample( generator=generator, ) case ("top_p", top_p, temperature): - tokens, softmax = top_p_sampling_batch( + tokens, softmax, processed_logits = top_p_sampling_batch( logits, top_p=top_p, generator=generator, temperature=temperature, + return_processed_logits=return_processed_logits, ) case ("top_k_top_p", top_k, top_p, temperature): - tokens, softmax = top_k_top_p_sampling_batch( + tokens, softmax, processed_logits = top_k_top_p_sampling_batch( logits, top_k=top_k, top_p=top_p, temperature=temperature, generator=generator, + return_processed_logits=return_processed_logits, ) case ("temperature", temperature): - tokens, softmax = temperature_sampling_batch( + tokens, softmax, processed_logits = temperature_sampling_batch( logits, temperature=temperature, generator=generator, + return_processed_logits=return_processed_logits, ) case ("greedy", None): - tokens, softmax = greedy_search_sampling_batch(logits, return_probs=return_probs) - return tokens, softmax + tokens, softmax, processed_logits = greedy_search_sampling_batch( + logits, return_probs=return_probs, return_processed_logits=return_processed_logits + ) + return tokens, softmax, processed_logits GenericStrategyKeyType = TypeVar("GenericStrategyKeyType") @@ -330,7 +400,8 @@ def sample_grouped_strategies( group_logit_indices: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None, return_probs: bool, - ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + return_processed_logits: bool = False, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: if group_logit_indices is None: assert logits.size(0) == len(strategies) else: @@ -343,6 +414,7 @@ def sample_grouped_strategies( logits, generator=generator, return_probs=return_probs, + return_processed_logits=return_processed_logits, ) diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index d47743cf8f0..425ce700a26 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -1007,6 +1007,7 @@ def compute_logprobs( context_logits: Optional[torch.Tensor], generation_logits: Optional[torch.Tensor], output_token_ids: Optional[list[int]], + apply_log_softmax: bool = True, ) -> LogProbsResult: """ Compute top-K logprobs from logits when engine doesn't provide them directly. @@ -1016,14 +1017,24 @@ def compute_logprobs( - Generation logprobs (from generation_logits, TRT backend): used when backend doesn't compute them in sampler (e.g., TRT). - Generation logprobs (PyTorch backend): not used; computed in sampler, not here. + Args: + k_prompt_logprobs: Number of top logprobs to return for prompt tokens + k_logprobs: Number of top logprobs to return for generated tokens + context_logits: Logits for context/prompt tokens + generation_logits: Logits for generated tokens + output_token_ids: Token IDs of generated outputs + apply_log_softmax: If True, apply log_softmax to logits. If False, logits are already log_softmax'ed or are raw unnormalized values (depending on logprobs_mode) + Returns: LogProbsResult, a NamedTuple containing: - prompt: Optional[List[Dict[token_id, Logprob]]] logprobs for prompt tokens. - generation: Optional[List[Dict[token_id, Logprob]]] logprobs for generated tokens. """ - def _topk_logprobs(logits: torch.Tensor, top_k: int, - tokens: Optional[list[int]]) -> TokenLogprobs: + def _topk_logprobs(logits: torch.Tensor, + top_k: int, + tokens: Optional[list[int]], + apply_softmax: bool = True) -> TokenLogprobs: if logits.dim() == 3: # reshape from [1, T, V] to [T, V] logits = logits.squeeze(0) @@ -1033,7 +1044,14 @@ def _topk_logprobs(logits: torch.Tensor, top_k: int, # than output tokens. logits = logits[:len(tokens)] - logprobs = F.log_softmax(logits.to("cuda", dtype=torch.float32), dim=-1) + # Apply log_softmax only if requested (for logprobs modes) + # For logits modes, skip log_softmax to return raw unnormalized values + if apply_softmax: + logprobs = F.log_softmax(logits.to("cuda", dtype=torch.float32), + dim=-1) + else: + logprobs = logits.to("cuda", dtype=torch.float32) + topk_vals, topk_indices = torch.topk(logprobs, k=top_k, dim=-1) results: TokenLogprobs = [] @@ -1058,10 +1076,16 @@ def _topk_logprobs(logits: torch.Tensor, top_k: int, return results prompt_logprobs = _topk_logprobs( - context_logits, k_prompt_logprobs, - None) if k_prompt_logprobs and context_logits is not None else None + context_logits, + k_prompt_logprobs, + None, + apply_softmax=apply_log_softmax + ) if k_prompt_logprobs and context_logits is not None else None generation_logprobs = _topk_logprobs( - generation_logits, k_logprobs, output_token_ids + generation_logits, + k_logprobs, + output_token_ids, + apply_softmax=apply_log_softmax ) if k_logprobs and generation_logits is not None else None return LogProbsResult(prompt=prompt_logprobs, diff --git a/tensorrt_llm/sampling_params.py b/tensorrt_llm/sampling_params.py index b7ad63821ad..5714c2451ca 100644 --- a/tensorrt_llm/sampling_params.py +++ b/tensorrt_llm/sampling_params.py @@ -2,7 +2,7 @@ import os from abc import ABC, abstractmethod from dataclasses import dataclass, field, fields -from typing import List, NamedTuple, Optional, Tuple, Union +from typing import List, Literal, NamedTuple, Optional, Tuple, Union import torch from pydantic import BaseModel @@ -10,6 +10,13 @@ from tensorrt_llm.bindings import executor as tllme from tensorrt_llm.logger import logger +# Logprobs mode type definition (similar to vLLM) +# - "raw_logits": return raw unnormalized logits before temperature/penalties/top-k/top-p +# - "raw_logprobs": return log-softmax of raw logits +# - "processed_logits": return unnormalized logits after temperature/penalties/top-k/top-p +# - "processed_logprobs": return log-softmax of processed logits +LogprobsMode = Literal["raw_logits", "raw_logprobs", "processed_logits", "processed_logprobs"] + @dataclass(slots=True, kw_only=True) class GuidedDecodingParams: @@ -44,6 +51,8 @@ class LogprobParams(NamedTuple): drop_context_logits: bool = False # Drop the geneation_logits once the logprobs are computed drop_generation_logits: bool = False + # Logprobs mode: controls whether to return logprobs before or after sampling modifications + logprobs_mode: LogprobsMode = "raw_logprobs" class LogitsProcessor(ABC): @@ -174,6 +183,12 @@ class SamplingParams: logprobs (int, optional): Number of log probabilities to return per output token. Defaults to None. prompt_logprobs (int, optional): Number of log probabilities to return per prompt token. Defaults to None. + logprobs_mode (str): Controls whether to return logprobs before or after sampling modifications. Defaults to "raw_logprobs". + Options: + - "raw_logits": Return raw unnormalized logits before temperature/penalties/top-k/top-p + - "raw_logprobs": Return log-softmax of raw logits + - "processed_logits": Return unnormalized logits after temperature/penalties/top-k/top-p + - "processed_logprobs": Return log-softmax of processed logits return_context_logits (bool): Controls if Result should contain the context logits. Defaults to False. return_generation_logits (bool): Controls if Result should contain the generation logits. Defaults to False. exclude_input_from_output (bool): Controls if output tokens in Result should include the input tokens. Defaults to True. @@ -250,6 +265,9 @@ class SamplingParams: return_perf_metrics: bool = False additional_model_outputs: Optional[List[str]] = None + # Logprobs mode: controls whether to return logprobs before or after sampling modifications + logprobs_mode: LogprobsMode = "raw_logprobs" + # Used in logprobs calculation in TRT flow to drop logits early if user did not explicitly request them. # Can be deprecated after migration to PyTorch backend. _context_logits_auto_enabled: bool = False diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 1bdd2dfbeb5..3c1ca6048c4 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -917,6 +917,155 @@ def test_llm_return_logprobs_streaming(prompt_logprobs, logprobs, backend="pytorch") +@skip_ray +@pytest.mark.parametrize("logprobs_mode", [ + "raw_logits", + "raw_logprobs", + "processed_logits", + "processed_logprobs", +]) +def test_llm_logprobs_modes(logprobs_mode: str): + """ + Test that different logprobs modes work correctly in PyTorch backend. + Validates that: + - logprobs modes return non-positive values + - logits modes return some positive values + - all modes return valid logprobs + """ + llm = LLM_torch( + llama_model_path, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + ) + + prompts = ["The future of AI is"] + sampling_params = SamplingParams( + max_tokens=5, + logprobs=3, + temperature=0.8, + top_k=50, + logprobs_mode=logprobs_mode, + ) + + outputs = list(llm.generate(prompts, sampling_params)) + assert len(outputs) == 1 + + output = outputs[0] + assert len(output.outputs) == 1 + logprobs_list = output.outputs[0].logprobs + + assert logprobs_list is not None + assert len(logprobs_list) > 0 + + # Collect all logprob values + all_values = [] + for token_logprobs in logprobs_list: + for logprob_obj in token_logprobs.values(): + all_values.append(logprob_obj.logprob) + + # Validate based on mode + if "logprobs" in logprobs_mode: + # Should have non-positive values + for val in all_values: + assert val <= 0.0, f"Mode {logprobs_mode} should have non-positive values, got {val}" + + if "logits" in logprobs_mode: + # Should have some positive values + has_positive = any(v > 0 for v in all_values) + assert has_positive, f"Mode {logprobs_mode} should have some positive values" + + +@skip_ray +@pytest.mark.parametrize("temperature", [0.5, 1.0, 1.5]) +def test_llm_raw_vs_processed_logprobs(temperature: float): + """ + Test that raw and processed logprobs differ when temperature != 1.0. + Raw logprobs are computed before temperature scaling. + Processed logprobs are computed after temperature scaling. + """ + llm = LLM_torch( + llama_model_path, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + ) + + prompt = ["The capital of France is"] + + # Get raw logprobs + raw_params = SamplingParams( + max_tokens=3, + logprobs=5, + temperature=temperature, + logprobs_mode="raw_logprobs", + seed=42, + ) + raw_outputs = list(llm.generate(prompt, raw_params)) + + # Get processed logprobs + processed_params = SamplingParams( + max_tokens=3, + logprobs=5, + temperature=temperature, + logprobs_mode="processed_logprobs", + seed=42, + ) + processed_outputs = list(llm.generate(prompt, processed_params)) + + # Compare first token logprobs + raw_first = raw_outputs[0].outputs[0].logprobs[0] + processed_first = processed_outputs[0].outputs[0].logprobs[0] + + # Find common tokens + common_ids = set(raw_first.keys()) & set(processed_first.keys()) + assert len(common_ids) > 0 + + token_id = list(common_ids)[0] + raw_val = raw_first[token_id].logprob + processed_val = processed_first[token_id].logprob + + if temperature != 1.0: + # Values should differ with temperature != 1.0 + assert raw_val != processed_val, \ + f"Raw and processed should differ with temperature={temperature}" + + +@skip_ray +def test_llm_logprobs_mode_backward_compatibility(): + """ + Test that default behavior is backward compatible (raw_logprobs). + """ + llm = LLM_torch( + llama_model_path, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + ) + + prompt = ["Hello world"] + + # Explicit raw_logprobs + explicit_params = SamplingParams( + max_tokens=3, + logprobs=2, + temperature=0.8, + logprobs_mode="raw_logprobs", + seed=123, + ) + explicit_outputs = list(llm.generate(prompt, explicit_params)) + + # Default (should be raw_logprobs) + default_params = SamplingParams( + max_tokens=3, + logprobs=2, + temperature=0.8, + seed=123, + ) + default_outputs = list(llm.generate(prompt, default_params)) + + # Should produce same tokens + explicit_tokens = explicit_outputs[0].outputs[0].token_ids + default_tokens = default_outputs[0].outputs[0].token_ids + + assert explicit_tokens == default_tokens, \ + "Default should match explicit raw_logprobs" + + class TestLlmError: def test_max_num_token_check(self): From aef48e0ecfbdcc5ba31a4aa2bbb8dbc8d290edd9 Mon Sep 17 00:00:00 2001 From: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> Date: Sun, 9 Nov 2025 02:47:31 -0800 Subject: [PATCH 02/11] Add logprobs mode pytest. Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- tests/unittest/llmapi/test_logprobs_mode.py | 270 ++++++++++++++++++++ 1 file changed, 270 insertions(+) create mode 100644 tests/unittest/llmapi/test_logprobs_mode.py diff --git a/tests/unittest/llmapi/test_logprobs_mode.py b/tests/unittest/llmapi/test_logprobs_mode.py new file mode 100644 index 00000000000..fabe056eb5c --- /dev/null +++ b/tests/unittest/llmapi/test_logprobs_mode.py @@ -0,0 +1,270 @@ +import pytest +from utils.llm_data import llm_models_root + +from tensorrt_llm import LLM +from tensorrt_llm.llmapi import KvCacheConfig +from tensorrt_llm.sampling_params import SamplingParams + +MODEL_PATH = llm_models_root() / "DeepSeek-V3-Lite/bf16" + + +@pytest.mark.parametrize( + "logprobs_mode", + [ + "raw_logits", + "raw_logprobs", + "processed_logits", + "processed_logprobs", + ], +) +@pytest.mark.parametrize("temperature", [0.0, 0.8]) +@pytest.mark.parametrize("top_k", [None, 50]) +def test_logprobs_mode_basic(logprobs_mode, temperature, top_k): + llm = LLM( + MODEL_PATH, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + ) + + sampling_params = SamplingParams( + max_tokens=5, + logprobs=3, + temperature=temperature, + top_k=top_k, + logprobs_mode=logprobs_mode, + ) + + prompts = ["The future of AI is"] + outputs = llm.generate(prompts, sampling_params=sampling_params) + + assert len(outputs) == 1 + output = outputs[0] + assert len(output.outputs) == 1 + completion = output.outputs[0] + + # Check that logprobs were returned + assert completion.logprobs is not None + assert len(completion.logprobs) > 0 + + # Collect all logprob values + all_logprob_values = [] + for token_logprobs in completion.logprobs: + for token_id, logprob_obj in token_logprobs.items(): + all_logprob_values.append(logprob_obj.logprob) + + # Validate based on mode + if "logprobs" in logprobs_mode: + for val in all_logprob_values: + assert val <= 0.0, ( + f"Logprobs mode {logprobs_mode} should have non-positive values, got {val}" + ) + + if "logits" in logprobs_mode: + has_positive = any(val > 0 for val in all_logprob_values) + assert has_positive, f"Logits mode {logprobs_mode} should have some positive values" + + del llm + + +@pytest.mark.parametrize("temperature", [0.5, 1.0]) +def test_raw_vs_processed_logprobs_difference(temperature): + llm = LLM( + MODEL_PATH, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.7), + ) + + prompt = ["The capital of France is"] + + # Get raw logprobs + raw_params = SamplingParams( + max_tokens=3, + logprobs=5, + temperature=temperature, + top_k=20, + logprobs_mode="raw_logprobs", + seed=42, # Fix seed for reproducibility + ) + raw_outputs = llm.generate(prompt, sampling_params=raw_params) + + # Get processed logprobs + processed_params = SamplingParams( + max_tokens=3, + logprobs=5, + temperature=temperature, + top_k=20, + logprobs_mode="processed_logprobs", + seed=42, # Same seed + ) + processed_outputs = llm.generate(prompt, sampling_params=processed_params) + + # Extract logprobs from first token + raw_first_token_logprobs = raw_outputs[0].outputs[0].logprobs[0] + processed_first_token_logprobs = processed_outputs[0].outputs[0].logprobs[0] + + # Get a common token ID + common_token_ids = set(raw_first_token_logprobs.keys()) & set( + processed_first_token_logprobs.keys() + ) + assert len(common_token_ids) > 0, "Should have some common token IDs" + + token_id = list(common_token_ids)[0] + raw_val = raw_first_token_logprobs[token_id].logprob + processed_val = processed_first_token_logprobs[token_id].logprob + + if temperature != 1.0: + assert raw_val != processed_val, ( + f"Raw and processed logprobs should differ with temperature={temperature}" + ) + + del llm + + +def test_logprobs_mode_with_greedy_sampling(): + llm = LLM( + MODEL_PATH, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + ) + + prompt = ["Once upon a time"] + + for mode in ["raw_logprobs", "processed_logprobs", "raw_logits", "processed_logits"]: + sampling_params = SamplingParams( + max_tokens=4, + logprobs=3, + temperature=0.0, # Greedy sampling + logprobs_mode=mode, + ) + + outputs = llm.generate(prompt, sampling_params=sampling_params) + + assert len(outputs) == 1 + assert len(outputs[0].outputs[0].logprobs) > 0, ( + f"Mode {mode} should return logprobs even with greedy sampling" + ) + + # Check value ranges + logprob_vals = [ + logprob_obj.logprob + for token_logprobs in outputs[0].outputs[0].logprobs + for logprob_obj in token_logprobs.values() + ] + + if "logprobs" in mode: + assert all(v <= 0.0 for v in logprob_vals), ( + f"Mode {mode} should have non-positive values" + ) + + if "logits" in mode: + assert any(v > 0 for v in logprob_vals), f"Mode {mode} should have some positive values" + + del llm + + +def test_backward_compatibility(): + llm = LLM( + MODEL_PATH, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + ) + + prompt = ["Hello world"] + + # Test with explicit raw_logprobs + explicit_params = SamplingParams( + max_tokens=3, + logprobs=2, + temperature=0.8, + logprobs_mode="raw_logprobs", + seed=123, + ) + explicit_outputs = llm.generate(prompt, sampling_params=explicit_params) + + # Test with default (should be raw_logprobs) + default_params = SamplingParams( + max_tokens=3, + logprobs=2, + temperature=0.8, + seed=123, + ) + default_outputs = llm.generate(prompt, sampling_params=default_params) + + # Results should be identical (same sampled tokens, same logprobs) + explicit_tokens = explicit_outputs[0].outputs[0].token_ids + default_tokens = default_outputs[0].outputs[0].token_ids + + assert explicit_tokens == default_tokens, ( + "Default mode should produce same results as explicit raw_logprobs" + ) + + del llm + + +def test_logprobs_mode_with_top_p(): + """Test that processed modes correctly capture the effect of top-p sampling.""" + llm = LLM( + MODEL_PATH, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + ) + + prompt = ["The weather today is"] + + params = SamplingParams( + max_tokens=2, + logprobs=10, # Request many logprobs to see the effect + temperature=1.0, + top_p=0.5, # Restrict to top 50% probability mass + logprobs_mode="processed_logits", + ) + + outputs = llm.generate(prompt, sampling_params=params) + + # Check that some logits are -inf (masked by top-p) + first_token_logprobs = outputs[0].outputs[0].logprobs[0] + logprob_values = [obj.logprob for obj in first_token_logprobs.values()] + print(f"logprob_values: {logprob_values}") + assert any(val == float("-inf") for val in logprob_values) + + del llm + + +@pytest.mark.parametrize("mode", ["raw_logprobs", "processed_logprobs"]) +def test_prompt_logprobs_with_modes(mode): + """Test that logprobs modes also work for prompt logprobs.""" + llm = LLM( + MODEL_PATH, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + ) + + prompt = ["Hello world, how are you?"] + + params = SamplingParams( + max_tokens=2, + logprobs=3, + prompt_logprobs=3, # Request prompt logprobs + logprobs_mode=mode, + temperature=0.8, + ) + + outputs = llm.generate(prompt, sampling_params=params) + + # Check that prompt logprobs were returned + prompt_logprobs = outputs[0].outputs[0].prompt_logprobs + assert prompt_logprobs is not None + assert len(prompt_logprobs) > 0 + + # Validate values based on mode + for token_logprobs in prompt_logprobs: + if token_logprobs: # Can be None for first token + for logprob_obj in token_logprobs.values(): + if "logprobs" in mode: + assert logprob_obj.logprob <= 0.0, ( + f"Prompt logprobs in mode {mode} should be non-positive" + ) + + del llm + + +if __name__ == "__main__": + # Run a quick smoke test + print("Running smoke test for logprobs modes...") + test_logprobs_mode_basic("raw_logprobs", 0.8, None) + test_logprobs_mode_basic("processed_logprobs", 0.8, None) + print("Smoke test passed!") From 3ae17cd29677d2d044d9d2875840d7ea71d81aed Mon Sep 17 00:00:00 2001 From: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> Date: Wed, 12 Nov 2025 00:43:48 -0800 Subject: [PATCH 03/11] Fix raw logits return mixed with logprobs issue. Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/llm_request.py | 2 ++ tensorrt_llm/_torch/pyexecutor/sampler.py | 2 ++ tensorrt_llm/executor/base_worker.py | 5 +++++ tensorrt_llm/executor/executor.py | 3 ++- tests/unittest/llmapi/test_logprobs_mode.py | 10 ++++++++-- 5 files changed, 19 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 01d3f35f876..1587b75ff62 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -797,6 +797,8 @@ def executor_request_to_llm_request( py_multimodal_data=getattr(executor_request, "py_multimodal_data", None), kv_cache_retention_config=executor_request.kv_cache_retention_config) + if hasattr(executor_request, "_logprob_params"): + llm_request._logprob_params = executor_request._logprob_params if child_req_ids: for child_id in child_req_ids: llm_request.create_child_request(child_id) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 6bb8cfc7d5d..5f1eacbec1f 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -1890,6 +1890,8 @@ def _process_requests( logprobs_cuda = logits_cuda[logprobs_logit_indices_cuda].to( dtype=torch.float32, non_blocking=True ) + print(f"DEBUG: this is should be raw logits: {logprobs_cuda}") + print(f"DEBUG: logprobs_cuda.shape: {logprobs_cuda.shape}") topk_vals_cuda, topk_indices_cuda = torch.topk( logprobs_cuda, k=max(req.py_num_logprobs for req in requests), dim=-1 diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index 65423e3f8ed..6a2d2d35e8f 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -569,6 +569,11 @@ def _deduce_max_tokens(request: GenerationRequest, if self._is_pytorch_backend and request.scheduling_params is not None: executor_request.py_scheduling_params = request.scheduling_params + if self._is_pytorch_backend: + logprob_params = self._get_logprob_params(request) + if logprob_params is not None: + executor_request._logprob_params = logprob_params + if request.arrival_time is not None: executor_request.py_arrival_time = request.arrival_time diff --git a/tensorrt_llm/executor/executor.py b/tensorrt_llm/executor/executor.py index e7ab9192ad1..a204238780f 100644 --- a/tensorrt_llm/executor/executor.py +++ b/tensorrt_llm/executor/executor.py @@ -234,7 +234,8 @@ def _get_logprob_params( or self.postproc_config.num_postprocess_workers > 0, drop_generation_logits=( not request.sampling_params._need_return_generation_logits) - or self.postproc_config.num_postprocess_workers > 0) + or self.postproc_config.num_postprocess_workers > 0, + logprobs_mode=request.sampling_params.logprobs_mode) return logprob_params diff --git a/tests/unittest/llmapi/test_logprobs_mode.py b/tests/unittest/llmapi/test_logprobs_mode.py index fabe056eb5c..de758c16c1e 100644 --- a/tests/unittest/llmapi/test_logprobs_mode.py +++ b/tests/unittest/llmapi/test_logprobs_mode.py @@ -1,10 +1,14 @@ import pytest from utils.llm_data import llm_models_root +import tensorrt_llm from tensorrt_llm import LLM from tensorrt_llm.llmapi import KvCacheConfig from tensorrt_llm.sampling_params import SamplingParams +print(f"tensorrt_llm.__file__: {tensorrt_llm.__file__}") +# /home/dominicw/.local/lib/python3.12/site-packages/tensorrt_llm + MODEL_PATH = llm_models_root() / "DeepSeek-V3-Lite/bf16" @@ -22,11 +26,11 @@ def test_logprobs_mode_basic(logprobs_mode, temperature, top_k): llm = LLM( MODEL_PATH, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.7), ) sampling_params = SamplingParams( - max_tokens=5, + max_tokens=10, logprobs=3, temperature=temperature, top_k=top_k, @@ -35,6 +39,7 @@ def test_logprobs_mode_basic(logprobs_mode, temperature, top_k): prompts = ["The future of AI is"] outputs = llm.generate(prompts, sampling_params=sampling_params) + print(f"outputs: {outputs}") assert len(outputs) == 1 output = outputs[0] @@ -51,6 +56,7 @@ def test_logprobs_mode_basic(logprobs_mode, temperature, top_k): for token_id, logprob_obj in token_logprobs.items(): all_logprob_values.append(logprob_obj.logprob) + print(f"all_logprob_values: {all_logprob_values}") # Validate based on mode if "logprobs" in logprobs_mode: for val in all_logprob_values: From 0873078fbe409b882ea6df2e56aa07ce36734eae Mon Sep 17 00:00:00 2001 From: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> Date: Thu, 13 Nov 2025 00:10:56 -0800 Subject: [PATCH 04/11] Add compare raw logits and context and generation logtis return on pytest. Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/sampler.py | 2 ++ tests/unittest/llmapi/test_logprobs_mode.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 5f1eacbec1f..cb39593173e 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -1879,6 +1879,7 @@ def _process_requests( # Compute raw logprobs or logits based on mode if logprobs_mode == "raw_logprobs": + logits_cuda = raw_logits_cuda[:sum_steps] logprobs_cuda = F.log_softmax( logits_cuda[logprobs_logit_indices_cuda].to( dtype=torch.float32, non_blocking=True @@ -1887,6 +1888,7 @@ def _process_requests( ) elif logprobs_mode == "raw_logits": # Return unnormalized logits + logits_cuda = raw_logits_cuda[:sum_steps] logprobs_cuda = logits_cuda[logprobs_logit_indices_cuda].to( dtype=torch.float32, non_blocking=True ) diff --git a/tests/unittest/llmapi/test_logprobs_mode.py b/tests/unittest/llmapi/test_logprobs_mode.py index de758c16c1e..da0466e7140 100644 --- a/tests/unittest/llmapi/test_logprobs_mode.py +++ b/tests/unittest/llmapi/test_logprobs_mode.py @@ -35,6 +35,8 @@ def test_logprobs_mode_basic(logprobs_mode, temperature, top_k): temperature=temperature, top_k=top_k, logprobs_mode=logprobs_mode, + return_context_logits=True, + return_generation_logits=True, ) prompts = ["The future of AI is"] From c1bd781b7889d11f1b87cffcd76f4b8ae9bcc589 Mon Sep 17 00:00:00 2001 From: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> Date: Tue, 18 Nov 2025 01:20:35 -0800 Subject: [PATCH 05/11] Only keep processed logprobs in logprobs mode for now. Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/sampler.py | 139 +++--------- .../_torch/pyexecutor/sampling_utils.py | 128 ++++++----- tensorrt_llm/executor/result.py | 28 +-- tensorrt_llm/sampling_params.py | 19 +- tests/unittest/llmapi/test_llm_pytorch.py | 139 ++++++------ tests/unittest/llmapi/test_logprobs_mode.py | 204 +++++++++--------- 6 files changed, 298 insertions(+), 359 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index cb39593173e..62296df1b89 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -68,6 +68,7 @@ Strategy, UtilsSamplingParams, get_rejected_indices, + process_logits, resolve_sampling_strategy, sample, sample_rejected, @@ -320,8 +321,6 @@ class _BatchedSamplingResult: batch_req_indices: torch.Tensor # Next tokens for all requests: batch_next_tokens_cuda_int: torch.Tensor - # Processed logits after temperature/penalties/top-k/top-p (for processed logprobs modes): - processed_logits_cuda: Optional[torch.Tensor] = None # Helper class for _PackedStepIndexer and _UnpackedStepIndexer, facilitating the @@ -1286,7 +1285,6 @@ def _sample_batched_by_strategy( req_num_steps: torch.Tensor, req_offsets: torch.Tensor, token_dtype: torch.dtype, - return_processed_logits: bool = False, ) -> _BatchedSamplingResult: grouped_requests = _group_requests_by_strategy_key( requests, @@ -1310,8 +1308,6 @@ def _sample_batched_by_strategy( batch_next_tokens_cuda_int = torch.empty( (logits_cuda.size(0),), device=cuda_device, dtype=token_dtype ) - # For processed logprobs: collect processed logits after temperature/top-k/top-p - processed_logits_list = [] if return_processed_logits else None batch_req_idx_offset_start = 0 batch_next_tokens_offset_start = 0 for (strategy_key, speculation_needs_probs), ( @@ -1358,7 +1354,6 @@ def _sample_batched_by_strategy( generator=generator_cuda, return_probs=speculation_needs_probs, group_logit_indices=logit_indices_for_sampler, - return_processed_logits=return_processed_logits, ) ) batch_next_tokens_offset_end = ( @@ -1368,10 +1363,6 @@ def _sample_batched_by_strategy( batch_next_tokens_offset_start:batch_next_tokens_offset_end ].copy_(group_next_tokens_cuda, non_blocking=True) - # Collect processed logits if requested - if return_processed_logits and group_processed_logits_cuda is not None: - processed_logits_list.append(group_processed_logits_cuda) - # Set LlmRequest.py_target_probs if speculation_needs_probs: assert group_softmax_cuda is not None @@ -1397,15 +1388,9 @@ def _sample_batched_by_strategy( if needs_d2t: self._apply_d2t(batch_next_tokens_cuda_int, model_outputs) - # Concatenate processed logits if collected - processed_logits_cuda = None - if return_processed_logits and processed_logits_list: - processed_logits_cuda = torch.cat(processed_logits_list, dim=0) - return _BatchedSamplingResult( batch_req_indices=batch_req_indices, batch_next_tokens_cuda_int=batch_next_tokens_cuda_int, - processed_logits_cuda=processed_logits_cuda, ) def _unbatch_sampling_results( @@ -1720,69 +1705,6 @@ def request_stop_words(request: LlmRequest, new_tokens: torch.Tensor): per_step[step][request_idx] = True return per_step - def _compute_processed_logprobs( - self, - batched_sampling_result: _BatchedSamplingResult, - requests: list[LlmRequest], - logits_cuda_indexer: _PackedStepIndexer, - req_num_steps: torch.Tensor, - logprobs_mode: str, - cuda_device: torch.device, - ) -> None: - """ - Compute processed logprobs from logits after temperature/penalties/top-k/top-p. - Updates request objects with the computed logprobs. - """ - processed_logits_cuda = batched_sampling_result.processed_logits_cuda - if processed_logits_cuda is None: - return - - # Get requests that need logprobs - logprobs_req_indices = [ - req_id for req_id, req in enumerate(requests) if req.py_num_logprobs - ] - logprobs_logit_indices = logits_cuda_indexer[logprobs_req_indices] - logprobs_logit_indices_cuda = logprobs_logit_indices.to( - device=cuda_device, non_blocking=True - ) - - # Apply log_softmax if mode is processed_logprobs - if logprobs_mode == "processed_logprobs": - processed_logprobs_cuda = F.log_softmax( - processed_logits_cuda[logprobs_logit_indices_cuda].to( - dtype=torch.float32, non_blocking=True - ), - dim=-1, - ) - elif logprobs_mode == "processed_logits": - processed_logprobs_cuda = processed_logits_cuda[logprobs_logit_indices_cuda].to( - dtype=torch.float32, non_blocking=True - ) - - # Compute top-k - topk_vals_cuda, topk_indices_cuda = torch.topk( - processed_logprobs_cuda, k=max(req.py_num_logprobs for req in requests), dim=-1 - ) - - # Transfer to CPU - topk_vals = torch.empty_like(topk_vals_cuda, device="cpu", pin_memory=True) - topk_indices = torch.empty_like(topk_indices_cuda, device="cpu", pin_memory=True) - topk_vals.copy_(topk_vals_cuda, non_blocking=True) - topk_indices.copy_(topk_indices_cuda, non_blocking=True) - - # Store in request objects - current_offset = 0 - for req_id, steps in zip( - logprobs_req_indices, req_num_steps[logprobs_req_indices].tolist() - ): - req = requests[req_id] - next_offset = current_offset + steps - req.py_topk_logprobs_vals = topk_vals[current_offset:next_offset, : req.py_num_logprobs] - req.py_topk_logprobs_indices = topk_indices[ - current_offset:next_offset, : req.py_num_logprobs - ] - current_offset = next_offset - @nvtx_range("_process_requests") def _process_requests( self, @@ -1855,9 +1777,9 @@ def _process_requests( # Handle top-k logprobs. This is done outside the sampling loop, # because the returned logprobs are specified to not reflect temperature scaling, # top-k/top-p masking, etc. - # Determine logprobs mode from first request that needs logprobs - logprobs_mode = "raw_logprobs" # default + if return_log_probs: + logprobs_mode = None for req in requests: if req.py_num_logprobs: logprob_params = getattr(req, "_logprob_params", None) @@ -1865,9 +1787,6 @@ def _process_requests( logprobs_mode = logprob_params.logprobs_mode break - # Handle raw logprobs modes: compute BEFORE sampling (temperature/penalties/top-k/top-p) - # Raw modes return logprobs/logits before any sampling modifications - if return_log_probs and logprobs_mode.startswith("raw"): assert logits_cuda.dim() == 2, "logits should be 2D" logprobs_req_indices = [ req_id for req_id, req in enumerate(requests) if req.py_num_logprobs @@ -1877,23 +1796,34 @@ def _process_requests( device=logits_cuda.device, non_blocking=True ) - # Compute raw logprobs or logits based on mode - if logprobs_mode == "raw_logprobs": - logits_cuda = raw_logits_cuda[:sum_steps] + # Compute logprobs based on mode + if logprobs_mode == "processed_logprobs": + # Process logits with the same transformations as sampling (temperature, top-k, top-p) + # but without actually sampling + processed_logits_list = [] + for req_id in logprobs_req_indices: + req = requests[req_id] + strategy = _request_strategy(req, vocab_size=logits_cuda.size(1)) + req_logits_indices = logits_cuda_indexer[req_id] + req_logits = logits_cuda[req_logits_indices].to( + dtype=torch.float32, non_blocking=True + ) + # Apply the same processing as sampling would apply + processed_req_logits = process_logits(strategy, req_logits) + processed_logits_list.append(processed_req_logits) + # Concatenate all processed logits + processed_logits_cuda = torch.cat(processed_logits_list, dim=0) + # Apply log_softmax to get log probabilities + logprobs_cuda = F.log_softmax(processed_logits_cuda, dim=-1) + else: + # For raw_logprobs and other modes, use raw logits (before sampling modifications) + raw_logits_for_logprobs = raw_logits_cuda[:sum_steps] logprobs_cuda = F.log_softmax( - logits_cuda[logprobs_logit_indices_cuda].to( + raw_logits_for_logprobs[logprobs_logit_indices_cuda].to( dtype=torch.float32, non_blocking=True ), dim=-1, ) - elif logprobs_mode == "raw_logits": - # Return unnormalized logits - logits_cuda = raw_logits_cuda[:sum_steps] - logprobs_cuda = logits_cuda[logprobs_logit_indices_cuda].to( - dtype=torch.float32, non_blocking=True - ) - print(f"DEBUG: this is should be raw logits: {logprobs_cuda}") - print(f"DEBUG: logprobs_cuda.shape: {logprobs_cuda.shape}") topk_vals_cuda, topk_indices_cuda = torch.topk( logprobs_cuda, k=max(req.py_num_logprobs for req in requests), dim=-1 @@ -1919,10 +1849,6 @@ def _process_requests( current_offset = next_offset # Perform sampling in batches - # For processed modes, we need to capture logits after temperature/penalties/top-k/top-p - return_processed_logits_for_logprobs = return_log_probs and logprobs_mode.startswith( - "processed" - ) batched_sampling_result = self._sample_batched_by_strategy( logits_cuda, @@ -1933,21 +1859,8 @@ def _process_requests( req_offsets=req_offsets, req_num_steps=req_num_steps, token_dtype=new_tokens_cuda.dtype, - return_processed_logits=return_processed_logits_for_logprobs, ) - # Handle processed logprobs modes: compute AFTER sampling (temperature/penalties/top-k/top-p) - # Processed modes return logprobs/logits after all sampling modifications - if return_processed_logits_for_logprobs: - self._compute_processed_logprobs( - batched_sampling_result, - requests, - logits_cuda_indexer, - req_num_steps, - logprobs_mode, - cuda_device, - ) - # Fill results into output buffers new_tokens_host = self._unbatch_sampling_results( batched_sampling_result, diff --git a/tensorrt_llm/_torch/pyexecutor/sampling_utils.py b/tensorrt_llm/_torch/pyexecutor/sampling_utils.py index 334947f252c..7b2fc2c81e4 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampling_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/sampling_utils.py @@ -95,8 +95,7 @@ def top_k_sampling_batch( top_k: int, temperature: float, generator: Optional[torch.Generator] = None, - return_processed_logits: bool = False, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[torch.Tensor, torch.Tensor]: # NB: To be replaced by a more efficient implementation. return top_k_top_p_sampling_batch( logits, @@ -104,7 +103,6 @@ def top_k_sampling_batch( temperature=temperature, generator=generator, top_p=1, - return_processed_logits=return_processed_logits, ) @@ -114,7 +112,6 @@ def top_p_sampling_batch( top_p: float, temperature: float, generator: Optional[torch.Generator] = None, - return_processed_logits: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: # NB: To be replaced by a more efficient implementation. return top_k_top_p_sampling_batch( @@ -123,7 +120,6 @@ def top_p_sampling_batch( top_k=logits.size(1), temperature=temperature, generator=generator, - return_processed_logits=return_processed_logits, ) @@ -132,7 +128,6 @@ def temperature_sampling_batch( *, temperature: float, generator: Optional[torch.Generator] = None, - return_processed_logits: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: # NB: To be replaced by a more efficient implementation. return top_k_top_p_sampling_batch( @@ -141,7 +136,6 @@ def temperature_sampling_batch( top_k=logits.size(1), temperature=temperature, generator=generator, - return_processed_logits=return_processed_logits, ) @@ -152,7 +146,6 @@ def top_k_top_p_sampling_batch( top_p: float, temperature: float, generator: Optional[torch.Generator] = None, - return_processed_logits: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Perform top-k and top-p sampling. @@ -163,17 +156,13 @@ def top_k_top_p_sampling_batch( top_p: Top-p (nucleus sampling) value temperature: Temperature for sampling generator: Optional torch random generator - return_processed_logits: If True, return processed logits after temperature/top-k/top-p Returns: - Tuple of (sampled_tokens, softmax_probs, processed_logits) - processed_logits is None if return_processed_logits is False + Tuple of (sampled_tokens, softmax_probs) """ logits_dim = logits.dim() assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]" assert temperature > 0, "non-greedy sampling requires valid temperature" - # Store processed logits if requested (clone before in-place modifications) - processed_logits_to_return = logits.clone() if return_processed_logits else None logits = logits / max(temperature, 1e-5) batch_size, vocab_size = logits.size() @@ -210,51 +199,34 @@ def top_k_top_p_sampling_batch( ) logits = logits.masked_fill(indices_to_remove, float("-inf")) - # Update processed logits if requested (apply same temperature/top-k/top-p) - if return_processed_logits: - processed_logits_to_return = processed_logits_to_return / max(temperature, 1e-5) - if need_top_k: - processed_logits_to_return = torch.where( - processed_logits_to_return < min_values, - torch.full_like(processed_logits_to_return, float("-inf")), - processed_logits_to_return, - ) - if need_top_p: - processed_logits_to_return = processed_logits_to_return.masked_fill( - indices_to_remove, float("-inf") - ) - # compute probability distribution softmax = torch.softmax(logits, dim=-1) # sample from the distribution and generate result of [batch_size, 1] next_tokens = torch.multinomial(softmax, num_samples=1, generator=generator).squeeze(-1) - return next_tokens, softmax, processed_logits_to_return + return next_tokens, softmax def greedy_search_sampling_batch( logits, *, return_probs: bool = True, - return_processed_logits: bool = False, -) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: +) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """ Perform greedy sampling. Args: logits: Input logits tensor return_probs: If True, return softmax probabilities - return_processed_logits: If True, return processed logits Returns: - Tuple of (sampled_tokens, softmax_probs, processed_logits) + Tuple of (sampled_tokens, softmax_probs) """ next_tokens = torch.argmax(logits, dim=-1) softmax: Optional[torch.Tensor] = None if return_probs: softmax = torch.softmax(logits, dim=-1) - processed_logits = logits.clone() if return_processed_logits else None - return next_tokens, softmax, processed_logits + return next_tokens, softmax def get_rejected_indices( @@ -299,13 +271,77 @@ def sample_rejected( return cast(int, new_token.item()) +def process_logits( + strategy: Strategy, + logits: torch.Tensor, +) -> torch.Tensor: + """ + Process logits according to the specified strategy (temperature, top-k, top-p) + without sampling. Returns processed logits ready for log_softmax. + + Args: + strategy: Sampling strategy tuple (strategy_name, *params) + logits: Input logits tensor [batch_size, vocab_size] + + Returns: + Processed logits tensor [batch_size, vocab_size] + """ + logits = logits.clone() + match strategy: + case ("top_k", top_k, temperature): + logits = logits / max(temperature, 1e-5) + batch_size, vocab_size = logits.size() + if top_k < vocab_size: + values, _ = torch.topk(logits, top_k, dim=-1) + min_values = values[:, -1].unsqueeze(-1).expand(batch_size, vocab_size) + logits = torch.where( + logits < min_values, torch.full_like(logits, float("-inf")), logits + ) + case ("top_p", top_p, temperature): + logits = logits / max(temperature, 1e-5) + if top_p < 1: + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() + sorted_indices_to_remove[:, 0] = 0 + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + logits = logits.masked_fill(indices_to_remove, float("-inf")) + case ("top_k_top_p", top_k, top_p, temperature): + logits = logits / max(temperature, 1e-5) + batch_size, vocab_size = logits.size() + if top_k < vocab_size: + values, _ = torch.topk(logits, top_k, dim=-1) + min_values = values[:, -1].unsqueeze(-1).expand(batch_size, vocab_size) + logits = torch.where( + logits < min_values, torch.full_like(logits, float("-inf")), logits + ) + if top_p < 1: + sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) + cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() + sorted_indices_to_remove[:, 0] = 0 + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + logits = logits.masked_fill(indices_to_remove, float("-inf")) + case ("temperature", temperature): + logits = logits / max(temperature, 1e-5) + case ("greedy", None): + # No processing needed for greedy + pass + return logits + + def sample( strategy: Strategy, logits: torch.Tensor, *, generator: Optional[torch.Generator] = None, return_probs: bool = True, - return_processed_logits: bool = False, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ Sample from logits using the specified strategy. @@ -315,10 +351,9 @@ def sample( logits: Input logits tensor generator: Optional random generator return_probs: If True, return softmax probabilities - return_processed_logits: If True, return processed logits after temperature/penalties/top-k/top-p Returns: - Tuple of (sampled_tokens, softmax_probs, processed_logits) + Tuple of (sampled_tokens, softmax_probs) """ match strategy: case ("top_k", top_k, temperature): @@ -329,34 +364,29 @@ def sample( generator=generator, ) case ("top_p", top_p, temperature): - tokens, softmax, processed_logits = top_p_sampling_batch( + tokens, softmax = top_p_sampling_batch( logits, top_p=top_p, generator=generator, temperature=temperature, - return_processed_logits=return_processed_logits, ) case ("top_k_top_p", top_k, top_p, temperature): - tokens, softmax, processed_logits = top_k_top_p_sampling_batch( + tokens, softmax = top_k_top_p_sampling_batch( logits, top_k=top_k, top_p=top_p, temperature=temperature, generator=generator, - return_processed_logits=return_processed_logits, ) case ("temperature", temperature): - tokens, softmax, processed_logits = temperature_sampling_batch( + tokens, softmax = temperature_sampling_batch( logits, temperature=temperature, generator=generator, - return_processed_logits=return_processed_logits, ) case ("greedy", None): - tokens, softmax, processed_logits = greedy_search_sampling_batch( - logits, return_probs=return_probs, return_processed_logits=return_processed_logits - ) - return tokens, softmax, processed_logits + tokens, softmax = greedy_search_sampling_batch(logits, return_probs=return_probs) + return tokens, softmax GenericStrategyKeyType = TypeVar("GenericStrategyKeyType") @@ -400,8 +430,7 @@ def sample_grouped_strategies( group_logit_indices: Optional[torch.Tensor] = None, generator: Optional[torch.Generator] = None, return_probs: bool, - return_processed_logits: bool = False, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: if group_logit_indices is None: assert logits.size(0) == len(strategies) else: @@ -414,7 +443,6 @@ def sample_grouped_strategies( logits, generator=generator, return_probs=return_probs, - return_processed_logits=return_processed_logits, ) diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index 425ce700a26..f92240e45c0 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -1007,7 +1007,6 @@ def compute_logprobs( context_logits: Optional[torch.Tensor], generation_logits: Optional[torch.Tensor], output_token_ids: Optional[list[int]], - apply_log_softmax: bool = True, ) -> LogProbsResult: """ Compute top-K logprobs from logits when engine doesn't provide them directly. @@ -1023,7 +1022,6 @@ def compute_logprobs( context_logits: Logits for context/prompt tokens generation_logits: Logits for generated tokens output_token_ids: Token IDs of generated outputs - apply_log_softmax: If True, apply log_softmax to logits. If False, logits are already log_softmax'ed or are raw unnormalized values (depending on logprobs_mode) Returns: LogProbsResult, a NamedTuple containing: @@ -1031,10 +1029,8 @@ def compute_logprobs( - generation: Optional[List[Dict[token_id, Logprob]]] logprobs for generated tokens. """ - def _topk_logprobs(logits: torch.Tensor, - top_k: int, - tokens: Optional[list[int]], - apply_softmax: bool = True) -> TokenLogprobs: + def _topk_logprobs(logits: torch.Tensor, top_k: int, + tokens: Optional[list[int]]) -> TokenLogprobs: if logits.dim() == 3: # reshape from [1, T, V] to [T, V] logits = logits.squeeze(0) @@ -1044,13 +1040,7 @@ def _topk_logprobs(logits: torch.Tensor, # than output tokens. logits = logits[:len(tokens)] - # Apply log_softmax only if requested (for logprobs modes) - # For logits modes, skip log_softmax to return raw unnormalized values - if apply_softmax: - logprobs = F.log_softmax(logits.to("cuda", dtype=torch.float32), - dim=-1) - else: - logprobs = logits.to("cuda", dtype=torch.float32) + logprobs = F.log_softmax(logits.to("cuda", dtype=torch.float32), dim=-1) topk_vals, topk_indices = torch.topk(logprobs, k=top_k, dim=-1) @@ -1076,16 +1066,10 @@ def _topk_logprobs(logits: torch.Tensor, return results prompt_logprobs = _topk_logprobs( - context_logits, - k_prompt_logprobs, - None, - apply_softmax=apply_log_softmax - ) if k_prompt_logprobs and context_logits is not None else None + context_logits, k_prompt_logprobs, + None) if k_prompt_logprobs and context_logits is not None else None generation_logprobs = _topk_logprobs( - generation_logits, - k_logprobs, - output_token_ids, - apply_softmax=apply_log_softmax + generation_logits, k_logprobs, output_token_ids ) if k_logprobs and generation_logits is not None else None return LogProbsResult(prompt=prompt_logprobs, diff --git a/tensorrt_llm/sampling_params.py b/tensorrt_llm/sampling_params.py index 5714c2451ca..bfaab6c5500 100644 --- a/tensorrt_llm/sampling_params.py +++ b/tensorrt_llm/sampling_params.py @@ -10,12 +10,10 @@ from tensorrt_llm.bindings import executor as tllme from tensorrt_llm.logger import logger -# Logprobs mode type definition (similar to vLLM) -# - "raw_logits": return raw unnormalized logits before temperature/penalties/top-k/top-p -# - "raw_logprobs": return log-softmax of raw logits -# - "processed_logits": return unnormalized logits after temperature/penalties/top-k/top-p -# - "processed_logprobs": return log-softmax of processed logits -LogprobsMode = Literal["raw_logits", "raw_logprobs", "processed_logits", "processed_logprobs"] +# Logprobs mode: +# - "processed_logprobs": return log-softmax of greedy sampled logits +# TODO: add "return_raw_context_logits" and "return_raw_generation_logits" later +LogprobsMode = Literal["processed_logprobs"] @dataclass(slots=True, kw_only=True) @@ -52,7 +50,7 @@ class LogprobParams(NamedTuple): # Drop the geneation_logits once the logprobs are computed drop_generation_logits: bool = False # Logprobs mode: controls whether to return logprobs before or after sampling modifications - logprobs_mode: LogprobsMode = "raw_logprobs" + logprobs_mode: LogprobsMode = "processed_logprobs" class LogitsProcessor(ABC): @@ -183,11 +181,8 @@ class SamplingParams: logprobs (int, optional): Number of log probabilities to return per output token. Defaults to None. prompt_logprobs (int, optional): Number of log probabilities to return per prompt token. Defaults to None. - logprobs_mode (str): Controls whether to return logprobs before or after sampling modifications. Defaults to "raw_logprobs". + logprobs_mode (str): Controls return logprobs after sampling modifications. Defaults to "processed_logprobs". Options: - - "raw_logits": Return raw unnormalized logits before temperature/penalties/top-k/top-p - - "raw_logprobs": Return log-softmax of raw logits - - "processed_logits": Return unnormalized logits after temperature/penalties/top-k/top-p - "processed_logprobs": Return log-softmax of processed logits return_context_logits (bool): Controls if Result should contain the context logits. Defaults to False. return_generation_logits (bool): Controls if Result should contain the generation logits. Defaults to False. @@ -266,7 +261,7 @@ class SamplingParams: additional_model_outputs: Optional[List[str]] = None # Logprobs mode: controls whether to return logprobs before or after sampling modifications - logprobs_mode: LogprobsMode = "raw_logprobs" + logprobs_mode: LogprobsMode = "processed_logprobs" # Used in logprobs calculation in TRT flow to drop logits early if user did not explicitly request them. # Can be deprecated after migration to PyTorch backend. diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 3c1ca6048c4..19c300b87bd 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -918,21 +918,14 @@ def test_llm_return_logprobs_streaming(prompt_logprobs, logprobs, @skip_ray -@pytest.mark.parametrize("logprobs_mode", [ - "raw_logits", - "raw_logprobs", - "processed_logits", - "processed_logprobs", -]) -def test_llm_logprobs_modes(logprobs_mode: str): +def test_llm_logprobs_modes(): """ - Test that different logprobs modes work correctly in PyTorch backend. + Test that processed_logprobs mode works correctly in PyTorch backend. Validates that: - - logprobs modes return non-positive values - - logits modes return some positive values - - all modes return valid logprobs + - processed_logprobs returns non-positive values (log probabilities) + - all values are valid logprobs """ - llm = LLM_torch( + llm = LLM( llama_model_path, kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), ) @@ -943,7 +936,7 @@ def test_llm_logprobs_modes(logprobs_mode: str): logprobs=3, temperature=0.8, top_k=50, - logprobs_mode=logprobs_mode, + logprobs_mode="processed_logprobs", ) outputs = list(llm.generate(prompts, sampling_params)) @@ -962,94 +955,69 @@ def test_llm_logprobs_modes(logprobs_mode: str): for logprob_obj in token_logprobs.values(): all_values.append(logprob_obj.logprob) - # Validate based on mode - if "logprobs" in logprobs_mode: - # Should have non-positive values - for val in all_values: - assert val <= 0.0, f"Mode {logprobs_mode} should have non-positive values, got {val}" - - if "logits" in logprobs_mode: - # Should have some positive values - has_positive = any(v > 0 for v in all_values) - assert has_positive, f"Mode {logprobs_mode} should have some positive values" + # Validate that processed_logprobs returns non-positive values + for val in all_values: + assert val <= 0.0, f"processed_logprobs should have non-positive values, got {val}" @skip_ray @pytest.mark.parametrize("temperature", [0.5, 1.0, 1.5]) -def test_llm_raw_vs_processed_logprobs(temperature: float): +def test_llm_processed_logprobs_with_temperature(temperature: float): """ - Test that raw and processed logprobs differ when temperature != 1.0. - Raw logprobs are computed before temperature scaling. - Processed logprobs are computed after temperature scaling. + Test that processed_logprobs correctly applies temperature scaling. """ - llm = LLM_torch( + llm = LLM( llama_model_path, kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), ) prompt = ["The capital of France is"] - # Get raw logprobs - raw_params = SamplingParams( - max_tokens=3, - logprobs=5, - temperature=temperature, - logprobs_mode="raw_logprobs", - seed=42, - ) - raw_outputs = list(llm.generate(prompt, raw_params)) - - # Get processed logprobs - processed_params = SamplingParams( + # Get processed logprobs (after temperature/top-k/top-p modifications) + params = SamplingParams( max_tokens=3, logprobs=5, temperature=temperature, + top_k=20, logprobs_mode="processed_logprobs", seed=42, ) - processed_outputs = list(llm.generate(prompt, processed_params)) + outputs = list(llm.generate(prompt, params)) - # Compare first token logprobs - raw_first = raw_outputs[0].outputs[0].logprobs[0] - processed_first = processed_outputs[0].outputs[0].logprobs[0] + # Check first token logprobs + first_token_logprobs = outputs[0].outputs[0].logprobs[0] + assert len(first_token_logprobs) > 0, "Should have logprobs returned" - # Find common tokens - common_ids = set(raw_first.keys()) & set(processed_first.keys()) - assert len(common_ids) > 0 - - token_id = list(common_ids)[0] - raw_val = raw_first[token_id].logprob - processed_val = processed_first[token_id].logprob - - if temperature != 1.0: - # Values should differ with temperature != 1.0 - assert raw_val != processed_val, \ - f"Raw and processed should differ with temperature={temperature}" + # Validate that all values are non-positive (log probabilities) + for token_id, logprob_obj in first_token_logprobs.items(): + assert logprob_obj.logprob <= 0.0, ( + f"processed_logprobs should have non-positive values, got {logprob_obj.logprob}" + ) @skip_ray def test_llm_logprobs_mode_backward_compatibility(): """ - Test that default behavior is backward compatible (raw_logprobs). + Test that default behavior uses processed_logprobs. """ - llm = LLM_torch( + llm = LLM( llama_model_path, kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), ) prompt = ["Hello world"] - # Explicit raw_logprobs + # Explicit processed_logprobs explicit_params = SamplingParams( max_tokens=3, logprobs=2, temperature=0.8, - logprobs_mode="raw_logprobs", + logprobs_mode="processed_logprobs", seed=123, ) explicit_outputs = list(llm.generate(prompt, explicit_params)) - # Default (should be raw_logprobs) + # Default (should be processed_logprobs) default_params = SamplingParams( max_tokens=3, logprobs=2, @@ -1063,7 +1031,54 @@ def test_llm_logprobs_mode_backward_compatibility(): default_tokens = default_outputs[0].outputs[0].token_ids assert explicit_tokens == default_tokens, \ - "Default should match explicit raw_logprobs" + "Default should match explicit processed_logprobs" + + +@skip_ray +def test_llm_processed_logprobs_with_top_k_top_p(): + """ + Test that processed_logprobs correctly applies top-k and top-p filtering. + This verifies the fix for processed_logprobs implementation. + """ + llm = LLM( + llama_model_path, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + ) + + prompt = ["The future of technology"] + + # Test with top_k and top_p to ensure processed_logprobs applies filtering + params = SamplingParams( + max_tokens=2, + logprobs=15, # Request more logprobs than top_k to see filtering + temperature=1.0, + top_k=5, # Only keep top 5 tokens + top_p=0.9, # Restrict to top 90% probability mass + logprobs_mode="processed_logprobs", + ) + + outputs = list(llm.generate(prompt, params)) + assert len(outputs) == 1 + + # Check that logprobs were returned + logprobs_list = outputs[0].outputs[0].logprobs + assert logprobs_list is not None + assert len(logprobs_list) > 0 + + # Check first token logprobs + first_token_logprobs = logprobs_list[0] + logprob_values = [obj.logprob for obj in first_token_logprobs.values()] + + # Should have some -inf values (masked by top-k/top-p) + assert any(val == float("-inf") for val in logprob_values), ( + "processed_logprobs should have -inf values for tokens masked by top-k/top-p" + ) + + # All non-inf values should be non-positive (log probabilities) + non_inf_values = [v for v in logprob_values if v != float("-inf")] + if non_inf_values: + assert all(v <= 0.0 for v in non_inf_values), ( + "processed_logprobs non-inf values should be non-positive") class TestLlmError: diff --git a/tests/unittest/llmapi/test_logprobs_mode.py b/tests/unittest/llmapi/test_logprobs_mode.py index da0466e7140..1036958c1da 100644 --- a/tests/unittest/llmapi/test_logprobs_mode.py +++ b/tests/unittest/llmapi/test_logprobs_mode.py @@ -12,18 +12,9 @@ MODEL_PATH = llm_models_root() / "DeepSeek-V3-Lite/bf16" -@pytest.mark.parametrize( - "logprobs_mode", - [ - "raw_logits", - "raw_logprobs", - "processed_logits", - "processed_logprobs", - ], -) @pytest.mark.parametrize("temperature", [0.0, 0.8]) @pytest.mark.parametrize("top_k", [None, 50]) -def test_logprobs_mode_basic(logprobs_mode, temperature, top_k): +def test_logprobs_mode_basic(temperature, top_k): llm = LLM( MODEL_PATH, kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.7), @@ -34,9 +25,10 @@ def test_logprobs_mode_basic(logprobs_mode, temperature, top_k): logprobs=3, temperature=temperature, top_k=top_k, - logprobs_mode=logprobs_mode, + logprobs_mode="processed_logprobs", return_context_logits=True, return_generation_logits=True, + seed=42, ) prompts = ["The future of AI is"] @@ -59,22 +51,15 @@ def test_logprobs_mode_basic(logprobs_mode, temperature, top_k): all_logprob_values.append(logprob_obj.logprob) print(f"all_logprob_values: {all_logprob_values}") - # Validate based on mode - if "logprobs" in logprobs_mode: - for val in all_logprob_values: - assert val <= 0.0, ( - f"Logprobs mode {logprobs_mode} should have non-positive values, got {val}" - ) - - if "logits" in logprobs_mode: - has_positive = any(val > 0 for val in all_logprob_values) - assert has_positive, f"Logits mode {logprobs_mode} should have some positive values" + # Validate that processed_logprobs returns non-positive values (log probabilities) + for val in all_logprob_values: + assert val <= 0.0, f"processed_logprobs mode should have non-positive values, got {val}" del llm -@pytest.mark.parametrize("temperature", [0.5, 1.0]) -def test_raw_vs_processed_logprobs_difference(temperature): +@pytest.mark.parametrize("temperature", [0.5, 1.0, 1.5]) +def test_processed_logprobs_with_temperature(temperature): llm = LLM( MODEL_PATH, kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.7), @@ -82,45 +67,25 @@ def test_raw_vs_processed_logprobs_difference(temperature): prompt = ["The capital of France is"] - # Get raw logprobs - raw_params = SamplingParams( - max_tokens=3, - logprobs=5, - temperature=temperature, - top_k=20, - logprobs_mode="raw_logprobs", - seed=42, # Fix seed for reproducibility - ) - raw_outputs = llm.generate(prompt, sampling_params=raw_params) - - # Get processed logprobs - processed_params = SamplingParams( + # Get processed logprobs (after temperature/top-k/top-p modifications) + params = SamplingParams( max_tokens=3, logprobs=5, temperature=temperature, top_k=20, logprobs_mode="processed_logprobs", - seed=42, # Same seed + seed=42, ) - processed_outputs = llm.generate(prompt, sampling_params=processed_params) + outputs = llm.generate(prompt, sampling_params=params) # Extract logprobs from first token - raw_first_token_logprobs = raw_outputs[0].outputs[0].logprobs[0] - processed_first_token_logprobs = processed_outputs[0].outputs[0].logprobs[0] - - # Get a common token ID - common_token_ids = set(raw_first_token_logprobs.keys()) & set( - processed_first_token_logprobs.keys() - ) - assert len(common_token_ids) > 0, "Should have some common token IDs" - - token_id = list(common_token_ids)[0] - raw_val = raw_first_token_logprobs[token_id].logprob - processed_val = processed_first_token_logprobs[token_id].logprob + first_token_logprobs = outputs[0].outputs[0].logprobs[0] + assert len(first_token_logprobs) > 0, "Should have logprobs returned" - if temperature != 1.0: - assert raw_val != processed_val, ( - f"Raw and processed logprobs should differ with temperature={temperature}" + # Validate that all values are non-positive (log probabilities) + for token_id, logprob_obj in first_token_logprobs.items(): + assert logprob_obj.logprob <= 0.0, ( + f"processed_logprobs should have non-positive values, got {logprob_obj.logprob}" ) del llm @@ -134,35 +99,28 @@ def test_logprobs_mode_with_greedy_sampling(): prompt = ["Once upon a time"] - for mode in ["raw_logprobs", "processed_logprobs", "raw_logits", "processed_logits"]: - sampling_params = SamplingParams( - max_tokens=4, - logprobs=3, - temperature=0.0, # Greedy sampling - logprobs_mode=mode, - ) - - outputs = llm.generate(prompt, sampling_params=sampling_params) + sampling_params = SamplingParams( + max_tokens=4, + logprobs=3, + temperature=0.0, # Greedy sampling + logprobs_mode="processed_logprobs", + ) - assert len(outputs) == 1 - assert len(outputs[0].outputs[0].logprobs) > 0, ( - f"Mode {mode} should return logprobs even with greedy sampling" - ) + outputs = llm.generate(prompt, sampling_params=sampling_params) - # Check value ranges - logprob_vals = [ - logprob_obj.logprob - for token_logprobs in outputs[0].outputs[0].logprobs - for logprob_obj in token_logprobs.values() - ] + assert len(outputs) == 1 + assert len(outputs[0].outputs[0].logprobs) > 0, ( + "processed_logprobs should return logprobs even with greedy sampling" + ) - if "logprobs" in mode: - assert all(v <= 0.0 for v in logprob_vals), ( - f"Mode {mode} should have non-positive values" - ) + # Check value ranges - all should be non-positive (log probabilities) + logprob_vals = [ + logprob_obj.logprob + for token_logprobs in outputs[0].outputs[0].logprobs + for logprob_obj in token_logprobs.values() + ] - if "logits" in mode: - assert any(v > 0 for v in logprob_vals), f"Mode {mode} should have some positive values" + assert all(v <= 0.0 for v in logprob_vals), "processed_logprobs should have non-positive values" del llm @@ -175,17 +133,17 @@ def test_backward_compatibility(): prompt = ["Hello world"] - # Test with explicit raw_logprobs + # Test with explicit processed_logprobs explicit_params = SamplingParams( max_tokens=3, logprobs=2, temperature=0.8, - logprobs_mode="raw_logprobs", + logprobs_mode="processed_logprobs", seed=123, ) explicit_outputs = llm.generate(prompt, sampling_params=explicit_params) - # Test with default (should be raw_logprobs) + # Test with default (should be processed_logprobs) default_params = SamplingParams( max_tokens=3, logprobs=2, @@ -199,14 +157,13 @@ def test_backward_compatibility(): default_tokens = default_outputs[0].outputs[0].token_ids assert explicit_tokens == default_tokens, ( - "Default mode should produce same results as explicit raw_logprobs" + "Default mode should produce same results as explicit processed_logprobs" ) del llm def test_logprobs_mode_with_top_p(): - """Test that processed modes correctly capture the effect of top-p sampling.""" llm = LLM( MODEL_PATH, kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), @@ -214,28 +171,35 @@ def test_logprobs_mode_with_top_p(): prompt = ["The weather today is"] + # Test processed_logprobs mode (should have -inf for masked tokens after log_softmax) params = SamplingParams( max_tokens=2, logprobs=10, # Request many logprobs to see the effect temperature=1.0, top_p=0.5, # Restrict to top 50% probability mass - logprobs_mode="processed_logits", + logprobs_mode="processed_logprobs", ) outputs = llm.generate(prompt, sampling_params=params) - # Check that some logits are -inf (masked by top-p) + # Check that some logprobs are -inf (masked by top-p) first_token_logprobs = outputs[0].outputs[0].logprobs[0] logprob_values = [obj.logprob for obj in first_token_logprobs.values()] - print(f"logprob_values: {logprob_values}") - assert any(val == float("-inf") for val in logprob_values) + print(f"processed_logprobs values: {logprob_values}") + assert any(val == float("-inf") for val in logprob_values), ( + "processed_logprobs should have -inf values for tokens masked by top-p" + ) + # All non-inf values should be non-positive (log probabilities) + non_inf_values = [v for v in logprob_values if v != float("-inf")] + if non_inf_values: + assert all(v <= 0.0 for v in non_inf_values), ( + "processed_logprobs non-inf values should be non-positive" + ) del llm -@pytest.mark.parametrize("mode", ["raw_logprobs", "processed_logprobs"]) -def test_prompt_logprobs_with_modes(mode): - """Test that logprobs modes also work for prompt logprobs.""" +def test_prompt_logprobs_with_processed_logprobs(): llm = LLM( MODEL_PATH, kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), @@ -247,7 +211,7 @@ def test_prompt_logprobs_with_modes(mode): max_tokens=2, logprobs=3, prompt_logprobs=3, # Request prompt logprobs - logprobs_mode=mode, + logprobs_mode="processed_logprobs", temperature=0.8, ) @@ -258,21 +222,61 @@ def test_prompt_logprobs_with_modes(mode): assert prompt_logprobs is not None assert len(prompt_logprobs) > 0 - # Validate values based on mode + # Validate values - processed_logprobs should be non-positive for token_logprobs in prompt_logprobs: if token_logprobs: # Can be None for first token for logprob_obj in token_logprobs.values(): - if "logprobs" in mode: - assert logprob_obj.logprob <= 0.0, ( - f"Prompt logprobs in mode {mode} should be non-positive" - ) + assert logprob_obj.logprob <= 0.0, ( + "Prompt logprobs in processed_logprobs mode should be non-positive" + ) + + del llm + + +def test_processed_logprobs_with_top_k(): + llm = LLM( + MODEL_PATH, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + ) + + prompt = ["The future of technology"] + + # Test with small top_k to ensure filtering is applied + params = SamplingParams( + max_tokens=2, + logprobs=20, # Request more logprobs than top_k to see filtering + temperature=1.0, + top_k=5, # Only keep top 5 tokens + logprobs_mode="processed_logprobs", + ) + + outputs = llm.generate(prompt, sampling_params=params) + + # Check that we have logprobs returned + first_token_logprobs = outputs[0].outputs[0].logprobs[0] + assert len(first_token_logprobs) > 0, "Should have logprobs returned" + + # With top_k=5, we should get at most 5 non-inf logprobs (plus potentially the sampled token) + logprob_values = [obj.logprob for obj in first_token_logprobs.values()] + non_inf_count = sum(1 for v in logprob_values if v != float("-inf")) + + # Should have at most top_k + 1 (top_k + sampled token if not in top_k) + assert non_inf_count <= 6, ( + f"With top_k=5, should have at most 6 non-inf logprobs, got {non_inf_count}" + ) + + # All values should be non-positive (log probabilities) + non_inf_values = [v for v in logprob_values if v != float("-inf")] + if non_inf_values: + assert all(v <= 0.0 for v in non_inf_values), ( + "processed_logprobs values should be non-positive" + ) del llm if __name__ == "__main__": # Run a quick smoke test - print("Running smoke test for logprobs modes...") - test_logprobs_mode_basic("raw_logprobs", 0.8, None) - test_logprobs_mode_basic("processed_logprobs", 0.8, None) - print("Smoke test passed!") + print("Running test for processed_logprobs mode...") + test_logprobs_mode_basic(0.8, None) + print("logprobs mode test passed!") From 666a894c8a8a2f7a156524b1b749d605147d6f00 Mon Sep 17 00:00:00 2001 From: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> Date: Wed, 19 Nov 2025 01:19:57 -0800 Subject: [PATCH 06/11] Update pytest for processed logprobs. Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- tests/unittest/llmapi/test_llm_pytorch.py | 120 ++++++--- tests/unittest/llmapi/test_logprobs_mode.py | 282 -------------------- 2 files changed, 82 insertions(+), 320 deletions(-) delete mode 100644 tests/unittest/llmapi/test_logprobs_mode.py diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 19c300b87bd..8f3a9f7dee5 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -918,25 +918,31 @@ def test_llm_return_logprobs_streaming(prompt_logprobs, logprobs, @skip_ray -def test_llm_logprobs_modes(): +@pytest.mark.parametrize("temperature", [0.0, 0.8]) +@pytest.mark.parametrize("top_k", [None, 50]) +# temperature: 0.0 is greedy sampling +# top_k: None means all logits +def test_llm_logprobs_modes_basic(temperature, top_k): """ - Test that processed_logprobs mode works correctly in PyTorch backend. + Test processed_logprobs mode works correctly in PyTorch backend. Validates that: - processed_logprobs returns non-positive values (log probabilities) - - all values are valid logprobs """ llm = LLM( llama_model_path, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.7), ) prompts = ["The future of AI is"] sampling_params = SamplingParams( max_tokens=5, logprobs=3, - temperature=0.8, - top_k=50, + temperature=temperature, + top_k=top_k, logprobs_mode="processed_logprobs", + seed=42, + return_context_logits=True, + return_generation_logits=True, ) outputs = list(llm.generate(prompts, sampling_params)) @@ -955,20 +961,22 @@ def test_llm_logprobs_modes(): for logprob_obj in token_logprobs.values(): all_values.append(logprob_obj.logprob) - # Validate that processed_logprobs returns non-positive values + # Validate that processed_logprobs returns non-positive values (log probabilities) for val in all_values: assert val <= 0.0, f"processed_logprobs should have non-positive values, got {val}" + del llm + @skip_ray @pytest.mark.parametrize("temperature", [0.5, 1.0, 1.5]) -def test_llm_processed_logprobs_with_temperature(temperature: float): +def test_llm_processed_logprobs_with_temperature(temperature): """ Test that processed_logprobs correctly applies temperature scaling. """ llm = LLM( llama_model_path, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.7), ) prompt = ["The capital of France is"] @@ -994,24 +1002,60 @@ def test_llm_processed_logprobs_with_temperature(temperature: float): f"processed_logprobs should have non-positive values, got {logprob_obj.logprob}" ) + del llm + + +@skip_ray +def test_llm_processed_logprobs_with_greedy_sampling(): + llm = LLM( + llama_model_path, + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.7), + ) + + prompt = ["Once upon a time"] + + sampling_params = SamplingParams( + max_tokens=10, + logprobs=3, + temperature=0.0, # Greedy sampling + logprobs_mode="processed_logprobs", + ) + + outputs = llm.generate(prompt, sampling_params=sampling_params) + + assert len(outputs) == 1 + assert len(outputs[0].outputs[0].logprobs) > 0, ( + "processed_logprobs should return logprobs even with greedy sampling") + + # Check value ranges - all should be non-positive (log probabilities) + logprob_vals = [ + logprob_obj.logprob for token_logprobs in outputs[0].outputs[0].logprobs + for logprob_obj in token_logprobs.values() + ] + + assert all( + v <= 0.0 for v in + logprob_vals), "processed_logprobs should have non-positive values" + + del llm + @skip_ray def test_llm_logprobs_mode_backward_compatibility(): """ - Test that default behavior uses processed_logprobs. + Test that default behavior without specifying logprobs_mode. """ llm = LLM( llama_model_path, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.7), ) - prompt = ["Hello world"] + prompt = ["once upon a time"] # Explicit processed_logprobs explicit_params = SamplingParams( - max_tokens=3, + max_tokens=10, logprobs=2, - temperature=0.8, logprobs_mode="processed_logprobs", seed=123, ) @@ -1019,9 +1063,8 @@ def test_llm_logprobs_mode_backward_compatibility(): # Default (should be processed_logprobs) default_params = SamplingParams( - max_tokens=3, + max_tokens=10, logprobs=2, - temperature=0.8, seed=123, ) default_outputs = list(llm.generate(prompt, default_params)) @@ -1030,50 +1073,51 @@ def test_llm_logprobs_mode_backward_compatibility(): explicit_tokens = explicit_outputs[0].outputs[0].token_ids default_tokens = default_outputs[0].outputs[0].token_ids - assert explicit_tokens == default_tokens, \ - "Default should match explicit processed_logprobs" + assert explicit_tokens == default_tokens, ( + "Default should match explicit processed_logprobs") + + del llm @skip_ray -def test_llm_processed_logprobs_with_top_k_top_p(): +@pytest.mark.parametrize("top_p", [0.5, 1.0]) +def test_llm_processed_logprobs_with_top_p(top_p): """ Test that processed_logprobs correctly applies top-k and top-p filtering. This verifies the fix for processed_logprobs implementation. """ llm = LLM( llama_model_path, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), + kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.7), ) prompt = ["The future of technology"] # Test with top_k and top_p to ensure processed_logprobs applies filtering params = SamplingParams( - max_tokens=2, - logprobs=15, # Request more logprobs than top_k to see filtering + max_tokens=5, + logprobs=3, temperature=1.0, - top_k=5, # Only keep top 5 tokens - top_p=0.9, # Restrict to top 90% probability mass + top_p=top_p, logprobs_mode="processed_logprobs", + seed=42, + return_context_logits=True, + return_generation_logits=True, ) outputs = list(llm.generate(prompt, params)) assert len(outputs) == 1 - # Check that logprobs were returned - logprobs_list = outputs[0].outputs[0].logprobs - assert logprobs_list is not None - assert len(logprobs_list) > 0 - - # Check first token logprobs - first_token_logprobs = logprobs_list[0] - logprob_values = [obj.logprob for obj in first_token_logprobs.values()] - - # Should have some -inf values (masked by top-k/top-p) - assert any(val == float("-inf") for val in logprob_values), ( - "processed_logprobs should have -inf values for tokens masked by top-k/top-p" - ) - + # Check that some logprobs are -inf (masked by top-p) across all generated tokens + # Note: With top_p, not every token position will have -inf values in the top-k logprobs + # We need to check across all tokens. + all_logprobs = outputs[0].outputs[0].logprobs + for token_idx, token_logprobs in enumerate(all_logprobs): + logprob_values = [obj.logprob for obj in token_logprobs.values()] + if token_idx == 0: + print(f"First token processed_logprobs values: {logprob_values}") + if any(val == float("-inf") for val in logprob_values): + break # All non-inf values should be non-positive (log probabilities) non_inf_values = [v for v in logprob_values if v != float("-inf")] if non_inf_values: diff --git a/tests/unittest/llmapi/test_logprobs_mode.py b/tests/unittest/llmapi/test_logprobs_mode.py deleted file mode 100644 index 1036958c1da..00000000000 --- a/tests/unittest/llmapi/test_logprobs_mode.py +++ /dev/null @@ -1,282 +0,0 @@ -import pytest -from utils.llm_data import llm_models_root - -import tensorrt_llm -from tensorrt_llm import LLM -from tensorrt_llm.llmapi import KvCacheConfig -from tensorrt_llm.sampling_params import SamplingParams - -print(f"tensorrt_llm.__file__: {tensorrt_llm.__file__}") -# /home/dominicw/.local/lib/python3.12/site-packages/tensorrt_llm - -MODEL_PATH = llm_models_root() / "DeepSeek-V3-Lite/bf16" - - -@pytest.mark.parametrize("temperature", [0.0, 0.8]) -@pytest.mark.parametrize("top_k", [None, 50]) -def test_logprobs_mode_basic(temperature, top_k): - llm = LLM( - MODEL_PATH, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.7), - ) - - sampling_params = SamplingParams( - max_tokens=10, - logprobs=3, - temperature=temperature, - top_k=top_k, - logprobs_mode="processed_logprobs", - return_context_logits=True, - return_generation_logits=True, - seed=42, - ) - - prompts = ["The future of AI is"] - outputs = llm.generate(prompts, sampling_params=sampling_params) - print(f"outputs: {outputs}") - - assert len(outputs) == 1 - output = outputs[0] - assert len(output.outputs) == 1 - completion = output.outputs[0] - - # Check that logprobs were returned - assert completion.logprobs is not None - assert len(completion.logprobs) > 0 - - # Collect all logprob values - all_logprob_values = [] - for token_logprobs in completion.logprobs: - for token_id, logprob_obj in token_logprobs.items(): - all_logprob_values.append(logprob_obj.logprob) - - print(f"all_logprob_values: {all_logprob_values}") - # Validate that processed_logprobs returns non-positive values (log probabilities) - for val in all_logprob_values: - assert val <= 0.0, f"processed_logprobs mode should have non-positive values, got {val}" - - del llm - - -@pytest.mark.parametrize("temperature", [0.5, 1.0, 1.5]) -def test_processed_logprobs_with_temperature(temperature): - llm = LLM( - MODEL_PATH, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.7), - ) - - prompt = ["The capital of France is"] - - # Get processed logprobs (after temperature/top-k/top-p modifications) - params = SamplingParams( - max_tokens=3, - logprobs=5, - temperature=temperature, - top_k=20, - logprobs_mode="processed_logprobs", - seed=42, - ) - outputs = llm.generate(prompt, sampling_params=params) - - # Extract logprobs from first token - first_token_logprobs = outputs[0].outputs[0].logprobs[0] - assert len(first_token_logprobs) > 0, "Should have logprobs returned" - - # Validate that all values are non-positive (log probabilities) - for token_id, logprob_obj in first_token_logprobs.items(): - assert logprob_obj.logprob <= 0.0, ( - f"processed_logprobs should have non-positive values, got {logprob_obj.logprob}" - ) - - del llm - - -def test_logprobs_mode_with_greedy_sampling(): - llm = LLM( - MODEL_PATH, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), - ) - - prompt = ["Once upon a time"] - - sampling_params = SamplingParams( - max_tokens=4, - logprobs=3, - temperature=0.0, # Greedy sampling - logprobs_mode="processed_logprobs", - ) - - outputs = llm.generate(prompt, sampling_params=sampling_params) - - assert len(outputs) == 1 - assert len(outputs[0].outputs[0].logprobs) > 0, ( - "processed_logprobs should return logprobs even with greedy sampling" - ) - - # Check value ranges - all should be non-positive (log probabilities) - logprob_vals = [ - logprob_obj.logprob - for token_logprobs in outputs[0].outputs[0].logprobs - for logprob_obj in token_logprobs.values() - ] - - assert all(v <= 0.0 for v in logprob_vals), "processed_logprobs should have non-positive values" - - del llm - - -def test_backward_compatibility(): - llm = LLM( - MODEL_PATH, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), - ) - - prompt = ["Hello world"] - - # Test with explicit processed_logprobs - explicit_params = SamplingParams( - max_tokens=3, - logprobs=2, - temperature=0.8, - logprobs_mode="processed_logprobs", - seed=123, - ) - explicit_outputs = llm.generate(prompt, sampling_params=explicit_params) - - # Test with default (should be processed_logprobs) - default_params = SamplingParams( - max_tokens=3, - logprobs=2, - temperature=0.8, - seed=123, - ) - default_outputs = llm.generate(prompt, sampling_params=default_params) - - # Results should be identical (same sampled tokens, same logprobs) - explicit_tokens = explicit_outputs[0].outputs[0].token_ids - default_tokens = default_outputs[0].outputs[0].token_ids - - assert explicit_tokens == default_tokens, ( - "Default mode should produce same results as explicit processed_logprobs" - ) - - del llm - - -def test_logprobs_mode_with_top_p(): - llm = LLM( - MODEL_PATH, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), - ) - - prompt = ["The weather today is"] - - # Test processed_logprobs mode (should have -inf for masked tokens after log_softmax) - params = SamplingParams( - max_tokens=2, - logprobs=10, # Request many logprobs to see the effect - temperature=1.0, - top_p=0.5, # Restrict to top 50% probability mass - logprobs_mode="processed_logprobs", - ) - - outputs = llm.generate(prompt, sampling_params=params) - - # Check that some logprobs are -inf (masked by top-p) - first_token_logprobs = outputs[0].outputs[0].logprobs[0] - logprob_values = [obj.logprob for obj in first_token_logprobs.values()] - print(f"processed_logprobs values: {logprob_values}") - assert any(val == float("-inf") for val in logprob_values), ( - "processed_logprobs should have -inf values for tokens masked by top-p" - ) - # All non-inf values should be non-positive (log probabilities) - non_inf_values = [v for v in logprob_values if v != float("-inf")] - if non_inf_values: - assert all(v <= 0.0 for v in non_inf_values), ( - "processed_logprobs non-inf values should be non-positive" - ) - - del llm - - -def test_prompt_logprobs_with_processed_logprobs(): - llm = LLM( - MODEL_PATH, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), - ) - - prompt = ["Hello world, how are you?"] - - params = SamplingParams( - max_tokens=2, - logprobs=3, - prompt_logprobs=3, # Request prompt logprobs - logprobs_mode="processed_logprobs", - temperature=0.8, - ) - - outputs = llm.generate(prompt, sampling_params=params) - - # Check that prompt logprobs were returned - prompt_logprobs = outputs[0].outputs[0].prompt_logprobs - assert prompt_logprobs is not None - assert len(prompt_logprobs) > 0 - - # Validate values - processed_logprobs should be non-positive - for token_logprobs in prompt_logprobs: - if token_logprobs: # Can be None for first token - for logprob_obj in token_logprobs.values(): - assert logprob_obj.logprob <= 0.0, ( - "Prompt logprobs in processed_logprobs mode should be non-positive" - ) - - del llm - - -def test_processed_logprobs_with_top_k(): - llm = LLM( - MODEL_PATH, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.4), - ) - - prompt = ["The future of technology"] - - # Test with small top_k to ensure filtering is applied - params = SamplingParams( - max_tokens=2, - logprobs=20, # Request more logprobs than top_k to see filtering - temperature=1.0, - top_k=5, # Only keep top 5 tokens - logprobs_mode="processed_logprobs", - ) - - outputs = llm.generate(prompt, sampling_params=params) - - # Check that we have logprobs returned - first_token_logprobs = outputs[0].outputs[0].logprobs[0] - assert len(first_token_logprobs) > 0, "Should have logprobs returned" - - # With top_k=5, we should get at most 5 non-inf logprobs (plus potentially the sampled token) - logprob_values = [obj.logprob for obj in first_token_logprobs.values()] - non_inf_count = sum(1 for v in logprob_values if v != float("-inf")) - - # Should have at most top_k + 1 (top_k + sampled token if not in top_k) - assert non_inf_count <= 6, ( - f"With top_k=5, should have at most 6 non-inf logprobs, got {non_inf_count}" - ) - - # All values should be non-positive (log probabilities) - non_inf_values = [v for v in logprob_values if v != float("-inf")] - if non_inf_values: - assert all(v <= 0.0 for v in non_inf_values), ( - "processed_logprobs values should be non-positive" - ) - - del llm - - -if __name__ == "__main__": - # Run a quick smoke test - print("Running test for processed_logprobs mode...") - test_logprobs_mode_basic(0.8, None) - print("logprobs mode test passed!") From e9986fc451446529bd554363dca2fe86c3e755c3 Mon Sep 17 00:00:00 2001 From: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> Date: Wed, 19 Nov 2025 01:59:32 -0800 Subject: [PATCH 07/11] Fix and clean up. Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/sampler.py | 1 - .../_torch/pyexecutor/sampling_utils.py | 29 ++----------------- 2 files changed, 3 insertions(+), 27 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 62296df1b89..f896a21ef57 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -1849,7 +1849,6 @@ def _process_requests( current_offset = next_offset # Perform sampling in batches - batched_sampling_result = self._sample_batched_by_strategy( logits_cuda, requests, diff --git a/tensorrt_llm/_torch/pyexecutor/sampling_utils.py b/tensorrt_llm/_torch/pyexecutor/sampling_utils.py index 7b2fc2c81e4..0058278afd7 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampling_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/sampling_utils.py @@ -112,7 +112,7 @@ def top_p_sampling_batch( top_p: float, temperature: float, generator: Optional[torch.Generator] = None, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[torch.Tensor, torch.Tensor]: # NB: To be replaced by a more efficient implementation. return top_k_top_p_sampling_batch( logits, @@ -128,7 +128,7 @@ def temperature_sampling_batch( *, temperature: float, generator: Optional[torch.Generator] = None, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[torch.Tensor, torch.Tensor]: # NB: To be replaced by a more efficient implementation. return top_k_top_p_sampling_batch( logits, @@ -146,20 +146,7 @@ def top_k_top_p_sampling_batch( top_p: float, temperature: float, generator: Optional[torch.Generator] = None, -) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Perform top-k and top-p sampling. - - Args: - logits: Input logits tensor [batch_size, vocab_size] - top_k: Top-k value - top_p: Top-p (nucleus sampling) value - temperature: Temperature for sampling - generator: Optional torch random generator - - Returns: - Tuple of (sampled_tokens, softmax_probs) - """ +) -> tuple[torch.Tensor, torch.Tensor]: logits_dim = logits.dim() assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]" assert temperature > 0, "non-greedy sampling requires valid temperature" @@ -212,16 +199,6 @@ def greedy_search_sampling_batch( *, return_probs: bool = True, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: - """ - Perform greedy sampling. - - Args: - logits: Input logits tensor - return_probs: If True, return softmax probabilities - - Returns: - Tuple of (sampled_tokens, softmax_probs) - """ next_tokens = torch.argmax(logits, dim=-1) softmax: Optional[torch.Tensor] = None if return_probs: From 0b900bd2227a0977510cf1f8b3e817a43a549b02 Mon Sep 17 00:00:00 2001 From: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> Date: Thu, 20 Nov 2025 01:26:16 -0800 Subject: [PATCH 08/11] Use function sample instead of process logits and change based on review comment. Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/llm_request.py | 6 +- tensorrt_llm/_torch/pyexecutor/sampler.py | 17 +-- .../_torch/pyexecutor/sampling_utils.py | 106 ++++---------- tests/unittest/llmapi/test_llm_pytorch.py | 131 ++---------------- 4 files changed, 44 insertions(+), 216 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 1587b75ff62..b65123cd786 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -438,6 +438,8 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest): """LlmRequest wraps `bindings.internal.batch_manager.LlmRequest` but detour some features to Python implementation""" + _logprob_params = None + def __init__( self, *args, @@ -797,8 +799,8 @@ def executor_request_to_llm_request( py_multimodal_data=getattr(executor_request, "py_multimodal_data", None), kv_cache_retention_config=executor_request.kv_cache_retention_config) - if hasattr(executor_request, "_logprob_params"): - llm_request._logprob_params = executor_request._logprob_params + llm_request._logprob_params = getattr(executor_request, "_logprob_params", + None) if child_req_ids: for child_id in child_req_ids: llm_request.create_child_request(child_id) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index f896a21ef57..2f586794b86 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -68,7 +68,6 @@ Strategy, UtilsSamplingParams, get_rejected_indices, - process_logits, resolve_sampling_strategy, sample, sample_rejected, @@ -975,7 +974,7 @@ def _process_draft_tokens_rejection_sampling( else _request_strategy(request, vocab_size=2**31) ) generator = self.get_generator(request.py_draft_logits.device) - _, draft_probs = sample( + _, draft_probs, _ = sample( draft_sampling_strategy, request.py_draft_logits, generator=generator, @@ -1800,7 +1799,7 @@ def _process_requests( if logprobs_mode == "processed_logprobs": # Process logits with the same transformations as sampling (temperature, top-k, top-p) # but without actually sampling - processed_logits_list = [] + logprobs_list = [] for req_id in logprobs_req_indices: req = requests[req_id] strategy = _request_strategy(req, vocab_size=logits_cuda.size(1)) @@ -1808,13 +1807,11 @@ def _process_requests( req_logits = logits_cuda[req_logits_indices].to( dtype=torch.float32, non_blocking=True ) - # Apply the same processing as sampling would apply - processed_req_logits = process_logits(strategy, req_logits) - processed_logits_list.append(processed_req_logits) - # Concatenate all processed logits - processed_logits_cuda = torch.cat(processed_logits_list, dim=0) - # Apply log_softmax to get log probabilities - logprobs_cuda = F.log_softmax(processed_logits_cuda, dim=-1) + # Use sample() to get processed logprobs (after temperature, top-k, top-p applied) + _, _, req_logprobs = sample(strategy, req_logits, return_probs=True) + logprobs_list.append(req_logprobs) + # Concatenate all logprobs + logprobs_cuda = torch.cat(logprobs_list, dim=0) else: # For raw_logprobs and other modes, use raw logits (before sampling modifications) raw_logits_for_logprobs = raw_logits_cuda[:sum_steps] diff --git a/tensorrt_llm/_torch/pyexecutor/sampling_utils.py b/tensorrt_llm/_torch/pyexecutor/sampling_utils.py index 0058278afd7..148286c09e2 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampling_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/sampling_utils.py @@ -24,6 +24,7 @@ from typing import Generic, Literal, Optional, TypeAlias, TypeVar, cast import torch +import torch.nn.functional as F from tensorrt_llm.sampling_params import SamplingParams @@ -95,7 +96,7 @@ def top_k_sampling_batch( top_k: int, temperature: float, generator: Optional[torch.Generator] = None, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # NB: To be replaced by a more efficient implementation. return top_k_top_p_sampling_batch( logits, @@ -112,7 +113,7 @@ def top_p_sampling_batch( top_p: float, temperature: float, generator: Optional[torch.Generator] = None, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # NB: To be replaced by a more efficient implementation. return top_k_top_p_sampling_batch( logits, @@ -128,7 +129,7 @@ def temperature_sampling_batch( *, temperature: float, generator: Optional[torch.Generator] = None, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # NB: To be replaced by a more efficient implementation. return top_k_top_p_sampling_batch( logits, @@ -146,7 +147,7 @@ def top_k_top_p_sampling_batch( top_p: float, temperature: float, generator: Optional[torch.Generator] = None, -) -> tuple[torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: logits_dim = logits.dim() assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]" assert temperature > 0, "non-greedy sampling requires valid temperature" @@ -189,21 +190,26 @@ def top_k_top_p_sampling_batch( # compute probability distribution softmax = torch.softmax(logits, dim=-1) + # compute log probabilities + logprobs = F.log_softmax(logits, dim=-1) + # sample from the distribution and generate result of [batch_size, 1] next_tokens = torch.multinomial(softmax, num_samples=1, generator=generator).squeeze(-1) - return next_tokens, softmax + return next_tokens, softmax, logprobs def greedy_search_sampling_batch( logits, *, return_probs: bool = True, -) -> tuple[torch.Tensor, Optional[torch.Tensor]]: +) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: next_tokens = torch.argmax(logits, dim=-1) softmax: Optional[torch.Tensor] = None + logprobs: Optional[torch.Tensor] = None if return_probs: softmax = torch.softmax(logits, dim=-1) - return next_tokens, softmax + logprobs = F.log_softmax(logits, dim=-1) + return next_tokens, softmax, logprobs def get_rejected_indices( @@ -248,71 +254,6 @@ def sample_rejected( return cast(int, new_token.item()) -def process_logits( - strategy: Strategy, - logits: torch.Tensor, -) -> torch.Tensor: - """ - Process logits according to the specified strategy (temperature, top-k, top-p) - without sampling. Returns processed logits ready for log_softmax. - - Args: - strategy: Sampling strategy tuple (strategy_name, *params) - logits: Input logits tensor [batch_size, vocab_size] - - Returns: - Processed logits tensor [batch_size, vocab_size] - """ - logits = logits.clone() - match strategy: - case ("top_k", top_k, temperature): - logits = logits / max(temperature, 1e-5) - batch_size, vocab_size = logits.size() - if top_k < vocab_size: - values, _ = torch.topk(logits, top_k, dim=-1) - min_values = values[:, -1].unsqueeze(-1).expand(batch_size, vocab_size) - logits = torch.where( - logits < min_values, torch.full_like(logits, float("-inf")), logits - ) - case ("top_p", top_p, temperature): - logits = logits / max(temperature, 1e-5) - if top_p < 1: - sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) - cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) - sorted_indices_to_remove = cumulative_probs > top_p - sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() - sorted_indices_to_remove[:, 0] = 0 - indices_to_remove = sorted_indices_to_remove.scatter( - 1, sorted_indices, sorted_indices_to_remove - ) - logits = logits.masked_fill(indices_to_remove, float("-inf")) - case ("top_k_top_p", top_k, top_p, temperature): - logits = logits / max(temperature, 1e-5) - batch_size, vocab_size = logits.size() - if top_k < vocab_size: - values, _ = torch.topk(logits, top_k, dim=-1) - min_values = values[:, -1].unsqueeze(-1).expand(batch_size, vocab_size) - logits = torch.where( - logits < min_values, torch.full_like(logits, float("-inf")), logits - ) - if top_p < 1: - sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) - cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) - sorted_indices_to_remove = cumulative_probs > top_p - sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() - sorted_indices_to_remove[:, 0] = 0 - indices_to_remove = sorted_indices_to_remove.scatter( - 1, sorted_indices, sorted_indices_to_remove - ) - logits = logits.masked_fill(indices_to_remove, float("-inf")) - case ("temperature", temperature): - logits = logits / max(temperature, 1e-5) - case ("greedy", None): - # No processing needed for greedy - pass - return logits - - def sample( strategy: Strategy, logits: torch.Tensor, @@ -327,28 +268,28 @@ def sample( strategy: Sampling strategy tuple (strategy_name, *params) logits: Input logits tensor generator: Optional random generator - return_probs: If True, return softmax probabilities + return_probs: If True, return softmax probabilities and log probabilities Returns: - Tuple of (sampled_tokens, softmax_probs) + Tuple of (sampled_tokens, softmax_probs, logprobs) """ match strategy: case ("top_k", top_k, temperature): - tokens, softmax = top_k_sampling_batch( + tokens, softmax, logprobs = top_k_sampling_batch( logits, top_k=top_k, temperature=temperature, generator=generator, ) case ("top_p", top_p, temperature): - tokens, softmax = top_p_sampling_batch( + tokens, softmax, logprobs = top_p_sampling_batch( logits, top_p=top_p, generator=generator, temperature=temperature, ) case ("top_k_top_p", top_k, top_p, temperature): - tokens, softmax = top_k_top_p_sampling_batch( + tokens, softmax, logprobs = top_k_top_p_sampling_batch( logits, top_k=top_k, top_p=top_p, @@ -356,14 +297,16 @@ def sample( generator=generator, ) case ("temperature", temperature): - tokens, softmax = temperature_sampling_batch( + tokens, softmax, logprobs = temperature_sampling_batch( logits, temperature=temperature, generator=generator, ) case ("greedy", None): - tokens, softmax = greedy_search_sampling_batch(logits, return_probs=return_probs) - return tokens, softmax + tokens, softmax, logprobs = greedy_search_sampling_batch( + logits, return_probs=return_probs + ) + return tokens, softmax, logprobs GenericStrategyKeyType = TypeVar("GenericStrategyKeyType") @@ -415,12 +358,13 @@ def sample_grouped_strategies( assert all(strategy == group_key for strategy in strategies), "group must be consistent" - return sample( + tokens, probs, _ = sample( group_key, logits, generator=generator, return_probs=return_probs, ) + return tokens, probs class _AcceptSyncCompute: diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index 8f3a9f7dee5..df0d4c7b184 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -918,11 +918,13 @@ def test_llm_return_logprobs_streaming(prompt_logprobs, logprobs, @skip_ray -@pytest.mark.parametrize("temperature", [0.0, 0.8]) -@pytest.mark.parametrize("top_k", [None, 50]) -# temperature: 0.0 is greedy sampling -# top_k: None means all logits -def test_llm_logprobs_modes_basic(temperature, top_k): +@pytest.mark.parametrize("temperature", [None, 0.8, 1.0]) +@pytest.mark.parametrize("top_k", [None, 10, 0]) +@pytest.mark.parametrize("top_p", [None, 0.5, 1.0]) +# temperature: 0.0 is greedy sampling and will be covered by below test +# top_k: 0 means all logits +# top_p: 1 means no top-p filtering +def test_llm_logprobs_modes_basic(temperature, top_k, top_p): """ Test processed_logprobs mode works correctly in PyTorch backend. Validates that: @@ -939,6 +941,7 @@ def test_llm_logprobs_modes_basic(temperature, top_k): logprobs=3, temperature=temperature, top_k=top_k, + top_p=top_p, logprobs_mode="processed_logprobs", seed=42, return_context_logits=True, @@ -968,78 +971,6 @@ def test_llm_logprobs_modes_basic(temperature, top_k): del llm -@skip_ray -@pytest.mark.parametrize("temperature", [0.5, 1.0, 1.5]) -def test_llm_processed_logprobs_with_temperature(temperature): - """ - Test that processed_logprobs correctly applies temperature scaling. - """ - llm = LLM( - llama_model_path, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.7), - ) - - prompt = ["The capital of France is"] - - # Get processed logprobs (after temperature/top-k/top-p modifications) - params = SamplingParams( - max_tokens=3, - logprobs=5, - temperature=temperature, - top_k=20, - logprobs_mode="processed_logprobs", - seed=42, - ) - outputs = list(llm.generate(prompt, params)) - - # Check first token logprobs - first_token_logprobs = outputs[0].outputs[0].logprobs[0] - assert len(first_token_logprobs) > 0, "Should have logprobs returned" - - # Validate that all values are non-positive (log probabilities) - for token_id, logprob_obj in first_token_logprobs.items(): - assert logprob_obj.logprob <= 0.0, ( - f"processed_logprobs should have non-positive values, got {logprob_obj.logprob}" - ) - - del llm - - -@skip_ray -def test_llm_processed_logprobs_with_greedy_sampling(): - llm = LLM( - llama_model_path, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.7), - ) - - prompt = ["Once upon a time"] - - sampling_params = SamplingParams( - max_tokens=10, - logprobs=3, - temperature=0.0, # Greedy sampling - logprobs_mode="processed_logprobs", - ) - - outputs = llm.generate(prompt, sampling_params=sampling_params) - - assert len(outputs) == 1 - assert len(outputs[0].outputs[0].logprobs) > 0, ( - "processed_logprobs should return logprobs even with greedy sampling") - - # Check value ranges - all should be non-positive (log probabilities) - logprob_vals = [ - logprob_obj.logprob for token_logprobs in outputs[0].outputs[0].logprobs - for logprob_obj in token_logprobs.values() - ] - - assert all( - v <= 0.0 for v in - logprob_vals), "processed_logprobs should have non-positive values" - - del llm - - @skip_ray def test_llm_logprobs_mode_backward_compatibility(): """ @@ -1079,52 +1010,6 @@ def test_llm_logprobs_mode_backward_compatibility(): del llm -@skip_ray -@pytest.mark.parametrize("top_p", [0.5, 1.0]) -def test_llm_processed_logprobs_with_top_p(top_p): - """ - Test that processed_logprobs correctly applies top-k and top-p filtering. - This verifies the fix for processed_logprobs implementation. - """ - llm = LLM( - llama_model_path, - kv_cache_config=KvCacheConfig(free_gpu_memory_fraction=0.7), - ) - - prompt = ["The future of technology"] - - # Test with top_k and top_p to ensure processed_logprobs applies filtering - params = SamplingParams( - max_tokens=5, - logprobs=3, - temperature=1.0, - top_p=top_p, - logprobs_mode="processed_logprobs", - seed=42, - return_context_logits=True, - return_generation_logits=True, - ) - - outputs = list(llm.generate(prompt, params)) - assert len(outputs) == 1 - - # Check that some logprobs are -inf (masked by top-p) across all generated tokens - # Note: With top_p, not every token position will have -inf values in the top-k logprobs - # We need to check across all tokens. - all_logprobs = outputs[0].outputs[0].logprobs - for token_idx, token_logprobs in enumerate(all_logprobs): - logprob_values = [obj.logprob for obj in token_logprobs.values()] - if token_idx == 0: - print(f"First token processed_logprobs values: {logprob_values}") - if any(val == float("-inf") for val in logprob_values): - break - # All non-inf values should be non-positive (log probabilities) - non_inf_values = [v for v in logprob_values if v != float("-inf")] - if non_inf_values: - assert all(v <= 0.0 for v in non_inf_values), ( - "processed_logprobs non-inf values should be non-positive") - - class TestLlmError: def test_max_num_token_check(self): From defa0a34350e61efe98d47db72b466925b42af28 Mon Sep 17 00:00:00 2001 From: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> Date: Thu, 20 Nov 2025 18:24:24 -0800 Subject: [PATCH 09/11] Fix auto deploy demo logprobs return issue. Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- tensorrt_llm/_torch/auto_deploy/shim/demollm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/shim/demollm.py b/tensorrt_llm/_torch/auto_deploy/shim/demollm.py index d0b93c2bd19..0f49bcb9aef 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/demollm.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/demollm.py @@ -235,11 +235,11 @@ def _sample( logits_shape = logits.shape logits = logits.view(-1, logits_shape[-1]) # sampling_batch expects 2D logits if isinstance(sampling_params.top_k, int) and sampling_params.top_k > 1: - idx_next, probs = top_k_sampling_batch( + idx_next, probs, _ = top_k_sampling_batch( logits, top_k=sampling_params.top_k, temperature=1.0 ) else: - idx_next, probs = greedy_search_sampling_batch(logits) + idx_next, probs, _ = greedy_search_sampling_batch(logits) idx_next = idx_next.view(logits_shape[:-1]) return idx_next, probs From d0cbc3c4e8dd94dea11c119a2261dda29d1d7057 Mon Sep 17 00:00:00 2001 From: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> Date: Fri, 21 Nov 2025 22:18:12 -0800 Subject: [PATCH 10/11] Fix greedy search sampling batch return value mismatch in flashinfer sampling. Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py b/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py index f8ce56a1672..734d320e079 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py +++ b/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py @@ -137,7 +137,7 @@ def _sample_greedy_with_probs( group_logit_indices: Optional[torch.Tensor], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: probs = self._prepare_probs_with_temperature(logits, group_logit_indices, None) - new_tokens, _ = greedy_search_sampling_batch(probs, return_probs=False) + new_tokens, _, _ = greedy_search_sampling_batch(probs, return_probs=False) return new_tokens, probs @classmethod @@ -370,7 +370,8 @@ def sample( ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: if group_logit_indices is not None: logits = logits[group_logit_indices] - return greedy_search_sampling_batch(logits, return_probs=False) + tokens, probs, _ = greedy_search_sampling_batch(logits, return_probs=False) + return tokens, probs class TopKTopPSampleOnly(StrategyImplSampleOnly): def __init__(self, top_k: torch.Tensor, top_p: torch.Tensor, temperature: torch.Tensor): From 7a5fe6461dbb1f2c81757d17d7dbf714eecaaef8 Mon Sep 17 00:00:00 2001 From: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> Date: Sat, 22 Nov 2025 04:59:11 -0800 Subject: [PATCH 11/11] Add logprobs mode in api stability test. Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com> --- tensorrt_llm/sampling_params.py | 2 +- tests/unittest/api_stability/references/sampling_params.yaml | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/sampling_params.py b/tensorrt_llm/sampling_params.py index bfaab6c5500..9ba52e464f5 100644 --- a/tensorrt_llm/sampling_params.py +++ b/tensorrt_llm/sampling_params.py @@ -181,7 +181,7 @@ class SamplingParams: logprobs (int, optional): Number of log probabilities to return per output token. Defaults to None. prompt_logprobs (int, optional): Number of log probabilities to return per prompt token. Defaults to None. - logprobs_mode (str): Controls return logprobs after sampling modifications. Defaults to "processed_logprobs". + logprobs_mode (Literal['processed_logprobs']): Controls return logprobs after sampling modifications. Defaults to "processed_logprobs". Options: - "processed_logprobs": Return log-softmax of processed logits return_context_logits (bool): Controls if Result should contain the context logits. Defaults to False. diff --git a/tests/unittest/api_stability/references/sampling_params.yaml b/tests/unittest/api_stability/references/sampling_params.yaml index d6b3e6156e3..948aee0b654 100644 --- a/tests/unittest/api_stability/references/sampling_params.yaml +++ b/tests/unittest/api_stability/references/sampling_params.yaml @@ -15,5 +15,8 @@ methods: prompt_ignore_length: annotation: Optional[int] default: null + logprobs_mode: + annotation: Literal['processed_logprobs'] + default: processed_logprobs return_annotation: None properties: {}