Skip to content

Commit 87269e8

Browse files
committed
add default sampler config for perf test
Signed-off-by: Ruodi Lu <[email protected]>
1 parent e051a05 commit 87269e8

File tree

2 files changed

+44
-10
lines changed

2 files changed

+44
-10
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# -*- coding: utf-8 -*-
16+
"""
17+
Sampler options config for trtllm-bench perf tests
18+
"""
19+
20+
21+
def get_sampler_options_config(model_label: str) -> dict:
22+
"""
23+
Return the sampler options config corresponding to the model label.
24+
Args:
25+
model_label: model label from self._config.to_string()
26+
Returns:
27+
dict: sampler options config
28+
"""
29+
base_config = {
30+
'top_k': 4,
31+
'top_p': 0.5,
32+
'temperature': 0.5,
33+
}
34+
return base_config

tests/integration/defs/perf/test_perf.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
from ..conftest import get_llm_root, llm_models_root, trt_environment
3131
from .pytorch_model_config import get_model_yaml_config
32-
from .sample_options_config import get_sample_options_config
32+
from .sampler_options_config import get_sampler_options_config
3333
from .utils import (AbstractPerfScriptTestClass, PerfBenchScriptTestCmds,
3434
PerfDisaggScriptTestCmds, PerfMetricType,
3535
PerfServerClientBenchmarkCmds, generate_test_nodes)
@@ -1683,15 +1683,15 @@ def get_trtllm_bench_command(self, engine_dir):
16831683
benchmark_cmd += [
16841684
f"--extra_llm_api_options={autodeploy_config_path}"
16851685
]
1686-
# for sample options
1687-
sample_options_path = os.path.join(engine_dir, "sample_options.yml")
1688-
if not os.path.exists(sample_options_path):
1689-
os.makedirs(os.path.dirname(sample_options_path), exist_ok=True)
1690-
sample_config = get_sample_options_config(self._config.to_string())
1691-
print_info(f"sample options config: {sample_config}")
1692-
with open(sample_options_path, 'w') as f:
1693-
yaml.dump(sample_config, f, default_flow_style=False)
1694-
benchmark_cmd += [f"--sample_options={sample_options_path}"]
1686+
# for sampler options
1687+
sampler_options_path = os.path.join(engine_dir, "sampler_options.yml")
1688+
if not os.path.exists(sampler_options_path):
1689+
os.makedirs(os.path.dirname(sampler_options_path), exist_ok=True)
1690+
sampler_config = get_sampler_options_config(self._config.to_string())
1691+
print_info(f"sampler options config: {sampler_config}")
1692+
with open(sampler_options_path, 'w') as f:
1693+
yaml.dump(sampler_config, f, default_flow_style=False)
1694+
benchmark_cmd += [f"--sampler_options={sampler_options_path}"]
16951695
return benchmark_cmd
16961696

16971697
def get_commands(self):

0 commit comments

Comments
 (0)