diff --git a/nemo_skills/inference/generate.py b/nemo_skills/inference/generate.py index b3f0094c1a..a1606d45bf 100644 --- a/nemo_skills/inference/generate.py +++ b/nemo_skills/inference/generate.py @@ -872,6 +872,14 @@ def generate(self): self.run_batch_evaluation() self.postprocess() + # Shutdown tool servers (e.g., persistent python_tool HTTP server). + # Unwrap ParallelThinkingTask if present to reach ToolCallingWrapper. + llm = self.llm + if hasattr(llm, "model"): + llm = llm.model + if hasattr(llm, "shutdown"): + asyncio.run(llm.shutdown()) + GENERATION_TASK_CLASS = GenerationTask diff --git a/nemo_skills/inference/model/tool_call.py b/nemo_skills/inference/model/tool_call.py index 313b88b057..bc192437b8 100644 --- a/nemo_skills/inference/model/tool_call.py +++ b/nemo_skills/inference/model/tool_call.py @@ -109,6 +109,9 @@ async def _execute_tool_calls(self, tool_calls: List, request_id: str, endpoint_ for tool_call, tool_result in zip(tool_calls, tool_results) ] + async def shutdown(self): + await self.tool_manager.shutdown() + async def generate_async( self, prompt: List, diff --git a/nemo_skills/mcp/servers/python_tool.py b/nemo_skills/mcp/servers/python_tool.py index cce3d9c301..0532b4e730 100644 --- a/nemo_skills/mcp/servers/python_tool.py +++ b/nemo_skills/mcp/servers/python_tool.py @@ -14,11 +14,17 @@ import argparse import logging +import os +import socket +import subprocess +import sys +import tempfile +import time from collections import defaultdict from dataclasses import dataclass from typing import Annotated, Any, Dict -from httpx import RemoteProtocolError +import httpx from mcp.server.fastmcp import FastMCP from omegaconf import OmegaConf from pydantic import Field @@ -59,7 +65,7 @@ async def stateful_python_code_exec( output_dict, session_id = await sandbox.execute_code( code, language=language, timeout=timeout, session_id=session_id ) - except RemoteProtocolError: + except httpx.RemoteProtocolError: output_dict = {"process_status": "fail", "stdout": "", "stderr": "Error connecting to sandbox"} session_id = None @@ -68,6 +74,8 @@ async def stateful_python_code_exec( def main(): parser = argparse.ArgumentParser(description="MCP server for executing Python code in a sandbox") + parser.add_argument("--host", default="127.0.0.1", help="Host to bind the HTTP server to") + parser.add_argument("--port", type=int, default=8765, help="Port to bind the HTTP server to") parser.add_argument( "--disable-session-restore", action="store_true", @@ -93,8 +101,7 @@ def main(): sandbox_cfg["disable_session_restore"] = True sandbox = get_sandbox(**sandbox_cfg) - # Initialize and run the server - mcp.run(transport="stdio") + mcp.run(transport="streamable-http", host=args.host, port=args.port) # ============================== @@ -102,31 +109,96 @@ def main(): # ============================== +def _get_free_port(): + """Get a free port by binding to port 0 and letting the OS assign one.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + class PythonTool(MCPClientTool): + """Tool provider that spawns a persistent python_tool HTTP server and connects via MCP Streamable HTTP.""" + def __init__(self) -> None: super().__init__() - # Defaults for stdio Python MCP using explicit client class self.apply_config_updates( { - "client": "nemo_skills.mcp.clients.MCPStdioClient", - "client_params": { - "command": "python", - "args": ["-m", "nemo_skills.mcp.servers.python_tool"], - }, - # hide args from schemas and sanitize at runtime + # placeholder base_url — replaced in post_configure() after server starts + "client": "nemo_skills.mcp.clients.MCPStreamableHttpClient", + "client_params": {"base_url": "http://127.0.0.1:0/mcp"}, "hide_args": {"stateful_python_code_exec": ["session_id", "timeout"]}, - # use explicit Hydra connector built from full context by default "init_hook": "hydra", - # execution-specific default "exec_timeout_s": 10, + "server_host": "127.0.0.1", + "server_port": 0, # 0 = auto-allocate } ) + self._server_process = None + self._config_tmpfile = None self.requests_to_sessions = defaultdict(lambda: None) + def configure(self, overrides=None, context=None): + self._context = context or {} + super().configure(overrides, context) + + def post_configure(self): + port = self._config.get("server_port", 0) + if not port: + port = _get_free_port() + host = self._config.get("server_host", "127.0.0.1") + + sandbox_cfg = self._context.get("sandbox", {}) + self._start_server(host, port, sandbox_cfg) + + # Replace the placeholder client with one pointing at the running server + from nemo_skills.mcp.clients import MCPStreamableHttpClient + + self._client = MCPStreamableHttpClient(base_url=f"http://{host}:{port}/mcp") + self._client._hide_args = self._config.get("hide_args", {}) + self._client._disabled_tools = set(self._config.get("disabled_tools", [])) + self._client._enabled_tools = set(self._config.get("enabled_tools", [])) + + def _start_server(self, host, port, sandbox_cfg): + cfg = OmegaConf.create({"sandbox": sandbox_cfg or {"sandbox_type": "local"}}) + self._config_tmpfile = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) + OmegaConf.save(cfg, self._config_tmpfile) + self._config_tmpfile.close() + + cmd = [ + sys.executable, + "-m", + "nemo_skills.mcp.servers.python_tool", + "--host", + host, + "--port", + str(port), + "--config", + self._config_tmpfile.name, + ] + logger.info(f"Starting python_tool HTTP server: {' '.join(cmd)}") + self._server_process = subprocess.Popen(cmd) + self._wait_for_ready(host, port) + logger.info(f"python_tool HTTP server ready (PID: {self._server_process.pid})") + + def _wait_for_ready(self, host, port, timeout=30): + url = f"http://{host}:{port}/mcp" + deadline = time.time() + timeout + while time.time() < deadline: + if self._server_process.poll() is not None: + raise RuntimeError( + f"python_tool server exited during startup (code {self._server_process.returncode})" + ) + try: + resp = httpx.get(url, timeout=2) + if resp.status_code < 500: + return + except (httpx.ConnectError, httpx.TimeoutException): + pass + time.sleep(0.5) + raise RuntimeError(f"python_tool server not ready after {timeout}s") + async def execute(self, tool_name: str, arguments: Dict[str, Any], extra_args: Dict[str, Any] | None = None): - # Ensure timeout is sent via extra_args (post-sanitize), not in main arguments arguments = dict(arguments) - # TODO: error handling? request_id = extra_args.pop("request_id") merged_extra = dict(extra_args or {}) merged_extra.setdefault("timeout", self._config.get("exec_timeout_s", 10)) @@ -139,7 +211,21 @@ async def execute(self, tool_name: str, arguments: Dict[str, Any], extra_args: D return output async def shutdown(self) -> None: - return None + if self._server_process: + logger.info(f"Terminating python_tool server (PID: {self._server_process.pid})") + self._server_process.terminate() + try: + self._server_process.wait(timeout=5) + except subprocess.TimeoutExpired: + logger.warning("python_tool server did not terminate, killing...") + self._server_process.kill() + self._server_process = None + if self._config_tmpfile: + try: + os.unlink(self._config_tmpfile.name) + except OSError: + pass + self._config_tmpfile = None if __name__ == "__main__":