Skip to content
187 changes: 168 additions & 19 deletions python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import logging
import multiprocessing as multiprocessing
import os
import random
import threading
import time
from http import HTTPStatus
Expand Down Expand Up @@ -776,6 +777,20 @@ def _wait_and_warmup(
image_token_text: str,
launch_callback: Optional[Callable[[], None]] = None,
):
def _generate_passkey_sample(length):
passkey = "The passkey is **000310**. " * 3
filler = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. "
repeat = int(length * 1024 / 24 / 2)
if "Llama-4" in server_args.model_path:
text = f"<|header_start|>user<|header_end|>\n\nYour task is find the passkey value from the text. {filler * repeat} {passkey} {filler * repeat}.<|eot|><|header_start|>assistant<|header_end|>\n\nThe passkey is **"
elif "Llama-3" in server_args.model_path:
text = f"<|start_header_id|>user<|end_header_id|>\n\nYour task is find the passkey value from the text. {filler * repeat} {passkey} {filler * repeat}.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nThe passkey is **"
elif "Qwen3" in server_args.model_path:
text = f"<|im_start|>user\nYour task is find the passkey value from the text. {filler * repeat} {passkey} {filler * repeat}.<|im_end|>\n<|im_start|>assistant\n<think></think>\n\nThe passkey is **"
else:
text = f"### User\n\nYour task is find the passkey value from the text. {filler * repeat} {passkey} {filler * repeat}.\n\n### Response\n\nThe passkey is **"
return text

