Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions nemo_skills/inference/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,14 @@ def generate(self):
self.run_batch_evaluation()
self.postprocess()

# Shutdown tool servers (e.g., persistent python_tool HTTP server).
# Unwrap ParallelThinkingTask if present to reach ToolCallingWrapper.
llm = self.llm
if hasattr(llm, "model"):
llm = llm.model
if hasattr(llm, "shutdown"):
asyncio.run(llm.shutdown())


GENERATION_TASK_CLASS = GenerationTask

Expand Down
3 changes: 3 additions & 0 deletions nemo_skills/inference/model/tool_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ async def _execute_tool_calls(self, tool_calls: List, request_id: str, endpoint_
for tool_call, tool_result in zip(tool_calls, tool_results)
]

async def shutdown(self):
await self.tool_manager.shutdown()

async def generate_async(
self,
prompt: List,
Expand Down
118 changes: 102 additions & 16 deletions nemo_skills/mcp/servers/python_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,17 @@

import argparse
import logging
import os
import socket
import subprocess
import sys
import tempfile
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import Annotated, Any, Dict

from httpx import RemoteProtocolError
import httpx
from mcp.server.fastmcp import FastMCP
from omegaconf import OmegaConf
from pydantic import Field
Expand Down Expand Up @@ -59,7 +65,7 @@ async def stateful_python_code_exec(
output_dict, session_id = await sandbox.execute_code(
code, language=language, timeout=timeout, session_id=session_id
)
except RemoteProtocolError:
except httpx.RemoteProtocolError:
output_dict = {"process_status": "fail", "stdout": "", "stderr": "Error connecting to sandbox"}
session_id = None

Expand All @@ -68,6 +74,8 @@ async def stateful_python_code_exec(

def main():
parser = argparse.ArgumentParser(description="MCP server for executing Python code in a sandbox")
parser.add_argument("--host", default="127.0.0.1", help="Host to bind the HTTP server to")
parser.add_argument("--port", type=int, default=8765, help="Port to bind the HTTP server to")
parser.add_argument(
"--disable-session-restore",
action="store_true",
Expand All @@ -93,40 +101,104 @@ def main():
sandbox_cfg["disable_session_restore"] = True

sandbox = get_sandbox(**sandbox_cfg)
# Initialize and run the server
mcp.run(transport="stdio")
mcp.run(transport="streamable-http", host=args.host, port=args.port)


# ==============================
# Module-based tool implementation
# ==============================


def _get_free_port():
"""Get a free port by binding to port 0 and letting the OS assign one."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]


class PythonTool(MCPClientTool):
"""Tool provider that spawns a persistent python_tool HTTP server and connects via MCP Streamable HTTP."""

def __init__(self) -> None:
super().__init__()
# Defaults for stdio Python MCP using explicit client class
self.apply_config_updates(
{
"client": "nemo_skills.mcp.clients.MCPStdioClient",
"client_params": {
"command": "python",
"args": ["-m", "nemo_skills.mcp.servers.python_tool"],
},
# hide args from schemas and sanitize at runtime
# placeholder base_url — replaced in post_configure() after server starts
"client": "nemo_skills.mcp.clients.MCPStreamableHttpClient",
"client_params": {"base_url": "http://127.0.0.1:0/mcp"},
"hide_args": {"stateful_python_code_exec": ["session_id", "timeout"]},
# use explicit Hydra connector built from full context by default
"init_hook": "hydra",
# execution-specific default
"exec_timeout_s": 10,
"server_host": "127.0.0.1",
"server_port": 0, # 0 = auto-allocate
}
)
self._server_process = None
self._config_tmpfile = None
self.requests_to_sessions = defaultdict(lambda: None)

def configure(self, overrides=None, context=None):
self._context = context or {}
super().configure(overrides, context)

def post_configure(self):
port = self._config.get("server_port", 0)
if not port:
port = _get_free_port()
host = self._config.get("server_host", "127.0.0.1")

sandbox_cfg = self._context.get("sandbox", {})
self._start_server(host, port, sandbox_cfg)

# Replace the placeholder client with one pointing at the running server
from nemo_skills.mcp.clients import MCPStreamableHttpClient

self._client = MCPStreamableHttpClient(base_url=f"http://{host}:{port}/mcp")
self._client._hide_args = self._config.get("hide_args", {})
self._client._disabled_tools = set(self._config.get("disabled_tools", []))
self._client._enabled_tools = set(self._config.get("enabled_tools", []))

def _start_server(self, host, port, sandbox_cfg):
cfg = OmegaConf.create({"sandbox": sandbox_cfg or {"sandbox_type": "local"}})
self._config_tmpfile = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False)
OmegaConf.save(cfg, self._config_tmpfile)
self._config_tmpfile.close()

cmd = [
sys.executable,
"-m",
"nemo_skills.mcp.servers.python_tool",
"--host",
host,
"--port",
str(port),
"--config",
self._config_tmpfile.name,
]
logger.info(f"Starting python_tool HTTP server: {' '.join(cmd)}")
self._server_process = subprocess.Popen(cmd)
self._wait_for_ready(host, port)
logger.info(f"python_tool HTTP server ready (PID: {self._server_process.pid})")

def _wait_for_ready(self, host, port, timeout=30):
url = f"http://{host}:{port}/mcp"
deadline = time.time() + timeout
while time.time() < deadline:
if self._server_process.poll() is not None:
raise RuntimeError(
f"python_tool server exited during startup (code {self._server_process.returncode})"
)
try:
resp = httpx.get(url, timeout=2)
if resp.status_code < 500:
return
except (httpx.ConnectError, httpx.TimeoutException):
pass
time.sleep(0.5)
raise RuntimeError(f"python_tool server not ready after {timeout}s")

async def execute(self, tool_name: str, arguments: Dict[str, Any], extra_args: Dict[str, Any] | None = None):
# Ensure timeout is sent via extra_args (post-sanitize), not in main arguments
arguments = dict(arguments)
# TODO: error handling?
request_id = extra_args.pop("request_id")
merged_extra = dict(extra_args or {})
merged_extra.setdefault("timeout", self._config.get("exec_timeout_s", 10))
Expand All @@ -139,7 +211,21 @@ async def execute(self, tool_name: str, arguments: Dict[str, Any], extra_args: D
return output

async def shutdown(self) -> None:
return None
if self._server_process:
logger.info(f"Terminating python_tool server (PID: {self._server_process.pid})")
self._server_process.terminate()
try:
self._server_process.wait(timeout=5)
except subprocess.TimeoutExpired:
logger.warning("python_tool server did not terminate, killing...")
self._server_process.kill()
self._server_process = None
if self._config_tmpfile:
try:
os.unlink(self._config_tmpfile.name)
except OSError:
pass
self._config_tmpfile = None


if __name__ == "__main__":
Expand Down