diff --git a/pyproject.toml b/pyproject.toml index 3416e3809cd..6e281577e56 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ extend_skip_glob = [ "tests/unittest/_torch/modeling/test_modeling_mistral.py", "tests/unittest/_torch/modeling/test_modeling_pixtral.py", "tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py", + "tests/unittest/_torch/sampler/test_torch_sampler.py", ] [tool.yapf] @@ -65,6 +66,7 @@ ignore_patterns = [ "tests/unittest/_torch/modeling/test_modeling_mistral.py", "tests/unittest/_torch/modeling/test_modeling_pixtral.py", "tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py", + "tests/unittest/_torch/sampler/test_torch_sampler.py", ] [tool.codespell] @@ -144,6 +146,7 @@ include = [ "tests/unittest/_torch/modeling/test_modeling_mistral.py", "tests/unittest/_torch/modeling/test_modeling_pixtral.py", "tests/unittest/_torch/models/checkpoints/hf/test_weight_loader.py", + "tests/unittest/_torch/sampler/test_torch_sampler.py", ] exclude = [ "**3rdparty/**", diff --git a/tensorrt_llm/_torch/auto_deploy/shim/demollm.py b/tensorrt_llm/_torch/auto_deploy/shim/demollm.py index f7c89773c5a..3799628268c 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/demollm.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/demollm.py @@ -234,8 +234,10 @@ def _sample( ) -> Tuple[torch.Tensor, torch.Tensor]: logits_shape = logits.shape logits = logits.view(-1, logits_shape[-1]) # sampling_batch expects 2D logits - if isinstance(sampling_params.top_k, int): - idx_next, probs = top_k_sampling_batch(logits, sampling_params.top_k) + if isinstance(sampling_params.top_k, int) and sampling_params.top_k > 1: + idx_next, probs = top_k_sampling_batch( + logits, top_k=sampling_params.top_k, temperature=1.0 + ) else: idx_next, probs = greedy_search_sampling_batch(logits) idx_next = idx_next.view(logits_shape[:-1]) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index f50f789ca75..fdd95760d30 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -6,7 +6,7 @@ from collections.abc import Iterable from dataclasses import dataclass from itertools import repeat -from typing import Any, List, Literal, Optional, cast +from typing import Any, List, Literal, Optional, TypeVar, cast import torch import torch.nn.functional as F @@ -26,6 +26,7 @@ GptDecoderBatched) from tensorrt_llm.executor.result import Logprob from tensorrt_llm.mapping import Mapping +from tensorrt_llm.sampling_params import SamplingParams from ..speculative.spec_tree_manager import SpecTreeManager from .finish_reason import FinishedState @@ -195,84 +196,75 @@ def is_generation_model(self) -> bool: def top_k_sampling_batch( logits, - top_k=50, - generator: Optional[torch.Generator] = None + *, + top_k: int, + temperature: float, + generator: Optional[torch.Generator] = None, ) -> tuple[torch.Tensor, torch.Tensor]: - logits_dim = logits.dim() - if logits_dim == 1: - logits = logits.unsqueeze(0) - # logits should be 2D :[batch_size, vocab_size] - batch_size, vocab_size = logits.size() + # NB: To be replaced by a more efficient implementation. + return top_k_top_p_sampling_batch( + logits, + top_k=top_k, + temperature=temperature, + generator=generator, + top_p=1, + ) - # get first top_k logits of each sample and their indices - if top_k > 0: - values, indices = torch.topk(logits, top_k, dim=-1) - min_values = values[:, -1].unsqueeze(-1).expand(batch_size, vocab_size) - # set the logits who is less than first top_k logits to -inf - logits = torch.where(logits < min_values, - torch.full_like(logits, float('-inf')), logits) +def top_p_sampling_batch( + logits: torch.Tensor, + *, + top_p: float, + temperature: float, + generator: Optional[torch.Generator] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + # NB: To be replaced by a more efficient implementation. + return top_k_top_p_sampling_batch( + logits, + top_p=top_p, + top_k=logits.size(1), + temperature=temperature, + generator=generator, + ) - # compute probability distribution - softmax = torch.softmax(logits, dim=-1) - # sample from the distribution and generate result of [batch_size, 1] - next_tokens = torch.multinomial(softmax, num_samples=1, - generator=generator).squeeze(-1) - return next_tokens, softmax +def temperature_sampling_batch( + logits: torch.Tensor, + *, + temperature: float, + generator: Optional[torch.Generator] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + # NB: To be replaced by a more efficient implementation. + return top_k_top_p_sampling_batch( + logits, + top_p=1, + top_k=logits.size(1), + temperature=temperature, + generator=generator, + ) -def top_p_sampling_batch( +def top_k_top_p_sampling_batch( logits: torch.Tensor, *, - top_p: float = 0.9, - temperature: float = 1.0, + top_k: int, + top_p: float, + temperature: float, generator: Optional[torch.Generator] = None ) -> tuple[torch.Tensor, torch.Tensor]: logits_dim = logits.dim() assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]" + assert temperature > 0, "non-greedy sampling requires valid temperature" + logits = logits / max(temperature, 1e-5) + batch_size, vocab_size = logits.size() - if temperature != 0: - logits = logits / max(temperature, 1e-5) - - # sort the logits of each sample in descending order - sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) - - # compute cumulative probability distribution of each sample - cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), - dim=-1) - # get the location of top_p - sorted_indices_to_remove = cumulative_probs > top_p - sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() - sorted_indices_to_remove[:, 0] = 0 - - # set the logits to -inf whose is outside top_p - indices_to_remove = sorted_indices_to_remove.scatter( - 1, sorted_indices, sorted_indices_to_remove) - logits = logits.masked_fill(indices_to_remove, float('-inf')) - - # compute probability distribution - softmax = torch.softmax(logits, dim=-1) - - # sample from the distribution and generate result of [batch_size, 1] - next_tokens = torch.multinomial(softmax, num_samples=1, - generator=generator).squeeze(-1) - return next_tokens, softmax - + assert top_k > 1, "non-greedy sampling requires valid top_k" + need_top_k = top_k < vocab_size + assert top_p > 0, "non-greedy sampling requires valid top_p" + need_top_p = top_p < 1 -def top_k_top_p_sampling_batch(logits: torch.Tensor, - *, - top_k: int, - top_p: float, - temperature: float = 1.0, - generator: Optional[torch.Generator] = None): - logits_dim = logits.dim() - assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]" - if temperature != 0: - logits = logits / max(temperature, 1e-5) - batch_size, vocab_size = logits.size() - # get first top_k logits of each sample and their indices - if top_k > 0: + # top-K: mask out logits not belonging to the top-K for each sample + if need_top_k: values, _ = torch.topk(logits, top_k, dim=-1) min_values = values[:, -1].unsqueeze(-1).expand(batch_size, vocab_size) @@ -280,21 +272,28 @@ def top_k_top_p_sampling_batch(logits: torch.Tensor, logits = torch.where(logits < min_values, torch.full_like(logits, float('-inf')), logits) - sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) - - # compute cumulative probability distribution of each sample - cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), - dim=-1) - - # get the location of top_p - sorted_indices_to_remove = cumulative_probs > top_p - sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() - sorted_indices_to_remove[:, 0] = 0 - - # set the logits to -inf whose is outside top_p - indices_to_remove = sorted_indices_to_remove.scatter( - 1, sorted_indices, sorted_indices_to_remove) - logits = logits.masked_fill(indices_to_remove, float('-inf')) + # top-p: mask out logits outside the nucleus + if need_top_p: + sorted_logits, sorted_indices = torch.sort(logits, + descending=True, + dim=-1) + + # compute cumulative probability distribution of each sample + cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), + dim=-1) + + # get the location of top_p + # NB: Currently selecting the smallest index with cumulative_probs > top_p. + # Thus, top_p -> 0 resembles greedy; agreement requires torch.sort(..., stable=True). + sorted_indices_to_remove = cumulative_probs > top_p + sorted_indices_to_remove[:, + 1:] = sorted_indices_to_remove[:, :-1].clone() + sorted_indices_to_remove[:, 0] = 0 + + # set the logits to -inf for token indices outside top_p + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove) + logits = logits.masked_fill(indices_to_remove, float('-inf')) # compute probability distribution softmax = torch.softmax(logits, dim=-1) @@ -359,48 +358,100 @@ def sample_rejected(draft_probs: torch.Tensor, target_probs: torch.Tensor, return new_token -TopK = tuple[Literal["top_k"], int] +TemperatureOnly = tuple[Literal["temperature"], float] +TopK = tuple[Literal["top_k"], int, float] TopP = tuple[Literal["top_p"], float, float] TopKTopP = tuple[Literal["top_k_top_p"], int, float, float] Greedy = tuple[Literal["greedy"], None] GREEDY: Greedy = ("greedy", None) -Strategy = TopK | TopP | Greedy | TopKTopP - - -def _request_strategy(request: LlmRequest) -> Strategy: - # top_p and top_K with temperature=0.0 reduces to greedy - # sampling - temperature = request.sampling_config.temperature - if temperature is not None: - temperature = temperature[0] - if temperature == 0.0: - return GREEDY - - if request.sampling_config.top_k is not None and len( - request.sampling_config.top_k - ) > 0 and request.sampling_config.top_p is not None and len( - request.sampling_config.top_p) > 0: - return ("top_k_top_p", request.sampling_config.top_k[0], - request.sampling_config.top_p[0], temperature) - elif request.sampling_config.top_p is not None and len( - request.sampling_config.top_p) > 0: - top_p = request.sampling_config.top_p[0] - return ("top_p", top_p, temperature) - elif request.sampling_config.top_k is not None and len( - request.sampling_config.top_k) > 0: - return ("top_k", request.sampling_config.top_k[0]) - else: +Strategy = TopK | TopP | Greedy | TopKTopP | TemperatureOnly + +T = TypeVar('T') + + +# Due to tensorrt_llm::runtime::SamplingConfig using vectors, params +# in LlmRequest.sampling_params are either None or single-element lists. +# This helper method simplifies code using such params. +def _unwrap_singleton(p: Optional[List[T]]) -> Optional[T]: + if p is None: + return None + t, = p + return t + + +@dataclass(frozen=True, kw_only=True) +class TorchSamplerSamplingParams: + """Subset of tensorrt_llm::runtime::SamplingConfig handled by TorchSampler.""" + temperature: Optional[float] + top_p: Optional[float] + top_k: Optional[int] + + +def _request_get_sampling_params( + request: LlmRequest) -> TorchSamplerSamplingParams: + sampling_config = request.sampling_config + temperature = _unwrap_singleton( + cast(Optional[List[float]], sampling_config.temperature)) + top_p = _unwrap_singleton(cast(Optional[List[float]], + sampling_config.top_p)) + top_k = _unwrap_singleton(cast(Optional[List[int]], sampling_config.top_k)) + + return TorchSamplerSamplingParams( + temperature=temperature, + top_p=top_p, + top_k=top_k, + ) + + +def _request_strategy(request: LlmRequest, *, vocab_size: int) -> Strategy: + # The semantics are specified in the doc-string of SamplingParams + + params = _request_get_sampling_params(request) + temperature = params.temperature + top_p = params.top_p + top_k = params.top_k + + if SamplingParams.params_imply_greedy_decoding( + temperature=temperature, + top_p=top_p, + top_k=top_k, + ): return GREEDY + # --- resolving default values + # NB: not greedy, hence temperature != 0 if specified + temperature = temperature or 1.0 + + # NB: not greedy, hence top_p != 0 if specified + top_p = top_p or 1.0 + # NB: not greedy, hence top_k != 1 if specified + # (0 and vocab_size are equivalent) + top_k = top_k or vocab_size + + assert top_k > 1, "non-greedy sampling requires valid top_k" + need_top_k = top_k < vocab_size + assert top_p > 0, "non-greedy sampling requires valid top_p" + need_top_p = top_p < 1 + + if need_top_p: + if need_top_k: + return ("top_k_top_p", top_k, top_p, temperature) + return ("top_p", top_p, temperature) + if need_top_k: + return ("top_k", top_k, temperature) + return ("temperature", temperature) + def _group_requests_by_sampling_strategy( requests: Iterable[LlmRequest], *, - pin_memory: bool = False) -> dict[Strategy, torch.Tensor]: + pin_memory: bool = False, + vocab_size: int) -> dict[Strategy, torch.Tensor]: # NB: Client code relies on request indices in returned torch.Tensor being sorted. strategy_dict: dict[Strategy, list[int]] = defaultdict(list) for req_index, req in enumerate(requests): - strategy_dict[_request_strategy(req)].append(req_index) + strategy_dict[_request_strategy( + req, vocab_size=vocab_size)].append(req_index) return { strategy: torch.tensor(indices, pin_memory=pin_memory, @@ -418,23 +469,32 @@ def sample( ) -> tuple[torch.Tensor, torch.Tensor]: filter_softmax = True match strategy: - case ("top_k", top_k): - tokens, softmax = top_k_sampling_batch(logits, top_k, generator) + case ("top_k", top_k, temperature): + tokens, softmax = top_k_sampling_batch(logits, + top_k=top_k, + temperature=temperature, + generator=generator) case ("top_p", top_p, temperature): tokens, softmax = top_p_sampling_batch( logits, top_p=top_p, generator=generator, - **(dict(temperature=temperature) - if temperature is not None else dict())) + temperature=temperature, + ) case ("top_k_top_p", top_k, top_p, temperature): tokens, softmax = top_k_top_p_sampling_batch( logits, top_k=top_k, top_p=top_p, + temperature=temperature, + generator=generator, + ) + case ("temperature", temperature): + tokens, softmax = temperature_sampling_batch( + logits, + temperature=temperature, generator=generator, - **(dict(temperature=temperature) - if temperature is not None else dict())) + ) case ("greedy", None): tokens, softmax = greedy_search_sampling_batch( logits, softmax_indices=softmax_indices) @@ -1070,7 +1130,11 @@ def _tree_sampling_batch(self, requests: list[LlmRequest], def _process_draft_tokens_rejection_sampling( self, request: LlmRequest, new_tokens: list[list[list[int]]], new_tokens_tensor: torch.Tensor) -> int: - sampling_strategy = _request_strategy(request) + # FIXME: Passing a dummy vocab_size could result in unnecessary + # filtering of vocab_size logits, out of vocab_size in + # total. The 'sample' below should generally be avoided + # by retaining the draft_probs during drafting (TRTLLM-7772). + sampling_strategy = _request_strategy(request, vocab_size=2**31) generator = self.get_generator(request.py_draft_logits.device) _, draft_probs = sample(sampling_strategy, request.py_draft_logits, @@ -1329,7 +1393,7 @@ def _sample_batched_by_strategy( dim=-1) requests_by_strategy = _group_requests_by_sampling_strategy( - requests, pin_memory=True) + requests, pin_memory=True, vocab_size=logits_cuda.size(1)) generator_cuda = self.get_generator(cuda_device) # FIXME: This check should/could be performed in ModelDrafter.prepare_draft_tokens @@ -1706,8 +1770,17 @@ def _process_requests( @override def should_provide_draft_probs(self, request: LlmRequest) -> bool: + params = _request_get_sampling_params(request) + temperature = params.temperature + top_p = params.top_p + top_k = params.top_k + # Do not request draft probs when sampling is greedy. - return _request_strategy(request) is not GREEDY + return not SamplingParams.params_imply_greedy_decoding( + temperature=temperature, + top_p=top_p, + top_k=top_k, + ) class Algorithms: diff --git a/tensorrt_llm/sampling_params.py b/tensorrt_llm/sampling_params.py index 1e194b25ad9..a1818a282ae 100644 --- a/tensorrt_llm/sampling_params.py +++ b/tensorrt_llm/sampling_params.py @@ -154,13 +154,25 @@ class SamplingParams: best_of (int, optional): Number of sequences to consider for best output. Defaults to None. use_beam_search (bool): Whether to use beam search. Defaults to False. - top_k (int, optional): Controls number of logits to sample from. None means using C++ runtime default 0, i.e., all logits. Defaults to None. - top_p (float, optional): Controls the top-P probability to sample from. None means using C++ runtime default 0.f. Defaults to None. + top_k (int, optional): Controls number of logits to sample from. Can assume non-negative values, where 0 means 'all logits'. Defaults to None. + The value None is treated as "not specified" in the following. + If neither temperature, top_p, nor top_k are specified, sampling is greedy. + If temperature > 0 and/or top_p < 1 are specified, sampling will proceed accordingly and top_k will default to top_k = 0. + Setting top_k = 1 results in greedy sampling. + top_p (float, optional): Controls the top-P probability to sample from. Can have values between 0 and 1. Defaults to None. + The value None is treated as "not specified" in the following. + If neither temperature, top_p, nor top_k are specified, sampling is greedy. + If temperature > 0 and/or top_k > 1 are specified, sampling will proceed accordingly and top_p will default to top_p = 1. + Setting top_p = 0 should result in greedy sampling, but is currently disallowed in the backend. top_p_min (float, optional): Controls decay in the top-P algorithm. topPMin is lower-bound. None means using C++ runtime default 1.e-6. Defaults to None. top_p_reset_ids (int, optional): Controls decay in the top-P algorithm. Indicates where to reset the decay. None means using C++ runtime default 1. Defaults to None. top_p_decay (float, optional): Controls decay in the top-P algorithm. The decay value. None means using C++ runtime default 1.f. Defaults to None. seed (int, optional): Controls the random seed used by the random number generator in sampling. None means using C++ runtime default 0. Defaults to None. - temperature (float, optional): Controls the modulation of logits when sampling new tokens. It can have values > 0.f. None means using C++ runtime default 1.0f. Defaults to None. + temperature (float, optional): Controls the modulation of logits when sampling new tokens. It can have values >= 0.f. Defaults to None. + The value None is treated as "not specified" in the following. + If neither temperature, top_p, nor top_k are specified, sampling is greedy. + If top_p < 1 and/or top_k > 1 are specified, sampling will proceed accordingly and temperature will default to temperature = 1. + Setting temperature = 0 results in greedy sampling. min_tokens (int, optional): Lower bound on the number of tokens to generate. Values < 1 have no effect. None means using C++ runtime default 1. Defaults to None. beam_search_diversity_rate (float, optional): Used to penalize tokens based on how often they appear in the sequence. It can have any value > 0.f. Values < 1.f encourages repetition, values > 1.f discourages it. None means using C++ runtime default 1.f. Defaults to None. repetition_penalty (float, optional): Used to penalize tokens based on how often they appear in the sequence. It can have any value > 0.f. Values < 1.f encourages repetition, values > 1.f discourages it. None means using C++ runtime default 1.f. Defaults to None. @@ -296,11 +308,19 @@ def _validate(self): For instance, while the greedy decoding with n > 1 is capable in the Executor class of C++ runtime, the LLM API disallows such combination. """ - if self.best_of < self.n: + if self.top_p is not None and (self.top_p < 0 or self.top_p > 1): + raise ValueError(f"require 0 <= top_p <= 1, got top_p={self.top_p}") + if self.top_k is not None and self.top_k < 0: + raise ValueError(f"require top_k >= 0, got top_k={self.top_k}") + if self.temperature is not None and self.temperature < 0: + raise ValueError(f"require temperature >= 0, got temperature={self.temperature}") + + if self.best_of is not None and self.best_of < self.n: raise ValueError(f"best_of ({self.best_of}) cannot be less than n ({self.n})") if ( - self.best_of > 1 + self.best_of is not None + and self.best_of > 1 and self._greedy_decoding and not os.environ.get("TLLM_ALLOW_N_GREEDY_DECODING", None) ): @@ -324,12 +344,25 @@ def _validate(self): self.logprobs = self.logprobs and int(self.logprobs) self.prompt_logprobs = self.prompt_logprobs and int(self.prompt_logprobs) + # NB: Static, because downstream code only holds instances of + # bindings.SamplingConfig (not SamplingParams). + @staticmethod + def params_imply_greedy_decoding( + *, temperature: Optional[float], top_p: Optional[float], top_k: Optional[int] + ): + return ( + (temperature is None and top_p is None and top_k is None) + or top_k == 1 + or top_p == 0.0 + or temperature == 0 + ) + @property def _greedy_decoding(self) -> bool: - return ( - not self.use_beam_search - and (self.top_k is None or self.top_k == 1) - and (self.top_p is None or self.top_p == 0.0) + return not self.use_beam_search and self.params_imply_greedy_decoding( + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, ) @property diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index d45fc865cb4..6bb7181a3d8 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -14,6 +14,7 @@ l0_a10: backend: pytorch tests: # ------------- PyTorch tests --------------- + - unittest/_torch/sampler/test_torch_sampler.py - unittest/_torch/modeling/test_modeling_mistral.py - unittest/_torch/modeling/test_modeling_pixtral.py # NOTE: this is a CPU-only test, but we do not have a dedicated job for this (and therefore no diff --git a/tests/unittest/_torch/sampler/test_torch_sampler.py b/tests/unittest/_torch/sampler/test_torch_sampler.py new file mode 100644 index 00000000000..b51de96ff4a --- /dev/null +++ b/tests/unittest/_torch/sampler/test_torch_sampler.py @@ -0,0 +1,303 @@ +from itertools import product +from typing import Optional, cast + +import pytest +from utils.util import force_ampere + +from tensorrt_llm._torch.pyexecutor.sampler import ( + GREEDY, + LlmRequest, + TorchSampler, + _request_strategy, +) +from tensorrt_llm.bindings import SamplingConfig +from tensorrt_llm.sampling_params import SamplingParams + + +@force_ampere +class TestStrategySelection: + VOCAB_SIZE = 1000 + TOP_K_VALS = [None, 0, 1, 42, 1000] + TOP_P_VALS = [None, 0, 0.42, 1] + TEMPERATURE_VALS = [None, 0, 1.42] + + # For non-greedy sampling, the following choices have no effect. + TOP_P_NEUTRAL_VALS = [None, 1] + TOP_K_NEUTRAL_VALS = [None, 0, VOCAB_SIZE] + TEMPERATURE_NEUTRAL_VALS = [None, 1] + + TEMPERATURE_NOT_GREEDY = [0.42] + [t for t in TEMPERATURE_NEUTRAL_VALS if t is not None] + + class MockLlmRequest: + sampling_config: SamplingConfig + + def _check_params(self, params: SamplingParams): + # cf. description of 'top_p' in doc-string of SamplingParams and + # test_top_p_0_disallowed below. + if params.top_p == 0: + pytest.skip("top_p = 0 disallowed by tensorrt_llm::executor::SamplingConfig") + + def test_top_p_0_disallowed(self): + # If this xpasses, update _check_params and doc-string of SamplingParams. + params = SamplingParams(top_p=0) + pytest.xfail("top_p = 0 disallowed by tensorrt_llm::executor::SamplingConfig") + params._get_sampling_config() + + def _build_mock_llm_request(self, params: SamplingParams) -> LlmRequest: + request = self.MockLlmRequest() + request.sampling_config = SamplingConfig(params._get_sampling_config()) + return cast(LlmRequest, request) + + def test_defaults(self): + # NB: The code in _request_strategy relies on the default values below. + default_params = SamplingParams() + assert default_params.top_k is None + assert default_params.top_p is None + assert default_params.temperature is None + + def test_defaults_config(self): + # NB: The code in _request_strategy relies on the default values below. + default_config = SamplingParams()._get_sampling_config() + assert default_config.top_k is None + assert default_config.top_p is None + assert default_config.temperature is None + + def test_defaults_request(self): + # NB: The code in _request_strategy relies on the default values below. + request = self._build_mock_llm_request(SamplingParams()) + default_config = request.sampling_config + assert default_config.top_k is None + assert default_config.top_p is None + assert default_config.temperature is None + + def test_default_is_greedy(self): + request = self._build_mock_llm_request(SamplingParams()) + assert _request_strategy(request, vocab_size=self.VOCAB_SIZE) is GREEDY + + @pytest.mark.parametrize( + "top_p, top_k", + [ + pytest.param(top_p, top_k) + # https://stackoverflow.com/a/75421799, does not work with nested loops + for (top_k, top_p) in product(TOP_K_VALS, TOP_P_VALS) + ], + ) + def test_temperature_0_is_greedy(self, top_p: Optional[float], top_k: Optional[int]): + params = SamplingParams(temperature=0, top_p=top_p, top_k=top_k) + self._check_params(params) + request = self._build_mock_llm_request(params) + assert _request_strategy(request, vocab_size=self.VOCAB_SIZE) is GREEDY + + @pytest.mark.parametrize( + "temperature, top_k", + [ + pytest.param(temperature, top_k) + # https://stackoverflow.com/a/75421799, does not work with nested loops + for (temperature, top_k) in product(TEMPERATURE_VALS, TOP_K_VALS) + ], + ) + def test_top_p_0_is_greedy(self, temperature: Optional[float], top_k: Optional[int]): + params = SamplingParams(top_p=0, temperature=temperature, top_k=top_k) + self._check_params(params) + request = self._build_mock_llm_request(params) + assert _request_strategy(request, vocab_size=self.VOCAB_SIZE) is GREEDY + + @pytest.mark.parametrize( + "temperature, top_p", + [ + pytest.param(temperature, top_p) + # https://stackoverflow.com/a/75421799, does not work with nested loops + for (temperature, top_p) in product(TEMPERATURE_VALS, TOP_P_VALS) + ], + ) + def test_top_k_1_is_greedy(self, temperature: Optional[float], top_p: Optional[float]): + params = SamplingParams(top_p=top_p, temperature=temperature, top_k=1) + self._check_params(params) + request = self._build_mock_llm_request(params) + assert _request_strategy(request, vocab_size=self.VOCAB_SIZE) is GREEDY + + @pytest.mark.parametrize( + "temperature, trivial_top_p, trivial_top_k", + [ + pytest.param(temperature, top_p, top_k) + # https://stackoverflow.com/a/75421799, does not work with nested loops + for (temperature, top_k, top_p) in product( + TEMPERATURE_NOT_GREEDY, TOP_K_NEUTRAL_VALS, TOP_P_NEUTRAL_VALS + ) + ], + ) + def test_temperature_only( + self, temperature: float, trivial_top_p: Optional[float], trivial_top_k: Optional[int] + ): + params = SamplingParams(temperature=temperature, top_p=trivial_top_p, top_k=trivial_top_k) + self._check_params(params) + request = self._build_mock_llm_request(params) + strat = _request_strategy(request, vocab_size=self.VOCAB_SIZE) + assert len(strat) == 2 + assert strat[0] == "temperature" + assert strat[1] == pytest.approx(temperature) + + @pytest.mark.parametrize( + "trivial_temperature, trivial_top_k", + [ + pytest.param(temperature, top_k) + # https://stackoverflow.com/a/75421799, does not work with nested loops + for (temperature, top_k) in product(TEMPERATURE_NEUTRAL_VALS, TOP_K_NEUTRAL_VALS) + ], + ) + def test_top_p_only(self, trivial_temperature: Optional[float], trivial_top_k: Optional[int]): + params = SamplingParams(top_p=0.42, temperature=trivial_temperature, top_k=trivial_top_k) + self._check_params(params) + request = self._build_mock_llm_request(params) + strat = _request_strategy(request, vocab_size=self.VOCAB_SIZE) + assert len(strat) == 3 + assert strat[0] == "top_p" + assert strat[1] == pytest.approx(0.42) + assert strat[2] == pytest.approx(1.0) + + @pytest.mark.parametrize( + "trivial_top_k", + [ + pytest.param(top_k) + for top_k in TOP_K_NEUTRAL_VALS # https://stackoverflow.com/a/75421799 + ], + ) + def test_top_p_with_temperature(self, trivial_top_k: Optional[int]): + params = SamplingParams(top_p=0.42, temperature=0.9, top_k=trivial_top_k) + self._check_params(params) + request = self._build_mock_llm_request(params) + strat = _request_strategy(request, vocab_size=self.VOCAB_SIZE) + assert len(strat) == 3 + assert strat[0] == "top_p" + assert strat[1] == pytest.approx(0.42) + assert strat[2] == pytest.approx(0.9) + + @pytest.mark.parametrize( + "trivial_temperature, trivial_top_p", + [ + pytest.param(temperature, top_p) + # https://stackoverflow.com/a/75421799, does not work with nested loops + for (temperature, top_p) in product(TEMPERATURE_NEUTRAL_VALS, TOP_P_NEUTRAL_VALS) + ], + ) + def test_top_k_only(self, trivial_temperature: Optional[float], trivial_top_p: Optional[float]): + params = SamplingParams(top_k=42, temperature=trivial_temperature, top_p=trivial_top_p) + self._check_params(params) + request = self._build_mock_llm_request(params) + strat = _request_strategy(request, vocab_size=self.VOCAB_SIZE) + assert len(strat) == 3 + assert strat[0] == "top_k" + assert strat[1] == 42 + assert strat[2] == pytest.approx(1.0) + + @pytest.mark.parametrize( + "trivial_top_p", + [ + pytest.param(top_p) + for top_p in TOP_P_NEUTRAL_VALS # https://stackoverflow.com/a/75421799 + ], + ) + def test_top_k_with_temperature(self, trivial_top_p: Optional[float]): + params = SamplingParams(top_k=42, temperature=0.9, top_p=trivial_top_p) + self._check_params(params) + request = self._build_mock_llm_request(params) + strat = _request_strategy(request, vocab_size=self.VOCAB_SIZE) + assert len(strat) == 3 + assert strat[0] == "top_k" + assert strat[1] == 42 + assert strat[2] == pytest.approx(0.9) + + @pytest.mark.parametrize( + "trivial_temperature", + [ + pytest.param(temperature) + for temperature in TEMPERATURE_NEUTRAL_VALS # https://stackoverflow.com/a/75421799 + ], + ) + def test_top_k_top_p(self, trivial_temperature: Optional[float]): + params = SamplingParams(top_k=42, top_p=0.7, temperature=trivial_temperature) + self._check_params(params) + request = self._build_mock_llm_request(params) + strat = _request_strategy(request, vocab_size=self.VOCAB_SIZE) + assert len(strat) == 4 + assert strat[0] == "top_k_top_p" + assert strat[1] == 42 + assert strat[2] == pytest.approx(0.7) + assert strat[3] == pytest.approx(1.0) + + def test_top_k_top_p_with_temperature(self): + params = SamplingParams(top_k=42, top_p=0.7, temperature=0.9) + self._check_params(params) + request = self._build_mock_llm_request(params) + strat = _request_strategy(request, vocab_size=self.VOCAB_SIZE) + assert len(strat) == 4 + assert strat[0] == "top_k_top_p" + assert strat[1] == 42 + assert strat[2] == pytest.approx(0.7) + assert strat[3] == pytest.approx(0.9) + + def test_param_validation(self): + with pytest.raises(ValueError, match="require temperature >= 0, got temperature=-1"): + SamplingParams(temperature=-1) + + with pytest.raises(ValueError, match="require 0 <= top_p <= 1, got top_p=-1"): + SamplingParams(top_p=-1) + + with pytest.raises(ValueError, match="require 0 <= top_p <= 1, got top_p=2"): + SamplingParams(top_p=2) + + with pytest.raises(ValueError, match="require top_k >= 0, got top_k=-1"): + SamplingParams(top_k=-1) + + @pytest.mark.parametrize( + "top_k, top_p", + [ + pytest.param(top_k, top_p) + # https://stackoverflow.com/a/75421799, does not work with nested loops + for (top_k, top_p) in product(TOP_K_NEUTRAL_VALS, TOP_P_NEUTRAL_VALS) + if (top_k, top_p) != (None, None) + ], + ) + def test_trivial_top_k_top_p_not_greedy(self, top_k: Optional[int], top_p: Optional[float]): + params = SamplingParams(top_k=top_k, top_p=top_p) + self._check_params(params) + request = self._build_mock_llm_request(params) + strat = _request_strategy(request, vocab_size=self.VOCAB_SIZE) + assert len(strat) == 2 + assert strat[0] == "temperature" + assert strat[1] == pytest.approx(1.0) + + @pytest.fixture + def torch_sampler(self) -> TorchSampler: + return TorchSampler( + TorchSampler.Args( + max_seq_len=123, + max_draft_len=3, + max_num_sequences=12, + max_beam_width=1, + max_total_draft_tokens=3, + ) + ) + + @pytest.mark.parametrize( + "temperature, top_p, top_k", + [ + pytest.param(temperature, top_p, top_k) + # https://stackoverflow.com/a/75421799, does not work with nested loops + for (temperature, top_p, top_k) in product(TEMPERATURE_VALS, TOP_P_VALS, TOP_K_VALS) + ], + ) + def test_should_provide_draft_probs_consistency( + self, + temperature: Optional[float], + top_p: Optional[float], + top_k: Optional[int], + torch_sampler: TorchSampler, + ): + params = SamplingParams(top_k=top_k, top_p=top_p, temperature=temperature) + self._check_params(params) + request = self._build_mock_llm_request(params) + strat = _request_strategy(request, vocab_size=self.VOCAB_SIZE) + is_greedy = strat is GREEDY + + assert torch_sampler.should_provide_draft_probs(request) == (not is_greedy) diff --git a/tests/unittest/_torch/sampler/test_trtllm_sampler.py b/tests/unittest/_torch/sampler/test_trtllm_sampler.py index 37227f9b53f..dec50239c13 100644 --- a/tests/unittest/_torch/sampler/test_trtllm_sampler.py +++ b/tests/unittest/_torch/sampler/test_trtllm_sampler.py @@ -23,6 +23,7 @@ def create_llm(model_dir): enable_chunked_prefill=True, cuda_graph_config=CudaGraphConfig(), kv_cache_config=trt_kv_cache_config, + sampler_type="TRTLLMSampler", max_num_tokens= 128 # Only one request longer than max_num_tokens is required to test chunked prefill ) diff --git a/tests/unittest/conftest.py b/tests/unittest/conftest.py index 359b3f33411..49fa9d3845c 100644 --- a/tests/unittest/conftest.py +++ b/tests/unittest/conftest.py @@ -18,6 +18,7 @@ import traceback from typing import Any +import _pytest.outcomes import pytest import torch import tqdm @@ -65,8 +66,9 @@ def pytest_pyfunc_call(pyfuncitem) -> Any: return (yield) # NB: _pytest.outcomes.OutcomeException subclasses BaseException except BaseException as e: - print(f"TEST RAISED ERROR: {e}") - traceback.print_exception(e) + if not isinstance(e, _pytest.outcomes.Skipped): + print(f"TEST RAISED ERROR: {e}") + traceback.print_exception(e) raise