Skip to content

Commit 7222e1d

Browse files
authored
Let bench_one_batch_server use sharegpt data to make expert distribution more natural (sgl-project#5573)
1 parent 505eec4 commit 7222e1d

2 files changed

Lines changed: 33 additions & 19 deletions

File tree

python/sglang/bench_one_batch_server.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import numpy as np
2323
import requests
2424

25+
from sglang.bench_serving import get_tokenizer, sample_random_requests
2526
from sglang.srt.entrypoints.http_server import launch_server
2627
from sglang.srt.server_args import ServerArgs
2728
from sglang.srt.utils import kill_process_tree
@@ -117,16 +118,19 @@ def run_one_case(
117118
input_len_step_percentage: float,
118119
run_name: str,
119120
result_filename: str,
121+
tokenizer,
120122
):
121123
requests.post(url + "/flush_cache")
122-
input_lens = [
123-
int(input_len * (1 + (i - (batch_size - 1) / 2) * input_len_step_percentage))
124-
for i in range(batch_size)
125-
]
126-
input_ids = [
127-
[int(x) for x in np.random.randint(0, high=16384, size=(input_lens[i],))]
128-
for i in range(batch_size)
129-
]
124+
input_requests = sample_random_requests(
125+
input_len=input_len,
126+
output_len=output_len,
127+
num_prompts=batch_size,
128+
range_ratio=1.0,
129+
tokenizer=tokenizer,
130+
dataset_path="",
131+
random_sample=True,
132+
return_text=False,
133+
)
130134

131135
use_structured_outputs = False
132136
if use_structured_outputs:
@@ -145,8 +149,7 @@ def run_one_case(
145149
response = requests.post(
146150
url + "/generate",
147151
json={
148-
# "text": texts,
149-
"input_ids": input_ids,
152+
"input_ids": [input_ids for input_ids, _, _ in input_requests],
150153
"sampling_params": {
151154
"temperature": temperature,
152155
"max_new_tokens": output_len,
@@ -228,6 +231,9 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
228231
else:
229232
proc, base_url = launch_server_process(server_args)
230233

234+
tokenizer_id = server_args.tokenizer_path or server_args.model_path
235+
tokenizer = get_tokenizer(tokenizer_id)
236+
231237
# warmup
232238
if not bench_args.skip_warmup:
233239
print("=" * 8 + " Warmup Begin " + "=" * 8)
@@ -241,6 +247,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
241247
input_len_step_percentage=bench_args.input_len_step_percentage,
242248
run_name="",
243249
result_filename="",
250+
tokenizer=tokenizer,
244251
)
245252
print("=" * 8 + " Warmup End " + "=" * 8 + "\n")
246253

python/sglang/bench_serving.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,10 @@ def get_model(pretrained_model_name_or_path: str) -> str:
471471
def get_tokenizer(
472472
pretrained_model_name_or_path: str,
473473
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
474+
assert (
475+
pretrained_model_name_or_path is not None
476+
and pretrained_model_name_or_path != ""
477+
)
474478
if pretrained_model_name_or_path.endswith(
475479
".json"
476480
) or pretrained_model_name_or_path.endswith(".model"):
@@ -832,6 +836,7 @@ def sample_random_requests(
832836
tokenizer: PreTrainedTokenizerBase,
833837
dataset_path: str,
834838
random_sample: bool = True,
839+
return_text: bool = True,
835840
) -> List[DatasetRow]:
836841
input_lens = np.random.randint(
837842
max(int(input_len * range_ratio), 1),
@@ -892,10 +897,12 @@ def sample_random_requests(
892897
else:
893898
ratio = (input_lens[i] + prompt_len - 1) // prompt_len
894899
input_ids = (prompt_token_ids * ratio)[: input_lens[i]]
895-
prompt = tokenizer.decode(input_ids)
900+
input_content = input_ids
901+
if return_text:
902+
input_content = tokenizer.decode(input_content)
896903
input_requests.append(
897904
DatasetRow(
898-
prompt=prompt,
905+
prompt=input_content,
899906
prompt_len=int(input_lens[i]),
900907
output_len=int(output_lens[i]),
901908
)
@@ -905,15 +912,15 @@ def sample_random_requests(
905912
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
906913
input_requests = []
907914
for i in range(num_prompts):
908-
prompt = tokenizer.decode(
909-
[
910-
(offsets[i] + i + j) % tokenizer.vocab_size
911-
for j in range(input_lens[i])
912-
]
913-
)
915+
input_content = [
916+
(offsets[i] + i + j) % tokenizer.vocab_size
917+
for j in range(input_lens[i])
918+
]
919+
if return_text:
920+
input_content = tokenizer.decode(input_content)
914921
input_requests.append(
915922
DatasetRow(
916-
prompt=prompt,
923+
prompt=input_content,
917924
prompt_len=int(input_lens[i]),
918925
output_len=int(output_lens[i]),
919926
)

0 commit comments

Comments
 (0)