Skip to content

Commit f4f2f85

Browse files
committed
Replace all args to generate_benchmark_report with a scenario
1 parent 8563e13 commit f4f2f85

File tree

1 file changed

+55
-49
lines changed

1 file changed

+55
-49
lines changed

src/guidellm/main.py

+55-49
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from transformers import AutoTokenizer # type: ignore[import-untyped]
77

88
from guidellm.backend import Backend, BackendType
9+
from guidellm.benchmark.scenario import Scenario, ScenarioManager
910
from guidellm.core import GuidanceReport, TextGenerationBenchmarkReport
1011
from guidellm.executor import Executor, ProfileGenerationMode
1112
from guidellm.request import (
@@ -19,7 +20,7 @@
1920
__all__ = ["generate_benchmark_report"]
2021

2122
# FIXME: Remove
22-
SCENARIOS = Literal["rag", "short"]
23+
SCENARIOS = ScenarioManager()
2324

2425
@click.command()
2526
@click.option(
@@ -33,7 +34,7 @@
3334
)
3435
@click.option(
3536
"--scenario",
36-
type=cli_params.Union(click.File(mode='r'), click.Choice(get_args(SCENARIOS))),
37+
type=cli_params.Union(click.File(mode='r'), click.Choice(SCENARIOS.list())),
3738
default=None,
3839
help=(
3940
"TODO: A scenario or path to config"
@@ -42,7 +43,7 @@
4243
@click.option(
4344
"--backend",
4445
type=click.Choice(get_args(BackendType)),
45-
default="openai_http",
46+
default=None,
4647
help=(
4748
"The backend to use for benchmarking. "
4849
"The default is OpenAI Server enabling compatability with any server that "
@@ -61,7 +62,7 @@
6162
@click.option(
6263
"--data",
6364
type=str,
64-
required=True,
65+
default=None,
6566
help=(
6667
"The data source to use for benchmarking. "
6768
"Depending on the data-type, it should be a "
@@ -74,7 +75,7 @@
7475
@click.option(
7576
"--data-type",
7677
type=click.Choice(["emulated", "file", "transformers"]),
77-
required=True,
78+
default=None,
7879
help=(
7980
"The type of data to use for benchmarking. "
8081
"Use 'emulated' for synthetic data, 'file' for a file, or 'transformers' "
@@ -96,7 +97,7 @@
9697
@click.option(
9798
"--rate-type",
9899
type=click.Choice(get_args(ProfileGenerationMode)),
99-
default="sweep",
100+
default=None,
100101
help=(
101102
"The type of request rate to use for benchmarking. "
102103
"Use sweep to run a full range from synchronous to throughput (default), "
@@ -119,7 +120,7 @@
119120
@click.option(
120121
"--max-seconds",
121122
type=int,
122-
default=120,
123+
default=None,
123124
help=(
124125
"The maximum number of seconds for each benchmark run. "
125126
"Either max-seconds, max-requests, or both must be set. "
@@ -164,25 +165,35 @@
164165
)
165166
def generate_benchmark_report_cli(
166167
target: str,
167-
scenario: Optional[Union[IO[Any], SCENARIOS]],
168-
backend: BackendType,
168+
scenario: Optional[Union[IO[Any], str]],
169+
backend: Optional[BackendType],
169170
model: Optional[str],
170171
data: Optional[str],
171-
data_type: Literal["emulated", "file", "transformers"],
172+
data_type: Optional[Literal["emulated", "file", "transformers"]],
172173
tokenizer: Optional[str],
173-
rate_type: ProfileGenerationMode,
174+
rate_type: Optional[ProfileGenerationMode],
174175
rate: Optional[float],
175176
max_seconds: Optional[int],
176177
max_requests: Union[Literal["dataset"], int, None],
177-
output_path: str,
178+
output_path: Optional[str],
178179
enable_continuous_refresh: bool,
179180
):
180181
"""
181182
Generate a benchmark report for a specified backend and dataset.
182183
"""
183-
generate_benchmark_report(
184-
target=target,
185-
scenario=scenario,
184+
185+
if isinstance(scenario, str):
186+
defaults = SCENARIOS[scenario]
187+
elif isinstance(scenario, IO):
188+
defaults = Scenario.from_json(scenario.read())
189+
SCENARIOS["custom"] = defaults
190+
elif scenario is None:
191+
defaults = Scenario()
192+
else:
193+
raise ValueError("Invalid scenario type")
194+
195+
# Update defaults with CLI args
196+
defaults.update(
186197
backend=backend,
187198
model=model,
188199
data=data,
@@ -191,25 +202,20 @@ def generate_benchmark_report_cli(
191202
rate_type=rate_type,
192203
rate=rate,
193204
max_seconds=max_seconds,
194-
max_requests=max_requests,
205+
max_requests=max_requests
206+
)
207+
208+
generate_benchmark_report(
209+
target=target,
210+
scenario=defaults,
195211
output_path=output_path,
196212
cont_refresh_table=enable_continuous_refresh,
197213
)
198214

199215

200216
def generate_benchmark_report(
201217
target: str,
202-
data: Optional[str],
203-
data_type: Literal["emulated", "file", "transformers"],
204-
scenario: Optional[Union[IO[Any], SCENARIOS]],
205-
backend: BackendType = "openai_http",
206-
backend_kwargs: Optional[Mapping[str, Any]] = None,
207-
model: Optional[str] = None,
208-
tokenizer: Optional[str] = None,
209-
rate_type: ProfileGenerationMode = "sweep",
210-
rate: Optional[float] = None,
211-
max_seconds: Optional[int] = 120,
212-
max_requests: Union[Literal["dataset"], int, None] = None,
218+
scenario: Scenario,
213219
output_path: Optional[str] = None,
214220
cont_refresh_table: bool = False,
215221
) -> GuidanceReport:
@@ -236,22 +242,22 @@ def generate_benchmark_report(
236242
:param backend_kwargs: Additional keyword arguments for the backend.
237243
"""
238244
logger.info(
239-
"Generating benchmark report with target: {}, backend: {}", target, backend
245+
"Generating benchmark report with target: {}, backend: {}", target, scenario.backend
240246
)
241247

242248
# Create backend
243249
backend_inst = Backend.create(
244-
type_=backend,
250+
type_=scenario.backend,
245251
target=target,
246-
model=model,
247-
**(backend_kwargs or {}),
252+
model=scenario.model,
253+
**(scenario.backend_kwargs or {}),
248254
)
249255
backend_inst.validate()
250256

251257
request_generator: RequestGenerator
252258

253259
# Create tokenizer and request generator
254-
tokenizer_inst = tokenizer
260+
tokenizer_inst = scenario.tokenizer
255261
if not tokenizer_inst:
256262
try:
257263
tokenizer_inst = AutoTokenizer.from_pretrained(backend_inst.model)
@@ -261,44 +267,44 @@ def generate_benchmark_report(
261267
"--tokenizer must be provided for request generation"
262268
) from err
263269

264-
if data_type == "emulated":
270+
if scenario.data_type == "emulated":
265271
request_generator = EmulatedRequestGenerator(
266-
config=data, tokenizer=tokenizer_inst
272+
config=scenario.data, tokenizer=tokenizer_inst
267273
)
268-
elif data_type == "file":
269-
request_generator = FileRequestGenerator(path=data, tokenizer=tokenizer_inst)
270-
elif data_type == "transformers":
274+
elif scenario.data_type == "file":
275+
request_generator = FileRequestGenerator(path=scenario.data, tokenizer=tokenizer_inst)
276+
elif scenario.data_type == "transformers":
271277
request_generator = TransformersDatasetRequestGenerator(
272-
dataset=data, tokenizer=tokenizer_inst
278+
dataset=scenario.data, tokenizer=tokenizer_inst
273279
)
274280
else:
275-
raise ValueError(f"Unknown data type: {data_type}")
281+
raise ValueError(f"Unknown data type: {scenario.data_type}")
276282

277-
if data_type == "emulated" and max_requests == "dataset":
283+
if scenario.data_type == "emulated" and scenario.max_requests == "dataset":
278284
raise ValueError("Cannot use 'dataset' for emulated data")
279285

280286
# Create executor
281287
executor = Executor(
282288
backend=backend_inst,
283289
request_generator=request_generator,
284-
mode=rate_type,
285-
rate=rate if rate_type in ("constant", "poisson") else None,
290+
mode=scenario.rate_type,
291+
rate=scenario.rate if scenario.rate_type in ("constant", "poisson") else None,
286292
max_number=(
287-
len(request_generator) if max_requests == "dataset" else max_requests
293+
len(request_generator) if scenario.max_requests == "dataset" else scenario.max_requests
288294
),
289-
max_duration=max_seconds,
295+
max_duration=scenario.max_seconds,
290296
)
291297

292298
# Run executor
293299
logger.debug(
294300
"Running executor with args: {}",
295301
{
296-
"backend": backend,
302+
"backend": scenario.backend,
297303
"request_generator": request_generator,
298-
"mode": rate_type,
299-
"rate": rate,
300-
"max_number": max_requests,
301-
"max_duration": max_seconds,
304+
"mode": scenario.rate_type,
305+
"rate": scenario.rate,
306+
"max_number": scenario.max_requests,
307+
"max_duration": scenario.max_seconds,
302308
},
303309
)
304310
report = asyncio.run(_run_executor_for_result(executor))

0 commit comments

Comments
 (0)