diff --git a/agent.py b/agent.py index 8a6341f..43e3764 100644 --- a/agent.py +++ b/agent.py @@ -4,8 +4,7 @@ from openai import OpenAI client = OpenAI( - api_key=os.environ.get("OPENAI_API_KEY"), - base_url=os.environ.get("OPENAI_BASE_URL") + api_key=os.environ.get("OPENAI_API_KEY"), base_url=os.environ.get("OPENAI_BASE_URL") ) tools = [ @@ -52,8 +51,13 @@ def execute_bash(command): - result = subprocess.run(command, shell=True, capture_output=True, text=True) - return result.stdout + result.stderr + try: + result = subprocess.run( + command, shell=True, capture_output=True, text=True, timeout=30 + ) + return result.stdout + result.stderr + except subprocess.TimeoutExpired: + return "Error: Command timed out after 30 seconds" def read_file(path): @@ -67,7 +71,11 @@ def write_file(path, content): return f"Wrote to {path}" -functions = {"execute_bash": execute_bash, "read_file": read_file, "write_file": write_file} +functions = { + "execute_bash": execute_bash, + "read_file": read_file, + "write_file": write_file, +} def run_agent(user_message, max_iterations=5): @@ -93,11 +101,14 @@ def run_agent(user_message, max_iterations=5): result = f"Error: Unknown tool '{name}'" else: result = functions[name](**args) - messages.append({"role": "tool", "tool_call_id": tool_call.id, "content": result}) + messages.append( + {"role": "tool", "tool_call_id": tool_call.id, "content": result} + ) return "Max iterations reached" if __name__ == "__main__": import sys + task = " ".join(sys.argv[1:]) if len(sys.argv) > 1 else "Hello" print(run_agent(task)) diff --git a/tests/test_agent.py b/tests/test_agent.py index 2d8a188..038da08 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1,4 +1,5 @@ import importlib.util +import subprocess import sys import types import unittest @@ -47,6 +48,33 @@ class AgentRegressionTests(unittest.TestCase): def setUp(self): self.agent = load_agent_module("agent.py", "nanoagent_agent") + def test_execute_bash_passes_timeout(self): + captured: dict[str, object] = {} + + def fake_run(command, **kwargs): + captured["command"] = command + captured.update(kwargs) + return SimpleNamespace(stdout="ok", stderr="") + + self.agent.subprocess.run = fake_run + + result = self.agent.execute_bash("echo ok") + + self.assertEqual(result, "ok") + self.assertEqual(captured["command"], "echo ok") + self.assertEqual(captured["timeout"], 30) + self.assertTrue(captured["shell"]) + + def test_execute_bash_returns_timeout_error(self): + def fake_run(*args, **kwargs): + raise subprocess.TimeoutExpired(cmd="sleep 31", timeout=30) + + self.agent.subprocess.run = fake_run + + result = self.agent.execute_bash("sleep 31") + + self.assertEqual(result, "Error: Command timed out after 30 seconds") + def test_parse_tool_arguments_reports_invalid_json(self): parsed = self.agent.parse_tool_arguments('{"command":') self.assertIn("_argument_error", parsed)