headers = {}
url = server_args.url()
if server_args.api_key:
Expand Down Expand Up @@ -805,7 +820,9 @@ def _wait_and_warmup(

# Send a warmup request
request_name = "/generate" if model_info["is_generation"] else "/encode"
max_new_tokens = 128 if model_info["is_generation"] else 1
max_new_tokens = (
int(os.getenv("PASSKEY_DECODE_LEN", 128)) if model_info["is_generation"] else 1
)
# if os.getenv('SGLANG_DEBUG_EXIT_WARMUP', '0') == '1':
# max_new_tokens = 10
json_data = {
Expand All @@ -823,17 +840,8 @@ def _wait_and_warmup(
if server_args.dp_size == 1:
json_data["input_ids"] = json_data["input_ids"][0]
else:
passkey = "The passkey is **000310**. " * 3
filler = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again. "
repeat = int(int(os.getenv("PASSKEY_LEN", "8")) * 1024 / 24 / 2)
if "Llama-4" in server_args.model_path:
text = f"<|header_start|>user<|header_end|>\n\nYour task is find the passkey value from the text. {filler * repeat} {passkey} {filler * repeat}.<|eot|><|header_start|>assistant<|header_end|>\n\nThe passkey is **"
elif "Llama-3" in server_args.model_path:
text = f"<|start_header_id|>user<|end_header_id|>\n\nYour task is find the passkey value from the text. {filler * repeat} {passkey} {filler * repeat}.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\nThe passkey is **"
elif "Qwen3" in server_args.model_path:
text = f"<|im_start|>user\nYour task is find the passkey value from the text. {filler * repeat} {passkey} {filler * repeat}.<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\nThe passkey is **"
else:
text = f"### User\n\nYour task is find the passkey value from the text. {filler * repeat} {passkey} {filler * repeat}.\n\n### Response\n\nThe passkey is **"
passkey_len = int(os.getenv("PASSKEY_LEN", "8"))
text = _generate_passkey_sample(passkey_len)

json_data["text"] = [text] * server_args.dp_size
# TODO Workaround the bug that embedding errors for list of size 1
Expand All @@ -850,14 +858,155 @@ def _wait_and_warmup(

try:
if server_args.disaggregation_mode == "null":
res = requests.post(
url + request_name,
json=json_data,
headers=headers,
timeout=6000,
)
warmup_all_seq_lens = os.getenv("SRT_WARMUP_ALL_SEQ_LENS", "0") == "1"
if warmup_all_seq_lens:
import tqdm
import transformers

step_size = 64

safe_zero = lambda x: x if x is not None else 0
context_size = safe_zero(server_args.chunked_prefill_size)
context_size = max(safe_zero(server_args.context_length), context_size)
assert context_size > 0, "consider pass explicit --context-length"

chunk_size = safe_zero(server_args.chunked_prefill_size)
chunk_size = chunk_size if chunk_size > 0 else context_size

tokenizer_path = model_info["tokenizer_path"]
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_path)

text = _generate_passkey_sample(context_size // 1024)
input_ids = tokenizer.encode(text)[:context_size]
num_decode = 10
step_size = 1024

logger.info(
f"Start warmup all sequences. max_context={context_size}, model={tokenizer_path}"
)

trial_sizes = []
for i_chunk in range(0, context_size, chunk_size):
max_context_len = min(context_size, i_chunk + chunk_size)
real_chunk_size = max_context_len - i_chunk
while real_chunk_size > 1:
trial_all_size = int(i_chunk + real_chunk_size)
trial_sizes.append((int(trial_all_size), int(0)))
real_chunk_size /= 2.0

trial_all_size = max_context_len
trial_prefix_size = trial_all_size
while trial_prefix_size > 1:
if (trial_all_size > 1024) and (
int(trial_all_size - trial_prefix_size) > (num_decode + 1)
):
trial_sizes.append(
(
int(trial_prefix_size),
int(trial_all_size - trial_prefix_size),
)
)
trial_prefix_size /= 2.0

logger.info(f"Prefix, Input")
for t_prefix, t_input in trial_sizes:
logger.info(f"{t_prefix}, {t_input}")

for trial_prefix, trial_input in tqdm.tqdm(
trial_sizes, dynamic_ncols=True
):
trial_input -= num_decode + 1

if trial_input < 1:
continue

input_ids = np.random.randint(10, 1000, (context_size,)).tolist()
new_input_ids = np.random.randint(
10, 1000, (context_size,)
).tolist()

prefix_input_ids = input_ids[: (trial_input + trial_prefix)]
cache_input_ids = new_input_ids[: (trial_input + trial_prefix)]

text_for_prefix = tokenizer.decode(prefix_input_ids)
text_for_cache = tokenizer.decode(
prefix_input_ids[:trial_prefix] + cache_input_ids[trial_prefix:]
)

if len(text_for_prefix) > step_size:

json_data["text"] = text_for_prefix
json_data["sampling_params"]["max_new_tokens"] = num_decode

t_start = time.time()
res = requests.post(
url + request_name,
json=json_data,
headers=headers,
timeout=6000,
)
assert res.status_code == 200, f"{res}"
t_end = time.time()

logger.info(
f"[WARMUP] {(trial_prefix, trial_input)} (no-prefix) took {(t_end - t_start):.2f} s"
)

if (len(text_for_cache) > step_size) and (trial_input > 0):

json_data["text"] = text_for_cache
json_data["sampling_params"]["max_new_tokens"] = num_decode

t_start = time.time()
res = requests.post(
url + request_name,
json=json_data,
headers=headers,
timeout=6000,
)
assert res.status_code == 200, f"{res}"
t_end = time.time()

logger.info(
f"[WARMUP] {(trial_prefix, trial_input)} (with-prefix) took {(t_end - t_start):.2f} s"
)

if (len(text_for_cache) > step_size) and (trial_input == 0):

json_data["text"] = text_for_cache
json_data["sampling_params"]["max_new_tokens"] = num_decode

t_start = time.time()
res = requests.post(
url + request_name,
json=json_data,
headers=headers,
timeout=6000,
)
assert res.status_code == 200, f"{res}"
t_end = time.time()

logger.info(
f"[WARMUP] {(trial_prefix + trial_input, 0)} (all-prefix) took {(t_end - t_start):.2f} s"
)

requests.get(
url + "/flush_cache",
json=json_data,
headers=headers,
timeout=6000,
)

logger.info("[WARM-UP DONE]")
else:
res = requests.post(
url + request_name,
json=json_data,
headers=headers,
timeout=6000,
)
assert res.status_code == 200, f"{res}"
print(res.json())
logger.info(f"Warm-up result: {res.json()}")
if os.getenv("SGLANG_DEBUG_EXIT_WARMUP", "0") == "1":
print("shutdown after warmup")
kill_process_tree(os.getpid())
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/layers/attention/flashattention_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,9 @@ def forward_extend(
cu_seqlens_k = metadata.encoder_cu_seqlens_k
window_size = (-1, -1)

if key_cache.dtype == torch.float8_e4m3fn:
q = q.to(key_cache.dtype)

result = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
Expand Down
Loading
Loading