Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""
Set up this example by starting a vLLM OpenAI-compatible server with tool call
options enabled.
Reasoning models can be used through the Responses API as seen here
Reasoning models can be used through the Responses API as seen here
https://platform.openai.com/docs/api-reference/responses
For example:
vllm serve Qwen/Qwen3-1.7B --reasoning-parser qwen3 \
Expand Down
99 changes: 96 additions & 3 deletions tests/entrypoints/openai/test_response_api_parsable_context.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import importlib
import json

import pytest
import pytest_asyncio
Expand All @@ -13,12 +15,27 @@

@pytest.fixture(scope="module")
def server():
args = ["--reasoning-parser", "qwen3", "--max_model_len", "5000"]
assert importlib.util.find_spec("gpt_oss") is not None, (
"Harmony tests require gpt_oss package to be installed"
)

args = [
"--reasoning-parser",
"qwen3",
"--max_model_len",
"5000",
"--structured-outputs-config.backend",
"xgrammar",
"--enable-auto-tool-choice",
"--tool-call-parser",
"hermes",
"--tool-server",
"demo",
]
env_dict = dict(
VLLM_ENABLE_RESPONSES_API_STORE="1",
VLLM_USE_EXPERIMENTAL_PARSER_CONTEXT="1",
# uncomment for tool calling
# PYTHON_EXECUTION_BACKEND="dangerously_use_uv",
PYTHON_EXECUTION_BACKEND="dangerously_use_uv",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh why was this commented before? did it have issues with ci?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i left it there in a previous PR because we didn't have tool calling yet, so it wasn't necessary yet. There weren't any CI issues

)

with RemoteOpenAIServer(MODEL_NAME, args, env_dict=env_dict) as remote_server:
Expand Down Expand Up @@ -85,3 +102,79 @@ async def test_reasoning_and_function_items(client: OpenAI, model_name: str):
assert response.output[0].type == "reasoning"
assert response.output[1].type == "message"
assert type(response.output[1].content[0].text) is str


def get_horoscope(sign):
return f"{sign}: Next Tuesday you will befriend a baby otter."


def call_function(name, args):
if name == "get_horoscope":
return get_horoscope(**args)
else:
raise ValueError(f"Unknown function: {name}")


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_function_call_first_turn(client: OpenAI, model_name: str):
tools = [
{
"type": "function",
"name": "get_horoscope",
"description": "Get today's horoscope for an astrological sign.",
"parameters": {
"type": "object",
"properties": {
"sign": {"type": "string"},
},
"required": ["sign"],
"additionalProperties": False,
},
"strict": True,
}
]

response = await client.responses.create(
model=model_name,
input="What is the horoscope for Aquarius today?",
tools=tools,
temperature=0.0,
)
assert response is not None
assert response.status == "completed"
assert len(response.output) == 2
assert response.output[0].type == "reasoning"
assert response.output[1].type == "function_call"

function_call = response.output[1]
assert function_call.name == "get_horoscope"
assert function_call.call_id is not None

args = json.loads(function_call.arguments)
assert "sign" in args

# the multi turn function call is tested above in
# test_reasoning_and_function_items


@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_mcp_tool_call(client: OpenAI, model_name: str):
response = await client.responses.create(
model=model_name,
input="What is 13 * 24? Use python to calculate the result.",
tools=[{"type": "code_interpreter", "container": {"type": "auto"}}],
temperature=0.0,
)

assert response is not None
assert response.status == "completed"
assert response.output[0].type == "reasoning"
assert response.output[1].type == "mcp_call"
assert type(response.output[1].arguments) is str
assert type(response.output[1].output) is str
assert response.output[2].type == "reasoning"
# make sure the correct math is in the final output
assert response.output[3].type == "message"
assert "312" in response.output[3].content[0].text
92 changes: 88 additions & 4 deletions vllm/entrypoints/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,16 @@
from contextlib import AsyncExitStack
from typing import TYPE_CHECKING, Union

from openai.types.responses.response_function_tool_call_output_item import (
ResponseFunctionToolCallOutputItem,
)
from openai.types.responses.tool import Mcp
from openai_harmony import Author, Message, Role, StreamState, TextContent

