Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ modelopt/torch/utils @NVIDIA/modelopt-torch-utils-codeowners
/examples/nemo_run @NVIDIA/modelopt-examples-megatron-codeowners
/examples/onnx_ptq @NVIDIA/modelopt-onnx-codeowners
/examples/pruning @NVIDIA/modelopt-torch-nas-prune-codeowners
/examples/specdec_bench @NVIDIA/modelopt-torch-speculative-codeowners
/examples/speculative_decoding @NVIDIA/modelopt-torch-speculative-codeowners
/examples/vlm_ptq @NVIDIA/modelopt-examples-vlm-codeowners
/examples/vllm_serve @NVIDIA/modelopt-examples-llm_ptq-codeowners
Expand Down
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ Model Optimizer Changelog (Linux)

- Fix a bug in FastNAS pruning (computer vision models) where the model parameters were sorted twice messing up the ordering.

**New Features**

- Add ``specdec_bench`` example to benchmark speculative decoding performance. See `examples/specdec_bench/README.md <https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/specdec_bench#speculative-decoding-benchmark>`_ for more details.

0.39 (2025-11-14)
^^^^^^^^^^^^^^^^^

Expand Down
49 changes: 49 additions & 0 deletions examples/specdec_bench/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Speculative Decoding (SpecDec) Bench

## Installation

This benchmark is meant to be a lightweight layer ontop of an existing vLLM/SGLang/TRTLLM installation. For example, no install
is required if one is running in the following dockers: `vllm/vllm-openai:v0.11.0` (vLLM), `lmsysorg/sglang:v0.5.4.post2` (SGLang), or
`nvcr.io/nvidia/tensorrt-llm/release:1.2.0rc1` (TRT-LLM).

Next

```bash
cd examples/specdec_bench
```

## Purpose

Collect relevant metrics on acceptance rate, timing, and outputs for Speculative Decoding methods.
Acceptance rate refers to the number of tokens generated on every iteration. For a standard Autoregressive LLM, this number
is just 1.

## Getting Started

