Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 232 additions & 0 deletions tests/test_responses_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for the OpenAI-compatible Responses API."""

import platform
import sys
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock

import pytest
from fastapi.testclient import TestClient

pytestmark = pytest.mark.skipif(
sys.platform != "darwin" or platform.machine() != "arm64",
reason="Requires Apple Silicon",
)


@pytest.fixture()
def client():
from vllm_mlx.server import app

return TestClient(app)


@pytest.fixture(autouse=True)
def server_state():
import vllm_mlx.server as srv

original_engine = srv._engine
original_model_name = srv._model_name
original_store = srv._responses_store
original_api_key = srv._api_key

srv._engine = None
srv._model_name = "test-model"
srv._responses_store = {}
srv._api_key = None

try:
yield
finally:
srv._engine = original_engine
srv._model_name = original_model_name
srv._responses_store = original_store
srv._api_key = original_api_key


def _mock_engine(*outputs):
engine = MagicMock()
engine.model_name = "test-model"
engine.preserve_native_tool_format = False
engine.chat = AsyncMock(side_effect=list(outputs))
return engine


def _output(text: str, prompt_tokens: int = 7, completion_tokens: int = 3):
return SimpleNamespace(
text=text,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
finish_reason="stop",
)


class TestResponsesEndpoint:
def test_basic_response(self, client):
import vllm_mlx.server as srv

srv._engine = _mock_engine(_output("Hello there"))

resp = client.post(
"/v1/responses",
json={"model": "test-model", "input": "Say hello"},
)

assert resp.status_code == 200
body = resp.json()
assert body["object"] == "response"
assert body["output_text"] == "Hello there"
assert body["output"][0]["type"] == "message"
assert body["output"][0]["content"][0]["type"] == "output_text"
assert body["usage"]["input_tokens"] == 7
assert body["usage"]["output_tokens"] == 3

def test_previous_response_id_reuses_prior_context(self, client):
import vllm_mlx.server as srv

engine = _mock_engine(_output("First answer"), _output("Second answer"))
srv._engine = engine

first = client.post(
"/v1/responses",
json={"model": "test-model", "input": "First prompt"},
)
first_id = first.json()["id"]

second = client.post(
"/v1/responses",
json={
"model": "test-model",
"previous_response_id": first_id,
"input": "Follow-up prompt",
},
)

assert second.status_code == 200
second_messages = engine.chat.call_args_list[1].kwargs["messages"]
assert second_messages[0]["role"] == "user"
assert second_messages[0]["content"] == "First prompt"
assert second_messages[1]["role"] == "assistant"
assert second_messages[1]["content"] == "First answer"
assert second_messages[2]["role"] == "user"
assert second_messages[2]["content"] == "Follow-up prompt"

def test_function_call_output_input_is_mapped_cleanly(self, client):
import vllm_mlx.server as srv

engine = _mock_engine(_output("Done"))
srv._engine = engine

resp = client.post(
"/v1/responses",
json={
"model": "test-model",
"input": [
{"type": "message", "role": "user", "content": "Run it"},
{
"type": "function_call",
"call_id": "call_1",
"name": "shell",
"arguments": "{\"cmd\":\"pwd\"}",
},
{
"type": "function_call_output",
"call_id": "call_1",
"output": "/tmp/work",
},
],
},
)

assert resp.status_code == 200
messages = engine.chat.call_args.kwargs["messages"]
assert messages[1]["role"] == "assistant"
assert "[Calling tool: shell(" in messages[1]["content"]
assert messages[2]["role"] == "user"
assert "[Tool Result (call_1)]" in messages[2]["content"]
assert "/tmp/work" in messages[2]["content"]

def test_unsupported_tools_and_items_do_not_fail(self, client):
import vllm_mlx.server as srv

engine = _mock_engine(_output("Fallback answer"))
srv._engine = engine

resp = client.post(
"/v1/responses",
json={
"model": "test-model",
"input": [
{"type": "message", "role": "user", "content": "Answer directly"},
{
"type": "web_search_call",
"status": "completed",
"action": {"type": "search", "query": "ignored"},
},
],
"tools": [
{"type": "web_search_preview"},
{"type": "file_search", "vector_store_ids": ["vs_123"]},
{
"type": "function",
"name": "shell",
"parameters": {"type": "object", "properties": {}},
},
],
},
)

assert resp.status_code == 200
messages = engine.chat.call_args.kwargs["messages"]
assert messages[0]["role"] == "system"
assert "not available on this backend" in messages[0]["content"]
assert messages[1]["role"] == "user"
assert engine.chat.call_args.kwargs["tools"][0]["type"] == "function"

def test_function_call_response_item(self, client):
import vllm_mlx.server as srv

srv._engine = _mock_engine(
_output('<tool_call>{"name":"shell","arguments":{"cmd":"pwd"}}</tool_call>')
)

resp = client.post(
"/v1/responses",
json={
"model": "test-model",
"input": "Use a tool",
"tools": [
{
"type": "function",
"name": "shell",
"parameters": {"type": "object", "properties": {}},
}
],
},
)

assert resp.status_code == 200
body = resp.json()
assert body["output"][0]["type"] == "function_call"
assert body["output"][0]["name"] == "shell"
assert body["output_text"] == ""

def test_streaming_response_events(self, client):
import vllm_mlx.server as srv

srv._engine = _mock_engine(_output("Hello stream"))

with client.stream(
"POST",
"/v1/responses",
json={"model": "test-model", "input": "Hello", "stream": True},
) as resp:
stream_text = "".join(resp.iter_text())

assert resp.status_code == 200
assert "event: response.created" in stream_text
assert "event: response.in_progress" in stream_text
assert "event: response.output_text.delta" in stream_text
assert "event: response.completed" in stream_text
assert "Hello stream" in stream_text
34 changes: 34 additions & 0 deletions vllm_mlx/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,24 @@
EmbeddingUsage,
EmbeddingResponse,
)
from .responses_models import (
ResponseTextFormat,
ResponseTextConfig,
ResponseReasoningConfig,
ResponseTextContentPart,
ResponseReasoningTextPart,
ResponseReasoningSummaryTextPart,
ResponseMessageItem,
ResponseReasoningItem,
ResponseFunctionCallItem,
ResponseFunctionCallOutputItem,
ResponseFunctionTool,
ResponsesUsage,
ResponseError,
ResponseIncompleteDetails,
ResponsesRequest,
ResponseObject,
)

from .utils import (
clean_output_text,
Expand Down Expand Up @@ -111,6 +129,22 @@
"EmbeddingData",
"EmbeddingUsage",
"EmbeddingResponse",
"ResponseTextFormat",
"ResponseTextConfig",
"ResponseReasoningConfig",
"ResponseTextContentPart",
"ResponseReasoningTextPart",
"ResponseReasoningSummaryTextPart",
"ResponseMessageItem",
"ResponseReasoningItem",
"ResponseFunctionCallItem",
"ResponseFunctionCallOutputItem",
"ResponseFunctionTool",
"ResponsesUsage",
"ResponseError",
"ResponseIncompleteDetails",
"ResponsesRequest",
"ResponseObject",
# Utils
"clean_output_text",
"is_mllm_model",
Expand Down
Loading