Skip to content

Commit 30e7b79

Browse files
Fix and clean up.
Signed-off-by: Wangshanshan <[email protected]>
1 parent 70247fb commit 30e7b79

File tree

2 files changed

+3
-27
lines changed

2 files changed

+3
-27
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1849,7 +1849,6 @@ def _process_requests(
18491849
current_offset = next_offset
18501850

18511851
# Perform sampling in batches
1852-
18531852
batched_sampling_result = self._sample_batched_by_strategy(
18541853
logits_cuda,
18551854
requests,

tensorrt_llm/_torch/pyexecutor/sampling_utils.py

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def top_p_sampling_batch(
112112
top_p: float,
113113
temperature: float,
114114
generator: Optional[torch.Generator] = None,
115-
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
115+
) -> tuple[torch.Tensor, torch.Tensor]:
116116
# NB: To be replaced by a more efficient implementation.
117117
return top_k_top_p_sampling_batch(
118118
logits,
@@ -128,7 +128,7 @@ def temperature_sampling_batch(
128128
*,
129129
temperature: float,
130130
generator: Optional[torch.Generator] = None,
131-
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
131+
) -> tuple[torch.Tensor, torch.Tensor]:
132132
# NB: To be replaced by a more efficient implementation.
133133
return top_k_top_p_sampling_batch(
134134
logits,
@@ -146,20 +146,7 @@ def top_k_top_p_sampling_batch(
146146
top_p: float,
147147
temperature: float,
148148
generator: Optional[torch.Generator] = None,
149-
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
150-
"""
151-
Perform top-k and top-p sampling.
152-
153-
Args:
154-
logits: Input logits tensor [batch_size, vocab_size]
155-
top_k: Top-k value
156-
top_p: Top-p (nucleus sampling) value
157-
temperature: Temperature for sampling
158-
generator: Optional torch random generator
159-
160-
Returns:
161-
Tuple of (sampled_tokens, softmax_probs)
162-
"""
149+
) -> tuple[torch.Tensor, torch.Tensor]:
163150
logits_dim = logits.dim()
164151
assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]"
165152
assert temperature > 0, "non-greedy sampling requires valid temperature"
@@ -212,16 +199,6 @@ def greedy_search_sampling_batch(
212199
*,
213200
return_probs: bool = True,
214201
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
215-
"""
216-
Perform greedy sampling.
217-
218-
Args:
219-
logits: Input logits tensor
220-
return_probs: If True, return softmax probabilities
221-
222-
Returns:
223-
Tuple of (sampled_tokens, softmax_probs)
224-
"""
225202
next_tokens = torch.argmax(logits, dim=-1)
226203
softmax: Optional[torch.Tensor] = None
227204
if return_probs:

0 commit comments

Comments
 (0)