diff --git a/tests/test_responses_api.py b/tests/test_responses_api.py index bd7c3b4..5cb136f 100644 --- a/tests/test_responses_api.py +++ b/tests/test_responses_api.py @@ -112,6 +112,56 @@ def test_previous_response_id_reuses_prior_context(self, client): assert second_messages[2]["role"] == "user" assert second_messages[2]["content"] == "Follow-up prompt" + def test_developer_role_is_normalized_to_system(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("Ready")) + srv._engine = engine + + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "input": [ + {"type": "message", "role": "user", "content": "Hi"}, + {"type": "message", "role": "developer", "content": "Be terse"}, + ], + }, + ) + + assert resp.status_code == 200 + messages = engine.chat.call_args.kwargs["messages"] + assert messages[0]["role"] == "system" + assert messages[0]["content"] == "Be terse" + assert messages[1]["role"] == "user" + assert messages[1]["content"] == "Hi" + + def test_instructions_and_developer_message_are_merged(self, client): + import vllm_mlx.server as srv + + engine = _mock_engine(_output("Ready")) + srv._engine = engine + + resp = client.post( + "/v1/responses", + json={ + "model": "test-model", + "instructions": "System instructions", + "input": [ + {"type": "message", "role": "developer", "content": "Developer note"}, + {"type": "message", "role": "user", "content": "Hi"}, + ], + }, + ) + + assert resp.status_code == 200 + messages = engine.chat.call_args.kwargs["messages"] + assert len([m for m in messages if m["role"] == "system"]) == 1 + assert messages[0]["role"] == "system" + assert "System instructions" in messages[0]["content"] + assert "Developer note" in messages[0]["content"] + assert messages[1]["role"] == "user" + def test_function_call_output_input_is_mapped_cleanly(self, client): import vllm_mlx.server as srv diff --git a/vllm_mlx/server.py b/vllm_mlx/server.py index 6de2ff9..4ec280d 100644 --- a/vllm_mlx/server.py +++ b/vllm_mlx/server.py @@ -538,9 +538,12 @@ def _responses_input_to_chat_messages(request: ResponsesRequest) -> list[dict]: if isinstance(item, dict): item_type = item.get("type", "") if item_type == "message": + role = item.get("role", "user") + if role == "developer": + role = "system" messages.append( { - "role": item.get("role", "user"), + "role": role, "content": _response_content_to_text(item.get("content")), } ) @@ -576,9 +579,12 @@ def _responses_input_to_chat_messages(request: ResponsesRequest) -> list[dict]: continue if isinstance(item, ResponseMessageItem): + role = item.role + if role == "developer": + role = "system" messages.append( { - "role": item.role, + "role": role, "content": _response_content_to_text(item.content), } ) @@ -642,6 +648,19 @@ def _responses_request_to_chat_request(request: ResponsesRequest) -> ChatComplet }, ) + system_messages = [msg for msg in messages if msg.get("role") == "system"] + non_system_messages = [msg for msg in messages if msg.get("role") != "system"] + merged_system_content = "\n\n".join( + str(msg.get("content", "")).strip() + for msg in system_messages + if str(msg.get("content", "")).strip() + ) + messages = ( + [{"role": "system", "content": merged_system_content}] + if merged_system_content + else [] + ) + non_system_messages + return ChatCompletionRequest( model=request.model, messages=[Message(**msg) for msg in messages],