diff --git a/gpt_oss/responses_api/api_server.py b/gpt_oss/responses_api/api_server.py index f3ba26a..8b5eeca 100644 --- a/gpt_oss/responses_api/api_server.py +++ b/gpt_oss/responses_api/api_server.py @@ -71,6 +71,9 @@ DEFAULT_TEMPERATURE = 0.0 +BROWSER_RESERVED_FUNCTIONS = {"browser.search", "browser.open", "browser.find"} + + def get_reasoning_effort( effort: Union[Literal["low", "medium", "high"], ReasoningEffort] ) -> ReasoningEffort: @@ -97,6 +100,28 @@ def is_not_builtin_tool( ) +def resolve_browser_recipient( + recipient: Optional[str], + browser_tool: Optional[SimpleBrowserTool], + user_defined_function_names: set[str], +) -> tuple[Optional[str], bool]: + if browser_tool is None or not recipient: + return (None, False) + + if recipient.startswith("browser."): + return (recipient, False) + + if recipient.startswith("functions."): + potential = recipient[len("functions.") :] + if ( + potential in BROWSER_RESERVED_FUNCTIONS + and potential not in user_defined_function_names + ): + return (potential, True) + + return (None, False) + + def create_api_server( infer_next_token: Callable[[list[int], float], int], encoding: HarmonyEncoding ) -> FastAPI: @@ -157,48 +182,25 @@ def generate_response( browser_tool_index = 0 python_tool_index = 0 reasoning_ids_iter = iter(reasoning_ids or []) + user_defined_function_names = { + name + for tool in (request_body.tools or []) + for name in [getattr(tool, "name", None)] + if getattr(tool, "type", None) == "function" and name + } for entry in entries: entry_dict = entry.to_dict() recipient = entry_dict.get("recipient", "") - if len(recipient) > 0 and is_not_builtin_tool( - recipient, treat_functions_python_as_builtin - ): - call = entry_dict["content"][0] - arguments = call["text"] - name = recipient - - if name.startswith("functions."): - name = name[len("functions.") :] - if function_call_ids and fc_index < len(function_call_ids): - fc_id, call_id = function_call_ids[fc_index] - else: - fc_id, call_id = ( - f"fc_{uuid.uuid4().hex}", - f"call_{uuid.uuid4().hex}", - ) - fc_index += 1 - output.append( - FunctionCallItem( - type="function_call", - name=name, - arguments=arguments, - id=fc_id, - call_id=call_id, - ) - ) - elif ( - len(recipient) > 0 - and recipient.startswith("browser.") - and browser_tool is not None - ): - # Mirror event-based creation of WebSearchCallItems when the browser tool is invoked - name = recipient + browser_recipient, _ = resolve_browser_recipient( + recipient, browser_tool, user_defined_function_names + ) + if browser_recipient is not None and browser_tool is not None: + name = browser_recipient call = entry_dict["content"][0] arguments = call["text"] function_name = name[len("browser.") :] - # Reconstruct a Message for argument parsing tool_msg = ( Message.from_role_and_content(Role.ASSISTANT, arguments) .with_recipient(name) @@ -243,6 +245,33 @@ def generate_response( action=action, ) ) + continue + if len(recipient) > 0 and is_not_builtin_tool( + recipient, treat_functions_python_as_builtin + ): + call = entry_dict["content"][0] + arguments = call["text"] + name = recipient + + if name.startswith("functions."): + name = name[len("functions.") :] + if function_call_ids and fc_index < len(function_call_ids): + fc_id, call_id = function_call_ids[fc_index] + else: + fc_id, call_id = ( + f"fc_{uuid.uuid4().hex}", + f"call_{uuid.uuid4().hex}", + ) + fc_index += 1 + output.append( + FunctionCallItem( + type="function_call", + name=name, + arguments=arguments, + id=fc_id, + call_id=call_id, + ) + ) elif ( len(recipient) > 0 and ( @@ -430,6 +459,19 @@ def __init__( self.reasoning_item_ids: list[str] = [] self.current_reasoning_item_id: Optional[str] = None self.functions_python_as_builtin = functions_python_as_builtin + self.user_defined_function_names = { + name + for tool in (request_body.tools or []) + for name in [getattr(tool, "name", None)] + if getattr(tool, "type", None) == "function" and name + } + + def _resolve_browser_recipient( + self, recipient: Optional[str] + ) -> tuple[Optional[str], bool]: + return resolve_browser_recipient( + recipient, self.browser_tool, self.user_defined_function_names + ) def _send_event(self, event: ResponseEvent): event.sequence_number = self.sequence_number @@ -508,8 +550,11 @@ async def run(self): previous_item = self.parser.messages[-1] if previous_item.recipient is not None: recipient = previous_item.recipient + browser_recipient, _ = self._resolve_browser_recipient( + recipient + ) if ( - not recipient.startswith("browser.") + browser_recipient is None and not ( recipient == "python" or ( @@ -763,14 +808,20 @@ async def run(self): if next_tok in encoding.stop_tokens_for_assistant_actions(): if len(self.parser.messages) > 0: last_message = self.parser.messages[-1] - if ( - self.use_browser_tool - and last_message.recipient is not None - and last_message.recipient.startswith("browser.") - ): - function_name = last_message.recipient[len("browser.") :] + browser_recipient, is_browser_fallback = ( + self._resolve_browser_recipient(last_message.recipient) + ) + if browser_recipient is not None and browser_tool is not None: + message_for_browser = ( + last_message + if not is_browser_fallback + else last_message.with_recipient(browser_recipient) + ) + function_name = browser_recipient[len("browser.") :] action = None - parsed_args = browser_tool.process_arguments(last_message) + parsed_args = browser_tool.process_arguments( + message_for_browser + ) if function_name == "search": action = WebSearchActionSearch( type="search", @@ -820,7 +871,9 @@ async def run(self): async def run_tool(): results = [] - async for msg in browser_tool.process(last_message): + async for msg in browser_tool.process( + message_for_browser + ): results.append(msg) return results diff --git a/gpt_oss/tools/python_docker/docker_tool.py b/gpt_oss/tools/python_docker/docker_tool.py index 91704fa..1e3535a 100644 --- a/gpt_oss/tools/python_docker/docker_tool.py +++ b/gpt_oss/tools/python_docker/docker_tool.py @@ -1,11 +1,15 @@ # Run this before running the tool: # $ docker image pull python:3.11 +import asyncio +import contextlib import io -import tarfile -from typing import Any, AsyncIterator -import tempfile import os +import queue import subprocess +import tarfile +import tempfile +from pathlib import Path +from typing import Any, AsyncIterator import docker from openai_harmony import ( @@ -21,10 +25,17 @@ _docker_client = None -PYTHON_EXECUTION_BACKEND = "docker" +VALID_EXECUTION_BACKENDS = { + "docker", + "dangerously_use_uv", + "dangerously_use_local_jupyter", +} + +_default_backend = os.environ.get("PYTHON_EXECUTION_BACKEND", "docker") +if _default_backend not in VALID_EXECUTION_BACKENDS: + _default_backend = "docker" -if os.environ.get("PYTHON_EXECUTION_BACKEND") == "dangerously_use_uv": - PYTHON_EXECUTION_BACKEND = "dangerously_use_uv" +PYTHON_EXECUTION_BACKEND = _default_backend def call_python_script(script: str) -> str: @@ -87,13 +98,184 @@ def call_python_script_with_uv(script: str) -> str: ) +class LocalJupyterSession: + """Stateful helper that proxies execution through a local Jupyter kernel.""" + + def __init__( + self, + connection_file: str | None = None, + *, + timeout: float = 120.0, + ) -> None: + try: + from jupyter_client import BlockingKernelClient, KernelManager + except ImportError as exc: # pragma: no cover - optional dependency + raise RuntimeError( + "The dangerously_use_local_jupyter backend requires the jupyter_client package to be installed." + ) from exc + + self._default_timeout = timeout + self._owns_kernel = False + self._client: BlockingKernelClient + self._km: KernelManager | None = None + + if connection_file: + connection_path = Path(connection_file).expanduser() + if not connection_path.exists(): + raise FileNotFoundError( + f"Cannot find Jupyter connection file at '{connection_path}'." + ) + client = BlockingKernelClient() + client.load_connection_file(str(connection_path)) + client.start_channels() + # Ensure the connection is ready before executing. + client.wait_for_ready(timeout=self._default_timeout) + self._client = client + else: + km = KernelManager() + km.start_kernel() + client = km.blocking_client() + client.start_channels() + client.wait_for_ready(timeout=self._default_timeout) + self._client = client + self._km = km + self._owns_kernel = True + + def execute(self, code: str, *, timeout: float | None = None) -> str: + """Execute code in the kernel, returning combined stdout/stderr output.""" + + client = self._client + effective_timeout = timeout or self._default_timeout + msg_id = client.execute( + code, + store_history=True, + allow_stdin=False, + stop_on_error=False, + ) + + stdout_parts: list[str] = [] + stderr_parts: list[str] = [] + + while True: + try: + msg = client.get_iopub_msg(timeout=effective_timeout) + except queue.Empty as exc: + raise TimeoutError("Timed out waiting for Jupyter kernel output.") from exc + + if msg.get("parent_header", {}).get("msg_id") != msg_id: + continue + + msg_type = msg.get("msg_type") + content = msg.get("content", {}) + + if msg_type == "stream": + text = content.get("text", "") + if content.get("name") == "stdout": + stdout_parts.append(text) + else: + stderr_parts.append(text) + elif msg_type == "error": + traceback_data = content.get("traceback") + if traceback_data: + stderr_parts.append("\n".join(traceback_data)) + else: + ename = content.get("ename", "") + evalue = content.get("evalue", "") + stderr_parts.append(f"{ename}: {evalue}".strip()) + elif msg_type in {"execute_result", "display_data"}: + data = content.get("data", {}) + text = data.get("text/plain") + if text: + stdout_parts.append(text if text.endswith("\n") else f"{text}\n") + elif msg_type == "status" and content.get("execution_state") == "idle": + break + + # Drain the shell channel to capture final execution status. + while True: + try: + reply = client.get_shell_msg(timeout=effective_timeout) + except queue.Empty as exc: + raise TimeoutError( + "Timed out waiting for Jupyter kernel execution reply." + ) from exc + + if reply.get("parent_header", {}).get("msg_id") != msg_id: + continue + + reply_content = reply.get("content", {}) + if reply_content.get("status") == "error": + traceback_data = reply_content.get("traceback") + if traceback_data: + stderr_parts.append("\n".join(traceback_data)) + else: + ename = reply_content.get("ename", "") + evalue = reply_content.get("evalue", "") + stderr_parts.append(f"{ename}: {evalue}".strip()) + break + + stdout = "".join(stdout_parts) + stderr = "".join(stderr_parts) + + if stderr: + if stdout: + stdout = f"{stdout.rstrip()}\n{stderr}" + else: + stdout = stderr + + if not stdout.strip(): + stdout = ( + "[WARN] No output available. Use print() to output anything to stdout to " + "receive the output" + ) + + return stdout + + def close(self) -> None: + with contextlib.suppress(Exception): + self._client.stop_channels() + + if self._owns_kernel and self._km is not None: + with contextlib.suppress(Exception): + self._km.shutdown_kernel(now=True) + + def __del__(self) -> None: # pragma: no cover - best-effort cleanup + self.close() + class PythonTool(Tool): def __init__( self, name: str = "python", + *, + execution_backend: str | None = None, + local_jupyter_connection_file: str | None = None, + local_jupyter_timeout: float = 60.0, ): assert name == "python" + backend = execution_backend or PYTHON_EXECUTION_BACKEND + if backend not in VALID_EXECUTION_BACKENDS: + raise ValueError( + "execution_backend must be one of: " + + ", ".join(sorted(VALID_EXECUTION_BACKENDS)) + ) + + self._execution_backend = backend + self._local_jupyter_connection_file = ( + local_jupyter_connection_file + or os.environ.get("PYTHON_LOCAL_JUPYTER_CONNECTION_FILE") + ) + self._local_jupyter_timeout = local_jupyter_timeout + + self._jupyter_session: LocalJupyterSession | None = None + self._execution_lock: asyncio.Lock | None = None + + if self._execution_backend == "dangerously_use_local_jupyter": + self._execution_lock = asyncio.Lock() + self._jupyter_session = LocalJupyterSession( + connection_file=self._local_jupyter_connection_file, + timeout=self._local_jupyter_timeout, + ) + @classmethod def get_tool_name(cls) -> str: return "python" @@ -104,9 +286,17 @@ def name(self) -> str: @property def instruction(self) -> str: - return """ + if self._execution_backend == "dangerously_use_local_jupyter": + return """ Use this tool to execute Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files). +When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 120.0 seconds. Internet access for this session is UNKNOWN. Depends on the cluster. + """.strip() + + return """ +Use this tool to execute STATELESS Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files). When you send a message containing python code to python, it will be executed in a stateless docker container, and the stdout of that process will be returned to you. You have to use print statements to access the output. + +IMPORTANT: Your python environment is not shared between calls. You will have to pass your entire code each time. """.strip() @property @@ -147,12 +337,34 @@ def make_response( async def _process(self, message: Message) -> AsyncIterator[Message]: script = message.content[0].text channel = message.channel - if PYTHON_EXECUTION_BACKEND == "docker": + + if self._execution_backend == "docker": output = call_python_script(script) - elif PYTHON_EXECUTION_BACKEND == "dangerously_use_uv": + elif self._execution_backend == "dangerously_use_uv": output = call_python_script_with_uv(script) + elif self._execution_backend == "dangerously_use_local_jupyter": + assert self._jupyter_session is not None + lock = self._execution_lock + if lock is not None: + async with lock: + try: + output = self._jupyter_session.execute(script) + except TimeoutError as exc: + output = f"[ERROR] {exc}" + else: + try: + output = self._jupyter_session.execute(script) + except TimeoutError as exc: + output = f"[ERROR] {exc}" else: raise ValueError( - f"Invalid PYTHON_EXECUTION_BACKEND: {PYTHON_EXECUTION_BACKEND}" + f"Invalid PYTHON_EXECUTION_BACKEND: {self._execution_backend}" ) yield self._make_response(output, channel=channel) + + def close(self) -> None: + if self._jupyter_session is not None: + self._jupyter_session.close() + + def __del__(self) -> None: # pragma: no cover - best-effort cleanup + self.close() diff --git a/pyproject.toml b/pyproject.toml index da46bd9..d2595a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "uvicorn>=0.35.0", "requests>=2.31.0", "termcolor", + "jupyter-client>=8.6.3", ] readme = "README.md" requires-python = ">=3.12"