diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index c74be8498..baa50d145 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -49,6 +49,7 @@ modelopt/torch/utils @NVIDIA/modelopt-torch-utils-codeowners /examples/nemo_run @NVIDIA/modelopt-examples-megatron-codeowners /examples/onnx_ptq @NVIDIA/modelopt-onnx-codeowners /examples/pruning @NVIDIA/modelopt-torch-nas-prune-codeowners +/examples/specdec_bench @NVIDIA/modelopt-torch-speculative-codeowners /examples/speculative_decoding @NVIDIA/modelopt-torch-speculative-codeowners /examples/vlm_ptq @NVIDIA/modelopt-examples-vlm-codeowners /examples/vllm_serve @NVIDIA/modelopt-examples-llm_ptq-codeowners diff --git a/CHANGELOG.rst b/CHANGELOG.rst index ebb94d731..d68b4f04b 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -8,6 +8,10 @@ Model Optimizer Changelog (Linux) - Fix a bug in FastNAS pruning (computer vision models) where the model parameters were sorted twice messing up the ordering. +**New Features** + +- Add ``specdec_bench`` example to benchmark speculative decoding performance. See `examples/specdec_bench/README.md `_ for more details. + 0.39 (2025-11-14) ^^^^^^^^^^^^^^^^^ diff --git a/examples/specdec_bench/README.md b/examples/specdec_bench/README.md new file mode 100644 index 000000000..b0e955b4a --- /dev/null +++ b/examples/specdec_bench/README.md @@ -0,0 +1,49 @@ +# Speculative Decoding (SpecDec) Bench + +## Installation + +This benchmark is meant to be a lightweight layer ontop of an existing vLLM/SGLang/TRTLLM installation. For example, no install +is required if one is running in the following dockers: `vllm/vllm-openai:v0.11.0` (vLLM), `lmsysorg/sglang:v0.5.4.post2` (SGLang), or +`nvcr.io/nvidia/tensorrt-llm/release:1.2.0rc1` (TRT-LLM). + +Next + +```bash +cd examples/specdec_bench +``` + +## Purpose + +Collect relevant metrics on acceptance rate, timing, and outputs for Speculative Decoding methods. +Acceptance rate refers to the number of tokens generated on every iteration. For a standard Autoregressive LLM, this number +is just 1. + +## Getting Started + +A basic example run script is provided which benchmarks MTBench (a standard 160 prompts spanning 8 categories). +MTBench is available [here](https://huggingface.co/datasets/HuggingFaceH4/mt_bench_prompts) + +### Running MTBench on GPT OSS + Eagle3 + +Download `nvidia/gpt-oss-120b-Eagle3` to a local directory `/path/to/eagle`. + +```bash +python3 run.py --model_dir openai/gpt-oss-120b --tokenizer openai/gpt-oss-120b --draft_model_dir /path/to/eagle --mtbench question.jsonl --tp_size 1 --ep_size 1 --draft_length 3 --output_length 4096 --num_requests 80 --engine TRTLLM --concurrency 1 + +``` + +### Running Random ids on GPT OSS + Eagle3 + +Download `nvidia/gpt-oss-120b-Eagle3` to a local directory `/path/to/eagle`. + +```bash +python3 run.py --model_dir openai/gpt-oss-120b --tokenizer openai/gpt-oss-120b --draft_model_dir /path/to/eagle --random_isl 1024 --tp_size 1 --ep_size 1 --draft_length 3 --output_length 4096 --num_requests 40 --engine TRTLLM --concurrency 1 + +``` + +## Notes + +The goal of this benchmark is to provide an easy way to configure, run, and compare speculative implementations across frameworks in an apples-to-apples method. +This benchmark sends request in a single-threaded fashion, so running large concurrency (>256) may result in python async scheduling delays and skew metrics. +If larger concurrency is needed, it is recommended to fully deploy the model using `vllm serve`, `python -m sglang.launch_server`, or `trtllm-serve` (for vLLM, SGlang, or TRTLLM respectively) and +use a more robust benchmarking client like NVIDIA AI Perf. diff --git a/examples/specdec_bench/run.py b/examples/specdec_bench/run.py new file mode 100644 index 000000000..f6c4e33e0 --- /dev/null +++ b/examples/specdec_bench/run.py @@ -0,0 +1,196 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import asyncio + +import yaml +from specdec_bench import datasets, metrics, models, runners +from specdec_bench.utils import decode_chat, encode_chat, get_tokenizer, postprocess_base + +engines_available = { + "TRTLLM": models.TRTLLMPYTModel, + "VLLM": models.VLLMModel, + "SGLANG": models.SGLANGModel, +} + + +async def run_loop(runner, dataset, tokenizer, output_length, postprocess, concurrency=10): + """ + Async version of run_loop with concurrency control using a semaphore. + + Args: + runner: The model runner instance + dataset: The dataset containing requests + tokenizer: The tokenizer instance + output_length: Maximum output length + concurrency: Maximum number of concurrent requests (default: 10) + """ + semaphore = asyncio.Semaphore(concurrency) + max_length = output_length + end_id = tokenizer.eos_token_id + + async def process_single_request(request, i): + """Process a single request with all its conversation turns.""" + async with semaphore: + messages = [] + if request.system_prompt is not None: + messages.append({"role": "system", "content": request.system_prompt}) + + for question in request.turns: + messages.append({"role": "user", "content": question}) + entry_encoded = encode_chat(tokenizer, messages) + + # Run the async runner.run directly + output_tokens = await runner.run(entry_encoded, max_length, end_id, i) + output_text = decode_chat(tokenizer, output_tokens["output_ids"][0]) + output_text = postprocess(output_text) + messages.append({"role": "assistant", "content": output_text}) + + return messages + + tasks = [process_single_request(request, i) for i, request in enumerate(dataset.data)] + text_outputs = await asyncio.gather(*tasks, return_exceptions=True) + + # Check for any exceptions and handle them + for i, result in enumerate(text_outputs): + if isinstance(result, Exception): + print(f"Error processing request {i}: {result}") + raise result + + runner.process_metrics_final(text_outputs) + return text_outputs + + +def run_simple(args): + tokenizer = get_tokenizer(args.tokenizer) + dataset_kwargs = args.runtime_params.get("dataset_kwargs", {}) + if args.mtbench is not None: + dataset = datasets.MTBench(args.mtbench, args.num_requests, **dataset_kwargs) + elif args.random_isl is not None: + dataset = datasets.RandomToken( + tokenizer, args.random_isl, args.num_requests, **dataset_kwargs + ) + engine_args = args.runtime_params.get("engine_args", {}) + sampling_kwargs = args.runtime_params.get("sampling_kwargs", {"temperature": 0}) + model_class = engines_available[args.engine] + model = model_class( + args.model_dir, + max_concurrent_requests=args.concurrency, + sampling_kwargs=sampling_kwargs, + speculative_algorithm=args.speculative_algorithm, + draft_model_dir=args.draft_model_dir, + speculative_num_steps=args.draft_length, + tensor_parallel_size=args.tp_size, + moe_expert_parallel_size=args.ep_size, + **engine_args, + ) + + metrics_list = [metrics.Timing(args.tp_size)] + if args.aa_timing: + metrics_list.append(metrics.AATiming(tokenizer)) + if args.mtbench is not None: + metrics_list.insert(0, metrics.MTBench()) + else: + metrics_list.insert(0, metrics.AcceptanceRate()) + runner = runners.SimpleRunner(model, metrics=metrics_list) + + postprocess = postprocess_base + + asyncio.run( + run_loop(runner, dataset, tokenizer, args.output_length, postprocess, args.concurrency) + ) + + runner.clear_metrics() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--tokenizer", type=str, required=True, help="Path to the tokenizer directory" + ) + parser.add_argument( + "--mtbench", type=str, required=False, default=None, help="Path to the mtbench dataset" + ) + parser.add_argument( + "--random_isl", + type=int, + required=False, + default=None, + help="How many tokens random input should be.", + ) + parser.add_argument("--num_requests", type=int, required=True, help="Number of requests to run") + parser.add_argument( + "--engine", + type=str, + required=False, + default="TRTLLM", + choices=list(engines_available.keys()), + help="Engine to use", + ) + parser.add_argument( + "--speculative_algorithm", + type=str, + required=False, + default="EAGLE3", + choices=["EAGLE3", "EAGLE", "DRAFT_TARGET", "NGRAM", "MTP", "NONE"], + help="Speculative algorithm to use", + ) + parser.add_argument("--model_dir", type=str, required=True, help="Path to the model directory") + parser.add_argument( + "--draft_model_dir", + type=str, + required=False, + default=None, + help="Path to the draft model directory", + ) + parser.add_argument( + "--runtime_params", + type=str, + required=False, + default=None, + help="Path to the runtime params yaml file", + ) + parser.add_argument( + "--output_length", type=int, required=False, default=4096, help="Output length" + ) + parser.add_argument("--draft_length", type=int, required=False, default=3, help="Draft length") + parser.add_argument( + "--tp_size", type=int, required=False, default=4, help="Tensor parallel size" + ) + parser.add_argument( + "--ep_size", type=int, required=False, default=2, help="Expert parallel size" + ) + parser.add_argument( + "--concurrency", + type=int, + required=False, + default=1, + help="Maximum number of concurrent requests", + ) + parser.add_argument("--aa_timing", action="store_true", help="Enable AA timing metric") + args = parser.parse_args() + + if args.runtime_params is not None: + with open(args.runtime_params) as f: + args.runtime_params = yaml.safe_load(f) + else: + args.runtime_params = {} + + assert args.mtbench is not None or args.random_isl is not None, ( + "Either mtbench or random_isl must be provided" + ) + + run_simple(args) diff --git a/examples/specdec_bench/specdec_bench/__init__.py b/examples/specdec_bench/specdec_bench/__init__.py new file mode 100644 index 000000000..3159bfe65 --- /dev/null +++ b/examples/specdec_bench/specdec_bench/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/specdec_bench/specdec_bench/datasets/__init__.py b/examples/specdec_bench/specdec_bench/datasets/__init__.py new file mode 100644 index 000000000..64449d2b5 --- /dev/null +++ b/examples/specdec_bench/specdec_bench/datasets/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import Dataset +from .base_hf import OpenMathInstructv2, OpenOrca, UltraChat +from .mtbench import MTBench +from .random_token import RandomToken diff --git a/examples/specdec_bench/specdec_bench/datasets/base.py b/examples/specdec_bench/specdec_bench/datasets/base.py new file mode 100644 index 000000000..587c04b07 --- /dev/null +++ b/examples/specdec_bench/specdec_bench/datasets/base.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Any + + +@dataclass +class Request: + system_prompt: str | None = None + turns: list[str] = field(default_factory=list) + mm_content: Any | None = None # TODO + + # not to be set by user + output_turn_ids = None + output_turn_text: list[str] = field(default_factory=list) + + +class Dataset: + def __init__(self, path, **kwargs): + self.data: list[Request] = [] + raise NotImplementedError + + def _preprocess(self): + raise NotImplementedError diff --git a/examples/specdec_bench/specdec_bench/datasets/base_hf.py b/examples/specdec_bench/specdec_bench/datasets/base_hf.py new file mode 100644 index 000000000..6c7be3d8c --- /dev/null +++ b/examples/specdec_bench/specdec_bench/datasets/base_hf.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +try: + from datasets import load_dataset +except ImportError: + print("datasets is not installed.") + datasets = None + + +from .base import Dataset, Request + + +class BaseHF(Dataset): + def __init__(self, num_samples=100, **kwargs): + self.data: list[Request] = [] # list of list of questions. + self.num_samples = num_samples + self._preprocess() + + def _preprocess(self): + dataset = self._load_dataset(self.num_samples) + for i, line in enumerate(dataset): + if i == self.num_samples: + break + self.data.append(self._single_line_process(line)) + + def _single_line_process(self, line): + raise NotImplementedError + + def _load_dataset(self, num_samples): + raise NotImplementedError + + +class OpenOrca(BaseHF): + def _single_line_process(self, line, **kwargs): + return Request(system_prompt=line["system_prompt"], turns=[line["question"]]) + + def _load_dataset(self, num_samples): + return load_dataset("Open-Orca/OpenOrca", split="train", streaming=True) + + +class OpenMathInstructv2(BaseHF): + def _single_line_process(self, line, **kwargs): + return Request(system_prompt=None, turns=[line["problem"]]) + + def _load_dataset(self, num_samples): + return load_dataset("nvidia/OpenMathInstruct-2", split="train_1M", streaming=True) + + +class UltraChat(BaseHF): + def _single_line_process(self, line, **kwargs): + return Request( + system_prompt=None, turns=[q for i, q in enumerate(line["data"]) if i % 2 == 0] + ) + + def _load_dataset(self, num_samples): + return load_dataset("stingning/ultrachat", split="train", streaming=True) diff --git a/examples/specdec_bench/specdec_bench/datasets/mtbench.py b/examples/specdec_bench/specdec_bench/datasets/mtbench.py new file mode 100644 index 000000000..f96fbef22 --- /dev/null +++ b/examples/specdec_bench/specdec_bench/datasets/mtbench.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from .base import Dataset, Request + +MTBENCH_TOPICS = [ + "writing", + "roleplay", + "reasoning", + "math", + "coding", + "extraction", + "stem", + "humanities", +] + + +class MTBench(Dataset): + def __init__(self, path, num_samples=80, **kwargs): + self.data: list[Request] = [] # list of list of questions. + self.num_samples = num_samples + self.path = path + self._preprocess() + + def _preprocess(self): + with open(self.path) as f: + for json_line in f: + line = json.loads(json_line) + self.data.append(Request(system_prompt=None, turns=line["turns"])) + self.data = self.data[: self.num_samples] diff --git a/examples/specdec_bench/specdec_bench/datasets/random_token.py b/examples/specdec_bench/specdec_bench/datasets/random_token.py new file mode 100644 index 000000000..972a0455c --- /dev/null +++ b/examples/specdec_bench/specdec_bench/datasets/random_token.py @@ -0,0 +1,47 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np + +from .base import Dataset, Request + + +class RandomToken(Dataset): + def __init__(self, tokenizer, input_len, num_samples=20, **kwargs): + self.data: list[Request] = [] # list of list of questions. + self.num_samples = num_samples + self.input_len = input_len + self.tokenizer = tokenizer + self._preprocess() + + def _preprocess(self): + np.random.seed(0) + tokenizer = self.tokenizer + num_prompts = self.num_samples + offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) + for i in range(num_prompts): + prompt = tokenizer.decode( + [ + (offsets[i] + i + j) % tokenizer.vocab_size + for j in range(int(self.input_len * 1.5)) + ] + ) + re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[ + : (self.input_len) + ] + prompt = tokenizer.decode(re_encoded_sequence) + self.data.append(Request(system_prompt=None, turns=[prompt])) + self.data = self.data[: self.num_samples] diff --git a/examples/specdec_bench/specdec_bench/metrics/__init__.py b/examples/specdec_bench/specdec_bench/metrics/__init__.py new file mode 100644 index 000000000..b61616830 --- /dev/null +++ b/examples/specdec_bench/specdec_bench/metrics/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .aa_timing import AATiming +from .acceptance_rate import AcceptanceRate +from .base import Metric +from .mtbench import MTBench +from .timing import Timing diff --git a/examples/specdec_bench/specdec_bench/metrics/aa_timing.py b/examples/specdec_bench/specdec_bench/metrics/aa_timing.py new file mode 100644 index 000000000..a4084a9c3 --- /dev/null +++ b/examples/specdec_bench/specdec_bench/metrics/aa_timing.py @@ -0,0 +1,61 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +try: + import tiktoken +except ImportError: + tiktoken = None +from .base import Metric +from .timing import compute_statistics + + +class AATiming(Metric): + def __init__(self, base_tokenizer): + super().__init__() + self.timing = [] + self.name = "aa_timing" + if tiktoken is None: + raise ImportError( + "Please install tiktoken to use the AATiming metric, or remove the metric from the run command" + ) + self.enc = tiktoken.get_encoding("cl100k_base") + self.base_tokenizer = base_tokenizer + self.total_tokens = [] + + def process_step(self, step_outputs, new_turn=True): + self.timing.append(step_outputs["token_times"]) + target_tokens = [ + t for tok_list in step_outputs["output_ids"] for tok in tok_list for t in tok + ] + target_text = self.base_tokenizer.decode(target_tokens) + target_tokens = self.enc.encode(target_text, disallowed_special=()) + self.total_tokens.append(len(target_tokens)) + + def process_final(self, text_outputs): + gen_tp_time = [] + start_time = min([t[0] for t in self.timing]) + end_time = max([t[-1] for t in self.timing]) + self.out["AA Output TPS"] = sum(self.total_tokens) / (end_time - start_time) + for tokens, times in zip(self.total_tokens, self.timing): + if len(times) > 2: + gen_tp_time.append((tokens - 1) / (times[-1] - times[1])) + if gen_tp_time: + self.out["AA Generation Tokens Per Second"] = compute_statistics(gen_tp_time) + for k, v in self.out.items(): + print(k, v) + self.write() + + def clear(self): + self.timing = [] diff --git a/examples/specdec_bench/specdec_bench/metrics/acceptance_rate.py b/examples/specdec_bench/specdec_bench/metrics/acceptance_rate.py new file mode 100644 index 000000000..22f10091a --- /dev/null +++ b/examples/specdec_bench/specdec_bench/metrics/acceptance_rate.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +from .base import Metric + + +class AcceptanceRate(Metric): + def __init__(self): + super().__init__() + self.prompt_ar = [] + self.name = "acceptance_rate" + + def process_step(self, step_outputs, new_turn=True): + if new_turn: + self.prompt_ar.append([]) + for i, beam_output in enumerate(step_outputs["output_ids"]): + for output_id_iter in beam_output: + self.prompt_ar[-1].append(len(output_id_iter)) + + def _get_lengths(self, turn, lengths): + for j in turn: + if j not in lengths: + lengths[j] = 0 + lengths[j] += 1 + + def _process_lengths(self, lengths): + lengths = dict(sorted(lengths.items(), key=lambda x: x[0])) + self.out["Acceptance_Length_Histogram"] = lengths + print("Acceptance Length Histogram") + print(lengths) + sum_lengths = sum(lengths.values()) + running_len = sum_lengths + prev_ratio = 1 + self.out["Conditional_Acceptance_Rate"] = {} + print("Conditional acceptance rate") + for k, v in lengths.items(): + print(k, running_len / sum_lengths / prev_ratio) + self.out["Conditional_Acceptance_Rate"][k] = running_len / sum_lengths / prev_ratio + prev_ratio = running_len / sum_lengths + running_len -= v + + def process_final(self, text_outputs): + i = 0 + lengths = {} + self.out["Request_AR"] = {} + while i < len(self.prompt_ar): + turn_1 = self.prompt_ar[i] + self.out["Request_AR"][i] = sum(turn_1) / len(turn_1) + self._get_lengths(turn_1, lengths) + print(i, self.out["Request_AR"][i]) + i += 1 + average_ar = sum(self.out["Request_AR"].values()) / len(self.out["Request_AR"]) + print("Average AR:", average_ar) + self.out["Average_AR"] = average_ar + self._process_lengths(lengths) + self.write() + self._format_write_output(text_outputs) + + def clear(self): + self.prompt_ar = [] + + def _format_write_output(self, outputs): + with open(os.path.join(self.directory, "responses.jsonl"), "w") as outfile: + for i, messages in enumerate(outputs): + q_id = i + out_line = {} + out_line["question_id"] = q_id + if messages[0]["role"] == "system": + out_line["system_prompt"] = messages[0]["content"] + q_turns = [c["content"] for c in messages if c["role"] == "user"] + a_turns = [c["content"] for c in messages if c["role"] == "assistant"] + out_line["turns"] = q_turns + out_line["choices"] = [{"index": 0, "turns": a_turns}] + json.dump(out_line, outfile) + outfile.write("\n") diff --git a/examples/specdec_bench/specdec_bench/metrics/base.py b/examples/specdec_bench/specdec_bench/metrics/base.py new file mode 100644 index 000000000..3092aa8d3 --- /dev/null +++ b/examples/specdec_bench/specdec_bench/metrics/base.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + + +class Metric: + directory = "./" + + def __init__(self): + self.out = {} + self.name = "metric" + + def process_step(self, step_outputs, new_turn=True): + raise NotImplementedError + + def process_final(self, text_outputs): + raise NotImplementedError + + def clear(self): + raise NotImplementedError + + def write(self): + os.makedirs(self.directory, exist_ok=True) + if self.out: + filename = os.path.join(self.directory, f"{self.name}.json") + if os.path.exists(filename): + with open(filename) as json_file: + existing_data = json.load(json_file) + existing_data.append(self.out) + else: + existing_data = [self.out] + + with open(filename, "w") as json_file: + json.dump(existing_data, json_file, indent=4) + + @classmethod + def update_directory(cls, new_dir): + cls.directory = new_dir diff --git a/examples/specdec_bench/specdec_bench/metrics/mtbench.py b/examples/specdec_bench/specdec_bench/metrics/mtbench.py new file mode 100644 index 000000000..2b6d8727b --- /dev/null +++ b/examples/specdec_bench/specdec_bench/metrics/mtbench.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +from .acceptance_rate import AcceptanceRate + +MTBENCH_TOPICS = [ + "writing", + "roleplay", + "reasoning", + "math", + "coding", + "extraction", + "stem", + "humanities", +] + + +class MTBench(AcceptanceRate): + def process_final(self, text_outputs): + i = 0 + lengths = {} + self.out["Request_AR"] = {} + while i < len(self.prompt_ar): + turn_1 = self.prompt_ar[i] + turn_2 = self.prompt_ar[i + 1] + q_id = i // 2 + mtbench_topic = MTBENCH_TOPICS[q_id // 10] + self.out["Request_AR"][q_id] = sum(turn_1 + turn_2) / len(turn_1 + turn_2) + self._get_lengths(turn_1, lengths) + self._get_lengths(turn_2, lengths) + print(mtbench_topic, sum(turn_1 + turn_2) / len(turn_1 + turn_2)) + i += 2 + per_category = [[] for _ in range(len(MTBENCH_TOPICS))] + for q_id, ar in self.out["Request_AR"].items(): + per_category[q_id // 10].append(ar) + self.out["Category_AR"] = {} + for i, category in enumerate(per_category): + if len(category) > 0: + category_ar = sum(category) / len(category) + self.out["Category_AR"][MTBENCH_TOPICS[i]] = category_ar + print(f"{MTBENCH_TOPICS[i]} Average AR: {category_ar}") + average_ar = sum(self.out["Request_AR"].values()) / len(self.out["Request_AR"]) + print("Average AR:", average_ar) + self.out["Average_AR"] = average_ar + self._process_lengths(lengths) + self.write() + self._format_write_output(text_outputs) + + def _format_write_output(self, outputs): + with open(os.path.join(self.directory, "mtbench_responses.jsonl"), "w") as outfile: + for i, messages in enumerate(outputs): + q_id = i + 81 + out_line = {} + out_line["question_id"] = q_id + out_line["category"] = MTBENCH_TOPICS[(q_id - 81) // 10] + q_turns = [c["content"] for c in messages if c["role"] == "user"] + a_turns = [c["content"] for c in messages if c["role"] == "assistant"] + out_line["turns"] = q_turns + out_line["choices"] = [{"index": 0, "turns": a_turns}] + json.dump(out_line, outfile) + outfile.write("\n") diff --git a/examples/specdec_bench/specdec_bench/metrics/timing.py b/examples/specdec_bench/specdec_bench/metrics/timing.py new file mode 100644 index 000000000..270ea697c --- /dev/null +++ b/examples/specdec_bench/specdec_bench/metrics/timing.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +from .base import Metric + + +class Timing(Metric): + def __init__(self, tp_size): + super().__init__() + self.timing = [] + self.name = "timing" + self.total_tokens = [] + self.tp_size = tp_size + + def process_step(self, step_outputs, new_turn=True): + self.timing.append(step_outputs["token_times"]) + self.total_tokens.append( + sum([sum([len(j) for j in i]) for i in step_outputs["output_ids"]]) + ) + + def process_final(self, text_outputs): + e2e_time = [] + ttft_time = [] + tpot_time = [] + gen_tp_time = [] + start_time = min([t[0] for t in self.timing]) + end_time = max([t[-1] for t in self.timing]) + self.out["Output TPS"] = sum(self.total_tokens) / (end_time - start_time) + self.out["Output TPS/gpu"] = self.out["Output TPS"] / self.tp_size + for tokens, times in zip(self.total_tokens, self.timing): + e2e_time.append(times[-1] - times[0]) + ttft_time.append(times[1] - times[0]) + if len(times) > 2: + gen_tp_time.append((tokens - 1) / (times[-1] - times[1])) + tpot_time.extend([a - b for a, b in zip(times[1:], times[:-1])]) + self.out["E2E Request Time"] = compute_statistics(e2e_time) + self.out["TTFT Time"] = compute_statistics(ttft_time) + if tpot_time: + self.out["Request Generation Step Time"] = compute_statistics(tpot_time) + self.out["Request Generation Tokens Per Second"] = compute_statistics(gen_tp_time) + for k, v in self.out.items(): + print(k, v) + self.write() + + def clear(self): + self.timing = [] + + +def compute_statistics(data, quantiles=[0.25, 0.5, 0.75]): + # Convert the data to a numpy array for easier calculations + data = np.array(data) + + # Compute the statistics + min_val = np.min(data) + max_val = np.max(data) + mean_val = np.mean(data) + std_val = np.std(data) + + # Compute quantiles (default: 25th, 50th, and 75th percentiles) + quantile_vals = np.percentile(data, [q * 100 for q in quantiles]) + + # Return the results as a dictionary + stats = { + "min": f"{min_val:.4f}", + "max": f"{max_val:.4f}", + "mean": f"{mean_val:.4f}", + "std": f"{std_val:.4f}", + "quantiles": {f"{q}": f"{v:.4f}" for q, v in zip(quantiles, quantile_vals)}, + } + + return stats diff --git a/examples/specdec_bench/specdec_bench/models/__init__.py b/examples/specdec_bench/specdec_bench/models/__init__.py new file mode 100644 index 000000000..8897d6de3 --- /dev/null +++ b/examples/specdec_bench/specdec_bench/models/__init__.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import Model +from .sglang import SGLANGModel +from .trtllm import TRTLLMModel +from .trtllm_torch_api import TRTLLMPYTModel +from .vllm import VLLMModel diff --git a/examples/specdec_bench/specdec_bench/models/base.py b/examples/specdec_bench/specdec_bench/models/base.py new file mode 100644 index 000000000..5f3a9616a --- /dev/null +++ b/examples/specdec_bench/specdec_bench/models/base.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class Model: + def __init__(self, model_dir, tokenizer, max_draft_length): + raise NotImplementedError + + async def run(self, prompt_ids, max_length, end_id, request_id): + """ + prompt_ids is list of tokens + output is list of list of tokens + len(output) = beam width + len(output[i]) = tokens produced per step? + """ + raise NotImplementedError + + def stop(self): + pass diff --git a/examples/specdec_bench/specdec_bench/models/sglang.py b/examples/specdec_bench/specdec_bench/models/sglang.py new file mode 100644 index 000000000..534303569 --- /dev/null +++ b/examples/specdec_bench/specdec_bench/models/sglang.py @@ -0,0 +1,105 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import time + +from .base import Model + +try: + import sglang as sgl +except ImportError: + print("sglang is not installed.") + sglang = None + + +class SGLANGModel(Model): + def __init__( + self, model_dir, max_concurrent_requests, sampling_kwargs, use_draft_logits=False, **kwargs + ): + speculative_algorithm = kwargs.get("speculative_algorithm") + if speculative_algorithm == "MTP": + speculative_algorithm = "EAGLE" + elif speculative_algorithm == "DRAFT_TARGET": + speculative_algorithm = "STANDALONE" + elif speculative_algorithm == "NGRAM": + speculative_algorithm = "LOOKAHEAD" + elif speculative_algorithm == "NONE": + speculative_algorithm = None + if speculative_algorithm is not None: + # https://github.com/sgl-project/sglang/pull/3582 + self.model = sgl.Engine( + model_path=model_dir, + skip_tokenizer_init=True, + mem_fraction_static=0.7, + disable_overlap_schedule=kwargs.get("disable_overlap_schedule", True), + tp_size=kwargs.get("tensor_parallel_size", 1), + speculative_algorithm=speculative_algorithm, + speculative_num_steps=kwargs.get("speculative_num_steps", 3), + speculative_eagle_topk=kwargs.get("speculative_eagle_topk", 1), + speculative_num_draft_tokens=kwargs.get("speculative_num_draft_tokens", 4), + torch_compile_max_bs=max_concurrent_requests, + attention_backend=kwargs.get("attention_backend"), + enable_torch_compile=kwargs.get("enable_torch_compile", False), + cuda_graph_max_bs=max_concurrent_requests, + ) + else: + self.model = sgl.Engine( + model_path=model_dir, + skip_tokenizer_init=True, + mem_fraction_static=0.7, + disable_overlap_schedule=kwargs.get("disable_overlap_schedule", True), + tp_size=kwargs.get("tensor_parallel_size", 1), + torch_compile_max_bs=max_concurrent_requests, + attention_backend=kwargs.get("attention_backend"), + enable_torch_compile=kwargs.get("enable_torch_compile", False), + cuda_graph_max_bs=max_concurrent_requests, + ) + + self.sampling_config = sampling_kwargs + + async def run(self, prompt_ids, max_length, end_id, request_id): + timing = [] + output_dict = {} + self.sampling_config["max_new_tokens"] = max_length + self.sampling_config["stop_token_ids"] = [end_id] + timing.append(time.perf_counter()) + assert self.sampling_config.get("beam_width", 1) == 1 + beam_lens = [[] for _ in range(self.sampling_config.get("beam_width", 1))] + outputs = [] + result = await self.model.async_generate( + sampling_params=self.sampling_config, input_ids=prompt_ids, stream=True + ) + async for chunk in result: + timing.append(time.perf_counter()) + outputs = chunk["output_ids"] + beam_lens[0].append(chunk["meta_info"]["completion_tokens"]) + + if end_id == outputs[-1]: + beam_lens[0].pop(-1) + outputs.pop(-1) + reformatted_output_ids = [[] for _ in range(self.sampling_config.get("beam_width", 1))] + for beam_idx, beam_len in enumerate(beam_lens): + response = outputs + if beam_len[0] != 0: + reformatted_output_ids[beam_idx].append(response[: beam_len[0]]) + for s, e in itertools.pairwise(beam_len): + reformatted_output_ids[beam_idx].append(response[s:e]) + if len(response) > beam_len[-1]: + reformatted_output_ids[beam_idx].append(response[beam_len[-1] :]) + output_dict["output_ids"] = reformatted_output_ids + output_dict["output_logits"] = None + output_dict["token_times"] = timing + return output_dict diff --git a/examples/specdec_bench/specdec_bench/models/trtllm_torch_api.py b/examples/specdec_bench/specdec_bench/models/trtllm_torch_api.py new file mode 100644 index 000000000..4d4e6c92c --- /dev/null +++ b/examples/specdec_bench/specdec_bench/models/trtllm_torch_api.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import time + +try: + import tensorrt_llm.bindings.executor as trtllm + from tensorrt_llm import LLM, SamplingParams + from tensorrt_llm.llmapi import ( + CudaGraphConfig, + DraftTargetDecodingConfig, + EagleDecodingConfig, + KvCacheConfig, + MoeConfig, + MTPDecodingConfig, + NGramDecodingConfig, + ) +except ImportError: + print("Failed to import tensorrt_llm._torch") + trtllm = None + + +from .base import Model + + +class TRTLLMPYTModel(Model): + def __init__( + self, model_path, max_concurrent_requests, sampling_kwargs, use_draft_logits=False, **kwargs + ): + self.model = create_executor(model_path, max_concurrent_requests, kwargs) + self.sampling_kwargs = sampling_kwargs + + async def run(self, prompt_ids, max_length, end_id, request_id): + output_dict = {} + sampling_config = check_sampling_config(self.sampling_kwargs, max_length, end_id) + outputs = [] + timing = [time.perf_counter()] + beam_lens = [[] for _ in range(self.sampling_kwargs.get("beam_width", 1))] + async for output in self.model.generate_async( + prompt_ids, + streaming=not sampling_config.use_beam_search, + sampling_params=sampling_config, + ): + for beam in output.outputs: + beam_lens[beam.index].append(len(beam.token_ids)) + outputs.append(output.outputs) + timing.append(time.perf_counter()) + reformatted_output_ids = [[] for _ in range(self.sampling_kwargs.get("beam_width", 1))] + for beam_idx, beam_len in enumerate(beam_lens): + response = outputs[-1][beam_idx] + if beam_len[0] != 0: + reformatted_output_ids[beam_idx].append(response.token_ids[: beam_len[0]]) + for s, e in itertools.pairwise(beam_len): + reformatted_output_ids[beam_idx].append(response.token_ids[s:e]) + if len(response.token_ids) > beam_len[-1]: + reformatted_output_ids[beam_idx].append(response.token_ids[beam_len[-1] :]) + output_dict["output_ids"] = reformatted_output_ids + output_dict["output_logits"] = None + output_dict["token_times"] = timing + return output_dict + + +def create_executor(model_path: str, max_concurrent_requests, kwargs): + disable_overlap_schedule = kwargs.get("disable_overlap_schedule", False) + if kwargs.get("speculative_algorithm", None) == "DRAFT_TARGET": + specdec = DraftTargetDecodingConfig( + max_draft_len=kwargs.get("speculative_num_steps", 3), + speculative_model_dir=kwargs.get("draft_model_dir", None), + ) + disable_overlap_schedule = True + + elif kwargs.get("speculative_algorithm", None) == "EAGLE3": + specdec = EagleDecodingConfig( + max_draft_len=kwargs.get("speculative_num_steps", 3), + speculative_model_dir=kwargs.get("draft_model_dir", None), + eagle3_one_model=kwargs.get("use_one_model", True), + eagle3_layers_to_capture=kwargs.get("eagle3_layers_to_capture", None), + ) + disable_overlap_schedule = not kwargs.get("use_one_model", True) + + elif kwargs.get("speculative_algorithm", None) == "MTP": + specdec = MTPDecodingConfig( + num_nextn_predict_layers=kwargs.get("speculative_num_steps", 3), + use_relaxed_acceptance_for_thinking=kwargs.get("relaxed_acceptance", False), + relaxed_topk=kwargs.get("relaxed_topk", 10), + relaxed_delta=kwargs.get("relaxed_delta", 0.6), + ) + elif kwargs.get("speculative_algorithm", None) == "NGRAM": + specdec = NGramDecodingConfig( + max_draft_len=kwargs.get("speculative_num_steps", 5), + max_matching_ngram_size=kwargs.get("max_matching_ngram_size", 3), + is_keep_all=True, + is_use_oldest=True, + is_public_pool=True, + ) + elif kwargs.get("speculative_algorithm", None) == "NONE": + specdec = None + else: + specdec = None + + kv_cache_config = KvCacheConfig( + enable_block_reuse=kwargs.get("prefix_cache", False), + free_gpu_memory_fraction=0.75, + ) + + cuda_graph_config = CudaGraphConfig( + batch_sizes=[max_concurrent_requests], + enable_padding=True, + ) + + model = LLM( + model=model_path, + tensor_parallel_size=kwargs.get("tensor_parallel_size", 4), + moe_expert_parallel_size=kwargs.get("moe_expert_parallel_size", 2), + disable_overlap_scheduler=disable_overlap_schedule, + cuda_graph_config=cuda_graph_config, + enable_chunked_prefill=kwargs.get("enable_chunked_prefill", False), + kv_cache_config=kv_cache_config, + speculative_config=specdec, + enable_attention_dp=kwargs.get("enable_attention_dp", False), + max_batch_size=max_concurrent_requests, + moe_config=MoeConfig(backend=kwargs.get("moe_backend", "TRTLLM")), + sampler_type="TorchSampler", + ) + return model + + +def check_sampling_config(sampling_config, max_length, end_id): + return SamplingParams( + use_beam_search=sampling_config.get("beam_width", 1) > 1, + n=sampling_config.get("beam_width", 1), # beam_width=1 for inflight batching + top_k=sampling_config.get("top_k", None), # SizeType topK + top_p=sampling_config.get("top_p", None), + seed=sampling_config.get("seed", None), + temperature=sampling_config.get("temperature", 1), + max_tokens=max_length, + end_id=end_id, + detokenize=False, + ) diff --git a/examples/specdec_bench/specdec_bench/models/vllm.py b/examples/specdec_bench/specdec_bench/models/vllm.py new file mode 100644 index 000000000..b0d4642d2 --- /dev/null +++ b/examples/specdec_bench/specdec_bench/models/vllm.py @@ -0,0 +1,140 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import time + +from .base import Model + +try: + from vllm import SamplingParams + from vllm.engine.arg_utils import AsyncEngineArgs + from vllm.inputs import TokensPrompt + from vllm.v1.engine.async_llm import AsyncLLM +except ImportError: + print("vllm is not installed.") + vllm = None + + +class VLLMModel(Model): + def __init__(self, model_dir, max_concurrent_requests, sampling_kwargs, **kwargs): + specdec = None + if kwargs.get("speculative_algorithm") == "EAGLE3": + specdec = { + "method": "eagle3", + "model": kwargs.get("draft_model_dir"), + "num_speculative_tokens": kwargs.get("speculative_num_steps", 3), + } + elif kwargs.get("speculative_algorithm") == "EAGLE": + specdec = { + "method": "eagle", + "model": kwargs.get("draft_model_dir"), + "num_speculative_tokens": kwargs.get("speculative_num_steps", 3), + } + elif kwargs.get("speculative_algorithm") == "NGRAM": + specdec = { + "method": "ngram", + "num_speculative_tokens": kwargs.get("speculative_num_steps", 3), + "prompt_lookup_max": kwargs.get("max_matching_ngram_size", 3), # No idea here + } + elif kwargs.get("speculative_algorithm") == "DRAFT_TARGET": + specdec = { + "method": "draft_target", + "model": kwargs.get("draft_model_dir"), + "num_speculative_tokens": kwargs.get("speculative_num_steps", 3), + } + elif kwargs.get("speculative_algorithm") == "MTP": + specdec = { + "method": "mtp", + "num_speculative_tokens": kwargs.get("speculative_num_steps", 3), + } + elif kwargs.get("speculative_algorithm") == "NONE": + specdec = None + engine_args = AsyncEngineArgs( + model=model_dir, + trust_remote_code=True, + tensor_parallel_size=kwargs.get("tensor_parallel_size", 1), + enable_expert_parallel=kwargs.get("moe_expert_parallel_size", 1) > 1, + enable_prefix_caching=kwargs.get("prefix_cache", False), + speculative_config=specdec, + max_num_seqs=max_concurrent_requests, + skip_tokenizer_init=False, + ) + self.model = AsyncLLM.from_engine_args(engine_args) + self.sampling_kwargs = sampling_kwargs + # https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py + self.sampling_config = SamplingParams( + detokenize=False, + temperature=sampling_kwargs.get("temperature", 1.0), + top_p=sampling_kwargs.get("top_p", 1.0), + top_k=sampling_kwargs.get("top_k", 0), + ) + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + async def run(self, prompt_ids, max_length, end_id, request_id): + output_dict = {} + self.sampling_config.max_tokens = max_length + self.sampling_config.stop_token_ids = [end_id] + + outputs, timing, full_tokens = await self.generate(prompt_ids, request_id) + + reformatted_output_ids = [[] for _ in range(self.sampling_kwargs.get("beam_width", 1))] + start = 0 + timing_to_strip = [] + for i in range(len(outputs)): + if outputs[i] == start: + timing_to_strip.append(i) + continue + if i == len(outputs) - 1: + if full_tokens[-1] == end_id: + if outputs[i] - start == 1: + timing_to_strip.append(i) + else: + reformatted_output_ids[0].append(full_tokens[start : outputs[i] - 1]) + break + reformatted_output_ids[0].append(full_tokens[start : outputs[i]]) + start = outputs[i] + output_dict["output_ids"] = reformatted_output_ids + output_dict["output_logits"] = None + output_dict["token_times"] = [ + timing[i] for i in range(len(timing)) if i not in timing_to_strip + ] + return output_dict + + async def generate(self, prompt_ids, request_id): + timing = [] + timing.append(time.perf_counter()) + outputs = [] + full_tokens = [] + async for output in self.model.generate( + request_id=str(request_id), + prompt=TokensPrompt(prompt_token_ids=prompt_ids), + sampling_params=self.sampling_config, + ): + for completion in output.outputs: + outputs.append(len(completion.token_ids)) + timing.append(time.perf_counter()) + full_tokens = completion.token_ids + if output.finished: + break + return outputs, timing, full_tokens + + def stop(self): + try: + self.loop.run_until_complete(self.model.shutdown()) + self.loop.close() + except Exception: + pass diff --git a/examples/specdec_bench/specdec_bench/runners/__init__.py b/examples/specdec_bench/specdec_bench/runners/__init__.py new file mode 100644 index 000000000..61a85c769 --- /dev/null +++ b/examples/specdec_bench/specdec_bench/runners/__init__.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import BaseRunner +from .simple import SimpleRunner diff --git a/examples/specdec_bench/specdec_bench/runners/base.py b/examples/specdec_bench/specdec_bench/runners/base.py new file mode 100644 index 000000000..794180ef5 --- /dev/null +++ b/examples/specdec_bench/specdec_bench/runners/base.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class BaseRunner: + def __init__(self, model, metrics): + # initialize the accelerate or the hf model + self.model = model + self.metrics = metrics + self.prompt_ar = [] + + async def run(self, prompt_ids, max_length, end_id, sampling_kwargs): + raise NotImplementedError() + + def process_metrics_final(self, text_outputs): + [metric.process_final(text_outputs) for metric in self.metrics] + + def process_metrics_step(self, step_outputs, new_turn=True): + [metric.process_step(step_outputs, new_turn) for metric in self.metrics] + + def clear_metrics(self): + [metric.clear() for metric in self.metrics] + + def stop(self): + self.model.stop() diff --git a/examples/specdec_bench/specdec_bench/runners/simple.py b/examples/specdec_bench/specdec_bench/runners/simple.py new file mode 100644 index 000000000..3b1458648 --- /dev/null +++ b/examples/specdec_bench/specdec_bench/runners/simple.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .base import BaseRunner + + +class SimpleRunner(BaseRunner): + def __init__(self, model, metrics): + # initialize the accelerate or the hf model + self.model = model + self.metrics = metrics + self.prompt_ar = [] + + async def run(self, prompt_ids, max_length, end_id, request_id): + model_output = await self.model.run(prompt_ids, max_length, end_id, request_id) + self.process_metrics_step(model_output) + output_ids = model_output["output_ids"] + flattened_output_ids = [[] for _ in range(len(output_ids))] + for i, beam_output in enumerate(output_ids): + for output_id_iter in beam_output: + flattened_output_ids[i].extend(output_id_iter) + + return { + "output_ids": flattened_output_ids, + "output_logits": model_output.get("output_logits", None), + } diff --git a/examples/specdec_bench/specdec_bench/utils.py b/examples/specdec_bench/specdec_bench/utils.py new file mode 100644 index 000000000..d605f0b4b --- /dev/null +++ b/examples/specdec_bench/specdec_bench/utils.py @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from transformers import AutoTokenizer + + +def get_tokenizer(path): + return AutoTokenizer.from_pretrained(path) + + +def encode_chat(tokenizer, messages): + return tokenizer.encode( + tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True), + add_special_tokens=False, + ) + + +def decode_chat(tokenizer, out_tokens): + return tokenizer.decode(out_tokens) + + +def read_json(path): + if path is not None: + with open(path) as f: + data = json.load(f) + return data + return {} + + +def postprocess_base(text): + return text + + +def postprocess_gptoss(text): + return text.split("<|channel|>final<|message|>")[-1] diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index 8186b0cd5..7ef7f3791 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -179,6 +179,10 @@ Please refer to [TRT-LLM Doc: Speculative Decoding](https://nvidia.github.io/Ten Please refer to [SGLang Doc: Speculative Decoding](https://docs.sglang.ai/advanced_features/speculative_decoding.html#EAGLE-3-Decoding) for detailed usage. +### SpecDec Bench + +One can also use [examples/specdec_bench](../specdec_bench) to validate the trained Eagle3 checkpoints in a variety of frameworks (vLLM, SGLang, TRT-LLM) on a set of datasets. + ### Deploying Quantized model See more details on deployment of quantized model to TRTLLM [here](../llm_ptq/README.md).