Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions tests/test_responses_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,56 @@ def test_previous_response_id_reuses_prior_context(self, client):
assert second_messages[2]["role"] == "user"
assert second_messages[2]["content"] == "Follow-up prompt"

def test_developer_role_is_normalized_to_system(self, client):
import vllm_mlx.server as srv

engine = _mock_engine(_output("Ready"))
srv._engine = engine

resp = client.post(
"/v1/responses",
json={
"model": "test-model",
"input": [
{"type": "message", "role": "user", "content": "Hi"},
{"type": "message", "role": "developer", "content": "Be terse"},
],
},
)

assert resp.status_code == 200
messages = engine.chat.call_args.kwargs["messages"]
assert messages[0]["role"] == "system"
assert messages[0]["content"] == "Be terse"
assert messages[1]["role"] == "user"
assert messages[1]["content"] == "Hi"

def test_instructions_and_developer_message_are_merged(self, client):
import vllm_mlx.server as srv

engine = _mock_engine(_output("Ready"))
srv._engine = engine

resp = client.post(
"/v1/responses",
json={
"model": "test-model",
"instructions": "System instructions",
"input": [
{"type": "message", "role": "developer", "content": "Developer note"},
{"type": "message", "role": "user", "content": "Hi"},
],
},
)

assert resp.status_code == 200
messages = engine.chat.call_args.kwargs["messages"]
assert len([m for m in messages if m["role"] == "system"]) == 1
assert messages[0]["role"] == "system"
assert "System instructions" in messages[0]["content"]
assert "Developer note" in messages[0]["content"]
assert messages[1]["role"] == "user"

def test_function_call_output_input_is_mapped_cleanly(self, client):
import vllm_mlx.server as srv

Expand Down
23 changes: 21 additions & 2 deletions vllm_mlx/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,9 +538,12 @@ def _responses_input_to_chat_messages(request: ResponsesRequest) -> list[dict]:
if isinstance(item, dict):
item_type = item.get("type", "")
if item_type == "message":
role = item.get("role", "user")
if role == "developer":
role = "system"
messages.append(
{
"role": item.get("role", "user"),
"role": role,
"content": _response_content_to_text(item.get("content")),
}
)
Expand Down Expand Up @@ -576,9 +579,12 @@ def _responses_input_to_chat_messages(request: ResponsesRequest) -> list[dict]:
continue

if isinstance(item, ResponseMessageItem):
role = item.role
if role == "developer":
role = "system"
messages.append(
{
"role": item.role,
"role": role,
"content": _response_content_to_text(item.content),
}
)
Expand Down Expand Up @@ -642,6 +648,19 @@ def _responses_request_to_chat_request(request: ResponsesRequest) -> ChatComplet
},
)

system_messages = [msg for msg in messages if msg.get("role") == "system"]
non_system_messages = [msg for msg in messages if msg.get("role") != "system"]
merged_system_content = "\n\n".join(
str(msg.get("content", "")).strip()
for msg in system_messages
if str(msg.get("content", "")).strip()
)
messages = (
[{"role": "system", "content": merged_system_content}]
if merged_system_content
else []
) + non_system_messages

return ChatCompletionRequest(
model=request.model,
messages=[Message(**msg) for msg in messages],
Expand Down