Skip to content

[WIP] MCP code executor to execute LLM-generated code flexibly in CodeAct and ProgramOfThought #8467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 62 additions & 13 deletions dspy/predict/code_act.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import asyncio
import atexit
import inspect
import logging
from typing import Callable, Type
Expand All @@ -6,17 +8,44 @@
from dspy.adapters.types.tool import Tool
from dspy.predict.program_of_thought import ProgramOfThought
from dspy.predict.react import ReAct
from dspy.primitives.python_interpreter import PythonInterpreter
from dspy.signatures.signature import Signature, ensure_signature
from dspy.utils.mcp_python_interpreter.client import PythonInterpreterClient as PythonInterpreter

logger = logging.getLogger(__name__)


def run_async(coro):
"""
Run an async coroutine from a synchronous context.
If already in an event loop (e.g., Jupyter), use nest_asyncio to allow nested loops.
"""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None

if loop and loop.is_running():
# If we're in a running event loop (e.g., Jupyter), use asyncio.create_task and run until done
import nest_asyncio

nest_asyncio.apply()
return asyncio.get_event_loop().run_until_complete(coro)
else:
return asyncio.run(coro)


class CodeAct(ReAct, ProgramOfThought):
"""
CodeAct is a module that utilizes the Code Interpreter and predefined tools to solve the problem.
"""

def __init__(self, signature: str | Type[Signature], tools: list[Callable], max_iters: int = 5, interpreter: PythonInterpreter | None = None):
def __init__(
self,
signature: str | Type[Signature],
tools: list[Callable],
max_iters: int = 5,
interpreter: PythonInterpreter | None = None,
):
"""
Initializes the CodeAct class with the specified model, temperature, and max tokens.

Expand All @@ -42,9 +71,7 @@ def factorial(n):
self.history = []

tools = [t if isinstance(t, Tool) else Tool(t) for t in tools]
if any(
not inspect.isfunction(tool.func) for tool in tools
):
if any(not inspect.isfunction(tool.func) for tool in tools):
raise ValueError("CodeAct only accepts functions and not callable objects.")
tools = {tool.name: tool for tool in tools}

Expand All @@ -53,7 +80,13 @@ def factorial(n):
codeact_signature = (
dspy.Signature({**self.signature.input_fields}, "\n".join(instructions))
.append("trajectory", dspy.InputField(), type_=str)
.append("generated_code", dspy.OutputField(desc="Python code that when executed, produces output relevant to answering the question"), type_=str)
.append(
"generated_code",
dspy.OutputField(
desc="Python code that when executed, produces output relevant to answering the question"
),
type_=str,
)
.append("finished", dspy.OutputField(desc="a boolean flag to determine if the process is done"), type_=bool)
)

Expand All @@ -67,6 +100,15 @@ def factorial(n):
self.extractor = dspy.ChainOfThought(extract_signature)
# It will raises exception when dspy cannot find available deno instance by now.
self.interpreter = interpreter or PythonInterpreter()
self.interpreter_initialized = False

# Register shutdown to atexit
atexit.register(self.shutdown)

async def init_interpreter(self):
await self.interpreter.connect_to_server()
await self.interpreter.register_functions([tool.func for tool in self.tools.values()])
self.interpreter_initialized = True

def _build_instructions(self, signature, tools):
instructions = [f"{signature.instructions}\n"] if signature.instructions else []
Expand All @@ -82,15 +124,21 @@ def _build_instructions(self, signature, tools):
"You have access to the Python Standard Library and the following functions:"
)

# for idx, tool in enumerate(tools.values()):
# instructions.append(f"({idx + 1}) {tool}")

for idx, tool in enumerate(tools.values()):
instructions.append(f"({idx + 1}) {tool}")
instructions.append(f"```python\n{inspect.getsource(tool.func)}\n```\n\n")

return instructions

def forward(self, **kwargs):
# Define the tool funcitons in the interpreter
for tool in self.tools.values():
self.interpreter(inspect.getsource(tool.func))
return run_async(self.aforward(**kwargs))

async def aforward(self, **kwargs):
if not self.interpreter_initialized:
await self.init_interpreter()

trajectory = {}
max_iters = kwargs.pop("max_iters", self.max_iters)
Expand All @@ -104,8 +152,7 @@ def forward(self, **kwargs):
continue

trajectory[f"generated_code_{idx}"] = code
output, error = self._execute_code(code)

output, error = await self._aexecute_code(code)
if not error:
trajectory[f"code_output_{idx}"] = output
else:
Expand All @@ -114,6 +161,8 @@ def forward(self, **kwargs):
if code_data.finished:
break

extract = self._call_with_potential_trajectory_truncation(self.extractor, trajectory, **kwargs)
self.interpreter.shutdown()
extract = await self._async_call_with_potential_trajectory_truncation(self.extractor, trajectory, **kwargs)
return dspy.Prediction(trajectory=trajectory, **extract)

def shutdown(self):
run_async(self.interpreter.shutdown())
19 changes: 17 additions & 2 deletions dspy/predict/program_of_thought.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

import dspy
from dspy.primitives.module import Module
from dspy.primitives.python_interpreter import PythonInterpreter
from dspy.signatures.signature import Signature, ensure_signature
from dspy.utils.mcp_python_interpreter.client import PythonInterpreterClient as PythonInterpreter

logger = logging.getLogger(__name__)

Expand All @@ -27,7 +27,9 @@ class ProgramOfThought(Module):
```
"""

