From b426f6c2b79fdd8ac195b14930dd14fbcb291940 Mon Sep 17 00:00:00 2001 From: Matheus Pereira Date: Wed, 21 Jan 2026 09:40:24 -0800 Subject: [PATCH 01/10] Implement retry logic for agent actions and enhance error handling for unrecoverable terminal states --- debug_gym/agents/base_agent.py | 135 +++++++++------- debug_gym/agents/utils.py | 85 ++++++++++ debug_gym/gym/envs/env.py | 6 +- debug_gym/gym/terminals/terminal.py | 4 + debug_gym/logger.py | 9 +- scripts/run.py | 152 +++++++++++------- tests/gym/envs/test_unrecoverable_terminal.py | 5 +- 7 files changed, 275 insertions(+), 121 deletions(-) diff --git a/debug_gym/agents/base_agent.py b/debug_gym/agents/base_agent.py index c64ef1ce..d32a70a9 100644 --- a/debug_gym/agents/base_agent.py +++ b/debug_gym/agents/base_agent.py @@ -2,12 +2,13 @@ import os import uuid from dataclasses import MISSING, asdict, dataclass, field, fields -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List from jinja2 import Environment, FileSystemLoader, Template from debug_gym.agents.history_tracker import HistoryTracker from debug_gym.gym.envs.env import EnvInfo, RepoEnv +from debug_gym.gym.terminals.terminal import UnrecoverableTerminalError from debug_gym.gym.utils import filter_non_utf8 from debug_gym.llms.base import LLM, LLMResponse from debug_gym.llms.utils import trim @@ -293,11 +294,17 @@ def step(self, info: EnvInfo) -> LLMResponse | List[LLMResponse]: return self.llm(messages, info.tools) def execute_action(self, llm_response: LLMResponse | List[LLMResponse]) -> EnvInfo: - next_info = self.env.step( - llm_response.tool, - llm_response.response, - llm_response.reasoning_response, - ) + try: + next_info = self.env.step( + llm_response.tool, + llm_response.response, + llm_response.reasoning_response, + ) + except UnrecoverableTerminalError as e: + # Record the failed step in history if env_info is available + if e.env_info is not None: + self.history.step(e.env_info, llm_response) + raise self.history.step(next_info, llm_response) return next_info @@ -324,6 +331,7 @@ def run( env: RepoEnv, debug: bool = False, reset_env: bool = True, + replay_actions: list = None, ) -> Dict[str, Any]: """Run the agent loop until termination or max steps. @@ -331,81 +339,88 @@ def run( env: The environment to interact with. debug: Whether to drop into debugger after each LLM call. reset_env: Whether to reset the environment (default True). + replay_actions: List of LLMResponse objects to replay before continuing + with new LLM calls. Used for retry after terminal failures. Returns: The trajectory as a JSON-serializable dict. """ info = None step = 0 + replay_actions = replay_actions or [] + replay_index = 0 # assign the env self.env = env - try: - if reset_env: - info = env.reset() - else: - info = env.info - - self.init(info) - - if info.resolved: - self.logger.report_progress( - problem_id=env.task_name, - step=0, - total_steps=self.args.max_steps, - score=info.score, - max_score=info.max_score, - status="resolved", - ) - return self.build_trajectory() + if reset_env: + info = env.reset() + else: + info = env.info - highscore = info.score - should_stop = False - step = 1 + self.init(info) - while not should_stop: - self.logger.info(f"\n{'='*20} STEP {step} {'='*20}\n") + if info.resolved: + self.logger.report_progress( + problem_id=env.task_name, + step=0, + total_steps=self.args.max_steps, + score=info.score, + max_score=info.max_score, + status="resolved", + ) + return self.build_trajectory() + highscore = info.score + should_stop = False + step = 1 + + while not should_stop: + self.logger.info(f"\n{'='*20} STEP {step} {'='*20}\n") + + # Check if we should replay a previous action or generate a new one + if replay_index < len(replay_actions): + agent_response = replay_actions[replay_index] + replay_index += 1 + # Log replay details similar to replay.py + self.logger.info( + f"[REPLAY] Replaying step {replay_index} from previous attempt:\n" + f" Tool: {agent_response.tool.name}\n" + f" Args: {agent_response.tool.arguments}\n" + f" Reasoning: {agent_response.reasoning_response[:100] if agent_response.reasoning_response else None}...\n" + f" Content: {agent_response.response[:100] if agent_response.response else None}..." + ) + else: agent_response = self.step(info) - info = self.execute_action(agent_response) - if debug: - breakpoint() + info = self.execute_action(agent_response) - should_stop, reason = self.should_stop(step + 1, info) - status = ( - "resolved" - if info.resolved - else ("unresolved" if should_stop else "running") - ) + if debug: + breakpoint() + + should_stop, reason = self.should_stop(step + 1, info) + status = ( + "resolved" + if info.resolved + else ("unresolved" if should_stop else "running") + ) + + highscore = max(highscore, info.score) + msg = f"[{env.task_name[:10]:<10}] Step {step} | Score: {info.score}/{info.max_score or '-'} [Best: {highscore}]" + if should_stop: + msg += f" | Stopping Reason: {reason}" + self.logger.info(msg) + step += 1 - highscore = max(highscore, info.score) - msg = f"[{env.task_name[:10]:<10}] Step {step} | Score: {info.score}/{info.max_score or '-'} [Best: {highscore}]" - if should_stop: - msg += f" | Stopping Reason: {reason}" - self.logger.info(msg) - step += 1 - - self.logger.report_progress( - problem_id=env.task_name, - step=step, - total_steps=self.args.max_steps, - score=info.score, - max_score=info.max_score, - status=status, - ) - return self.build_trajectory() - except Exception as e: self.logger.report_progress( problem_id=env.task_name, step=step, - total_steps=step, - score=getattr(info, "score", 0), - max_score=getattr(info, "max_score", None), - status="error", + total_steps=self.args.max_steps, + score=info.score, + max_score=info.max_score, + status=status, ) - raise e + return self.build_trajectory() def create_agent(config: Dict[str, Any], **kwargs) -> BaseAgent: diff --git a/debug_gym/agents/utils.py b/debug_gym/agents/utils.py index 1c5446a2..18205abd 100644 --- a/debug_gym/agents/utils.py +++ b/debug_gym/agents/utils.py @@ -6,6 +6,8 @@ import yaml +from debug_gym.gym.tools.tool import ToolCall +from debug_gym.llms.base import LLMResponse from debug_gym.logger import DebugGymLogger @@ -90,6 +92,15 @@ def load_config(args=None): default=20, help="Maximum number of tasks to display in the progress bar.", ) + parser.add_argument( + "--max-retries", + type=int, + default=3, + help="Maximum number of retries for terminal or timeout failures" + " (e.g., spot instance eviction, container deletion). " + "Steps are replayed, including the step that failed, " + "after which new steps are generated. Default: 3.", + ) parser.add_argument( "-p", "--params", @@ -144,3 +155,77 @@ def save_trajectory(agent, problem_path: Path, logger: DebugGymLogger): json.dump(trajectory, f, indent=4) logger.debug(f"Trajectory saved in {json_file}") + + +def load_trajectory( + problem_path: Path, logger: DebugGymLogger +) -> list[LLMResponse] | None: + """Load a previous trajectory and reconstruct LLMResponse objects for replay. + + Follows the same approach as replay.py for accurate reconstruction of LLMResponse + objects, including token counts and original prompt data from prompt_response_pairs. + + Returns a list of LLMResponse objects that can be passed to agent.execute_action(), + or None if no trajectory exists. + """ + json_file = problem_path / "trajectory.json" + if not json_file.exists(): + return None + + try: + with open(json_file, "r") as f: + trajectory = json.load(f) + except (json.JSONDecodeError, IOError) as e: + logger.warning(f"Failed to load trajectory from {json_file}: {e}") + return None + + log = trajectory.get("log", []) + if not log: + return None + + llm_responses = [] + for step in log: + # Skip step 0 (initial state with no action) + if step.get("step_id") == 0 or step.get("action") is None: + continue + + # Reconstruct ToolCall from saved action + action_data = step.get("action", {}) + if not action_data: + continue + + tool_call = ToolCall( + id=action_data.get("id", ""), + name=action_data.get("name", ""), + arguments=action_data.get("arguments", {}), + ) + + # Extract data from prompt_response_pairs if available (like replay.py does) + prompt_response_pairs = step.get("prompt_response_pairs", []) + if prompt_response_pairs and len(prompt_response_pairs) > 0: + prompt_response = prompt_response_pairs[0] + token_usage = prompt_response.get("token_usage", {}) + llm_response = LLMResponse( + prompt=prompt_response.get("prompt", []), + response=prompt_response.get("response"), + reasoning_response=prompt_response.get("reasoning_response"), + tool=tool_call, + prompt_token_count=token_usage.get("prompt", 0), + response_token_count=token_usage.get("response", 0), + ) + else: + # Fallback to step-level data if prompt_response_pairs not available + llm_response = LLMResponse( + prompt=[], + response=step.get("content"), + reasoning_response=step.get("reasoning"), + tool=tool_call, + ) + llm_responses.append(llm_response) + + if llm_responses: + logger.info( + f"Loaded {len(llm_responses)} steps from previous trajectory for replay" + ) + + return llm_responses if llm_responses else None diff --git a/debug_gym/gym/envs/env.py b/debug_gym/gym/envs/env.py index 7c67dd84..dab34d04 100644 --- a/debug_gym/gym/envs/env.py +++ b/debug_gym/gym/envs/env.py @@ -51,7 +51,7 @@ def __str__(self) -> str: lines.append("") # Observations section - lines.append(f"👁️ Observation:") + lines.append("👁️ Observation:") lines.append(f"```\n{self.step_observation}\n```") lines.append("") @@ -468,7 +468,9 @@ def step( resolved=self.resolved, tools=self.tools, ) - return self.infos + # Attach env_info to exception and re-raise to allow retry logic + e.env_info = self.infos + raise except BaseException as e: error_message = ( f"Error while using tool {triggered_tool.name} " diff --git a/debug_gym/gym/terminals/terminal.py b/debug_gym/gym/terminals/terminal.py index c5c83dc0..3008f665 100644 --- a/debug_gym/gym/terminals/terminal.py +++ b/debug_gym/gym/terminals/terminal.py @@ -14,6 +14,10 @@ class TerminalError(RuntimeError): class UnrecoverableTerminalError(TerminalError): """Raised when the terminal becomes unusable and the episode must stop.""" + def __init__(self, message: str, env_info=None): + super().__init__(message) + self.env_info = env_info + DISABLE_ECHO_COMMAND = "stty -echo" diff --git a/debug_gym/logger.py b/debug_gym/logger.py index 5b2aeb07..c14a6a81 100644 --- a/debug_gym/logger.py +++ b/debug_gym/logger.py @@ -395,7 +395,7 @@ def __init__( expand=True, ) self.total = len(problems) - self.completed = 0 + self._completed_task_ids = set() self._overall_task = self.overall_progress.add_task( "Overall", # Placeholder description, will be set by _refresh total=self.total, @@ -424,6 +424,10 @@ def __init__( ) self._listener_thread.start() + @property + def completed(self) -> int: + return len(self._completed_task_ids) + def advance(self, progress_update: TaskProgress): """Advance the progress for a specific task based on the provided update. Sets task as completed if its status is completed (e.g. early stopping).""" @@ -434,7 +438,8 @@ def advance(self, progress_update: TaskProgress): # Update the task progress self.tasks_progress.advance(progress_update) # Update overall progress completion - self.completed += 1 if progress_update.completed else 0 + if progress_update.completed: + self._completed_task_ids.add(progress_update.problem_id) def close(self): """Stop the listener thread and wait until it exits.""" diff --git a/scripts/run.py b/scripts/run.py index 32fbb730..be0a9597 100644 --- a/scripts/run.py +++ b/scripts/run.py @@ -7,9 +7,15 @@ from pathlib import Path from debug_gym.agents.base_agent import AGENT_REGISTRY, create_agent -from debug_gym.agents.utils import load_config, save_patch, save_trajectory +from debug_gym.agents.utils import ( + load_config, + load_trajectory, + save_patch, + save_trajectory, +) from debug_gym.experiment import create_env, dump_experiment_info from debug_gym.gym.envs import load_dataset +from debug_gym.gym.terminals.terminal import UnrecoverableTerminalError from debug_gym.llms.base import LLM from debug_gym.llms.human import Human from debug_gym.logger import DebugGymLogger, load_previous_run_status @@ -41,10 +47,6 @@ def run_agent(args, task_name: str, task_data: dict, config: dict): success = True env = None - # Flag to not report errors from the agent, since they report - # errors themselves and we want to avoid double reporting. - report_progress_error = True - task_path = Path(config["output_path"]) / task_name task_logger = DebugGymLogger( @@ -76,6 +78,8 @@ def run_agent(args, task_name: str, task_data: dict, config: dict): task_logger.debug(f"Skipping {task_name}, already done.") return success + max_retries = args.max_retries + task_logger.report_progress( problem_id=task_name, step=0, @@ -85,42 +89,83 @@ def run_agent(args, task_name: str, task_data: dict, config: dict): status="running", ) - env = create_env(config, task_data, task_logger) - llm = LLM.instantiate(**config.get("llm", {}), logger=task_logger) - agent = create_agent(config.get("agent", {}), llm=llm, logger=task_logger) - - try: - success = agent.run(env, debug=args.debug) - except KeyboardInterrupt: - task_logger.error("Agent run was interrupted by user.") - task_logger.report_progress( - problem_id=task_name, - step=1, - total_steps=1, - score=0, - max_score=None, - status="error", - ) - success = False - raise - except AgentTimeoutException: - task_logger.error( - f"Timeout: Problem `{task_name}` exceeded " - f"the time limit of {args.timeout} seconds." - ) - task_logger.report_progress( - problem_id=task_name, - step=1, - total_steps=1, - score=0, - max_score=None, - status="error", - ) - success = False - raise - except: - report_progress_error = False - raise + # Track actions from previous attempts for replay + replay_actions = None + for attempt in range(max_retries): + try: + # Load actions from previous attempt for replay on retry + if attempt > 0: + task_logger.info(f"Replaying actions from attempt {attempt}") + # Load actions from previous attempt for replay + replay_actions = load_trajectory(task_path, task_logger) + task_logger.report_progress( + problem_id=task_name, + step=0, + total_steps=1, + score=0, + max_score=None, + status="running", + ) + + env = create_env(config, task_data, task_logger) + llm = LLM.instantiate(**config.get("llm", {}), logger=task_logger) + agent = create_agent( + config.get("agent", {}), llm=llm, logger=task_logger + ) + + success = agent.run( + env, + debug=args.debug, + replay_actions=replay_actions, + ) + break # Exit retry loop + except (UnrecoverableTerminalError, AgentTimeoutException) as e: + # Save trajectory before retry so we can replay actions + try: + save_trajectory(agent, task_path, task_logger) + except Exception as save_error: + task_logger.error( + f"Could not save trajectory for replay: {save_error!r}" + ) + raise + + task_logger.warning( + f"Terminal lost (attempt {attempt + 1}/{max_retries}): {e}" + ) + # Close the failed environment + if env is not None: + env.close() + env = None + + if attempt < max_retries - 1: + task_logger.info(f"Retrying task {task_name}...") + else: + task_logger.error( + f"Task {task_name} failed after {max_retries} attempts." + ) + + task_logger.report_progress( + problem_id=task_name, + step=1, + total_steps=1, + score=0, + max_score=None, + status="error", + ) + success = False + raise + except KeyboardInterrupt: + task_logger.error("Agent run was interrupted by user.") + task_logger.report_progress( + problem_id=task_name, + step=1, + total_steps=1, + score=0, + max_score=None, + status="error", + ) + success = False + raise # save trajectory save_trajectory(agent, task_path, task_logger) @@ -141,17 +186,16 @@ def run_agent(args, task_name: str, task_data: dict, config: dict): task_logger.debug( f"Task {task_name} generated an exception: {e!r}. Traceback: {traceback.format_exc()}" ) - if report_progress_error: - task_logger.report_progress( - problem_id=task_name, - step=1, - total_steps=1, - score=0, - max_score=None, - status="error", - ) + task_logger.report_progress( + problem_id=task_name, + step=1, + total_steps=1, + score=0, + max_score=None, + status="error", + ) if args.debug: - raise e + raise success = False finally: @@ -177,7 +221,7 @@ def main(): # Load the dataset based on the information found in the config. if config.get("task_data") is not None: - dataset = {f"custom-task": config["task_data"]} + dataset = {"custom-task": config["task_data"]} else: dataset = load_dataset(config["dataset"], logger=logger) @@ -222,8 +266,6 @@ def main(): for problem in problems: try: run_agent(args, problem, dataset[problem], config) - except AgentTimeoutException: - pass # Handled in run_agent, just continue except (KeyboardInterrupt, Exception) as e: raise e else: @@ -242,8 +284,6 @@ def main(): try: problem = futures[future] future.result() - except AgentTimeoutException: - pass # Handled in run_agent, just continue except (KeyboardInterrupt, Exception) as e: executor.shutdown(wait=True, cancel_futures=True) raise e diff --git a/tests/gym/envs/test_unrecoverable_terminal.py b/tests/gym/envs/test_unrecoverable_terminal.py index 459fdfa4..d2a9e800 100644 --- a/tests/gym/envs/test_unrecoverable_terminal.py +++ b/tests/gym/envs/test_unrecoverable_terminal.py @@ -73,7 +73,10 @@ def fatal_env(): def test_env_terminates_after_unrecoverable_terminal_error(fatal_env): tool_call = ToolCall(id="bash-1", name="bash", arguments={"command": "ls"}) - info = fatal_env.step(tool_call) + with pytest.raises(UnrecoverableTerminalError) as exc_info: + fatal_env.step(tool_call) + + info = exc_info.value.env_info assert info.terminated is True assert fatal_env.terminated is True From 778389ade040affe622c5e41dad8d04fe20b98c7 Mon Sep 17 00:00:00 2001 From: Matheus Pereira Date: Wed, 21 Jan 2026 09:50:57 -0800 Subject: [PATCH 02/10] Add tests for load_trajectory function --- tests/agents/test_utils.py | 301 ++++++++++++++++++++++++++++++++++++- 1 file changed, 299 insertions(+), 2 deletions(-) diff --git a/tests/agents/test_utils.py b/tests/agents/test_utils.py index d44639b9..9934ddc6 100644 --- a/tests/agents/test_utils.py +++ b/tests/agents/test_utils.py @@ -1,7 +1,11 @@ +import json import logging -from unittest.mock import patch +from unittest.mock import MagicMock -from debug_gym.agents.utils import load_config +import pytest + +from debug_gym.agents.utils import load_config, load_trajectory +from debug_gym.llms.base import LLMResponse def test_load_config(): @@ -73,3 +77,296 @@ def test_load_config(): assert _config == expected_config assert _args.debug is True assert _args.logging_level == logging.INFO + + +@pytest.fixture +def mock_logger(): + """Create a mock logger for testing.""" + logger = MagicMock() + return logger + + +class TestLoadTrajectory: + """Tests for the load_trajectory function.""" + + def test_load_trajectory_no_file(self, tmp_path, mock_logger): + """Test that load_trajectory returns None when trajectory file doesn't exist.""" + result = load_trajectory(tmp_path, mock_logger) + assert result is None + + def test_load_trajectory_empty_log(self, tmp_path, mock_logger): + """Test that load_trajectory returns None when log is empty.""" + trajectory = {"log": []} + trajectory_file = tmp_path / "trajectory.json" + with open(trajectory_file, "w") as f: + json.dump(trajectory, f) + + result = load_trajectory(tmp_path, mock_logger) + assert result is None + + def test_load_trajectory_no_log_key(self, tmp_path, mock_logger): + """Test that load_trajectory returns None when log key is missing.""" + trajectory = {"other_key": "value"} + trajectory_file = tmp_path / "trajectory.json" + with open(trajectory_file, "w") as f: + json.dump(trajectory, f) + + result = load_trajectory(tmp_path, mock_logger) + assert result is None + + def test_load_trajectory_skips_step_zero(self, tmp_path, mock_logger): + """Test that load_trajectory skips step 0 (initial state).""" + trajectory = { + "log": [ + {"step_id": 0, "action": None}, + { + "step_id": 1, + "action": { + "id": "call_1", + "name": "test_tool", + "arguments": {"arg1": "value1"}, + }, + "prompt_response_pairs": [ + { + "prompt": [{"role": "user", "content": "test"}], + "response": "test response", + "reasoning_response": None, + "token_usage": {"prompt": 10, "response": 20}, + } + ], + }, + ] + } + trajectory_file = tmp_path / "trajectory.json" + with open(trajectory_file, "w") as f: + json.dump(trajectory, f) + + result = load_trajectory(tmp_path, mock_logger) + assert result is not None + assert len(result) == 1 + assert result[0].tool.name == "test_tool" + + def test_load_trajectory_with_prompt_response_pairs(self, tmp_path, mock_logger): + """Test load_trajectory with full prompt_response_pairs data.""" + trajectory = { + "log": [ + { + "step_id": 1, + "action": { + "id": "call_123", + "name": "edit_file", + "arguments": {"file": "test.py", "content": "print('hello')"}, + }, + "prompt_response_pairs": [ + { + "prompt": [ + { + "role": "system", + "content": "You are a helpful assistant", + }, + {"role": "user", "content": "Edit the file"}, + ], + "response": "I'll edit the file for you", + "reasoning_response": "Thinking about the edit...", + "token_usage": {"prompt": 100, "response": 50}, + } + ], + } + ] + } + trajectory_file = tmp_path / "trajectory.json" + with open(trajectory_file, "w") as f: + json.dump(trajectory, f) + + result = load_trajectory(tmp_path, mock_logger) + assert result is not None + assert len(result) == 1 + + llm_response = result[0] + assert isinstance(llm_response, LLMResponse) + assert llm_response.response == "I'll edit the file for you" + assert llm_response.reasoning_response == "Thinking about the edit..." + assert llm_response.token_usage.prompt == 100 + assert llm_response.token_usage.response == 50 + assert llm_response.tool.id == "call_123" + assert llm_response.tool.name == "edit_file" + assert llm_response.tool.arguments == { + "file": "test.py", + "content": "print('hello')", + } + assert len(llm_response.prompt) == 2 + + def test_load_trajectory_fallback_without_prompt_response_pairs( + self, tmp_path, mock_logger + ): + """Test load_trajectory fallback when prompt_response_pairs is missing.""" + trajectory = { + "log": [ + { + "step_id": 1, + "action": { + "id": "call_456", + "name": "run_command", + "arguments": {"cmd": "ls"}, + }, + "content": "Running the command", + "reasoning": "Need to list files", + } + ] + } + trajectory_file = tmp_path / "trajectory.json" + with open(trajectory_file, "w") as f: + json.dump(trajectory, f) + + result = load_trajectory(tmp_path, mock_logger) + assert result is not None + assert len(result) == 1 + + llm_response = result[0] + assert llm_response.response == "Running the command" + assert llm_response.reasoning_response == "Need to list files" + assert llm_response.prompt == [] + assert llm_response.tool.name == "run_command" + + def test_load_trajectory_multiple_steps(self, tmp_path, mock_logger): + """Test load_trajectory with multiple steps.""" + trajectory = { + "log": [ + {"step_id": 0, "action": None}, + { + "step_id": 1, + "action": { + "id": "call_1", + "name": "tool_1", + "arguments": {}, + }, + "prompt_response_pairs": [ + { + "prompt": [], + "response": "response 1", + "token_usage": {"prompt": 10, "response": 5}, + } + ], + }, + { + "step_id": 2, + "action": { + "id": "call_2", + "name": "tool_2", + "arguments": {"key": "value"}, + }, + "prompt_response_pairs": [ + { + "prompt": [], + "response": "response 2", + "token_usage": {"prompt": 20, "response": 10}, + } + ], + }, + { + "step_id": 3, + "action": { + "id": "call_3", + "name": "tool_3", + "arguments": {}, + }, + "prompt_response_pairs": [ + { + "prompt": [], + "response": "response 3", + "token_usage": {"prompt": 30, "response": 15}, + } + ], + }, + ] + } + trajectory_file = tmp_path / "trajectory.json" + with open(trajectory_file, "w") as f: + json.dump(trajectory, f) + + result = load_trajectory(tmp_path, mock_logger) + assert result is not None + assert len(result) == 3 + assert result[0].tool.name == "tool_1" + assert result[1].tool.name == "tool_2" + assert result[2].tool.name == "tool_3" + mock_logger.info.assert_called_once() + + def test_load_trajectory_invalid_json(self, tmp_path, mock_logger): + """Test that load_trajectory handles invalid JSON gracefully.""" + trajectory_file = tmp_path / "trajectory.json" + with open(trajectory_file, "w") as f: + f.write("not valid json {{{") + + result = load_trajectory(tmp_path, mock_logger) + assert result is None + mock_logger.warning.assert_called_once() + + def test_load_trajectory_skips_steps_without_action(self, tmp_path, mock_logger): + """Test that load_trajectory skips steps with no action data.""" + trajectory = { + "log": [ + { + "step_id": 1, + "action": {}, # Empty action + }, + { + "step_id": 2, + # Missing action key entirely + }, + { + "step_id": 3, + "action": { + "id": "call_valid", + "name": "valid_tool", + "arguments": {}, + }, + "prompt_response_pairs": [ + { + "prompt": [], + "response": "valid response", + "token_usage": {"prompt": 5, "response": 5}, + } + ], + }, + ] + } + trajectory_file = tmp_path / "trajectory.json" + with open(trajectory_file, "w") as f: + json.dump(trajectory, f) + + result = load_trajectory(tmp_path, mock_logger) + assert result is not None + assert len(result) == 1 + assert result[0].tool.name == "valid_tool" + + def test_load_trajectory_missing_optional_fields(self, tmp_path, mock_logger): + """Test load_trajectory with missing optional fields in action.""" + trajectory = { + "log": [ + { + "step_id": 1, + "action": { + "name": "minimal_tool", + # Missing id and arguments + }, + "prompt_response_pairs": [ + { + "prompt": [], + "response": "response", + "token_usage": {"prompt": 1, "response": 1}, + } + ], + } + ] + } + trajectory_file = tmp_path / "trajectory.json" + with open(trajectory_file, "w") as f: + json.dump(trajectory, f) + + result = load_trajectory(tmp_path, mock_logger) + assert result is not None + assert len(result) == 1 + assert result[0].tool.id == "" + assert result[0].tool.name == "minimal_tool" + assert result[0].tool.arguments == {} From d4e07add4fbab058647ba7f15313bd2aa6866ec4 Mon Sep 17 00:00:00 2001 From: Matheus Pereira Date: Wed, 21 Jan 2026 09:57:27 -0800 Subject: [PATCH 03/10] Add tests for BaseAgent.execute_action --- tests/agents/test_base_agent.py | 127 +++++++++++++++++++++++++++++++- 1 file changed, 125 insertions(+), 2 deletions(-) diff --git a/tests/agents/test_base_agent.py b/tests/agents/test_base_agent.py index 4257cce2..8562e3d0 100644 --- a/tests/agents/test_base_agent.py +++ b/tests/agents/test_base_agent.py @@ -1,16 +1,18 @@ import json -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, patch import pytest from jinja2 import Template from debug_gym.agents.base_agent import ( AGENT_REGISTRY, - AgentArgs, BaseAgent, create_agent, register_agent, ) +from debug_gym.gym.terminals.terminal import UnrecoverableTerminalError +from debug_gym.gym.tools.tool import ToolCall +from debug_gym.llms.base import LLMResponse from debug_gym.llms.human import Human @@ -335,3 +337,124 @@ def test_load_prompt_template_with_custom_loader_root(tmp_path): assert "=== Explorer ===" in rendered assert "Body content." in rendered + + +class TestExecuteAction: + """Tests for BaseAgent.execute_action method.""" + + @pytest.fixture + def agent_with_mocks(self): + """Create a BaseAgent with mocked env and llm.""" + agent = BaseAgent() + agent.env = MagicMock() + agent.llm = MagicMock() + # Initialize history with mock data + mock_info = MagicMock() + mock_info.instructions = "Test instructions" + agent.history.init( + {"role": "system", "content": "system"}, + {"role": "user", "content": "instance"}, + mock_info, + ) + return agent + + @pytest.fixture + def mock_llm_response(self): + """Create a mock LLMResponse.""" + tool_call = ToolCall( + id="call_123", + name="test_tool", + arguments={"arg1": "value1"}, + ) + return LLMResponse( + prompt=[{"role": "user", "content": "test"}], + response="test response", + reasoning_response="test reasoning", + tool=tool_call, + ) + + def test_execute_action_success(self, agent_with_mocks, mock_llm_response): + """Test that execute_action updates history on successful step.""" + agent = agent_with_mocks + mock_env_info = MagicMock() + agent.env.step.return_value = mock_env_info + + # Initial history should have 1 entry (from init) + initial_history_len = len(agent.history) + + result = agent.execute_action(mock_llm_response) + + assert result == mock_env_info + assert len(agent.history) == initial_history_len + 1 + agent.env.step.assert_called_once_with( + mock_llm_response.tool, + mock_llm_response.response, + mock_llm_response.reasoning_response, + ) + + def test_execute_action_unrecoverable_error_with_env_info( + self, agent_with_mocks, mock_llm_response + ): + """Test that history is updated when UnrecoverableTerminalError has env_info.""" + agent = agent_with_mocks + mock_env_info = MagicMock() + mock_env_info.step_observation = MagicMock() + mock_env_info.step_observation.observation = "error observation" + + error = UnrecoverableTerminalError("Terminal died", env_info=mock_env_info) + agent.env.step.side_effect = error + + initial_history_len = len(agent.history) + + with pytest.raises(UnrecoverableTerminalError) as exc_info: + agent.execute_action(mock_llm_response) + + assert exc_info.value is error + # History should be updated with the failed step + assert len(agent.history) == initial_history_len + 1 + # Verify the last llm_response in history is our mock + assert agent.history.llm_responses[-1] == mock_llm_response + # Verify the last env_observation in history is from the error + assert agent.history.env_observations[-1] == mock_env_info + + def test_execute_action_unrecoverable_error_without_env_info( + self, agent_with_mocks, mock_llm_response + ): + """Test that history is NOT updated when UnrecoverableTerminalError has no env_info.""" + agent = agent_with_mocks + + error = UnrecoverableTerminalError("Terminal died", env_info=None) + agent.env.step.side_effect = error + + initial_history_len = len(agent.history) + + with pytest.raises(UnrecoverableTerminalError) as exc_info: + agent.execute_action(mock_llm_response) + + assert exc_info.value is error + # History should NOT be updated since env_info is None + assert len(agent.history) == initial_history_len + + def test_execute_action_history_contains_correct_data( + self, agent_with_mocks, mock_llm_response + ): + """Test that history contains the correct tool call data after execute_action.""" + agent = agent_with_mocks + mock_env_info = MagicMock() + mock_env_info.step_observation = MagicMock() + mock_env_info.step_observation.observation = "tool output" + agent.env.step.return_value = mock_env_info + + agent.execute_action(mock_llm_response) + + # Verify the history contains correct llm_response data + last_llm_response = agent.history.llm_responses[-1] + assert last_llm_response.tool.id == "call_123" + assert last_llm_response.tool.name == "test_tool" + assert last_llm_response.tool.arguments == {"arg1": "value1"} + assert last_llm_response.response == "test response" + assert last_llm_response.reasoning_response == "test reasoning" + + # Verify the history contains correct env_observation + last_env_obs = agent.history.env_observations[-1] + assert last_env_obs == mock_env_info From d714ec33feb388825cc9f752814e580a2e301194 Mon Sep 17 00:00:00 2001 From: Matheus Pereira Date: Wed, 21 Jan 2026 10:06:14 -0800 Subject: [PATCH 04/10] Add tests for replay_actions functionality in BaseAgent.run() --- tests/agents/test_base_agent.py | 310 ++++++++++++++++++++++++++++++++ 1 file changed, 310 insertions(+) diff --git a/tests/agents/test_base_agent.py b/tests/agents/test_base_agent.py index 8562e3d0..92762c92 100644 --- a/tests/agents/test_base_agent.py +++ b/tests/agents/test_base_agent.py @@ -458,3 +458,313 @@ def test_execute_action_history_contains_correct_data( # Verify the history contains correct env_observation last_env_obs = agent.history.env_observations[-1] assert last_env_obs == mock_env_info + + +class TestBaseAgentRunReplayActions: + """Tests for replay_actions parameter in BaseAgent.run().""" + + @pytest.fixture + def mock_env(self): + """Create a mock environment.""" + env = MagicMock() + env.task_name = "test_task" + env.tools = [] + env.resolved = False + return env + + @pytest.fixture + def mock_env_info(self): + """Create a mock EnvInfo.""" + info = MagicMock() + info.instructions = "Test instructions" + info.tools = [] + info.resolved = False + info.terminated = False + info.score = 0 + info.max_score = 10 + info.step_observation = MagicMock() + info.step_observation.observation = "observation" + info.action_tool_call = None + return info + + @pytest.fixture + def mock_llm(self): + """Create a mock LLM.""" + llm = MagicMock() + llm.context_length = 4096 + llm.count_tokens = lambda x: len(str(x)) + llm.define_tools = lambda x: [] + llm.convert_observation_to_message = lambda obs, **kwargs: { + "role": "user", + "content": obs if isinstance(obs, str) else str(obs), + } + llm.convert_response_to_message = lambda resp: { + "role": "assistant", + "content": resp.response, + } + return llm + + def test_run_uses_replay_actions_instead_of_step( + self, mock_env, mock_env_info, mock_llm + ): + """Test that replay_actions are used instead of calling step().""" + agent = BaseAgent(agent_args={"max_steps": 5}) + agent.llm = mock_llm + + # Create replay actions + replay_actions = [ + LLMResponse( + prompt=[], + response="replay response 1", + reasoning_response="replay reasoning 1", + tool=ToolCall(id="replay_1", name="tool_1", arguments={}), + ), + LLMResponse( + prompt=[], + response="replay response 2", + reasoning_response="replay reasoning 2", + tool=ToolCall(id="replay_2", name="tool_2", arguments={}), + ), + ] + + # Set up env to terminate after 2 steps + step_count = {"count": 0} + + def mock_reset(): + return mock_env_info + + def mock_step(tool, response, reasoning): + step_count["count"] += 1 + info = MagicMock() + info.instructions = "Test" + info.tools = [] + info.resolved = step_count["count"] >= 2 + info.terminated = step_count["count"] >= 2 + info.score = step_count["count"] + info.max_score = 10 + info.step_observation = MagicMock() + info.step_observation.observation = f"observation {step_count['count']}" + info.action_tool_call = tool + return info + + mock_env.reset.return_value = mock_env_info + mock_env.step.side_effect = mock_step + mock_env.info = mock_env_info + + # Mock the step method to track if it's called + agent.step = MagicMock() + + agent.run(mock_env, replay_actions=replay_actions) + + # step() should NOT have been called since we had replay actions + agent.step.assert_not_called() + + # env.step should have been called with replay action tools + assert mock_env.step.call_count == 2 + calls = mock_env.step.call_args_list + assert calls[0][0][0].id == "replay_1" + assert calls[1][0][0].id == "replay_2" + + def test_run_switches_to_step_after_replay_exhausted( + self, mock_env, mock_env_info, mock_llm + ): + """Test that run() switches to calling step() after replay_actions are exhausted.""" + agent = BaseAgent(agent_args={"max_steps": 5}) + agent.llm = mock_llm + + # Create only 1 replay action + replay_actions = [ + LLMResponse( + prompt=[], + response="replay response", + reasoning_response="replay reasoning", + tool=ToolCall(id="replay_1", name="replay_tool", arguments={}), + ), + ] + + step_count = {"count": 0} + + def mock_reset(): + return mock_env_info + + def mock_env_step(tool, response, reasoning): + step_count["count"] += 1 + info = MagicMock() + info.instructions = "Test" + info.tools = [] + info.resolved = step_count["count"] >= 3 + info.terminated = step_count["count"] >= 3 + info.score = step_count["count"] + info.max_score = 10 + info.step_observation = MagicMock() + info.step_observation.observation = f"observation {step_count['count']}" + info.action_tool_call = tool + return info + + mock_env.reset.return_value = mock_env_info + mock_env.step.side_effect = mock_env_step + mock_env.info = mock_env_info + + # Create new LLM responses for step() calls + new_llm_responses = [ + LLMResponse( + prompt=[], + response="new response 1", + reasoning_response="new reasoning 1", + tool=ToolCall(id="new_1", name="new_tool_1", arguments={}), + ), + LLMResponse( + prompt=[], + response="new response 2", + reasoning_response="new reasoning 2", + tool=ToolCall(id="new_2", name="new_tool_2", arguments={}), + ), + ] + agent.step = MagicMock(side_effect=new_llm_responses) + + agent.run(mock_env, replay_actions=replay_actions) + + # step() should have been called 2 times (after replay action exhausted) + assert agent.step.call_count == 2 + + # env.step should have been called 3 times total + assert mock_env.step.call_count == 3 + calls = mock_env.step.call_args_list + # First call: replay action + assert calls[0][0][0].id == "replay_1" + # Second and third calls: from step() + assert calls[1][0][0].id == "new_1" + assert calls[2][0][0].id == "new_2" + + def test_run_with_empty_replay_actions(self, mock_env, mock_env_info, mock_llm): + """Test that empty replay_actions list behaves normally.""" + agent = BaseAgent(agent_args={"max_steps": 3}) + agent.llm = mock_llm + + step_count = {"count": 0} + + def mock_env_step(tool, response, reasoning): + step_count["count"] += 1 + info = MagicMock() + info.instructions = "Test" + info.tools = [] + info.resolved = step_count["count"] >= 2 + info.terminated = step_count["count"] >= 2 + info.score = step_count["count"] + info.max_score = 10 + info.step_observation = MagicMock() + info.step_observation.observation = f"observation {step_count['count']}" + info.action_tool_call = tool + return info + + mock_env.reset.return_value = mock_env_info + mock_env.step.side_effect = mock_env_step + mock_env.info = mock_env_info + + llm_responses = [ + LLMResponse( + prompt=[], + response="response 1", + reasoning_response="reasoning 1", + tool=ToolCall(id="call_1", name="tool_1", arguments={}), + ), + LLMResponse( + prompt=[], + response="response 2", + reasoning_response="reasoning 2", + tool=ToolCall(id="call_2", name="tool_2", arguments={}), + ), + ] + agent.step = MagicMock(side_effect=llm_responses) + + # Pass empty list + agent.run(mock_env, replay_actions=[]) + + # step() should be called for all actions + assert agent.step.call_count == 2 + + def test_run_with_none_replay_actions(self, mock_env, mock_env_info, mock_llm): + """Test that None replay_actions behaves normally.""" + agent = BaseAgent(agent_args={"max_steps": 3}) + agent.llm = mock_llm + + step_count = {"count": 0} + + def mock_env_step(tool, response, reasoning): + step_count["count"] += 1 + info = MagicMock() + info.instructions = "Test" + info.tools = [] + info.resolved = step_count["count"] >= 1 + info.terminated = step_count["count"] >= 1 + info.score = step_count["count"] + info.max_score = 10 + info.step_observation = MagicMock() + info.step_observation.observation = f"observation {step_count['count']}" + info.action_tool_call = tool + return info + + mock_env.reset.return_value = mock_env_info + mock_env.step.side_effect = mock_env_step + mock_env.info = mock_env_info + + llm_response = LLMResponse( + prompt=[], + response="response", + reasoning_response="reasoning", + tool=ToolCall(id="call_1", name="tool_1", arguments={}), + ) + agent.step = MagicMock(return_value=llm_response) + + # Pass None (default) + agent.run(mock_env, replay_actions=None) + + # step() should be called + assert agent.step.call_count == 1 + + def test_replay_actions_order_preserved(self, mock_env, mock_env_info, mock_llm): + """Test that replay actions are executed in the correct order.""" + agent = BaseAgent(agent_args={"max_steps": 10}) + agent.llm = mock_llm + + # Create 5 replay actions with distinct IDs + replay_actions = [ + LLMResponse( + prompt=[], + response=f"response {i}", + reasoning_response=f"reasoning {i}", + tool=ToolCall( + id=f"action_{i}", name=f"tool_{i}", arguments={"order": i} + ), + ) + for i in range(5) + ] + + step_count = {"count": 0} + + def mock_env_step(tool, response, reasoning): + step_count["count"] += 1 + info = MagicMock() + info.instructions = "Test" + info.tools = [] + info.resolved = step_count["count"] >= 5 + info.terminated = step_count["count"] >= 5 + info.score = step_count["count"] + info.max_score = 10 + info.step_observation = MagicMock() + info.step_observation.observation = f"observation {step_count['count']}" + info.action_tool_call = tool + return info + + mock_env.reset.return_value = mock_env_info + mock_env.step.side_effect = mock_env_step + mock_env.info = mock_env_info + + agent.run(mock_env, replay_actions=replay_actions) + + # Verify all 5 actions were executed in order + assert mock_env.step.call_count == 5 + calls = mock_env.step.call_args_list + for i, call in enumerate(calls): + assert call[0][0].id == f"action_{i}" + assert call[0][0].arguments == {"order": i} From 5730a223b64d58c2bd5acdd89ec7f65df1be0baa Mon Sep 17 00:00:00 2001 From: Matheus Pereira Date: Wed, 21 Jan 2026 10:26:29 -0800 Subject: [PATCH 05/10] Update pre-commit: isort 7.0.0 and black 26.1.0 --- .pre-commit-config.yaml | 4 +-- debug_gym/llms/base.py | 6 ++-- tests/gym/tools/test_grep.py | 54 ++++++++++++------------------------ tests/gym/tools/test_pdb.py | 6 ++-- 4 files changed, 24 insertions(+), 46 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 16df0e87..6b463076 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,12 @@ repos: - repo: https://github.com/pycqa/isort - rev: 5.13.2 + rev: 7.0.0 hooks: - id: isort args: ["--profile", "black", "--filter-files"] - repo: https://github.com/psf/black - rev: 24.4.2 + rev: 26.1.0 hooks: - id: black args: ["--line-length", "88", "--force-exclude", "data/*"] \ No newline at end of file diff --git a/debug_gym/llms/base.py b/debug_gym/llms/base.py index c2de0ebe..97845f3b 100644 --- a/debug_gym/llms/base.py +++ b/debug_gym/llms/base.py @@ -211,10 +211,8 @@ def __init__( # Runtime generation kwargs from experiment config (temperature, max_tokens, etc.) self.runtime_generate_kwargs = runtime_generate_kwargs or {} - self.logger.debug( - f"Using {self.model_name} with max context length of { - self.context_length:,} tokens." - ) + self.logger.debug(f"Using {self.model_name} with max context length of { + self.context_length:,} tokens.") @classmethod def instantiate( diff --git a/tests/gym/tools/test_grep.py b/tests/gym/tools/test_grep.py index b594bd6f..9107ba94 100644 --- a/tests/gym/tools/test_grep.py +++ b/tests/gym/tools/test_grep.py @@ -19,8 +19,7 @@ def _setup_grep_test_repo(base_dir): # Python files with various content with (working_dir / "main.py").open("w") as f: - f.write( - """#!/usr/bin/env python3 + f.write("""#!/usr/bin/env python3 import os import sys @@ -38,12 +37,10 @@ def method_with_bug(self): if __name__ == "__main__": hello_world() -""" - ) +""") with (working_dir / "src" / "utils.py").open("w") as f: - f.write( - """import re + f.write("""import re import json def validate_email(email): @@ -61,12 +58,10 @@ def __init__(self): def validate(self, email): return re.match(self.pattern, email) is not None -""" - ) +""") with (working_dir / "tests" / "test_utils.py").open("w") as f: - f.write( - """import pytest + f.write("""import pytest from src.utils import validate_email, EmailValidator def test_validate_email(): @@ -82,33 +77,27 @@ def test_email_validator_class(): def test_broken_function(): # This test needs to be fixed assert False # This should pass -""" - ) +""") # Configuration files with (working_dir / "config.json").open("w") as f: - f.write( - """{ + f.write("""{ "name": "test_project", "version": "1.0.0", "debug": true, "database_url": "sqlite:///test.db" -}""" - ) +}""") with (working_dir / "requirements.txt").open("w") as f: - f.write( - """pytest>=6.0.0 + f.write("""pytest>=6.0.0 requests>=2.25.0 flask>=2.0.0 sqlalchemy>=1.4.0 -""" - ) +""") # Documentation with (working_dir / "README.md").open("w") as f: - f.write( - """# Test Project + f.write("""# Test Project This is a test project for grep functionality. @@ -126,12 +115,10 @@ def test_broken_function(): ```bash pip install -r requirements.txt ``` -""" - ) +""") with (working_dir / "docs" / "api.md").open("w") as f: - f.write( - """# API Documentation + f.write("""# API Documentation ## EmailValidator Class @@ -153,8 +140,7 @@ def test_broken_function(): validator = EmailValidator() result = validator.validate("user@example.com") ``` -""" - ) +""") # Binary file (should be ignored) with (working_dir / "binary.bin").open("wb") as f: @@ -162,19 +148,16 @@ def test_broken_function(): # Log file with (working_dir / "app.log").open("w") as f: - f.write( - """2024-01-01 10:00:00 INFO Starting application + f.write("""2024-01-01 10:00:00 INFO Starting application 2024-01-01 10:00:01 DEBUG Loading configuration 2024-01-01 10:00:02 ERROR Failed to connect to database 2024-01-01 10:00:03 WARNING Retrying connection 2024-01-01 10:00:04 INFO Application started successfully -""" - ) +""") # Hidden files with (working_dir / ".gitignore").open("w") as f: - f.write( - """__pycache__/ + f.write("""__pycache__/ *.pyc *.pyo *.pyd @@ -182,8 +165,7 @@ def test_broken_function(): .venv venv/ env/ -""" - ) +""") return working_dir diff --git a/tests/gym/tools/test_pdb.py b/tests/gym/tools/test_pdb.py index 25d1fbfb..72c57163 100644 --- a/tests/gym/tools/test_pdb.py +++ b/tests/gym/tools/test_pdb.py @@ -852,8 +852,7 @@ def test_pdb_changing_entrypoint(tmp_path, setup_pdb_repo_env): # Create a simple Python script to debug with (wd / "simple_script.py").open("w") as f: - f.write( - """ + f.write(""" def main(): x = 42 print(f"Value is {x}") @@ -861,8 +860,7 @@ def main(): if __name__ == "__main__": main() -""" - ) +""") # Use entrypoint to debug the simple script instead of pytest script_entrypoint = "python -m pdb simple_script.py" From 0a68120f5157f94576da68343de4b5f21d512768 Mon Sep 17 00:00:00 2001 From: Matheus Pereira Date: Wed, 21 Jan 2026 10:32:56 -0800 Subject: [PATCH 06/10] Remove unused mock_reset function from TestBaseAgentRunReplayActions --- tests/agents/test_base_agent.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/agents/test_base_agent.py b/tests/agents/test_base_agent.py index 92762c92..9183e9bb 100644 --- a/tests/agents/test_base_agent.py +++ b/tests/agents/test_base_agent.py @@ -530,9 +530,6 @@ def test_run_uses_replay_actions_instead_of_step( # Set up env to terminate after 2 steps step_count = {"count": 0} - def mock_reset(): - return mock_env_info - def mock_step(tool, response, reasoning): step_count["count"] += 1 info = MagicMock() @@ -584,9 +581,6 @@ def test_run_switches_to_step_after_replay_exhausted( step_count = {"count": 0} - def mock_reset(): - return mock_env_info - def mock_env_step(tool, response, reasoning): step_count["count"] += 1 info = MagicMock() From 8e03eadf3c1e7bbe44dc2b338739e30bccd7892e Mon Sep 17 00:00:00 2001 From: Matheus Pereira Date: Fri, 23 Jan 2026 08:46:36 -0800 Subject: [PATCH 07/10] Enhance error handling in save_patch and save_trajectory functions; ensure trajectories and patches are saved on errors --- debug_gym/agents/utils.py | 32 +++++++++++++++++++------------- scripts/run.py | 35 ++++++++--------------------------- 2 files changed, 27 insertions(+), 40 deletions(-) diff --git a/debug_gym/agents/utils.py b/debug_gym/agents/utils.py index 18205abd..b0070098 100644 --- a/debug_gym/agents/utils.py +++ b/debug_gym/agents/utils.py @@ -138,23 +138,29 @@ def load_config(args=None): def save_patch(env, problem_path: Path, logger: DebugGymLogger): """Persist the current environment patch to disk.""" - problem_path.mkdir(parents=True, exist_ok=True) - patch_path = problem_path / "debug_gym.patch" - with open(patch_path, "w") as f: - f.write(env.patch) - - logger.debug(f"Patch saved in {patch_path}") + try: + problem_path.mkdir(parents=True, exist_ok=True) + patch_path = problem_path / "debug_gym.patch" + with open(patch_path, "w") as f: + f.write(env.patch) + logger.debug(f"Patch saved in {patch_path}") + except Exception as patch_error: + # Terminal may be unavailable (e.g., pod died), log and continue + logger.warning(f"Could not save patch: {patch_error!r}") def save_trajectory(agent, problem_path: Path, logger: DebugGymLogger): """Persist the agent trajectory to disk.""" - problem_path.mkdir(parents=True, exist_ok=True) - trajectory = agent.build_trajectory() - json_file = problem_path / "trajectory.json" - with open(json_file, "w") as f: - json.dump(trajectory, f, indent=4) - - logger.debug(f"Trajectory saved in {json_file}") + try: + problem_path.mkdir(parents=True, exist_ok=True) + trajectory = agent.build_trajectory() + json_file = problem_path / "trajectory.json" + with open(json_file, "w") as f: + json.dump(trajectory, f, indent=4) + logger.debug(f"Trajectory saved in {json_file}") + except Exception as save_error: + logger.error(f"Could not save trajectory for replay: {save_error!r}") + raise def load_trajectory( diff --git a/scripts/run.py b/scripts/run.py index be0a9597..19f89eae 100644 --- a/scripts/run.py +++ b/scripts/run.py @@ -120,30 +120,19 @@ def run_agent(args, task_name: str, task_data: dict, config: dict): ) break # Exit retry loop except (UnrecoverableTerminalError, AgentTimeoutException) as e: - # Save trajectory before retry so we can replay actions - try: - save_trajectory(agent, task_path, task_logger) - except Exception as save_error: - task_logger.error( - f"Could not save trajectory for replay: {save_error!r}" - ) - raise - - task_logger.warning( - f"Terminal lost (attempt {attempt + 1}/{max_retries}): {e}" - ) # Close the failed environment if env is not None: env.close() env = None if attempt < max_retries - 1: + # Save trajectory before retry so we can replay actions + save_trajectory(agent, task_path, task_logger) task_logger.info(f"Retrying task {task_name}...") else: task_logger.error( f"Task {task_name} failed after {max_retries} attempts." ) - task_logger.report_progress( problem_id=task_name, step=1, @@ -166,18 +155,6 @@ def run_agent(args, task_name: str, task_data: dict, config: dict): ) success = False raise - - # save trajectory - save_trajectory(agent, task_path, task_logger) - - # optionally apply patch - if config.get("save_patch", True): - try: - save_patch(env, task_path, task_logger) - except Exception as patch_error: - # Terminal may be unavailable (e.g., pod died), log and continue - task_logger.warning(f"Could not save patch: {patch_error!r}") - except Exception as e: task_logger.error( f"Task Error: {task_name} - {e!r}. Run with --very-verbose " @@ -199,10 +176,14 @@ def run_agent(args, task_name: str, task_data: dict, config: dict): success = False finally: - # Close env and cancel any pending alarm - signal.alarm(0) + # Save trajectory and patch, close env and cancel any pending alarm + if agent is not None: + save_trajectory(agent, task_path, task_logger) if env: + if config.get("save_patch", True): # optionally apply patch + save_patch(env, task_path, task_logger) env.close() + signal.alarm(0) return success From 31775084ce154da1519f0c90cf21585609f24e16 Mon Sep 17 00:00:00 2001 From: Matheus Pereira Date: Fri, 23 Jan 2026 08:50:47 -0800 Subject: [PATCH 08/10] Add tests for save_trajectory function --- tests/agents/test_utils.py | 87 +++++++++++++++++++++++++++++++++++++- 1 file changed, 86 insertions(+), 1 deletion(-) diff --git a/tests/agents/test_utils.py b/tests/agents/test_utils.py index 9934ddc6..bbf849e8 100644 --- a/tests/agents/test_utils.py +++ b/tests/agents/test_utils.py @@ -4,7 +4,7 @@ import pytest -from debug_gym.agents.utils import load_config, load_trajectory +from debug_gym.agents.utils import load_config, load_trajectory, save_trajectory from debug_gym.llms.base import LLMResponse @@ -86,6 +86,91 @@ def mock_logger(): return logger +@pytest.fixture +def mock_agent(): + """Create a mock agent with a trajectory for testing.""" + agent = MagicMock() + agent.build_trajectory.return_value = { + "log": [ + {"step_id": 0, "action": None, "content": "initial state"}, + { + "step_id": 1, + "action": { + "id": "call_1", + "name": "view", + "arguments": {"file": "test.py"}, + }, + "content": "viewed file", + "reasoning": "need to see the code", + "prompt_response_pairs": [ + { + "prompt": [{"role": "user", "content": "test"}], + "response": "viewed file", + "reasoning_response": "need to see the code", + "token_usage": {"prompt": 100, "response": 50}, + } + ], + }, + ] + } + return agent + + +class TestSaveTrajectory: + """Tests for the save_trajectory function.""" + + def test_save_trajectory_creates_directory(self, mock_agent, mock_logger, tmp_path): + """Test that save_trajectory creates the directory if it doesn't exist.""" + problem_path = tmp_path / "new_dir" / "problem1" + save_trajectory(mock_agent, problem_path, mock_logger) + + assert problem_path.exists() + assert (problem_path / "trajectory.json").exists() + + def test_save_trajectory_writes_json(self, mock_agent, mock_logger, tmp_path): + """Test that save_trajectory writes valid JSON to disk.""" + problem_path = tmp_path / "problem1" + save_trajectory(mock_agent, problem_path, mock_logger) + + json_file = problem_path / "trajectory.json" + with open(json_file) as f: + data = json.load(f) + + assert "log" in data + assert len(data["log"]) == 2 + + def test_save_trajectory_raises_on_error(self, mock_agent, mock_logger, tmp_path): + """Test that save_trajectory raises and logs error on failure.""" + mock_agent.build_trajectory.side_effect = RuntimeError("test error") + problem_path = tmp_path / "problem1" + + with pytest.raises(RuntimeError): + save_trajectory(mock_agent, problem_path, mock_logger) + + mock_logger.error.assert_called_once() + + def test_save_trajectory_overwrites_existing( + self, mock_agent, mock_logger, tmp_path + ): + """Test that save_trajectory overwrites an existing trajectory file.""" + problem_path = tmp_path / "problem1" + problem_path.mkdir(parents=True) + + # Write initial file + json_file = problem_path / "trajectory.json" + with open(json_file, "w") as f: + json.dump({"log": [{"step_id": 0}]}, f) + + # Save new trajectory + save_trajectory(mock_agent, problem_path, mock_logger) + + with open(json_file) as f: + data = json.load(f) + + # Should have the new trajectory with 2 steps + assert len(data["log"]) == 2 + + class TestLoadTrajectory: """Tests for the load_trajectory function.""" From 42fac59a2dd1ae17e006d200a8594940fe1ad31c Mon Sep 17 00:00:00 2001 From: Matheus Pereira Date: Fri, 23 Jan 2026 10:49:07 -0800 Subject: [PATCH 09/10] Implement centralized error handling in RepoEnv class to improve error management and logging --- debug_gym/gym/envs/env.py | 82 ++++++++++++++++++++++++--------------- 1 file changed, 51 insertions(+), 31 deletions(-) diff --git a/debug_gym/gym/envs/env.py b/debug_gym/gym/envs/env.py index dab34d04..05c8adee 100644 --- a/debug_gym/gym/envs/env.py +++ b/debug_gym/gym/envs/env.py @@ -415,6 +415,38 @@ def apply_gold_patch(self): f"apply_gold_patch is not implemented for {self.__class__.__name__}." ) + def _handle_error( + self, + exception: BaseException, + message: str, + action_tool_call: ToolCall, + action_content: str | None, + action_reasoning: str | None, + ) -> None: + """Handle errors by setting up the environment state and attaching env_info.""" + self.logger.error(message, exc_info=True) + self.step_observation = Observation("env", message) + self.terminated = True + self.all_observations = [self.step_observation] + self.infos = EnvInfo( + step_observation=self.step_observation, + all_observations=self.all_observations, + eval_observation=( + Observation("env", self.last_eval.output) if self.last_eval else None + ), + current_breakpoints=self.current_breakpoints(), + action_reasoning=action_reasoning, + action_content=action_content, + action_tool_call=action_tool_call, + instructions=self.instructions, + score=self.score, + max_score=self.max_score, + terminated=self.terminated, + resolved=self.resolved, + tools=self.tools, + ) + exception.env_info = self.infos + def step( self, action_tool_call: ToolCall, @@ -433,43 +465,31 @@ def step( try: # tool_kwargs is a dict, so we need to unpack it self.step_observation = triggered_tool(self, **tool_kwargs) - except KeyboardInterrupt: - self.logger.error("Step was interrupted by user.") + except KeyboardInterrupt as e: + error_message = "Step was interrupted by user." + self._handle_error( + e, + error_message, + action_tool_call, + action_content, + action_reasoning, + ) raise except UnrecoverableTerminalError as e: - fatal_message = ( - "Fatal terminal error detected. The remote execution pod is no longer " - "available, so the episode will terminate." + error_message = ( + "Fatal terminal error detected. The remote execution " + "pod is no longer available, so the episode will terminate." ) details = str(e).strip() if details: - fatal_message += f"\n{details}" - self.logger.error(fatal_message, exc_info=True) - self.step_observation = Observation("env", fatal_message) - self.terminated = True - # Return early to avoid overwriting terminated flag from last_eval - self.all_observations = [self.step_observation] - self.infos = EnvInfo( - step_observation=self.step_observation, - all_observations=self.all_observations, - eval_observation=( - Observation("env", self.last_eval.output) - if self.last_eval - else None - ), - current_breakpoints=self.current_breakpoints(), - action_reasoning=action_reasoning, - action_content=action_content, - action_tool_call=action_tool_call, - instructions=self.instructions, - score=self.score, - max_score=self.max_score, - terminated=self.terminated, - resolved=self.resolved, - tools=self.tools, + error_message += f"\n{details}" + self._handle_error( + e, + error_message, + action_tool_call, + action_content, + action_reasoning, ) - # Attach env_info to exception and re-raise to allow retry logic - e.env_info = self.infos raise except BaseException as e: error_message = ( From 583abadcb5050bf6fe362300879557aa80b4547d Mon Sep 17 00:00:00 2001 From: Matheus Pereira Date: Fri, 23 Jan 2026 10:51:32 -0800 Subject: [PATCH 10/10] Rename error handling method to clarify its purpose --- debug_gym/gym/envs/env.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/debug_gym/gym/envs/env.py b/debug_gym/gym/envs/env.py index 05c8adee..aff3e06f 100644 --- a/debug_gym/gym/envs/env.py +++ b/debug_gym/gym/envs/env.py @@ -415,7 +415,7 @@ def apply_gold_patch(self): f"apply_gold_patch is not implemented for {self.__class__.__name__}." ) - def _handle_error( + def _handle_fatal_error( self, exception: BaseException, message: str, @@ -423,7 +423,7 @@ def _handle_error( action_content: str | None, action_reasoning: str | None, ) -> None: - """Handle errors by setting up the environment state and attaching env_info.""" + """Handle fatal errors by setting up the environment state and attaching env_info.""" self.logger.error(message, exc_info=True) self.step_observation = Observation("env", message) self.terminated = True @@ -445,7 +445,9 @@ def _handle_error( resolved=self.resolved, tools=self.tools, ) + # Attach env_info to exception and re-raise to allow retry logic exception.env_info = self.infos + raise def step( self, @@ -467,14 +469,13 @@ def step( self.step_observation = triggered_tool(self, **tool_kwargs) except KeyboardInterrupt as e: error_message = "Step was interrupted by user." - self._handle_error( + self._handle_fatal_error( e, error_message, action_tool_call, action_content, action_reasoning, ) - raise except UnrecoverableTerminalError as e: error_message = ( "Fatal terminal error detected. The remote execution " @@ -483,14 +484,13 @@ def step( details = str(e).strip() if details: error_message += f"\n{details}" - self._handle_error( + self._handle_fatal_error( e, error_message, action_tool_call, action_content, action_reasoning, ) - raise except BaseException as e: error_message = ( f"Error while using tool {triggered_tool.name} "