diff --git a/.env.example b/.env.example index a0a48050..88bd87b4 100644 --- a/.env.example +++ b/.env.example @@ -20,3 +20,6 @@ WEAVIATE_HTTP_PORT="443" # or 8080 for localhost WEAVIATE_GRPC_PORT="443" # or 50051 for localhost WEAVIATE_HTTP_SECURE="true" # set to false for localhost WEAVIATE_GRPC_SECURE="true" # set to false for localhost + +# Optionally, specify E2B.dev API key for Python Code Interpreter +E2B_API_KEY="e2b_..." diff --git a/README.md b/README.md index a31b34fc..6056a421 100644 --- a/README.md +++ b/README.md @@ -92,6 +92,12 @@ uv run --env-file .env \ -m src.2_frameworks.2_multi_agent.planner_worker_gradio ``` +Python Code Interpreter demo- using the OpenAI Agent SDK, E2B for secure code sandbox, and LangFuse for observability. + +```bash +uv run --env-file .env -m src.2_frameworks.code_interpreter_gradio +``` + ### 3. Evals Synthetic data. diff --git a/pyproject.toml b/pyproject.toml index fbbc6950..4d3be224 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "aiohttp>=3.12.14", "beautifulsoup4>=4.13.4", "datasets>=3.6.0", + "e2b-code-interpreter>=1.5.2", "gradio>=5.35.0", "langfuse>=3.1.3", "lxml>=6.0.0", @@ -23,6 +24,7 @@ dependencies = [ "pydantic-ai-slim[logfire]>=0.3.7", "pytest-asyncio>=0.25.2", "scikit-learn>=1.7.0", + "starlette>=0.47.2", "weaviate-client>=4.15.4", ] diff --git a/src/2_frameworks/code_interpreter_gradio.py b/src/2_frameworks/code_interpreter_gradio.py new file mode 100644 index 00000000..b039d41d --- /dev/null +++ b/src/2_frameworks/code_interpreter_gradio.py @@ -0,0 +1,98 @@ +"""Code Interpreter example. + +Logs traces to LangFuse for observability and evaluation. + +You will need your E2B API Key. +""" + +import logging +from pathlib import Path + +import agents +import gradio as gr +from dotenv import load_dotenv +from gradio.components.chatbot import ChatMessage +from openai import AsyncOpenAI + +from src.utils import ( + CodeInterpreter, + oai_agent_stream_to_gradio_messages, + pretty_print, + setup_langfuse_tracer, +) +from src.utils.langfuse.shared_client import langfuse_client + + +load_dotenv(verbose=True) + + +logging.basicConfig(level=logging.INFO) + +CODE_INTERPRETER_INSTRUCTIONS = """\ +The `code_interpreter` tool executes Python commands. \ +Please note that data is not persisted. Each time you invoke this tool, \ +you will need to run import and define all variables from scratch. + +You can access the local filesystem using this tool. \ +Instead of asking the user for file inputs, you should try to find the file \ +using this tool. + +Recommended packages: Pandas, Numpy, SymPy, Scikit-learn. + +You can also run Jupyter-style shell commands (e.g., `!pip freeze`) +but you won't be able to install packages. +""" + +AGENT_LLM_NAME = "gemini-2.5-flash" +async_openai_client = AsyncOpenAI() +code_interpreter = CodeInterpreter( + local_files=[Path("tests/tool_tests/example_files/example_a.csv")] +) + + +async def _main(question: str, gr_messages: list[ChatMessage]): + setup_langfuse_tracer() + + main_agent = agents.Agent( + name="Data Analysis Agent", + instructions=CODE_INTERPRETER_INSTRUCTIONS, + tools=[ + agents.function_tool( + code_interpreter.run_code, + name_override="code_interpreter", + ) + ], + model=agents.OpenAIChatCompletionsModel( + model=AGENT_LLM_NAME, openai_client=async_openai_client + ), + ) + + with langfuse_client.start_as_current_span(name="Agents-SDK-Trace") as span: + span.update(input=question) + + result_stream = agents.Runner.run_streamed(main_agent, input=question) + async for _item in result_stream.stream_events(): + gr_messages += oai_agent_stream_to_gradio_messages(_item) + if len(gr_messages) > 0: + yield gr_messages + + span.update(output=result_stream.final_output) + + pretty_print(gr_messages) + yield gr_messages + + +demo = gr.ChatInterface( + _main, + title="2.1 OAI Agent SDK ReAct + LangFuse Code Interpreter", + type="messages", + examples=[ + "What is the sum of the column `x` in this example_a.csv?", + "What is the sum of the column `y` in this example_a.csv?", + "Create a linear best-fit line for the data in example_a.csv.", + ], +) + + +if __name__ == "__main__": + demo.launch(server_name="0.0.0.0") diff --git a/src/utils/__init__.py b/src/utils/__init__.py index a35f92a7..dc2e3678 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -11,5 +11,6 @@ from .langfuse.oai_sdk_setup import setup_langfuse_tracer from .logging import set_up_logging from .pretty_printing import pretty_print +from .tools.code_interpreter import CodeInterpreter from .tools.kb_weaviate import AsyncWeaviateKnowledgeBase, get_weaviate_async_client from .trees import tree_filter diff --git a/src/utils/gradio/messages.py b/src/utils/gradio/messages.py index 596bb4b9..62003010 100644 --- a/src/utils/gradio/messages.py +++ b/src/utils/gradio/messages.py @@ -131,5 +131,18 @@ def oai_agent_stream_to_gradio_messages( }, ) ) + elif isinstance(stream_event, stream_events.RunItemStreamEvent): + name = stream_event.name + item = stream_event.item + if name == "tool_output" and isinstance(item, ToolCallOutputItem): + output.append( + ChatMessage( + role="assistant", + content=f"```\n{item.output}\n```", + metadata={ + "title": "*Tool call output*", + }, + ) + ) return output diff --git a/src/utils/tools/code_interpreter.py b/src/utils/tools/code_interpreter.py new file mode 100644 index 00000000..39106941 --- /dev/null +++ b/src/utils/tools/code_interpreter.py @@ -0,0 +1,129 @@ +"""Code interpreter tool.""" + +from pathlib import Path +from typing import Sequence + +from e2b_code_interpreter import AsyncSandbox +from pydantic import BaseModel + +from ..async_utils import gather_with_progress + + +class _CodeInterpreterOutputError(BaseModel): + """Error from code interpreter.""" + + name: str + value: str + traceback: str + + +class CodeInterpreterOutput(BaseModel): + """Output from code interpreter.""" + + stdout: list[str] + stderr: list[str] + error: _CodeInterpreterOutputError | None = None + + def __init__(self, stdout: list[str], stderr: list[str], **kwargs): + """Split lines in stdout and stderr.""" + stdout_processed = [] + for _line in stdout: + stdout_processed.extend(_line.splitlines()) + + stderr_processed = [] + for _line in stderr: + stderr_processed.extend(_line.splitlines()) + + super().__init__(stdout=stdout_processed, stderr=stderr_processed, **kwargs) + + +async def _upload_file(sandbox: "AsyncSandbox", local_path: "str | Path") -> str: + """Upload file to sandbox. + + Returns + ------- + str, denoting the remote path. + """ + path = Path(local_path) + remote_path = f"{path.name}" + with open(local_path, "rb") as file: + await sandbox.files.write(remote_path, file) + + return remote_path + + +async def _upload_files( + sandbox: "AsyncSandbox", paths: Sequence[Path | str] +) -> list[str]: + """Upload files to the sandbox. + + Parameters + ---------- + paths: Sequence[pathlib.Path | str] + Files to upload to the sandbox. + + Returns + ------- + list[str] + List of remote paths, one per file. + """ + if not paths: + return [] + + file_upload_coros = [_upload_file(sandbox, _path) for _path in paths] + remote_paths = await gather_with_progress( + file_upload_coros, description=f"Uploading {len(paths)} to sandbox" + ) + return list(remote_paths) + + +class CodeInterpreter: + """Code Interpreter tool for the agent.""" + + def __init__( + self, + local_files: "Sequence[Path | str]| None" = None, + timeout_seconds: int = 30, + ): + """Configure your Code Interpreter session. + + Note that the sandbox is not persistent, and each run_code will + execute in a fresh sandbox! (e.g., variables need to be re-declared each time.) + + Parameters + ---------- + local_files : list[pathlib.Path | str] | None + Optionally, specify a list of local files (as paths) + to upload to sandbox working directory. + timeout_seconds : int + Limit executions to this duration. + """ + self.timeout_seconds = timeout_seconds + self.local_files = local_files if local_files else [] + + async def run_code(self, code: str) -> str: + """Run the given Python code in a sandbox environment. + + Parameters + ---------- + code : str + Python logic to execute. + """ + sbx = await AsyncSandbox.create(timeout=self.timeout_seconds) + await _upload_files(sbx, self.local_files) + + try: + result = await sbx.run_code( + code, on_error=lambda error: print(error.traceback) + ) + response = CodeInterpreterOutput.model_validate_json(result.logs.to_json()) + + error = result.error + if error is not None: + response.error = _CodeInterpreterOutputError.model_validate_json( + error.to_json() + ) + + return response.model_dump_json() + finally: + await sbx.kill() diff --git a/tests/README.md b/tests/README.md index 0c188be0..970ef93e 100644 --- a/tests/README.md +++ b/tests/README.md @@ -2,5 +2,7 @@ ```bash uv run pytest -sv tests/tool_tests/test_weaviate.py +uv run pytest -sv tests/tool_tests/test_code_interpreter.py PYTHONPATH="." uv run pytest -sv tests/tool_tests/test_integration.py + ``` diff --git a/tests/tool_tests/example_files/example_a.csv b/tests/tool_tests/example_files/example_a.csv new file mode 100644 index 00000000..663155b5 --- /dev/null +++ b/tests/tool_tests/example_files/example_a.csv @@ -0,0 +1,3 @@ +x,y +1,18 +6,108 diff --git a/tests/tool_tests/test_code_interpreter.py b/tests/tool_tests/test_code_interpreter.py new file mode 100644 index 00000000..6318c7bf --- /dev/null +++ b/tests/tool_tests/test_code_interpreter.py @@ -0,0 +1,67 @@ +"""Test code interpreter tool.""" + +from pathlib import Path + +import pytest + +from src.utils import pretty_print +from src.utils.tools.code_interpreter import ( + CodeInterpreter, + CodeInterpreterOutput, +) + + +PANDAS_VERSION_SCRIPT = """\ +import os +import pandas as pd +print(pd.__version__) +""" + +PANDAS_READ_FILE_SCRIPT = """\ +import pandas as pd +from pathlib import Path + +assert Path("example_a.csv").exists() +df = pd.read_csv("example_a.csv") +print(df.sum()["y"]) +""" + + +@pytest.mark.asyncio +async def test_code_interpreter(): + """Test running a Python command in the interpreter.""" + session = CodeInterpreter(timeout_seconds=15) + + response = await session.run_code(PANDAS_VERSION_SCRIPT) + response_typed = CodeInterpreterOutput.model_validate_json(response) + assert response_typed.error is None + + pretty_print(response_typed) + pd_version_major, *_ = response_typed.stdout[0].strip().split(".") + assert int(pd_version_major) >= 2 + + +@pytest.mark.asyncio +async def test_jupyter_command(): + """Test running a Python command in the interpreter.""" + session = CodeInterpreter(timeout_seconds=15) + + response = await session.run_code("!pip freeze") + response_typed = CodeInterpreterOutput.model_validate_json(response) + + pretty_print(response_typed) + + +@pytest.mark.asyncio +async def test_code_interpreter_upload_file(): + """Test running a Python command in the interpreter.""" + example_paths = [Path("tests/tool_tests/example_files/example_a.csv")] + for _path in example_paths: + assert _path.exists() + + session = CodeInterpreter(timeout_seconds=15, local_files=example_paths) + response = await session.run_code(PANDAS_READ_FILE_SCRIPT) + response_typed = CodeInterpreterOutput.model_validate_json(response) + + pretty_print(response_typed) + assert int(response_typed.stdout[0]) == 126 diff --git a/uv.lock b/uv.lock index 4644045c..c7c52600 100644 --- a/uv.lock +++ b/uv.lock @@ -14,6 +14,7 @@ dependencies = [ { name = "aiohttp" }, { name = "beautifulsoup4" }, { name = "datasets" }, + { name = "e2b-code-interpreter" }, { name = "gradio" }, { name = "langfuse" }, { name = "lxml" }, @@ -26,6 +27,7 @@ dependencies = [ { name = "pydantic-ai-slim", extra = ["logfire"] }, { name = "pytest-asyncio" }, { name = "scikit-learn" }, + { name = "starlette" }, { name = "weaviate-client" }, ] @@ -61,6 +63,7 @@ requires-dist = [ { name = "aiohttp", specifier = ">=3.12.14" }, { name = "beautifulsoup4", specifier = ">=4.13.4" }, { name = "datasets", specifier = ">=3.6.0" }, + { name = "e2b-code-interpreter", specifier = ">=1.5.2" }, { name = "gradio", specifier = ">=5.35.0" }, { name = "langfuse", specifier = ">=3.1.3" }, { name = "lxml", specifier = ">=6.0.0" }, @@ -73,6 +76,7 @@ requires-dist = [ { name = "pydantic-ai-slim", extras = ["logfire"], specifier = ">=0.3.7" }, { name = "pytest-asyncio", specifier = ">=0.25.2" }, { name = "scikit-learn", specifier = ">=1.7.0" }, + { name = "starlette", specifier = ">=0.47.2" }, { name = "weaviate-client", specifier = ">=4.15.4" }, ] @@ -765,6 +769,38 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277 }, ] +[[package]] +name = "e2b" +version = "1.7.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "httpcore" }, + { name = "httpx" }, + { name = "packaging" }, + { name = "protobuf" }, + { name = "python-dateutil" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6c/d5/989b74badfcfe9c1c114e6c4143bedcfcdea23b1be14423bd99cc27b78a1/e2b-1.7.0.tar.gz", hash = "sha256:7783408c2cdf7aee9b088d31759364f2b13b21100cc4e132ba36fd84cfc72e31", size = 57304 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4b/e8/c856caa5d36e8fa1081c3987942fbff4583b087c7526a6ed9c69dca41f98/e2b-1.7.0-py3-none-any.whl", hash = "sha256:6bd3d935249fcf5684494a97178d4d58446b4ed4018ac09087e4000046e82aab", size = 106254 }, +] + +[[package]] +name = "e2b-code-interpreter" +version = "1.5.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "e2b" }, + { name = "httpx" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ee/85/b4a1c9427b45818d4c3773ed967ec1fcc2d7b677e096d8051303889adc2d/e2b_code_interpreter-1.5.2.tar.gz", hash = "sha256:3bd6ea70596290e85aaf0a2f19f28bf37a5e73d13086f5e6a0080bb591c5a547", size = 10006 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f5/4a/7dc5c673c47418e1b38594ab3b022ee20ea1bf3ff8f8aa8273d6ddc99532/e2b_code_interpreter-1.5.2-py3-none-any.whl", hash = "sha256:5c3188d8f25226b28fef4b255447cc6a4c36afb748bdd5180b45be486d5169f3", size = 12873 }, +] + [[package]] name = "eval-type-backport" version = "0.2.2" @@ -785,16 +821,16 @@ wheels = [ [[package]] name = "fastapi" -version = "0.116.0" +version = "0.116.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pydantic" }, { name = "starlette" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/20/38/e1da78736143fd885c36213a3ccc493c384ae8fea6a0f0bc272ef42ebea8/fastapi-0.116.0.tar.gz", hash = "sha256:80dc0794627af0390353a6d1171618276616310d37d24faba6648398e57d687a", size = 296518 } +sdist = { url = "https://files.pythonhosted.org/packages/78/d7/6c8b3bfe33eeffa208183ec037fee0cce9f7f024089ab1c5d12ef04bd27c/fastapi-0.116.1.tar.gz", hash = "sha256:ed52cbf946abfd70c5a0dccb24673f0670deeb517a88b3544d03c2a6bf283143", size = 296485 } wheels = [ - { url = "https://files.pythonhosted.org/packages/2f/68/d80347fe2360445b5f58cf290e588a4729746e7501080947e6cdae114b1f/fastapi-0.116.0-py3-none-any.whl", hash = "sha256:fdcc9ed272eaef038952923bef2b735c02372402d1203ee1210af4eea7a78d2b", size = 95625 }, + { url = "https://files.pythonhosted.org/packages/e5/47/d63c60f59a59467fda0f93f46335c9d18526d7071f025cb5b89d5353ea42/fastapi-0.116.1-py3-none-any.whl", hash = "sha256:c46ac7c312df840f0c9e220f7964bada936781bc4e2e6eb71f1c4d7553786565", size = 95631 }, ] [[package]] @@ -3581,14 +3617,15 @@ wheels = [ [[package]] name = "starlette" -version = "0.46.2" +version = "0.47.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ce/20/08dfcd9c983f6a6f4a1000d934b9e6d626cff8d2eeb77a89a68eef20a2b7/starlette-0.46.2.tar.gz", hash = "sha256:7f7361f34eed179294600af672f565727419830b54b7b084efe44bb82d2fccd5", size = 2580846 } +sdist = { url = "https://files.pythonhosted.org/packages/04/57/d062573f391d062710d4088fa1369428c38d51460ab6fedff920efef932e/starlette-0.47.2.tar.gz", hash = "sha256:6ae9aa5db235e4846decc1e7b79c4f346adf41e9777aebeb49dfd09bbd7023d8", size = 2583948 } wheels = [ - { url = "https://files.pythonhosted.org/packages/8b/0c/9d30a4ebeb6db2b25a841afbb80f6ef9a854fc3b41be131d249a977b4959/starlette-0.46.2-py3-none-any.whl", hash = "sha256:595633ce89f8ffa71a015caed34a5b2dc1c0cdb3f0f1fbd1e69339cf2abeec35", size = 72037 }, + { url = "https://files.pythonhosted.org/packages/f7/1f/b876b1f83aef204198a42dc101613fefccb32258e5428b5f9259677864b4/starlette-0.47.2-py3-none-any.whl", hash = "sha256:c5847e96134e5c5371ee9fac6fdf1a67336d5815e09eb2a01fdb57a351ef915b", size = 72984 }, ] [[package]]