Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 50 additions & 10 deletions oai_server_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from datetime import datetime
from datasets import load_dataset
from openai import OpenAI
import random

class DatasetHandler:
@staticmethod
Expand All @@ -20,9 +21,34 @@ def aime_handler(client, model, item):
max_tokens=1
).usage.prompt_tokens
return messages, prompt_tokens

@staticmethod
def dummy_handler(client, model, item, num_tokens):
"""Handler for dummy datasets that sample specified number of tokens."""
sampled_tokens = random.choices(["apple"], k=num_tokens)
messages = [
{"role": "user", "content": " ".join(sampled_tokens)}
]
prompt_tokens = client.chat.completions.create(
model=model,
messages=messages,
max_tokens=1
).usage.prompt_tokens
return messages, prompt_tokens

def get_dataset_config(dataset_key: str):
"""Return dataset configuration based on key."""

if dataset_key.startswith("dummy_"):
try:
num_tokens = int(dataset_key.split("_")[1])
return {
"dataset": "",
"handler": lambda client, model, item: DatasetHandler.dummy_handler(client, model, item, num_tokens)
}
except (IndexError, ValueError):
assert False, "Invalid dummy dataset format. Should be dummy_<number>"

configs = {
"aime": {
"dataset": "gneubig/aime-1983-2024",
Expand All @@ -37,16 +63,27 @@ def create_sample_conversations(client, model: str, dataset_key: str, num_sample
if not dataset_config:
raise ValueError(f"Unknown dataset key: {dataset_key}")

ds = load_dataset(dataset_config["dataset"])
sampled_dataset = ds["train"].shuffle(seed=seed).select(range(num_samples))

conversations = []
total_input_tokens = 0

for item in sampled_dataset:
messages, prompt_tokens = dataset_config["handler"](client, model, item)
conversations.append(messages)
total_input_tokens += prompt_tokens
# Special handling for dummy datasets
if dataset_key.startswith("dummy_"):
# For dummy datasets, create synthetic data
for _ in range(num_samples):
# Create a dummy item (can be anything, it will be handled by the dummy handler)
dummy_item = {"dummy": True}
messages, prompt_tokens = dataset_config["handler"](client, model, dummy_item)
conversations.append(messages)
total_input_tokens += prompt_tokens
else:
# For real datasets, load from HuggingFace
ds = load_dataset(dataset_config["dataset"])
sampled_dataset = ds["train"].shuffle(seed=seed).select(range(num_samples))

for item in sampled_dataset:
messages, prompt_tokens = dataset_config["handler"](client, model, item)
conversations.append(messages)
total_input_tokens += prompt_tokens

avg_input_tokens = total_input_tokens / len(conversations) if conversations else 0
return conversations, avg_input_tokens
Expand Down Expand Up @@ -74,8 +111,9 @@ def run_benchmark(client, model: str, conversations, temperature: float, max_tok
"""Run a benchmark for one batch of conversations concurrently."""
start_time = time.perf_counter()
results = []
batch_size = len(conversations)

with concurrent.futures.ThreadPoolExecutor() as executor:
with concurrent.futures.ThreadPoolExecutor(max_workers=batch_size) as executor:
futures = [
executor.submit(call_server_completion, client, model, conv, temperature, max_tokens)
for conv in conversations
Expand Down Expand Up @@ -115,9 +153,11 @@ def main():
parser.add_argument("--model", type=str, required=True,
help="Model tag to use (e.g., 'distill-llama-8b').")
parser.add_argument("--dataset_key", type=str, default="aime",
help="Dataset key (e.g., 'aime', 'conversation')")
help="Dataset key (e.g., 'aime', 'dummy_300')")
parser.add_argument("--api_base", type=str, default="http://localhost:8000/v1",
help="Base URL of the vLLM server API.")
parser.add_argument("--api_key", type=str, default="sk-dummy",
help="API key for the server. Defaults to 'sk-dummy'.")
parser.add_argument("--batch_sizes", type=str, default="1,2,4,8",
help="Comma-separated batch sizes (e.g., '1,2,4,8').")
parser.add_argument("--num_runs", type=int, default=3,
Expand All @@ -135,7 +175,7 @@ def main():
args = parser.parse_args()

batch_sizes = [int(bs.strip()) for bs in args.batch_sizes.split(",") if bs.strip()]
client = OpenAI(api_key="sk-dummy", base_url=args.api_base)
client = OpenAI(api_key=args.api_key, base_url=args.api_base)

results = {
"metadata": {
Expand Down