Skip to content
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/shim/demollm.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,11 @@ def _sample(
logits_shape = logits.shape
logits = logits.view(-1, logits_shape[-1]) # sampling_batch expects 2D logits
if isinstance(sampling_params.top_k, int) and sampling_params.top_k > 1:
idx_next, probs = top_k_sampling_batch(
idx_next, probs, _ = top_k_sampling_batch(
logits, top_k=sampling_params.top_k, temperature=1.0
)
else:
idx_next, probs = greedy_search_sampling_batch(logits)
idx_next, probs, _ = greedy_search_sampling_batch(logits)
idx_next = idx_next.view(logits_shape[:-1])
return idx_next, probs

Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,8 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
"""LlmRequest wraps `bindings.internal.batch_manager.LlmRequest`
but detour some features to Python implementation"""

_logprob_params = None

def __init__(
self,
*args,
Expand Down Expand Up @@ -797,6 +799,8 @@ def executor_request_to_llm_request(
py_multimodal_data=getattr(executor_request, "py_multimodal_data",
None),
kv_cache_retention_config=executor_request.kv_cache_retention_config)
llm_request._logprob_params = getattr(executor_request, "_logprob_params",
None)
if child_req_ids:
for child_id in child_req_ids:
llm_request.create_child_request(child_id)
Expand Down
43 changes: 38 additions & 5 deletions tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,7 @@ def _process_draft_tokens_rejection_sampling(
else _request_strategy(request, vocab_size=2**31)
)
generator = self.get_generator(request.py_draft_logits.device)
_, draft_probs = sample(
_, draft_probs, _ = sample(
draft_sampling_strategy,
request.py_draft_logits,
generator=generator,
Expand Down Expand Up @@ -1776,7 +1776,16 @@ def _process_requests(
# Handle top-k logprobs. This is done outside the sampling loop,
# because the returned logprobs are specified to not reflect temperature scaling,
# top-k/top-p masking, etc.

if return_log_probs:
logprobs_mode = None
for req in requests:
if req.py_num_logprobs:
logprob_params = getattr(req, "_logprob_params", None)
if logprob_params and hasattr(logprob_params, "logprobs_mode"):
logprobs_mode = logprob_params.logprobs_mode
break

assert logits_cuda.dim() == 2, "logits should be 2D"
logprobs_req_indices = [
req_id for req_id, req in enumerate(requests) if req.py_num_logprobs
Expand All @@ -1785,10 +1794,34 @@ def _process_requests(
logprobs_logit_indices_cuda = logprobs_logit_indices.to(
device=logits_cuda.device, non_blocking=True
)
logprobs_cuda = F.log_softmax(
logits_cuda[logprobs_logit_indices_cuda].to(dtype=torch.float32, non_blocking=True),
dim=-1,
)

# Compute logprobs based on mode
if logprobs_mode == "processed_logprobs":
# Process logits with the same transformations as sampling (temperature, top-k, top-p)
# but without actually sampling
logprobs_list = []
for req_id in logprobs_req_indices:
req = requests[req_id]
strategy = _request_strategy(req, vocab_size=logits_cuda.size(1))
req_logits_indices = logits_cuda_indexer[req_id]
req_logits = logits_cuda[req_logits_indices].to(
dtype=torch.float32, non_blocking=True
)
# Use sample() to get processed logprobs (after temperature, top-k, top-p applied)
_, _, req_logprobs = sample(strategy, req_logits, return_probs=True)
logprobs_list.append(req_logprobs)
# Concatenate all logprobs
logprobs_cuda = torch.cat(logprobs_list, dim=0)
else:
# For raw_logprobs and other modes, use raw logits (before sampling modifications)
raw_logits_for_logprobs = raw_logits_cuda[:sum_steps]
logprobs_cuda = F.log_softmax(
raw_logits_for_logprobs[logprobs_logit_indices_cuda].to(
dtype=torch.float32, non_blocking=True
),
dim=-1,
)

topk_vals_cuda, topk_indices_cuda = torch.topk(
logprobs_cuda, k=max(req.py_num_logprobs for req in requests), dim=-1
)
Expand Down
51 changes: 36 additions & 15 deletions tensorrt_llm/_torch/pyexecutor/sampling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing import Generic, Literal, Optional, TypeAlias, TypeVar, cast

import torch
import torch.nn.functional as F

from tensorrt_llm.sampling_params import SamplingParams

Expand Down Expand Up @@ -95,7 +96,7 @@ def top_k_sampling_batch(
top_k: int,
temperature: float,
generator: Optional[torch.Generator] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# NB: To be replaced by a more efficient implementation.
return top_k_top_p_sampling_batch(
logits,
Expand All @@ -112,7 +113,7 @@ def top_p_sampling_batch(
top_p: float,
temperature: float,
generator: Optional[torch.Generator] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# NB: To be replaced by a more efficient implementation.
return top_k_top_p_sampling_batch(
logits,
Expand All @@ -128,7 +129,7 @@ def temperature_sampling_batch(
*,
temperature: float,
generator: Optional[torch.Generator] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# NB: To be replaced by a more efficient implementation.
return top_k_top_p_sampling_batch(
logits,
Expand All @@ -146,7 +147,7 @@ def top_k_top_p_sampling_batch(
top_p: float,
temperature: float,
generator: Optional[torch.Generator] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
logits_dim = logits.dim()
assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]"
assert temperature > 0, "non-greedy sampling requires valid temperature"
Expand Down Expand Up @@ -189,21 +190,26 @@ def top_k_top_p_sampling_batch(
# compute probability distribution
softmax = torch.softmax(logits, dim=-1)

# compute log probabilities
logprobs = F.log_softmax(logits, dim=-1)

# sample from the distribution and generate result of [batch_size, 1]
next_tokens = torch.multinomial(softmax, num_samples=1, generator=generator).squeeze(-1)
return next_tokens, softmax
return next_tokens, softmax, logprobs


def greedy_search_sampling_batch(
logits,
*,
return_probs: bool = True,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
next_tokens = torch.argmax(logits, dim=-1)
softmax: Optional[torch.Tensor] = None
logprobs: Optional[torch.Tensor] = None
if return_probs:
softmax = torch.softmax(logits, dim=-1)
return next_tokens, softmax
logprobs = F.log_softmax(logits, dim=-1)
return next_tokens, softmax, logprobs


def get_rejected_indices(
Expand Down Expand Up @@ -254,39 +260,53 @@ def sample(
*,
generator: Optional[torch.Generator] = None,
return_probs: bool = True,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Sample from logits using the specified strategy.

Args:
strategy: Sampling strategy tuple (strategy_name, *params)
logits: Input logits tensor
generator: Optional random generator
return_probs: If True, return softmax probabilities and log probabilities

Returns:
Tuple of (sampled_tokens, softmax_probs, logprobs)
"""
match strategy:
case ("top_k", top_k, temperature):
tokens, softmax = top_k_sampling_batch(
tokens, softmax, logprobs = top_k_sampling_batch(
logits,
top_k=top_k,
temperature=temperature,
generator=generator,
)
case ("top_p", top_p, temperature):
tokens, softmax = top_p_sampling_batch(
tokens, softmax, logprobs = top_p_sampling_batch(
logits,
top_p=top_p,
generator=generator,
temperature=temperature,
)
case ("top_k_top_p", top_k, top_p, temperature):
tokens, softmax = top_k_top_p_sampling_batch(
tokens, softmax, logprobs = top_k_top_p_sampling_batch(
logits,
top_k=top_k,
top_p=top_p,
temperature=temperature,
generator=generator,
)
case ("temperature", temperature):
tokens, softmax = temperature_sampling_batch(
tokens, softmax, logprobs = temperature_sampling_batch(
logits,
temperature=temperature,
generator=generator,
)
case ("greedy", None):
tokens, softmax = greedy_search_sampling_batch(logits, return_probs=return_probs)
return tokens, softmax
tokens, softmax, logprobs = greedy_search_sampling_batch(
logits, return_probs=return_probs
)
return tokens, softmax, logprobs


GenericStrategyKeyType = TypeVar("GenericStrategyKeyType")
Expand Down Expand Up @@ -338,12 +358,13 @@ def sample_grouped_strategies(

assert all(strategy == group_key for strategy in strategies), "group must be consistent"

return sample(
tokens, probs, _ = sample(
group_key,
logits,
generator=generator,
return_probs=return_probs,
)
return tokens, probs


class _AcceptSyncCompute:
Expand Down
5 changes: 5 additions & 0 deletions tensorrt_llm/executor/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,11 @@ def _deduce_max_tokens(request: GenerationRequest,
if self._is_pytorch_backend and request.scheduling_params is not None:
executor_request.py_scheduling_params = request.scheduling_params

if self._is_pytorch_backend:
logprob_params = self._get_logprob_params(request)
if logprob_params is not None:
executor_request._logprob_params = logprob_params

if request.arrival_time is not None:
executor_request.py_arrival_time = request.arrival_time

Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/executor/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,8 @@ def _get_logprob_params(
or self.postproc_config.num_postprocess_workers > 0,
drop_generation_logits=(
not request.sampling_params._need_return_generation_logits)
or self.postproc_config.num_postprocess_workers > 0)
or self.postproc_config.num_postprocess_workers > 0,
logprobs_mode=request.sampling_params.logprobs_mode)

return logprob_params

Expand Down
8 changes: 8 additions & 0 deletions tensorrt_llm/executor/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,13 @@ def compute_logprobs(
- Generation logprobs (from generation_logits, TRT backend): used when backend doesn't compute them in sampler (e.g., TRT).
- Generation logprobs (PyTorch backend): not used; computed in sampler, not here.

Args:
k_prompt_logprobs: Number of top logprobs to return for prompt tokens
k_logprobs: Number of top logprobs to return for generated tokens
context_logits: Logits for context/prompt tokens
generation_logits: Logits for generated tokens
output_token_ids: Token IDs of generated outputs

Returns:
LogProbsResult, a NamedTuple containing:
- prompt: Optional[List[Dict[token_id, Logprob]]] logprobs for prompt tokens.
Expand All @@ -1034,6 +1041,7 @@ def _topk_logprobs(logits: torch.Tensor, top_k: int,
logits = logits[:len(tokens)]

logprobs = F.log_softmax(logits.to("cuda", dtype=torch.float32), dim=-1)

topk_vals, topk_indices = torch.topk(logprobs, k=top_k, dim=-1)

results: TokenLogprobs = []
Expand Down
15 changes: 14 additions & 1 deletion tensorrt_llm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, fields
from typing import List, NamedTuple, Optional, Tuple, Union
from typing import List, Literal, NamedTuple, Optional, Tuple, Union

import torch
from pydantic import BaseModel

from tensorrt_llm.bindings import executor as tllme
from tensorrt_llm.logger import logger

# Logprobs mode:
# - "processed_logprobs": return log-softmax of greedy sampled logits
# TODO: add "return_raw_context_logits" and "return_raw_generation_logits" later
LogprobsMode = Literal["processed_logprobs"]


@dataclass(slots=True, kw_only=True)
class GuidedDecodingParams:
Expand Down Expand Up @@ -44,6 +49,8 @@ class LogprobParams(NamedTuple):
drop_context_logits: bool = False
# Drop the geneation_logits once the logprobs are computed
drop_generation_logits: bool = False
# Logprobs mode: controls whether to return logprobs before or after sampling modifications
logprobs_mode: LogprobsMode = "processed_logprobs"


class LogitsProcessor(ABC):
Expand Down Expand Up @@ -174,6 +181,9 @@ class SamplingParams:

logprobs (int, optional): Number of log probabilities to return per output token. Defaults to None.
prompt_logprobs (int, optional): Number of log probabilities to return per prompt token. Defaults to None.
logprobs_mode (str): Controls return logprobs after sampling modifications. Defaults to "processed_logprobs".
Options:
- "processed_logprobs": Return log-softmax of processed logits
return_context_logits (bool): Controls if Result should contain the context logits. Defaults to False.
return_generation_logits (bool): Controls if Result should contain the generation logits. Defaults to False.
exclude_input_from_output (bool): Controls if output tokens in Result should include the input tokens. Defaults to True.
Expand Down Expand Up @@ -250,6 +260,9 @@ class SamplingParams:
return_perf_metrics: bool = False
additional_model_outputs: Optional[List[str]] = None

# Logprobs mode: controls whether to return logprobs before or after sampling modifications
logprobs_mode: LogprobsMode = "processed_logprobs"

# Used in logprobs calculation in TRT flow to drop logits early if user did not explicitly request them.
# Can be deprecated after migration to PyTorch backend.
_context_logits_auto_enabled: bool = False
Expand Down
Loading