Skip to content

Commit 211318d

Browse files
authored
Add throughput benchmarking script (vllm-project#133)
1 parent 337871c commit 211318d

12 files changed

+145
-257
lines changed

benchmark/benchmark_attention.py

-165
This file was deleted.

benchmark/benchmark_cache.py

-81
This file was deleted.

benchmarks/README.md

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Benchmarking CacheFlow
2+
3+
## Downloading the ShareGPT dataset
4+
5+
You can download the dataset by running:
6+
```bash
7+
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
8+
```
File renamed without changes.

benchmarks/benchmark_throughput.py

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import argparse
2+
import json
3+
import random
4+
import time
5+
from typing import List, Tuple
6+
7+
from cacheflow import LLM, SamplingParams
8+
from transformers import PreTrainedTokenizerBase
9+
10+
11+
def sample_requests(
12+
dataset_path: str,
13+
num_requests: int,
14+
tokenizer: PreTrainedTokenizerBase,
15+
) -> List[Tuple[List[int], int]]:
16+
# Load the dataset.
17+
with open(dataset_path) as f:
18+
dataset = json.load(f)
19+
# Filter out the conversations with less than 2 turns.
20+
dataset = [
21+
data for data in dataset
22+
if len(data["conversations"]) >= 2
23+
]
24+
# Only keep the first two turns of each conversation.
25+
dataset = [
26+
(data["conversations"][0]["value"], data["conversations"][1]["value"])
27+
for data in dataset
28+
]
29+
30+
# Tokenize the prompts and completions.
31+
prompts = [prompt for prompt, _ in dataset]
32+
prompt_token_ids = tokenizer(prompts).input_ids
33+
completions = [completion for _, completion in dataset]
34+
completion_token_ids = tokenizer(completions).input_ids
35+
tokenized_dataset = []
36+
for i in range(len(dataset)):
37+
output_len = len(completion_token_ids[i])
38+
tokenized_dataset.append((prompt_token_ids[i], output_len))
39+
# Filter out if the prompt length + output length is greater than 2048.
40+
tokenized_dataset = [
41+
(prompt_token_ids, output_len)
42+
for prompt_token_ids, output_len in tokenized_dataset
43+
if len(prompt_token_ids) + output_len <= 2048
44+
]
45+
46+
# Sample the requests.
47+
sampled_requests = random.sample(tokenized_dataset, num_requests)
48+
return sampled_requests
49+
50+
51+
def main(args: argparse.Namespace):
52+
print(args)
53+
random.seed(args.seed)
54+
55+
llm = LLM(
56+
model=args.model,
57+
tensor_parallel_size=args.tensor_parallel_size,
58+
seed=args.seed,
59+
)
60+
tokenizer = llm.get_tokenizer()
61+
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
62+
63+
# Add the requests to the server.
64+
for prompt_token_ids, output_len in requests:
65+
sampling_params = SamplingParams(
66+
n=args.n,
67+
temperature=0.0 if args.use_beam_search else 1.0,
68+
top_p=1.0,
69+
use_beam_search=args.use_beam_search,
70+
ignore_eos=True,
71+
max_tokens=output_len,
72+
)
73+
# FIXME(woosuk): Do not use internal method.
74+
llm._add_request(
75+
prompt="",
76+
sampling_params=sampling_params,
77+
prompt_token_ids=prompt_token_ids,
78+
)
79+
80+
start = time.time()
81+
# FIXME(woosuk): Do use internal method.
82+
llm._run_server(use_tqdm=True)
83+
end = time.time()
84+
total_num_tokens = sum(
85+
len(prompt_token_ids) + output_len
86+
for prompt_token_ids, output_len in requests
87+
)
88+
print(f"Throughput: {total_num_tokens / (end - start):.2f} tokens/s")
89+
90+
91+
if __name__ == "__main__":
92+
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
93+
parser.add_argument("--dataset", type=str, required=True,
94+
help="Path to the dataset.")
95+
parser.add_argument("--model", type=str, default="facebook/opt-125m")
96+
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
97+
parser.add_argument("--n", type=int, default=1,
98+
help="Number of generated sequences per prompt.")
99+
parser.add_argument("--use-beam-search", action="store_true")
100+
parser.add_argument("--num-prompts", type=int, default=1000,
101+
help="Number of prompts to process.")
102+
parser.add_argument("--seed", type=int, default=0)
103+
args = parser.parse_args()
104+
main(args)
File renamed without changes.

cacheflow/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from cacheflow.entrypoints.llm import LLM
2-
from cacheflow.outputs import RequestOutput
2+
from cacheflow.outputs import RequestOutput, CompletionOutput
33
from cacheflow.sampling_params import SamplingParams
44
from cacheflow.server.arg_utils import ServerArgs
55
from cacheflow.server.llm_server import LLMServer
@@ -9,6 +9,7 @@
99
"LLM",
1010
"SamplingParams",
1111
"RequestOutput",
12+
"CompletionOutput",
1213
"LLMServer",
1314
"ServerArgs",
1415
"initialize_cluster",

cacheflow/core/scheduler.py

+3
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ def add_seq_group(self, seq_group: SequenceGroup) -> None:
8787
def has_unfinished_seqs(self) -> bool:
8888
return self.waiting or self.running or self.swapped
8989

90+
def get_num_unfinished_seq_groups(self) -> int:
91+
return len(self.waiting) + len(self.running) + len(self.swapped)
92+
9093
def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]:
9194
# Blocks that need to be swaped or copied before model execution.
9295
blocks_to_swap_in: Dict[int, int] = {}

0 commit comments

Comments
 (0)