From 3e264d3ef2191f547dce95f7ccde8f47cf457c3c Mon Sep 17 00:00:00 2001 From: Vedant Nanda Date: Sun, 16 Mar 2025 11:43:15 -0400 Subject: [PATCH 1/2] add key to config and increase the number of concurent requests allowed --- oai_server_benchmark.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/oai_server_benchmark.py b/oai_server_benchmark.py index 8149a9c..7940291 100644 --- a/oai_server_benchmark.py +++ b/oai_server_benchmark.py @@ -74,8 +74,9 @@ def run_benchmark(client, model: str, conversations, temperature: float, max_tok """Run a benchmark for one batch of conversations concurrently.""" start_time = time.perf_counter() results = [] + batch_size = len(conversations) - with concurrent.futures.ThreadPoolExecutor() as executor: + with concurrent.futures.ThreadPoolExecutor(max_workers=batch_size) as executor: futures = [ executor.submit(call_server_completion, client, model, conv, temperature, max_tokens) for conv in conversations @@ -118,6 +119,8 @@ def main(): help="Dataset key (e.g., 'aime', 'conversation')") parser.add_argument("--api_base", type=str, default="http://localhost:8000/v1", help="Base URL of the vLLM server API.") + parser.add_argument("--api_key", type=str, default="sk-dummy", + help="API key for the server. Defaults to 'sk-dummy'.") parser.add_argument("--batch_sizes", type=str, default="1,2,4,8", help="Comma-separated batch sizes (e.g., '1,2,4,8').") parser.add_argument("--num_runs", type=int, default=3, @@ -135,7 +138,7 @@ def main(): args = parser.parse_args() batch_sizes = [int(bs.strip()) for bs in args.batch_sizes.split(",") if bs.strip()] - client = OpenAI(api_key="sk-dummy", base_url=args.api_base) + client = OpenAI(api_key=args.api_key, base_url=args.api_base) results = { "metadata": { From 9b809ff2958ae8633f578e1771495e72daea5c49 Mon Sep 17 00:00:00 2001 From: Pieter Delobelle Date: Mon, 24 Mar 2025 13:47:29 -0400 Subject: [PATCH 2/2] enable running a dummy batch --- oai_server_benchmark.py | 53 ++++++++++++++++++++++++++++++++++------- 1 file changed, 45 insertions(+), 8 deletions(-) diff --git a/oai_server_benchmark.py b/oai_server_benchmark.py index 7940291..6afdbf7 100644 --- a/oai_server_benchmark.py +++ b/oai_server_benchmark.py @@ -6,6 +6,7 @@ from datetime import datetime from datasets import load_dataset from openai import OpenAI +import random class DatasetHandler: @staticmethod @@ -20,9 +21,34 @@ def aime_handler(client, model, item): max_tokens=1 ).usage.prompt_tokens return messages, prompt_tokens + + @staticmethod + def dummy_handler(client, model, item, num_tokens): + """Handler for dummy datasets that sample specified number of tokens.""" + sampled_tokens = random.choices(["apple"], k=num_tokens) + messages = [ + {"role": "user", "content": " ".join(sampled_tokens)} + ] + prompt_tokens = client.chat.completions.create( + model=model, + messages=messages, + max_tokens=1 + ).usage.prompt_tokens + return messages, prompt_tokens def get_dataset_config(dataset_key: str): """Return dataset configuration based on key.""" + + if dataset_key.startswith("dummy_"): + try: + num_tokens = int(dataset_key.split("_")[1]) + return { + "dataset": "", + "handler": lambda client, model, item: DatasetHandler.dummy_handler(client, model, item, num_tokens) + } + except (IndexError, ValueError): + assert False, "Invalid dummy dataset format. Should be dummy_" + configs = { "aime": { "dataset": "gneubig/aime-1983-2024", @@ -37,16 +63,27 @@ def create_sample_conversations(client, model: str, dataset_key: str, num_sample if not dataset_config: raise ValueError(f"Unknown dataset key: {dataset_key}") - ds = load_dataset(dataset_config["dataset"]) - sampled_dataset = ds["train"].shuffle(seed=seed).select(range(num_samples)) - conversations = [] total_input_tokens = 0 - for item in sampled_dataset: - messages, prompt_tokens = dataset_config["handler"](client, model, item) - conversations.append(messages) - total_input_tokens += prompt_tokens + # Special handling for dummy datasets + if dataset_key.startswith("dummy_"): + # For dummy datasets, create synthetic data + for _ in range(num_samples): + # Create a dummy item (can be anything, it will be handled by the dummy handler) + dummy_item = {"dummy": True} + messages, prompt_tokens = dataset_config["handler"](client, model, dummy_item) + conversations.append(messages) + total_input_tokens += prompt_tokens + else: + # For real datasets, load from HuggingFace + ds = load_dataset(dataset_config["dataset"]) + sampled_dataset = ds["train"].shuffle(seed=seed).select(range(num_samples)) + + for item in sampled_dataset: + messages, prompt_tokens = dataset_config["handler"](client, model, item) + conversations.append(messages) + total_input_tokens += prompt_tokens avg_input_tokens = total_input_tokens / len(conversations) if conversations else 0 return conversations, avg_input_tokens @@ -116,7 +153,7 @@ def main(): parser.add_argument("--model", type=str, required=True, help="Model tag to use (e.g., 'distill-llama-8b').") parser.add_argument("--dataset_key", type=str, default="aime", - help="Dataset key (e.g., 'aime', 'conversation')") + help="Dataset key (e.g., 'aime', 'dummy_300')") parser.add_argument("--api_base", type=str, default="http://localhost:8000/v1", help="Base URL of the vLLM server API.") parser.add_argument("--api_key", type=str, default="sk-dummy",