Skip to content

Commit f444fe2

Browse files
authored
[None][test] fix a typo in perf test sampler config (#8726)
Signed-off-by: Ruodi Lu <[email protected]> Signed-off-by: ruodil <[email protected]> Co-authored-by: Ruodi Lu <[email protected]>
1 parent b828b64 commit f444fe2

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# -*- coding: utf-8 -*-
16+
"""
17+
Sampler options config for trtllm-bench perf tests
18+
"""
19+
20+
21+
def get_sampler_options_config(model_label: str) -> dict:
22+
"""
23+
Return the sampler options config corresponding to the model label.
24+
Args:
25+
model_label: model label from self._config.to_string()
26+
Returns:
27+
dict: sampler options config
28+
"""
29+
base_config = {
30+
'top_k': 4,
31+
'top_p': 0.5,
32+
'temperature': 0.5,
33+
}
34+
return base_config

tests/integration/defs/perf/test_perf.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from ..conftest import get_llm_root, llm_models_root, trt_environment
3131
from .pytorch_model_config import get_model_yaml_config
32+
from .sampler_options_config import get_sampler_options_config
3233
from .utils import (AbstractPerfScriptTestClass, PerfBenchScriptTestCmds,
3334
PerfDisaggScriptTestCmds, PerfMetricType,
3435
PerfServerClientBenchmarkCmds, generate_test_nodes)
@@ -1684,6 +1685,15 @@ def get_trtllm_bench_command(self, engine_dir):
16841685
benchmark_cmd += [
16851686
f"--extra_llm_api_options={autodeploy_config_path}"
16861687
]
1688+
# for sampler options
1689+
sampler_options_path = os.path.join(engine_dir, "sampler_options.yml")
1690+
if not os.path.exists(sampler_options_path):
1691+
os.makedirs(os.path.dirname(sampler_options_path), exist_ok=True)
1692+
sampler_config = get_sampler_options_config(self._config.to_string())
1693+
print_info(f"sampler options config: {sampler_config}")
1694+
with open(sampler_options_path, 'w') as f:
1695+
yaml.dump(sampler_config, f, default_flow_style=False)
1696+
benchmark_cmd += [f"--sampler_options={sampler_options_path}"]
16871697
return benchmark_cmd
16881698

16891699
def get_commands(self):

0 commit comments

Comments
 (0)