Skip to content

Commit b778200

Browse files
gty111ywang96
andauthored
[Benchmark] Refactor sample_requests in benchmark_throughput (vllm-project#3613)
Co-authored-by: Roger Wang <[email protected]>
1 parent 819a309 commit b778200

File tree

1 file changed

+15
-16
lines changed

1 file changed

+15
-16
lines changed

benchmarks/benchmark_throughput.py

+15-16
Original file line numberDiff line numberDiff line change
@@ -29,22 +29,23 @@ def sample_requests(
2929
dataset = [(data["conversations"][0]["value"],
3030
data["conversations"][1]["value"]) for data in dataset]
3131

32-
# Tokenize the prompts and completions.
33-
prompts = [prompt for prompt, _ in dataset]
34-
prompt_token_ids = tokenizer(prompts).input_ids
35-
completions = [completion for _, completion in dataset]
36-
completion_token_ids = tokenizer(completions).input_ids
37-
tokenized_dataset = []
38-
for i in range(len(dataset)):
39-
output_len = len(completion_token_ids[i])
40-
if fixed_output_len is not None:
41-
output_len = fixed_output_len
42-
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
32+
# Shuffle the dataset.
33+
random.shuffle(dataset)
4334

44-
# Filter out too long sequences.
35+
# Filter out sequences that are too long or too short
4536
filtered_dataset: List[Tuple[str, int, int]] = []
46-
for prompt, prompt_token_ids, output_len in tokenized_dataset:
37+
for i in range(len(dataset)):
38+
if len(filtered_dataset) == num_requests:
39+
break
40+
41+
# Tokenize the prompts and completions.
42+
prompt = dataset[i][0]
43+
prompt_token_ids = tokenizer(prompt).input_ids
44+
completion = dataset[i][1]
45+
completion_token_ids = tokenizer(completion).input_ids
4746
prompt_len = len(prompt_token_ids)
47+
output_len = len(completion_token_ids
48+
) if fixed_output_len is None else fixed_output_len
4849
if prompt_len < 4 or output_len < 4:
4950
# Prune too short sequences.
5051
continue
@@ -53,9 +54,7 @@ def sample_requests(
5354
continue
5455
filtered_dataset.append((prompt, prompt_len, output_len))
5556

56-
# Sample the requests.
57-
sampled_requests = random.sample(filtered_dataset, num_requests)
58-
return sampled_requests
57+
return filtered_dataset
5958

6059

6160
def run_vllm(

0 commit comments

Comments
 (0)