from vllm import envs
from vllm.entrypoints.chat_utils import (
ChatTemplateContentFormatOption,
)
from vllm.entrypoints.harmony_utils import (
get_encoding,
get_streamable_parser_for_assistant,
Expand All @@ -22,16 +28,20 @@
get_responses_parser_for_simple_context,
)
from vllm.entrypoints.openai.protocol import (
FunctionCall,
ResponseInputOutputItem,
ResponseRawMessageAndToken,
ResponsesRequest,
)
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser
from vllm.entrypoints.responses_utils import construct_tool_dicts
from vllm.entrypoints.tool import Tool
from vllm.entrypoints.tool_server import ToolServer
from vllm.outputs import RequestOutput
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.tokenizers.protocol import TokenizerLike
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid

if TYPE_CHECKING:
from mcp.client import ClientSession
Expand Down Expand Up @@ -221,6 +231,10 @@ def __init__(
tokenizer: AnyTokenizer,
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser] | None,
request: ResponsesRequest,
available_tools: list[str] | None,
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
chat_template: str | None,
chat_template_content_format: ChatTemplateContentFormatOption,
):
self.num_prompt_tokens = 0
self.num_output_tokens = 0
Expand All @@ -238,12 +252,19 @@ def __init__(
reasoning_parser_cls=reasoning_parser_cls,
response_messages=response_messages,
request=request,
tool_parser_cls=tool_parser_cls,
)
self.tool_parser_cls = tool_parser_cls
self.request = request
self.tokenizer = tokenizer

self.available_tools = available_tools or []
self._tool_sessions: dict[str, ClientSession | Tool] = {}
self.called_tools: set[str] = set()

self.tool_dicts = construct_tool_dicts(request.tools, request.tool_choice)
self.chat_template = chat_template
self.chat_template_content_format = chat_template_content_format

def append_output(self, output: RequestOutput) -> None:
self.num_prompt_tokens = len(output.prompt_token_ids or [])
Expand All @@ -252,14 +273,50 @@ def append_output(self, output: RequestOutput) -> None:
self.parser.process(output.outputs[0])

def append_tool_output(self, output: list[ResponseInputOutputItem]) -> None:
raise NotImplementedError("Should not be called.")
self.parser.response_messages.extend(output)

def need_builtin_tool_call(self) -> bool:
"""Return true if the last message is a MCP tool call"""
last_message = self.parser.response_messages[-1]
# TODO: figure out which tools are MCP tools
if ( # noqa: SIM103
last_message.type == "function_call"
and last_message.name in ("code_interpreter", "python")
):
return True

return False
Comment on lines 278 to 288
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this format is quite bad lol. let's directly check the condition. also should we hardcode "code_interpreter", "python" here? i remember @alecsolder made the changes to centralize all tools to go through mcp tool type.

if xxxx:
    return True
return False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i was thinking to clean up the code in #29989, which will include browser & container tool if that's okay? This PR is just to complete the ability to call only the python tool lol


async def call_python_tool(
self, tool_session: Union["ClientSession", Tool], last_msg: FunctionCall
) -> list[ResponseInputOutputItem]:
self.called_tools.add("python")
if isinstance(tool_session, Tool):
return await tool_session.get_result_parsable_context(self)
args = json.loads(last_msg.arguments)
param = {
"code": args["code"],
}
result = await tool_session.call_tool("python", param)
result_str = result.content[0].text

message = ResponseFunctionToolCallOutputItem(
id=f"fco_{random_uuid()}",
type="function_call_output",
call_id=f"call_{random_uuid()}",
output=result_str,
status="completed",
)

return [message]

async def call_tool(self) -> list[ResponseInputOutputItem]:
raise NotImplementedError("Should not be called.")
if not self.parser.response_messages:
return []
last_msg = self.parser.response_messages[-1]
if last_msg.name == "code_interpreter":
return await self.call_python_tool(self._tool_sessions["python"], last_msg)
return []

def render_for_completion(self):
raise NotImplementedError("Should not be called.")
Expand All @@ -271,11 +328,38 @@ async def init_tool_sessions(
request_id: str,
mcp_tools: dict[str, Mcp],
):
pass
if tool_server:
for tool_name in self.available_tools:
if tool_name in self._tool_sessions:
continue

