Skip to content

Commit 7cbcc0c

Browse files
committed
Specdec Bench: Initial
Signed-off-by: Izzy Putterman <[email protected]>
1 parent ea1fcb0 commit 7cbcc0c

File tree

23 files changed

+1425
-0
lines changed

23 files changed

+1425
-0
lines changed

examples/specdec_bench/README.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Speculative Decoding (SpecDec) Bench
2+
3+
## Purpose
4+
5+
Collect relevant metrics on acceptance rate, timing, and outputs for Speculative Decoding methods.
6+
Acceptance rate refers to the number of tokens generated on every iteration. For a standard Autoregressive LLM, this number
7+
is just 1.
8+
9+
## Getting Started
10+
11+
A basic example run script is provided which benchmarks MTBench (a standard 160 prompts spanning 8 categories).
12+
MTBench is available [here](https://huggingface.co/datasets/HuggingFaceH4/mt_bench_prompts)
13+
14+
### Running MTBench on GPT OSS + Eagle3
15+
16+
Download `nvidia/gpt-oss-120b-Eagle3` to a local directory `/path/to/eagle`.
17+
18+
```bash
19+
python3 run.py --model_dir openai/gpt-oss-120b --tokenizer openai/gpt-oss-120b --draft_model_dir /path/to/eagle --mtbench question.jsonl --tp_size 1 --ep_size 1 --draft_length 3 --output_length 4096 --num_requests 80 --engine TRTLLM --concurrency 1
20+
21+
```
22+
23+
### Running Random ids on GPT OSS + Eagle3
24+
25+
Download `nvidia/gpt-oss-120b-Eagle3` to a local directory `/path/to/eagle`.
26+
27+
```bash
28+
python3 run.py --model_dir openai/gpt-oss-120b --tokenizer openai/gpt-oss-120b --draft_model_dir /path/to/eagle --random_isl 1024 --tp_size 1 --ep_size 1 --draft_length 3 --output_length 4096 --num_requests 40 --engine TRTLLM --concurrency 1
29+
30+
```

examples/specdec_bench/run.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import argparse
17+
import asyncio
18+
19+
import yaml
20+
from specdec_bench import datasets, metrics, models, runners
21+
from specdec_bench.utils import decode_chat, encode_chat, get_tokenizer, postprocess_base
22+
23+
engines_available = {
24+
"TRTLLM": models.TRTLLMPYTModel,
25+
"VLLM": models.VLLMModel,
26+
"SGLANG": models.SGLANGModel,
27+
}
28+
29+
30+
async def run_loop(runner, dataset, tokenizer, output_length, postprocess, concurrency=10):
31+
"""
32+
Async version of run_loop with concurrency control using a semaphore.
33+
34+
Args:
35+
runner: The model runner instance
36+
dataset: The dataset containing requests
37+
tokenizer: The tokenizer instance
38+
output_length: Maximum output length
39+
concurrency: Maximum number of concurrent requests (default: 10)
40+
"""
41+
semaphore = asyncio.Semaphore(concurrency)
42+
max_length = output_length
43+
end_id = tokenizer.eos_token_id
44+
45+
async def process_single_request(request, i):
46+
"""Process a single request with all its conversation turns."""
47+
async with semaphore:
48+
messages = []
49+
if request.system_prompt is not None:
50+
messages.append({"role": "system", "content": request.system_prompt})
51+
52+
for question in request.turns:
53+
messages.append({"role": "user", "content": question})
54+
entry_encoded = encode_chat(tokenizer, messages)
55+
56+
# Run the async runner.run directly
57+
output_tokens = await runner.run(entry_encoded, max_length, end_id, i)
58+
output_text = decode_chat(tokenizer, output_tokens["output_ids"][0])
59+
output_text = postprocess(output_text)
60+
messages.append({"role": "assistant", "content": output_text})
61+
62+
return messages
63+
64+
tasks = [process_single_request(request, i) for i, request in enumerate(dataset.data)]
65+
text_outputs = await asyncio.gather(*tasks, return_exceptions=True)
66+
67+
# Check for any exceptions and handle them
68+
for i, result in enumerate(text_outputs):
69+
if isinstance(result, Exception):
70+
print(f"Error processing request {i}: {result}")
71+
raise result
72+
73+
runner.process_metrics_final(text_outputs)
74+
return text_outputs
75+
76+
77+
def run_simple(args):
78+
tokenizer = get_tokenizer(args.tokenizer)
79+
dataset_kwargs = args.runtime_params.get("dataset_kwargs", {})
80+
if args.mtbench is not None:
81+
dataset = datasets.MTBench(args.mtbench, args.num_requests, **dataset_kwargs)
82+
elif args.random_isl is not None:
83+
dataset = datasets.RandomToken(
84+
tokenizer, args.random_isl, args.num_requests, **dataset_kwargs
85+
)
86+
engine_args = args.runtime_params.get("engine_args", {})
87+
sampling_kwargs = args.runtime_params.get("sampling_kwargs", {"temperature": 0})
88+
model_class = engines_available[args.engine]
89+
model = model_class(
90+
args.model_dir,
91+
max_concurrent_requests=args.concurrency,
92+
sampling_kwargs=sampling_kwargs,
93+
speculative_algorithm=args.speculative_algorithm,
94+
draft_model_dir=args.draft_model_dir,
95+
speculative_num_steps=args.draft_length,
96+
tensor_parallel_size=args.tp_size,
97+
moe_expert_parallel_size=args.ep_size,
98+
**engine_args,
99+
)
100+
101+
metrics_list = [metrics.Timing(), metrics.AATiming(tokenizer)]
102+
if args.mtbench is not None:
103+
metrics_list.insert(0, metrics.MTBench())
104+
else:
105+
metrics_list.insert(0, metrics.AcceptanceRate())
106+
runner = runners.SimpleRunner(model, metrics=metrics_list)
107+
108+
postprocess = postprocess_base
109+
110+
asyncio.run(
111+
run_loop(runner, dataset, tokenizer, args.output_length, postprocess, args.concurrency)
112+
)
113+
114+
runner.clear_metrics()
115+
116+
117+
if __name__ == "__main__":
118+
parser = argparse.ArgumentParser()
119+
parser.add_argument(
120+
"--tokenizer", type=str, required=True, help="Path to the tokenizer directory"
121+
)
122+
parser.add_argument(
123+
"--mtbench", type=str, required=False, default=None, help="Path to the mtbench dataset"
124+
)
125+
parser.add_argument(
126+
"--random_isl",
127+
type=int,
128+
required=False,
129+
default=None,
130+
help="How many tokens random input should be.",
131+
)
132+
parser.add_argument("--num_requests", type=int, required=True, help="Number of requests to run")
133+
parser.add_argument(
134+
"--engine",
135+
type=str,
136+
required=False,
137+
default="TRTLLM",
138+
choices=list(engines_available.keys()),
139+
help="Engine to use",
140+
)
141+
parser.add_argument(
142+
"--speculative_algorithm",
143+
type=str,
144+
required=False,
145+
default="EAGLE3",
146+
choices=["EAGLE3", "EAGLE", "DRAFT_TARGET", "NGRAM", "MTP", "NONE"],
147+
help="Speculative algorithm to use",
148+
)
149+
parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory")
150+
parser.add_argument(
151+
"--draft_model_dir",
152+
type=str,
153+
required=False,
154+
default=None,
155+
help="Path to the draft model directory",
156+
)
157+
parser.add_argument(
158+
"--runtime_params",
159+
type=str,
160+
required=False,
161+
default=None,
162+
help="Path to the runtime params yaml file",
163+
)
164+
parser.add_argument(
165+
"--output_length", type=int, required=False, default=4096, help="Output length"
166+
)
167+
parser.add_argument("--draft_length", type=int, required=False, default=3, help="Draft length")
168+
parser.add_argument(
169+
"--tp_size", type=int, required=False, default=4, help="Tensor parallel size"
170+
)
171+
parser.add_argument(
172+
"--ep_size", type=int, required=False, default=2, help="Expert parallel size"
173+
)
174+
parser.add_argument(
175+
"--concurrency",
176+
type=int,
177+
required=False,
178+
default=1,
179+
help="Maximum number of concurrent requests",
180+
)
181+
args = parser.parse_args()
182+
183+
if args.runtime_params is not None:
184+
with open(args.runtime_params) as f:
185+
args.runtime_params = yaml.safe_load(f)
186+
else:
187+
args.runtime_params = {}
188+
189+
assert args.mtbench is not None or args.random_isl is not None, (
190+
"Either mtbench or random_isl must be provided"
191+
)
192+
193+
run_simple(args)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from .base import Dataset
17+
from .base_hf import OpenMathInstructv2, OpenOrca, UltraChat
18+
from .mtbench import MTBench
19+
from .random_token import RandomToken
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from dataclasses import dataclass, field
17+
from typing import Any
18+
19+
20+
@dataclass
21+
class Request:
22+
system_prompt: str | None = None
23+
turns: list[str] = field(default_factory=list)
24+
mm_content: Any | None = None # TODO
25+
26+
# not to be set by user
27+
output_turn_ids = None
28+
output_turn_text: list[str] = field(default_factory=list)
29+
30+
31+
class Dataset:
32+
def __init__(self, path, **kwargs):
33+
self.data: list[Request] = []
34+
raise NotImplementedError
35+
36+
def _preprocess(self):
37+
raise NotImplementedError
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
17+
try:
18+
from datasets import load_dataset
19+
except ImportError:
20+
print("datasets is not installed.")
21+
datasets = None
22+
23+
24+
from .base import Dataset, Request
25+
26+
27+
class BaseHF(Dataset):
28+
def __init__(self, num_samples=100, **kwargs):
29+
self.data: list[Request] = [] # list of list of questions.
30+
self.num_samples = num_samples
31+
self._preprocess()
32+
33+
def _preprocess(self):
34+
dataset = self._load_dataset(self.num_samples)
35+
for i, line in enumerate(dataset):
36+
if i == self.num_samples:
37+
break
38+
self.data.append(self._single_line_process(line))
39+
40+
def _single_line_process(self, line):
41+
raise NotImplementedError
42+
43+
def _load_dataset(self, num_samples):
44+
raise NotImplementedError
45+
46+
47+
class OpenOrca(BaseHF):
48+
def _single_line_process(self, line, **kwargs):
49+
return Request(system_prompt=line["system_prompt"], turns=[line["question"]])
50+
51+
def _load_dataset(self, num_samples):
52+
return load_dataset("Open-Orca/OpenOrca", split="train", streaming=True)
53+
54+
55+
class OpenMathInstructv2(BaseHF):
56+
def _single_line_process(self, line, **kwargs):
57+
return Request(system_prompt=None, turns=[line["problem"]])
58+
59+
def _load_dataset(self, num_samples):
60+
return load_dataset("nvidia/OpenMathInstruct-2", split="train_1M", streaming=True)
61+
62+
63+
class UltraChat(BaseHF):
64+
def _single_line_process(self, line, **kwargs):
65+
return Request(
66+
system_prompt=None, turns=[q for i, q in enumerate(line["data"]) if i % 2 == 0]
67+
)
68+
69+
def _load_dataset(self, num_samples):
70+
return load_dataset("stingning/ultrachat", split="train", streaming=True)

0 commit comments

Comments
 (0)