A basic example run script is provided which benchmarks MTBench (a standard 160 prompts spanning 8 categories).
MTBench is available [here](https://huggingface.co/datasets/HuggingFaceH4/mt_bench_prompts)

### Running MTBench on GPT OSS + Eagle3

Download `nvidia/gpt-oss-120b-Eagle3` to a local directory `/path/to/eagle`.

```bash
python3 run.py --model_dir openai/gpt-oss-120b --tokenizer openai/gpt-oss-120b --draft_model_dir /path/to/eagle --mtbench question.jsonl --tp_size 1 --ep_size 1 --draft_length 3 --output_length 4096 --num_requests 80 --engine TRTLLM --concurrency 1

```

### Running Random ids on GPT OSS + Eagle3

Download `nvidia/gpt-oss-120b-Eagle3` to a local directory `/path/to/eagle`.

```bash
python3 run.py --model_dir openai/gpt-oss-120b --tokenizer openai/gpt-oss-120b --draft_model_dir /path/to/eagle --random_isl 1024 --tp_size 1 --ep_size 1 --draft_length 3 --output_length 4096 --num_requests 40 --engine TRTLLM --concurrency 1

```

## Notes

The goal of this benchmark is to provide an easy way to configure, run, and compare speculative implementations across frameworks in an apples-to-apples method.
This benchmark sends request in a single-threaded fashion, so running large concurrency (>256) may result in python async scheduling delays and skew metrics.
If larger concurrency is needed, it is recommended to fully deploy the model using `vllm serve`, `python -m sglang.launch_server`, or `trtllm-serve` (for vLLM, SGlang, or TRTLLM respectively) and
use a more robust benchmarking client like NVIDIA AI Perf.
196 changes: 196 additions & 0 deletions examples/specdec_bench/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import asyncio

import yaml
from specdec_bench import datasets, metrics, models, runners
from specdec_bench.utils import decode_chat, encode_chat, get_tokenizer, postprocess_base

engines_available = {
"TRTLLM": models.TRTLLMPYTModel,
"VLLM": models.VLLMModel,
"SGLANG": models.SGLANGModel,
}


async def run_loop(runner, dataset, tokenizer, output_length, postprocess, concurrency=10):
"""
Async version of run_loop with concurrency control using a semaphore.

Args:
runner: The model runner instance
dataset: The dataset containing requests
tokenizer: The tokenizer instance
output_length: Maximum output length
concurrency: Maximum number of concurrent requests (default: 10)
"""
semaphore = asyncio.Semaphore(concurrency)
max_length = output_length
end_id = tokenizer.eos_token_id

async def process_single_request(request, i):
"""Process a single request with all its conversation turns."""
async with semaphore:
messages = []
if request.system_prompt is not None:
messages.append({"role": "system", "content": request.system_prompt})

for question in request.turns:
messages.append({"role": "user", "content": question})
entry_encoded = encode_chat(tokenizer, messages)

# Run the async runner.run directly
output_tokens = await runner.run(entry_encoded, max_length, end_id, i)
output_text = decode_chat(tokenizer, output_tokens["output_ids"][0])
output_text = postprocess(output_text)
messages.append({"role": "assistant", "content": output_text})

return messages

tasks = [process_single_request(request, i) for i, request in enumerate(dataset.data)]
text_outputs = await asyncio.gather(*tasks, return_exceptions=True)

# Check for any exceptions and handle them
for i, result in enumerate(text_outputs):
if isinstance(result, Exception):
print(f"Error processing request {i}: {result}")
raise result

runner.process_metrics_final(text_outputs)
return text_outputs


def run_simple(args):
tokenizer = get_tokenizer(args.tokenizer)
dataset_kwargs = args.runtime_params.get("dataset_kwargs", {})
if args.mtbench is not None:
dataset = datasets.MTBench(args.mtbench, args.num_requests, **dataset_kwargs)
elif args.random_isl is not None:
dataset = datasets.RandomToken(
tokenizer, args.random_isl, args.num_requests, **dataset_kwargs
)
engine_args = args.runtime_params.get("engine_args", {})
sampling_kwargs = args.runtime_params.get("sampling_kwargs", {"temperature": 0})
model_class = engines_available[args.engine]
model = model_class(
args.model_dir,
max_concurrent_requests=args.concurrency,
sampling_kwargs=sampling_kwargs,
speculative_algorithm=args.speculative_algorithm,
draft_model_dir=args.draft_model_dir,
speculative_num_steps=args.draft_length,
tensor_parallel_size=args.tp_size,
moe_expert_parallel_size=args.ep_size,
**engine_args,
)

metrics_list = [metrics.Timing(args.tp_size)]
if args.aa_timing:
metrics_list.append(metrics.AATiming(tokenizer))
if args.mtbench is not None:
metrics_list.insert(0, metrics.MTBench())
else:
metrics_list.insert(0, metrics.AcceptanceRate())
runner = runners.SimpleRunner(model, metrics=metrics_list)

postprocess = postprocess_base

asyncio.run(
run_loop(runner, dataset, tokenizer, args.output_length, postprocess, args.concurrency)
)

runner.clear_metrics()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--tokenizer", type=str, required=True, help="Path to the tokenizer directory"
)
parser.add_argument(
"--mtbench", type=str, required=False, default=None, help="Path to the mtbench dataset"
)
parser.add_argument(
"--random_isl",
type=int,
required=False,
default=None,
help="How many tokens random input should be.",
)
parser.add_argument("--num_requests", type=int, required=True, help="Number of requests to run")
parser.add_argument(
"--engine",
type=str,
required=False,
default="TRTLLM",
choices=list(engines_available.keys()),
help="Engine to use",
)
parser.add_argument(
"--speculative_algorithm",
type=str,
required=False,
default="EAGLE3",
choices=["EAGLE3", "EAGLE", "DRAFT_TARGET", "NGRAM", "MTP", "NONE"],
help="Speculative algorithm to use",
)
parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory")
parser.add_argument(
"--draft_model_dir",
type=str,
required=False,
default=None,
help="Path to the draft model directory",
)
parser.add_argument(
"--runtime_params",
type=str,
required=False,
default=None,
help="Path to the runtime params yaml file",
)
parser.add_argument(
"--output_length", type=int, required=False, default=4096, help="Output length"
)
parser.add_argument("--draft_length", type=int, required=False, default=3, help="Draft length")
parser.add_argument(
"--tp_size", type=int, required=False, default=4, help="Tensor parallel size"
)
parser.add_argument(
"--ep_size", type=int, required=False, default=2, help="Expert parallel size"
)
parser.add_argument(
"--concurrency",
type=int,
required=False,
default=1,
help="Maximum number of concurrent requests",
)
parser.add_argument("--aa_timing", action="store_true", help="Enable AA timing metric")
args = parser.parse_args()

if args.runtime_params is not None:
with open(args.runtime_params) as f:
args.runtime_params = yaml.safe_load(f)
else:
args.runtime_params = {}

assert args.mtbench is not None or args.random_isl is not None, (
"Either mtbench or random_isl must be provided"
)

run_simple(args)
14 changes: 14 additions & 0 deletions examples/specdec_bench/specdec_bench/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
19 changes: 19 additions & 0 deletions examples/specdec_bench/specdec_bench/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .base import Dataset
from .base_hf import OpenMathInstructv2, OpenOrca, UltraChat
from .mtbench import MTBench
from .random_token import RandomToken
37 changes: 37 additions & 0 deletions examples/specdec_bench/specdec_bench/datasets/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass, field
from typing import Any


@dataclass
class Request:
system_prompt: str | None = None
turns: list[str] = field(default_factory=list)
mm_content: Any | None = None # TODO

# not to be set by user
output_turn_ids = None
output_turn_text: list[str] = field(default_factory=list)


class Dataset:
def __init__(self, path, **kwargs):
self.data: list[Request] = []
raise NotImplementedError

def _preprocess(self):
raise NotImplementedError
Loading