tool_type = _map_tool_name_to_tool_type(tool_name)
headers = (
mcp_tools[tool_type].headers if tool_type in mcp_tools else None
)
tool_session = await exit_stack.enter_async_context(
tool_server.new_session(tool_name, request_id, headers)
)
self._tool_sessions[tool_name] = tool_session
exit_stack.push_async_exit(self.cleanup_session)

async def cleanup_session(self, *args, **kwargs) -> None:
"""Can be used as coro to used in __aexit__"""
raise NotImplementedError("Should not be called.")

async def cleanup_tool_session(tool_session):
if not isinstance(tool_session, Tool):
logger.info(
"Cleaning up tool session for %s", tool_session._client_info
)
with contextlib.suppress(Exception):
await tool_session.call_tool("cleanup_session", {})

await asyncio.gather(
*(
cleanup_tool_session(self._tool_sessions[tool])
for tool in self.called_tools
)
)


class HarmonyContext(ConversationContext):
Expand Down
34 changes: 34 additions & 0 deletions vllm/entrypoints/openai/parser/responses_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from collections.abc import Callable

from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
from openai.types.responses.response_output_message import ResponseOutputMessage
from openai.types.responses.response_output_text import ResponseOutputText
from openai.types.responses.response_reasoning_item import (
Expand All @@ -11,8 +12,10 @@
)

from vllm.entrypoints.openai.protocol import ResponseInputOutputItem, ResponsesRequest
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser
from vllm.outputs import CompletionOutput
from vllm.reasoning.abs_reasoning_parsers import ReasoningParser
from vllm.tokenizers.protocol import TokenizerLike
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import random_uuid

Expand All @@ -29,6 +32,7 @@ def __init__(
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser],
response_messages: list[ResponseInputOutputItem],
request: ResponsesRequest,
tool_parser_cls: Callable[[TokenizerLike], ToolParser] | None,
):
self.response_messages: list[ResponseInputOutputItem] = (
# TODO: initial messages may not be properly typed
Expand All @@ -39,6 +43,9 @@ def __init__(
self.request = request

self.reasoning_parser_instance = reasoning_parser_cls(tokenizer)
self.tool_parser_instance = None
if tool_parser_cls is not None:
self.tool_parser_instance = tool_parser_cls(tokenizer)

def process(self, output: CompletionOutput) -> "ResponsesParser":
reasoning_content, content = self.reasoning_parser_instance.extract_reasoning(
Expand All @@ -59,6 +66,29 @@ def process(self, output: CompletionOutput) -> "ResponsesParser":
)
)

function_calls: list[ResponseFunctionToolCall] = []
if self.tool_parser_instance is not None:
tool_call_info = self.tool_parser_instance.extract_tool_calls(
content if content is not None else "",
request=self.request, # type: ignore
)
if tool_call_info is not None and tool_call_info.tools_called:
# extract_tool_calls() returns a list of tool calls.
function_calls.extend(
ResponseFunctionToolCall(
id=f"fc_{random_uuid()}",
call_id=f"call_{random_uuid()}",
type="function_call",
status="completed",
name=tool_call.function.name,
arguments=tool_call.function.arguments,
)
for tool_call in tool_call_info.tool_calls
)
content = tool_call_info.content
if content and content.strip() == "":
content = None

if content:
self.response_messages.append(
ResponseOutputMessage(
Expand All @@ -76,6 +106,8 @@ def process(self, output: CompletionOutput) -> "ResponsesParser":
],
)
)
if len(function_calls) > 0:
self.response_messages.extend(function_calls)

return self

Expand All @@ -86,6 +118,7 @@ def get_responses_parser_for_simple_context(
reasoning_parser_cls: Callable[[AnyTokenizer], ReasoningParser],
response_messages: list[ResponseInputOutputItem],
request: ResponsesRequest,
tool_parser_cls,
) -> ResponsesParser:
"""Factory function to create a ResponsesParser with
optional reasoning parser.
Expand All @@ -98,4 +131,5 @@ def get_responses_parser_for_simple_context(
reasoning_parser_cls=reasoning_parser_cls,
response_messages=response_messages,
request=request,
tool_parser_cls=tool_parser_cls,
)
Loading