def __init__(self, signature: str | Type[Signature], max_iters: int = 3, interpreter: PythonInterpreter | None = None):
def __init__(
self, signature: str | Type[Signature], max_iters: int = 3, interpreter: PythonInterpreter | None = None
):
"""
Args:
signature: The signature of the module.
Expand Down Expand Up @@ -62,6 +64,9 @@ def __init__(self, signature: str | Type[Signature], max_iters: int = 3, interpr
# It will raises exception when dspy cannot find available deno instance by now.
self.interpreter = interpreter or PythonInterpreter()

async def init_interpreter(self):
await self.interpreter.connect_to_server()

def _generate_signature(self, mode):
signature_dict = dict(self.input_fields)
fields_for_mode = {
Expand Down Expand Up @@ -172,6 +177,16 @@ def _execute_code(self, code):
except Exception as e:
return None, str(e)

async def _aexecute_code(self, code):
if not code:
return None, "Error: Empty code before execution."

try:
output = await self.interpreter.execute(code)
return output, None
except Exception as e:
return None, str(e)

def forward(self, **kwargs):
input_kwargs = {field_name: kwargs[field_name] for field_name in self.input_fields}
code_data = self.code_generate(**input_kwargs)
Expand Down
2 changes: 2 additions & 0 deletions dspy/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from dspy.utils.callback import BaseCallback, with_callbacks
from dspy.utils.dummies import DummyLM, DummyVectorizer, dummy_rm
from dspy.utils.inspect_history import pretty_print_history
from dspy.utils.mcp_python_interpreter.client import PythonInterpreterClient


def download(url):
Expand All @@ -32,4 +33,5 @@ def download(url):
"StatusMessage",
"StatusMessageProvider",
"pretty_print_history",
"PythonInterpreterClient",
]
Empty file.
133 changes: 133 additions & 0 deletions dspy/utils/mcp_python_interpreter/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import inspect
import os
from contextlib import AsyncExitStack
from typing import Callable

from dotenv import load_dotenv
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client

load_dotenv() # load environment variables from .env


class PythonInterpreterClient:
def __init__(self):
# Initialize session and client objects
self.session: ClientSession | None = None
self.exit_stack = AsyncExitStack()

import dspy

class GetCleanFunctionDefinition(dspy.Signature):
"""Get the clean function definition code. Code unrelated to function definition, like main functions is
removed.

A few additional rules:
1. If the function definition relies on imported modules, make sure the import statements are included in the
output clean code.
2. If the function definition relies on custom helper functions, make sure the helper function definitions are
included in the output clean code.
3. If the function definition relies on some global variables, make sure the global variable definitions are
included in the output clean code.
"""

dirty_code: str = dspy.InputField(
description="The code containing the function definitions, which might be dirty."
)
function_names: list[str] = dspy.InputField(
description=(
"The names of the functions that the clean code must be able to define. If it relies on "
"custom helper functions, imported modules or global variables, make sure the relevant code is "
"included in the output clean code."
)
)
clean_code: str = dspy.OutputField(
description="The code only contain the function definitions, without any other code."
)

self.code_cleaner = dspy.ChainOfThought(GetCleanFunctionDefinition)

async def connect_to_server(self):
"""Connect to an MCP server

Args:
server_script_path: Path to the server script (.py or .js)
"""

server_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), # directory of client.py
"./server.py",
)
server_path = os.path.abspath(server_path)
server_params = StdioServerParameters(
command="python",
args=[server_path],
env=None,
)

stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
self.stdio, self.write = stdio_transport
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))

await self.session.initialize()

# List available tools
response = await self.session.list_tools()
tools = response.tools
print("\nConnected to server with tools:", [tool.name for tool in tools])

async def call_tool(self, tool_name: str, tool_args: dict):
"""Call a tool"""
response = await self.session.call_tool(tool_name, tool_args)
return response.content

def _get_source_code(self, funcs: list[Callable]):
source_files = set()
for func in funcs:
original_func = inspect.unwrap(func)
path = inspect.getsourcefile(original_func)
if path is None:
raise ValueError("Could not determine source file")
source_files.add(path)

source_code = ""
for path in source_files:
with open(path) as f:
source_code += f.read()
source_code += "\n\n"
return source_code

async def execute(self, code: str):
"""Execute Python code"""
response = await self.session.call_tool("run_python_code", {"code": code})
return response.content

async def register_functions_by_code(self, code: str):
"""Register functions by code"""
await self.session.call_tool("register_functions", {"code": code})

async def register_functions_by_file(self, file_path: str):
"""Register functions by file path"""
with open(file_path) as f:
code = f.read()
await self.session.call_tool("register_functions", {"code": code})

async def register_functions(self, functions: list[dict]):
"""Register functions to the MCP server"""
source_code = self._get_source_code(functions)

import dspy

with dspy.context(lm=dspy.LM("openai/gpt-4o-mini")):
clean_code = self.code_cleaner(
dirty_code=source_code,
function_names=[func.__name__ for func in functions],
).clean_code
if clean_code.startswith("```python"):
clean_code = clean_code[len("```python") : -len("```")]
await self.session.call_tool("register_functions", {"code": clean_code})

async def shutdown(self):
"""Clean up resources"""
await self.session.call_tool("cleanup", {})
await self.exit_stack.aclose()
Loading
Loading