Skip to content

Commit 4111724

Browse files
committed
Specdec Bench: Initial
Signed-off-by: Izzy Putterman <[email protected]>
1 parent ea1fcb0 commit 4111724

File tree

25 files changed

+1447
-0
lines changed

25 files changed

+1447
-0
lines changed

.github/CODEOWNERS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ modelopt/torch/utils @NVIDIA/modelopt-torch-utils-codeowners
4949
/examples/nemo_run @NVIDIA/modelopt-examples-megatron-codeowners
5050
/examples/onnx_ptq @NVIDIA/modelopt-onnx-codeowners
5151
/examples/pruning @NVIDIA/modelopt-torch-nas-prune-codeowners
52+
/examples/specdec_bench @NVIDIA/modelopt-torch-speculative-codeowners
5253
/examples/speculative_decoding @NVIDIA/modelopt-torch-speculative-codeowners
5354
/examples/vlm_ptq @NVIDIA/modelopt-examples-vlm-codeowners
5455
/examples/vllm_serve @NVIDIA/modelopt-examples-llm_ptq-codeowners

CHANGELOG.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
Model Optimizer Changelog (Linux)
22
=================================
33

4+
5+
0.40 (2025-xx-xx)
6+
^^^^^^^^^^^^^^^^^
7+
8+
**New Features**
9+
10+
- Add ``specdec_bench`` example to benchmark speculative decoding performance. See `examples/specdec_bench/README.md <https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/specdec_bench#speculative-decoding-benchmark>`_ for more details.
11+
12+
413
0.39 (2025-11-07)
514
^^^^^^^^^^^^^^^^^
615

examples/specdec_bench/README.md

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Speculative Decoding (SpecDec) Bench
2+
3+
## Installation
4+
5+
This benchmark is meant to be a lightweight layer ontop of an existing vLLM/SGLang/TRTLLM installation. For example, no install
6+
is required if one is running in the following dockers: `vllm/vllm-openai:v0.11.0` (vLLM), `lmsysorg/sglang:v0.5.4.post2` (SGLang), or
7+
`nvcr.io/nvidia/tensorrt-llm/release:1.2.0rc1` (TRT-LLM).
8+
9+
Next
10+
11+
```bash
12+
cd examples/specdec_bench
13+
```
14+
15+
## Purpose
16+
17+
Collect relevant metrics on acceptance rate, timing, and outputs for Speculative Decoding methods.
18+
Acceptance rate refers to the number of tokens generated on every iteration. For a standard Autoregressive LLM, this number
19+
is just 1.
20+
21+
## Getting Started
22+
23+
A basic example run script is provided which benchmarks MTBench (a standard 160 prompts spanning 8 categories).
24+
MTBench is available [here](https://huggingface.co/datasets/HuggingFaceH4/mt_bench_prompts)
25+
26+
### Running MTBench on GPT OSS + Eagle3
27+
28+
Download `nvidia/gpt-oss-120b-Eagle3` to a local directory `/path/to/eagle`.
29+
30+
```bash
31+
python3 run.py --model_dir openai/gpt-oss-120b --tokenizer openai/gpt-oss-120b --draft_model_dir /path/to/eagle --mtbench question.jsonl --tp_size 1 --ep_size 1 --draft_length 3 --output_length 4096 --num_requests 80 --engine TRTLLM --concurrency 1
32+
33+
```
34+
35+
### Running Random ids on GPT OSS + Eagle3
36+
37+
Download `nvidia/gpt-oss-120b-Eagle3` to a local directory `/path/to/eagle`.
38+
39+
```bash
40+
python3 run.py --model_dir openai/gpt-oss-120b --tokenizer openai/gpt-oss-120b --draft_model_dir /path/to/eagle --random_isl 1024 --tp_size 1 --ep_size 1 --draft_length 3 --output_length 4096 --num_requests 40 --engine TRTLLM --concurrency 1
41+
42+
```

examples/specdec_bench/run.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import argparse
17+
import asyncio
18+
19+
import yaml
20+
from specdec_bench import datasets, metrics, models, runners
21+
from specdec_bench.utils import decode_chat, encode_chat, get_tokenizer, postprocess_base
22+
23+
engines_available = {
24+
"TRTLLM": models.TRTLLMPYTModel,
25+
"VLLM": models.VLLMModel,
26+
"SGLANG": models.SGLANGModel,
27+
}
28+
29+
30+
async def run_loop(runner, dataset, tokenizer, output_length, postprocess, concurrency=10):
31+
"""
32+
Async version of run_loop with concurrency control using a semaphore.
33+
34+
Args:
35+
runner: The model runner instance
36+
dataset: The dataset containing requests
37+
tokenizer: The tokenizer instance
38+
output_length: Maximum output length
39+
concurrency: Maximum number of concurrent requests (default: 10)
40+
"""
41+
semaphore = asyncio.Semaphore(concurrency)
42+
max_length = output_length
43+
end_id = tokenizer.eos_token_id
44+
45+
async def process_single_request(request, i):
46+
"""Process a single request with all its conversation turns."""
47+
async with semaphore:
48+
messages = []
49+
if request.system_prompt is not None:
50+
messages.append({"role": "system", "content": request.system_prompt})
51+
52+
for question in request.turns:
53+
messages.append({"role": "user", "content": question})
54+
entry_encoded = encode_chat(tokenizer, messages)
55+
56+
# Run the async runner.run directly
57+
output_tokens = await runner.run(entry_encoded, max_length, end_id, i)
58+
output_text = decode_chat(tokenizer, output_tokens["output_ids"][0])
59+
output_text = postprocess(output_text)
60+
messages.append({"role": "assistant", "content": output_text})
61+
62+
return messages
63+
64+
tasks = [process_single_request(request, i) for i, request in enumerate(dataset.data)]
65+
text_outputs = await asyncio.gather(*tasks, return_exceptions=True)
66+
67+
# Check for any exceptions and handle them
68+
for i, result in enumerate(text_outputs):
69+
if isinstance(result, Exception):
70+
print(f"Error processing request {i}: {result}")
71+
raise result
72+
73+
runner.process_metrics_final(text_outputs)
74+
return text_outputs
75+
76+
77+
def run_simple(args):
78+
tokenizer = get_tokenizer(args.tokenizer)
79+
dataset_kwargs = args.runtime_params.get("dataset_kwargs", {})
80+
if args.mtbench is not None:
81+
dataset = datasets.MTBench(args.mtbench, args.num_requests, **dataset_kwargs)
82+
elif args.random_isl is not None:
83+
dataset = datasets.RandomToken(
84+
tokenizer, args.random_isl, args.num_requests, **dataset_kwargs
85+
)
86+
engine_args = args.runtime_params.get("engine_args", {})
87+
sampling_kwargs = args.runtime_params.get("sampling_kwargs", {"temperature": 0})
88+
model_class = engines_available[args.engine]
89+
model = model_class(
90+
args.model_dir,
91+
max_concurrent_requests=args.concurrency,
92+
sampling_kwargs=sampling_kwargs,
93+
speculative_algorithm=args.speculative_algorithm,
94+
draft_model_dir=args.draft_model_dir,
95+
speculative_num_steps=args.draft_length,
96+
tensor_parallel_size=args.tp_size,
97+
moe_expert_parallel_size=args.ep_size,
98+
**engine_args,
99+
)
100+
101+
metrics_list = [metrics.Timing(), metrics.AATiming(tokenizer)]
102+
if args.mtbench is not None:
103+
metrics_list.insert(0, metrics.MTBench())
104+
else:
105+
metrics_list.insert(0, metrics.AcceptanceRate())
106+
runner = runners.SimpleRunner(model, metrics=metrics_list)
107+
108+
postprocess = postprocess_base
109+
110+
asyncio.run(
111+
run_loop(runner, dataset, tokenizer, args.output_length, postprocess, args.concurrency)
112+
)
113+
114+
runner.clear_metrics()
115+
116+
117+
if __name__ == "__main__":
118+
parser = argparse.ArgumentParser()
119+
parser.add_argument(
120+
"--tokenizer", type=str, required=True, help="Path to the tokenizer directory"
121+
)
122+
parser.add_argument(
123+
"--mtbench", type=str, required=False, default=None, help="Path to the mtbench dataset"
124+
)
125+
parser.add_argument(
126+
"--random_isl",
127+
type=int,
128+
required=False,
129+
default=None,
130+
help="How many tokens random input should be.",
131+
)
132+
parser.add_argument("--num_requests", type=int, required=True, help="Number of requests to run")
133+
parser.add_argument(
134+
"--engine",
135+
type=str,
136+
required=False,
137+
default="TRTLLM",
138+
choices=list(engines_available.keys()),
139+
help="Engine to use",
140+
)
141+
parser.add_argument(
142+
"--speculative_algorithm",
143+
type=str,
144+
required=False,
145+
default="EAGLE3",
146+
choices=["EAGLE3", "EAGLE", "DRAFT_TARGET", "NGRAM", "MTP", "NONE"],
147+
help="Speculative algorithm to use",
148+
)
149+
parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory")
150+
parser.add_argument(
151+
"--draft_model_dir",
152+
type=str,
153+
required=False,
154+
default=None,
155+
help="Path to the draft model directory",
156+
)
157+
parser.add_argument(
158+
"--runtime_params",
159+
type=str,
160+
required=False,
161+
default=None,
162+
help="Path to the runtime params yaml file",
163+
)
164+
parser.add_argument(
165+
"--output_length", type=int, required=False, default=4096, help="Output length"
166+
)
167+
parser.add_argument("--draft_length", type=int, required=False, default=3, help="Draft length")
168+
parser.add_argument(
169+
"--tp_size", type=int, required=False, default=4, help="Tensor parallel size"
170+
)
171+
parser.add_argument(
172+
"--ep_size", type=int, required=False, default=2, help="Expert parallel size"
173+
)
174+
parser.add_argument(
175+
"--concurrency",
176+
type=int,
177+
required=False,
178+
default=1,
179+
help="Maximum number of concurrent requests",
180+
)
181+
args = parser.parse_args()
182+
183+
if args.runtime_params is not None:
184+
with open(args.runtime_params) as f:
185+
args.runtime_params = yaml.safe_load(f)
186+
else:
187+
args.runtime_params = {}
188+
189+
assert args.mtbench is not None or args.random_isl is not None, (
190+
"Either mtbench or random_isl must be provided"
191+
)
192+
193+
run_simple(args)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from .base import Dataset
17+
from .base_hf import OpenMathInstructv2, OpenOrca, UltraChat
18+
from .mtbench import MTBench
19+
from .random_token import RandomToken
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from dataclasses import dataclass, field
17+
from typing import Any
18+
19+
20+
@dataclass
21+
class Request:
22+
system_prompt: str | None = None
23+
turns: list[str] = field(default_factory=list)
24+
mm_content: Any | None = None # TODO
25+
26+
# not to be set by user
27+
output_turn_ids = None
28+
output_turn_text: list[str] = field(default_factory=list)
29+
30+
31+
class Dataset:
32+
def __init__(self, path, **kwargs):
33+
self.data: list[Request] = []
34+
raise NotImplementedError
35+
36+
def _preprocess(self):
37+
raise NotImplementedError

0 commit comments

Comments
 (0)