diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4c44ee12..6b463076 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pycqa/isort - rev: 5.13.2 + rev: 7.0.0 hooks: - id: isort args: ["--profile", "black", "--filter-files"] diff --git a/debug_gym/agents/base_agent.py b/debug_gym/agents/base_agent.py index c64ef1ce..d32a70a9 100644 --- a/debug_gym/agents/base_agent.py +++ b/debug_gym/agents/base_agent.py @@ -2,12 +2,13 @@ import os import uuid from dataclasses import MISSING, asdict, dataclass, field, fields -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List from jinja2 import Environment, FileSystemLoader, Template from debug_gym.agents.history_tracker import HistoryTracker from debug_gym.gym.envs.env import EnvInfo, RepoEnv +from debug_gym.gym.terminals.terminal import UnrecoverableTerminalError from debug_gym.gym.utils import filter_non_utf8 from debug_gym.llms.base import LLM, LLMResponse from debug_gym.llms.utils import trim @@ -293,11 +294,17 @@ def step(self, info: EnvInfo) -> LLMResponse | List[LLMResponse]: return self.llm(messages, info.tools) def execute_action(self, llm_response: LLMResponse | List[LLMResponse]) -> EnvInfo: - next_info = self.env.step( - llm_response.tool, - llm_response.response, - llm_response.reasoning_response, - ) + try: + next_info = self.env.step( + llm_response.tool, + llm_response.response, + llm_response.reasoning_response, + ) + except UnrecoverableTerminalError as e: + # Record the failed step in history if env_info is available + if e.env_info is not None: + self.history.step(e.env_info, llm_response) + raise self.history.step(next_info, llm_response) return next_info @@ -324,6 +331,7 @@ def run( env: RepoEnv, debug: bool = False, reset_env: bool = True, + replay_actions: list = None, ) -> Dict[str, Any]: """Run the agent loop until termination or max steps. @@ -331,81 +339,88 @@ def run( env: The environment to interact with. debug: Whether to drop into debugger after each LLM call. reset_env: Whether to reset the environment (default True). + replay_actions: List of LLMResponse objects to replay before continuing + with new LLM calls. Used for retry after terminal failures. Returns: The trajectory as a JSON-serializable dict. """ info = None step = 0 + replay_actions = replay_actions or [] + replay_index = 0 # assign the env self.env = env - try: - if reset_env: - info = env.reset() - else: - info = env.info - - self.init(info) - - if info.resolved: - self.logger.report_progress( - problem_id=env.task_name, - step=0, - total_steps=self.args.max_steps, - score=info.score, - max_score=info.max_score, - status="resolved", - ) - return self.build_trajectory() + if reset_env: + info = env.reset() + else: + info = env.info - highscore = info.score - should_stop = False - step = 1 + self.init(info) - while not should_stop: - self.logger.info(f"\n{'='*20} STEP {step} {'='*20}\n") + if info.resolved: + self.logger.report_progress( + problem_id=env.task_name, + step=0, + total_steps=self.args.max_steps, + score=info.score, + max_score=info.max_score, + status="resolved", + ) + return self.build_trajectory() + highscore = info.score + should_stop = False + step = 1 + + while not should_stop: + self.logger.info(f"\n{'='*20} STEP {step} {'='*20}\n") + + # Check if we should replay a previous action or generate a new one + if replay_index < len(replay_actions): + agent_response = replay_actions[replay_index] + replay_index += 1 + # Log replay details similar to replay.py + self.logger.info( + f"[REPLAY] Replaying step {replay_index} from previous attempt:\n" + f" Tool: {agent_response.tool.name}\n" + f" Args: {agent_response.tool.arguments}\n" + f" Reasoning: {agent_response.reasoning_response[:100] if agent_response.reasoning_response else None}...\n" + f" Content: {agent_response.response[:100] if agent_response.response else None}..." + ) + else: agent_response = self.step(info) - info = self.execute_action(agent_response) - if debug: - breakpoint() + info = self.execute_action(agent_response) - should_stop, reason = self.should_stop(step + 1, info) - status = ( - "resolved" - if info.resolved - else ("unresolved" if should_stop else "running") - ) + if debug: + breakpoint() + + should_stop, reason = self.should_stop(step + 1, info) + status = ( + "resolved" + if info.resolved + else ("unresolved" if should_stop else "running") + ) + + highscore = max(highscore, info.score) + msg = f"[{env.task_name[:10]:<10}] Step {step} | Score: {info.score}/{info.max_score or '-'} [Best: {highscore}]" + if should_stop: + msg += f" | Stopping Reason: {reason}" + self.logger.info(msg) + step += 1 - highscore = max(highscore, info.score) - msg = f"[{env.task_name[:10]:<10}] Step {step} | Score: {info.score}/{info.max_score or '-'} [Best: {highscore}]" - if should_stop: - msg += f" | Stopping Reason: {reason}" - self.logger.info(msg) - step += 1 - - self.logger.report_progress( - problem_id=env.task_name, - step=step, - total_steps=self.args.max_steps, - score=info.score, - max_score=info.max_score, - status=status, - ) - return self.build_trajectory() - except Exception as e: self.logger.report_progress( problem_id=env.task_name, step=step, - total_steps=step, - score=getattr(info, "score", 0), - max_score=getattr(info, "max_score", None), - status="error", + total_steps=self.args.max_steps, + score=info.score, + max_score=info.max_score, + status=status, ) - raise e + return self.build_trajectory() def create_agent(config: Dict[str, Any], **kwargs) -> BaseAgent: diff --git a/debug_gym/agents/utils.py b/debug_gym/agents/utils.py index 1c5446a2..b0070098 100644 --- a/debug_gym/agents/utils.py +++ b/debug_gym/agents/utils.py @@ -6,6 +6,8 @@ import yaml +from debug_gym.gym.tools.tool import ToolCall +from debug_gym.llms.base import LLMResponse from debug_gym.logger import DebugGymLogger @@ -90,6 +92,15 @@ def load_config(args=None): default=20, help="Maximum number of tasks to display in the progress bar.", ) + parser.add_argument( + "--max-retries", + type=int, + default=3, + help="Maximum number of retries for terminal or timeout failures" + " (e.g., spot instance eviction, container deletion). " + "Steps are replayed, including the step that failed, " + "after which new steps are generated. Default: 3.", + ) parser.add_argument( "-p", "--params", @@ -127,20 +138,100 @@ def load_config(args=None): def save_patch(env, problem_path: Path, logger: DebugGymLogger): """Persist the current environment patch to disk.""" - problem_path.mkdir(parents=True, exist_ok=True) - patch_path = problem_path / "debug_gym.patch" - with open(patch_path, "w") as f: - f.write(env.patch) - - logger.debug(f"Patch saved in {patch_path}") + try: + problem_path.mkdir(parents=True, exist_ok=True) + patch_path = problem_path / "debug_gym.patch" + with open(patch_path, "w") as f: + f.write(env.patch) + logger.debug(f"Patch saved in {patch_path}") + except Exception as patch_error: + # Terminal may be unavailable (e.g., pod died), log and continue + logger.warning(f"Could not save patch: {patch_error!r}") def save_trajectory(agent, problem_path: Path, logger: DebugGymLogger): """Persist the agent trajectory to disk.""" - problem_path.mkdir(parents=True, exist_ok=True) - trajectory = agent.build_trajectory() + try: + problem_path.mkdir(parents=True, exist_ok=True) + trajectory = agent.build_trajectory() + json_file = problem_path / "trajectory.json" + with open(json_file, "w") as f: + json.dump(trajectory, f, indent=4) + logger.debug(f"Trajectory saved in {json_file}") + except Exception as save_error: + logger.error(f"Could not save trajectory for replay: {save_error!r}") + raise + + +def load_trajectory( + problem_path: Path, logger: DebugGymLogger +) -> list[LLMResponse] | None: + """Load a previous trajectory and reconstruct LLMResponse objects for replay. + + Follows the same approach as replay.py for accurate reconstruction of LLMResponse + objects, including token counts and original prompt data from prompt_response_pairs. + + Returns a list of LLMResponse objects that can be passed to agent.execute_action(), + or None if no trajectory exists. + """ json_file = problem_path / "trajectory.json" - with open(json_file, "w") as f: - json.dump(trajectory, f, indent=4) + if not json_file.exists(): + return None + + try: + with open(json_file, "r") as f: + trajectory = json.load(f) + except (json.JSONDecodeError, IOError) as e: + logger.warning(f"Failed to load trajectory from {json_file}: {e}") + return None + + log = trajectory.get("log", []) + if not log: + return None + + llm_responses = [] + for step in log: + # Skip step 0 (initial state with no action) + if step.get("step_id") == 0 or step.get("action") is None: + continue + + # Reconstruct ToolCall from saved action + action_data = step.get("action", {}) + if not action_data: + continue + + tool_call = ToolCall( + id=action_data.get("id", ""), + name=action_data.get("name", ""), + arguments=action_data.get("arguments", {}), + ) + + # Extract data from prompt_response_pairs if available (like replay.py does) + prompt_response_pairs = step.get("prompt_response_pairs", []) + if prompt_response_pairs and len(prompt_response_pairs) > 0: + prompt_response = prompt_response_pairs[0] + token_usage = prompt_response.get("token_usage", {}) + llm_response = LLMResponse( + prompt=prompt_response.get("prompt", []), + response=prompt_response.get("response"), + reasoning_response=prompt_response.get("reasoning_response"), + tool=tool_call, + prompt_token_count=token_usage.get("prompt", 0), + response_token_count=token_usage.get("response", 0), + ) + else: + # Fallback to step-level data if prompt_response_pairs not available + llm_response = LLMResponse( + prompt=[], + response=step.get("content"), + reasoning_response=step.get("reasoning"), + tool=tool_call, + ) + llm_responses.append(llm_response) + + if llm_responses: + logger.info( + f"Loaded {len(llm_responses)} steps from previous trajectory for replay" + ) - logger.debug(f"Trajectory saved in {json_file}") + return llm_responses if llm_responses else None diff --git a/debug_gym/gym/envs/env.py b/debug_gym/gym/envs/env.py index 7c67dd84..aff3e06f 100644 --- a/debug_gym/gym/envs/env.py +++ b/debug_gym/gym/envs/env.py @@ -51,7 +51,7 @@ def __str__(self) -> str: lines.append("") # Observations section - lines.append(f"👁️ Observation:") + lines.append("👁️ Observation:") lines.append(f"```\n{self.step_observation}\n```") lines.append("") @@ -415,6 +415,40 @@ def apply_gold_patch(self): f"apply_gold_patch is not implemented for {self.__class__.__name__}." ) + def _handle_fatal_error( + self, + exception: BaseException, + message: str, + action_tool_call: ToolCall, + action_content: str | None, + action_reasoning: str | None, + ) -> None: + """Handle fatal errors by setting up the environment state and attaching env_info.""" + self.logger.error(message, exc_info=True) + self.step_observation = Observation("env", message) + self.terminated = True + self.all_observations = [self.step_observation] + self.infos = EnvInfo( + step_observation=self.step_observation, + all_observations=self.all_observations, + eval_observation=( + Observation("env", self.last_eval.output) if self.last_eval else None + ), + current_breakpoints=self.current_breakpoints(), + action_reasoning=action_reasoning, + action_content=action_content, + action_tool_call=action_tool_call, + instructions=self.instructions, + score=self.score, + max_score=self.max_score, + terminated=self.terminated, + resolved=self.resolved, + tools=self.tools, + ) + # Attach env_info to exception and re-raise to allow retry logic + exception.env_info = self.infos + raise + def step( self, action_tool_call: ToolCall, @@ -433,42 +467,30 @@ def step( try: # tool_kwargs is a dict, so we need to unpack it self.step_observation = triggered_tool(self, **tool_kwargs) - except KeyboardInterrupt: - self.logger.error("Step was interrupted by user.") - raise + except KeyboardInterrupt as e: + error_message = "Step was interrupted by user." + self._handle_fatal_error( + e, + error_message, + action_tool_call, + action_content, + action_reasoning, + ) except UnrecoverableTerminalError as e: - fatal_message = ( - "Fatal terminal error detected. The remote execution pod is no longer " - "available, so the episode will terminate." + error_message = ( + "Fatal terminal error detected. The remote execution " + "pod is no longer available, so the episode will terminate." ) details = str(e).strip() if details: - fatal_message += f"\n{details}" - self.logger.error(fatal_message, exc_info=True) - self.step_observation = Observation("env", fatal_message) - self.terminated = True - # Return early to avoid overwriting terminated flag from last_eval - self.all_observations = [self.step_observation] - self.infos = EnvInfo( - step_observation=self.step_observation, - all_observations=self.all_observations, - eval_observation=( - Observation("env", self.last_eval.output) - if self.last_eval - else None - ), - current_breakpoints=self.current_breakpoints(), - action_reasoning=action_reasoning, - action_content=action_content, - action_tool_call=action_tool_call, - instructions=self.instructions, - score=self.score, - max_score=self.max_score, - terminated=self.terminated, - resolved=self.resolved, - tools=self.tools, + error_message += f"\n{details}" + self._handle_fatal_error( + e, + error_message, + action_tool_call, + action_content, + action_reasoning, ) - return self.infos except BaseException as e: error_message = ( f"Error while using tool {triggered_tool.name} " diff --git a/debug_gym/gym/terminals/terminal.py b/debug_gym/gym/terminals/terminal.py index c5c83dc0..3008f665 100644 --- a/debug_gym/gym/terminals/terminal.py +++ b/debug_gym/gym/terminals/terminal.py @@ -14,6 +14,10 @@ class TerminalError(RuntimeError): class UnrecoverableTerminalError(TerminalError): """Raised when the terminal becomes unusable and the episode must stop.""" + def __init__(self, message: str, env_info=None): + super().__init__(message) + self.env_info = env_info + DISABLE_ECHO_COMMAND = "stty -echo" diff --git a/debug_gym/logger.py b/debug_gym/logger.py index 5b2aeb07..c14a6a81 100644 --- a/debug_gym/logger.py +++ b/debug_gym/logger.py @@ -395,7 +395,7 @@ def __init__( expand=True, ) self.total = len(problems) - self.completed = 0 + self._completed_task_ids = set() self._overall_task = self.overall_progress.add_task( "Overall", # Placeholder description, will be set by _refresh total=self.total, @@ -424,6 +424,10 @@ def __init__( ) self._listener_thread.start() + @property + def completed(self) -> int: + return len(self._completed_task_ids) + def advance(self, progress_update: TaskProgress): """Advance the progress for a specific task based on the provided update. Sets task as completed if its status is completed (e.g. early stopping).""" @@ -434,7 +438,8 @@ def advance(self, progress_update: TaskProgress): # Update the task progress self.tasks_progress.advance(progress_update) # Update overall progress completion - self.completed += 1 if progress_update.completed else 0 + if progress_update.completed: + self._completed_task_ids.add(progress_update.problem_id) def close(self): """Stop the listener thread and wait until it exits.""" diff --git a/scripts/run.py b/scripts/run.py index 32fbb730..19f89eae 100644 --- a/scripts/run.py +++ b/scripts/run.py @@ -7,9 +7,15 @@ from pathlib import Path from debug_gym.agents.base_agent import AGENT_REGISTRY, create_agent -from debug_gym.agents.utils import load_config, save_patch, save_trajectory +from debug_gym.agents.utils import ( + load_config, + load_trajectory, + save_patch, + save_trajectory, +) from debug_gym.experiment import create_env, dump_experiment_info from debug_gym.gym.envs import load_dataset +from debug_gym.gym.terminals.terminal import UnrecoverableTerminalError from debug_gym.llms.base import LLM from debug_gym.llms.human import Human from debug_gym.logger import DebugGymLogger, load_previous_run_status @@ -41,10 +47,6 @@ def run_agent(args, task_name: str, task_data: dict, config: dict): success = True env = None - # Flag to not report errors from the agent, since they report - # errors themselves and we want to avoid double reporting. - report_progress_error = True - task_path = Path(config["output_path"]) / task_name task_logger = DebugGymLogger( @@ -76,6 +78,8 @@ def run_agent(args, task_name: str, task_data: dict, config: dict): task_logger.debug(f"Skipping {task_name}, already done.") return success + max_retries = args.max_retries + task_logger.report_progress( problem_id=task_name, step=0, @@ -85,54 +89,72 @@ def run_agent(args, task_name: str, task_data: dict, config: dict): status="running", ) - env = create_env(config, task_data, task_logger) - llm = LLM.instantiate(**config.get("llm", {}), logger=task_logger) - agent = create_agent(config.get("agent", {}), llm=llm, logger=task_logger) - - try: - success = agent.run(env, debug=args.debug) - except KeyboardInterrupt: - task_logger.error("Agent run was interrupted by user.") - task_logger.report_progress( - problem_id=task_name, - step=1, - total_steps=1, - score=0, - max_score=None, - status="error", - ) - success = False - raise - except AgentTimeoutException: - task_logger.error( - f"Timeout: Problem `{task_name}` exceeded " - f"the time limit of {args.timeout} seconds." - ) - task_logger.report_progress( - problem_id=task_name, - step=1, - total_steps=1, - score=0, - max_score=None, - status="error", - ) - success = False - raise - except: - report_progress_error = False - raise - - # save trajectory - save_trajectory(agent, task_path, task_logger) - - # optionally apply patch - if config.get("save_patch", True): + # Track actions from previous attempts for replay + replay_actions = None + for attempt in range(max_retries): try: - save_patch(env, task_path, task_logger) - except Exception as patch_error: - # Terminal may be unavailable (e.g., pod died), log and continue - task_logger.warning(f"Could not save patch: {patch_error!r}") + # Load actions from previous attempt for replay on retry + if attempt > 0: + task_logger.info(f"Replaying actions from attempt {attempt}") + # Load actions from previous attempt for replay + replay_actions = load_trajectory(task_path, task_logger) + task_logger.report_progress( + problem_id=task_name, + step=0, + total_steps=1, + score=0, + max_score=None, + status="running", + ) + + env = create_env(config, task_data, task_logger) + llm = LLM.instantiate(**config.get("llm", {}), logger=task_logger) + agent = create_agent( + config.get("agent", {}), llm=llm, logger=task_logger + ) + success = agent.run( + env, + debug=args.debug, + replay_actions=replay_actions, + ) + break # Exit retry loop + except (UnrecoverableTerminalError, AgentTimeoutException) as e: + # Close the failed environment + if env is not None: + env.close() + env = None + + if attempt < max_retries - 1: + # Save trajectory before retry so we can replay actions + save_trajectory(agent, task_path, task_logger) + task_logger.info(f"Retrying task {task_name}...") + else: + task_logger.error( + f"Task {task_name} failed after {max_retries} attempts." + ) + task_logger.report_progress( + problem_id=task_name, + step=1, + total_steps=1, + score=0, + max_score=None, + status="error", + ) + success = False + raise + except KeyboardInterrupt: + task_logger.error("Agent run was interrupted by user.") + task_logger.report_progress( + problem_id=task_name, + step=1, + total_steps=1, + score=0, + max_score=None, + status="error", + ) + success = False + raise except Exception as e: task_logger.error( f"Task Error: {task_name} - {e!r}. Run with --very-verbose " @@ -141,24 +163,27 @@ def run_agent(args, task_name: str, task_data: dict, config: dict): task_logger.debug( f"Task {task_name} generated an exception: {e!r}. Traceback: {traceback.format_exc()}" ) - if report_progress_error: - task_logger.report_progress( - problem_id=task_name, - step=1, - total_steps=1, - score=0, - max_score=None, - status="error", - ) + task_logger.report_progress( + problem_id=task_name, + step=1, + total_steps=1, + score=0, + max_score=None, + status="error", + ) if args.debug: - raise e + raise success = False finally: - # Close env and cancel any pending alarm - signal.alarm(0) + # Save trajectory and patch, close env and cancel any pending alarm + if agent is not None: + save_trajectory(agent, task_path, task_logger) if env: + if config.get("save_patch", True): # optionally apply patch + save_patch(env, task_path, task_logger) env.close() + signal.alarm(0) return success @@ -177,7 +202,7 @@ def main(): # Load the dataset based on the information found in the config. if config.get("task_data") is not None: - dataset = {f"custom-task": config["task_data"]} + dataset = {"custom-task": config["task_data"]} else: dataset = load_dataset(config["dataset"], logger=logger) @@ -222,8 +247,6 @@ def main(): for problem in problems: try: run_agent(args, problem, dataset[problem], config) - except AgentTimeoutException: - pass # Handled in run_agent, just continue except (KeyboardInterrupt, Exception) as e: raise e else: @@ -242,8 +265,6 @@ def main(): try: problem = futures[future] future.result() - except AgentTimeoutException: - pass # Handled in run_agent, just continue except (KeyboardInterrupt, Exception) as e: executor.shutdown(wait=True, cancel_futures=True) raise e diff --git a/tests/agents/test_base_agent.py b/tests/agents/test_base_agent.py index 4257cce2..9183e9bb 100644 --- a/tests/agents/test_base_agent.py +++ b/tests/agents/test_base_agent.py @@ -1,16 +1,18 @@ import json -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, patch import pytest from jinja2 import Template from debug_gym.agents.base_agent import ( AGENT_REGISTRY, - AgentArgs, BaseAgent, create_agent, register_agent, ) +from debug_gym.gym.terminals.terminal import UnrecoverableTerminalError +from debug_gym.gym.tools.tool import ToolCall +from debug_gym.llms.base import LLMResponse from debug_gym.llms.human import Human @@ -335,3 +337,428 @@ def test_load_prompt_template_with_custom_loader_root(tmp_path): assert "=== Explorer ===" in rendered assert "Body content." in rendered + + +class TestExecuteAction: + """Tests for BaseAgent.execute_action method.""" + + @pytest.fixture + def agent_with_mocks(self): + """Create a BaseAgent with mocked env and llm.""" + agent = BaseAgent() + agent.env = MagicMock() + agent.llm = MagicMock() + # Initialize history with mock data + mock_info = MagicMock() + mock_info.instructions = "Test instructions" + agent.history.init( + {"role": "system", "content": "system"}, + {"role": "user", "content": "instance"}, + mock_info, + ) + return agent + + @pytest.fixture + def mock_llm_response(self): + """Create a mock LLMResponse.""" + tool_call = ToolCall( + id="call_123", + name="test_tool", + arguments={"arg1": "value1"}, + ) + return LLMResponse( + prompt=[{"role": "user", "content": "test"}], + response="test response", + reasoning_response="test reasoning", + tool=tool_call, + ) + + def test_execute_action_success(self, agent_with_mocks, mock_llm_response): + """Test that execute_action updates history on successful step.""" + agent = agent_with_mocks + mock_env_info = MagicMock() + agent.env.step.return_value = mock_env_info + + # Initial history should have 1 entry (from init) + initial_history_len = len(agent.history) + + result = agent.execute_action(mock_llm_response) + + assert result == mock_env_info + assert len(agent.history) == initial_history_len + 1 + agent.env.step.assert_called_once_with( + mock_llm_response.tool, + mock_llm_response.response, + mock_llm_response.reasoning_response, + ) + + def test_execute_action_unrecoverable_error_with_env_info( + self, agent_with_mocks, mock_llm_response + ): + """Test that history is updated when UnrecoverableTerminalError has env_info.""" + agent = agent_with_mocks + mock_env_info = MagicMock() + mock_env_info.step_observation = MagicMock() + mock_env_info.step_observation.observation = "error observation" + + error = UnrecoverableTerminalError("Terminal died", env_info=mock_env_info) + agent.env.step.side_effect = error + + initial_history_len = len(agent.history) + + with pytest.raises(UnrecoverableTerminalError) as exc_info: + agent.execute_action(mock_llm_response) + + assert exc_info.value is error + # History should be updated with the failed step + assert len(agent.history) == initial_history_len + 1 + # Verify the last llm_response in history is our mock + assert agent.history.llm_responses[-1] == mock_llm_response + # Verify the last env_observation in history is from the error + assert agent.history.env_observations[-1] == mock_env_info + + def test_execute_action_unrecoverable_error_without_env_info( + self, agent_with_mocks, mock_llm_response + ): + """Test that history is NOT updated when UnrecoverableTerminalError has no env_info.""" + agent = agent_with_mocks + + error = UnrecoverableTerminalError("Terminal died", env_info=None) + agent.env.step.side_effect = error + + initial_history_len = len(agent.history) + + with pytest.raises(UnrecoverableTerminalError) as exc_info: + agent.execute_action(mock_llm_response) + + assert exc_info.value is error + # History should NOT be updated since env_info is None + assert len(agent.history) == initial_history_len + + def test_execute_action_history_contains_correct_data( + self, agent_with_mocks, mock_llm_response + ): + """Test that history contains the correct tool call data after execute_action.""" + agent = agent_with_mocks + mock_env_info = MagicMock() + mock_env_info.step_observation = MagicMock() + mock_env_info.step_observation.observation = "tool output" + agent.env.step.return_value = mock_env_info + + agent.execute_action(mock_llm_response) + + # Verify the history contains correct llm_response data + last_llm_response = agent.history.llm_responses[-1] + assert last_llm_response.tool.id == "call_123" + assert last_llm_response.tool.name == "test_tool" + assert last_llm_response.tool.arguments == {"arg1": "value1"} + assert last_llm_response.response == "test response" + assert last_llm_response.reasoning_response == "test reasoning" + + # Verify the history contains correct env_observation + last_env_obs = agent.history.env_observations[-1] + assert last_env_obs == mock_env_info + + +class TestBaseAgentRunReplayActions: + """Tests for replay_actions parameter in BaseAgent.run().""" + + @pytest.fixture + def mock_env(self): + """Create a mock environment.""" + env = MagicMock() + env.task_name = "test_task" + env.tools = [] + env.resolved = False + return env + + @pytest.fixture + def mock_env_info(self): + """Create a mock EnvInfo.""" + info = MagicMock() + info.instructions = "Test instructions" + info.tools = [] + info.resolved = False + info.terminated = False + info.score = 0 + info.max_score = 10 + info.step_observation = MagicMock() + info.step_observation.observation = "observation" + info.action_tool_call = None + return info + + @pytest.fixture + def mock_llm(self): + """Create a mock LLM.""" + llm = MagicMock() + llm.context_length = 4096 + llm.count_tokens = lambda x: len(str(x)) + llm.define_tools = lambda x: [] + llm.convert_observation_to_message = lambda obs, **kwargs: { + "role": "user", + "content": obs if isinstance(obs, str) else str(obs), + } + llm.convert_response_to_message = lambda resp: { + "role": "assistant", + "content": resp.response, + } + return llm + + def test_run_uses_replay_actions_instead_of_step( + self, mock_env, mock_env_info, mock_llm + ): + """Test that replay_actions are used instead of calling step().""" + agent = BaseAgent(agent_args={"max_steps": 5}) + agent.llm = mock_llm + + # Create replay actions + replay_actions = [ + LLMResponse( + prompt=[], + response="replay response 1", + reasoning_response="replay reasoning 1", + tool=ToolCall(id="replay_1", name="tool_1", arguments={}), + ), + LLMResponse( + prompt=[], + response="replay response 2", + reasoning_response="replay reasoning 2", + tool=ToolCall(id="replay_2", name="tool_2", arguments={}), + ), + ] + + # Set up env to terminate after 2 steps + step_count = {"count": 0} + + def mock_step(tool, response, reasoning): + step_count["count"] += 1 + info = MagicMock() + info.instructions = "Test" + info.tools = [] + info.resolved = step_count["count"] >= 2 + info.terminated = step_count["count"] >= 2 + info.score = step_count["count"] + info.max_score = 10 + info.step_observation = MagicMock() + info.step_observation.observation = f"observation {step_count['count']}" + info.action_tool_call = tool + return info + + mock_env.reset.return_value = mock_env_info + mock_env.step.side_effect = mock_step + mock_env.info = mock_env_info + + # Mock the step method to track if it's called + agent.step = MagicMock() + + agent.run(mock_env, replay_actions=replay_actions) + + # step() should NOT have been called since we had replay actions + agent.step.assert_not_called() + + # env.step should have been called with replay action tools + assert mock_env.step.call_count == 2 + calls = mock_env.step.call_args_list + assert calls[0][0][0].id == "replay_1" + assert calls[1][0][0].id == "replay_2" + + def test_run_switches_to_step_after_replay_exhausted( + self, mock_env, mock_env_info, mock_llm + ): + """Test that run() switches to calling step() after replay_actions are exhausted.""" + agent = BaseAgent(agent_args={"max_steps": 5}) + agent.llm = mock_llm + + # Create only 1 replay action + replay_actions = [ + LLMResponse( + prompt=[], + response="replay response", + reasoning_response="replay reasoning", + tool=ToolCall(id="replay_1", name="replay_tool", arguments={}), + ), + ] + + step_count = {"count": 0} + + def mock_env_step(tool, response, reasoning): + step_count["count"] += 1 + info = MagicMock() + info.instructions = "Test" + info.tools = [] + info.resolved = step_count["count"] >= 3 + info.terminated = step_count["count"] >= 3 + info.score = step_count["count"] + info.max_score = 10 + info.step_observation = MagicMock() + info.step_observation.observation = f"observation {step_count['count']}" + info.action_tool_call = tool + return info + + mock_env.reset.return_value = mock_env_info + mock_env.step.side_effect = mock_env_step + mock_env.info = mock_env_info + + # Create new LLM responses for step() calls + new_llm_responses = [ + LLMResponse( + prompt=[], + response="new response 1", + reasoning_response="new reasoning 1", + tool=ToolCall(id="new_1", name="new_tool_1", arguments={}), + ), + LLMResponse( + prompt=[], + response="new response 2", + reasoning_response="new reasoning 2", + tool=ToolCall(id="new_2", name="new_tool_2", arguments={}), + ), + ] + agent.step = MagicMock(side_effect=new_llm_responses) + + agent.run(mock_env, replay_actions=replay_actions) + + # step() should have been called 2 times (after replay action exhausted) + assert agent.step.call_count == 2 + + # env.step should have been called 3 times total + assert mock_env.step.call_count == 3 + calls = mock_env.step.call_args_list + # First call: replay action + assert calls[0][0][0].id == "replay_1" + # Second and third calls: from step() + assert calls[1][0][0].id == "new_1" + assert calls[2][0][0].id == "new_2" + + def test_run_with_empty_replay_actions(self, mock_env, mock_env_info, mock_llm): + """Test that empty replay_actions list behaves normally.""" + agent = BaseAgent(agent_args={"max_steps": 3}) + agent.llm = mock_llm + + step_count = {"count": 0} + + def mock_env_step(tool, response, reasoning): + step_count["count"] += 1 + info = MagicMock() + info.instructions = "Test" + info.tools = [] + info.resolved = step_count["count"] >= 2 + info.terminated = step_count["count"] >= 2 + info.score = step_count["count"] + info.max_score = 10 + info.step_observation = MagicMock() + info.step_observation.observation = f"observation {step_count['count']}" + info.action_tool_call = tool + return info + + mock_env.reset.return_value = mock_env_info + mock_env.step.side_effect = mock_env_step + mock_env.info = mock_env_info + + llm_responses = [ + LLMResponse( + prompt=[], + response="response 1", + reasoning_response="reasoning 1", + tool=ToolCall(id="call_1", name="tool_1", arguments={}), + ), + LLMResponse( + prompt=[], + response="response 2", + reasoning_response="reasoning 2", + tool=ToolCall(id="call_2", name="tool_2", arguments={}), + ), + ] + agent.step = MagicMock(side_effect=llm_responses) + + # Pass empty list + agent.run(mock_env, replay_actions=[]) + + # step() should be called for all actions + assert agent.step.call_count == 2 + + def test_run_with_none_replay_actions(self, mock_env, mock_env_info, mock_llm): + """Test that None replay_actions behaves normally.""" + agent = BaseAgent(agent_args={"max_steps": 3}) + agent.llm = mock_llm + + step_count = {"count": 0} + + def mock_env_step(tool, response, reasoning): + step_count["count"] += 1 + info = MagicMock() + info.instructions = "Test" + info.tools = [] + info.resolved = step_count["count"] >= 1 + info.terminated = step_count["count"] >= 1 + info.score = step_count["count"] + info.max_score = 10 + info.step_observation = MagicMock() + info.step_observation.observation = f"observation {step_count['count']}" + info.action_tool_call = tool + return info + + mock_env.reset.return_value = mock_env_info + mock_env.step.side_effect = mock_env_step + mock_env.info = mock_env_info + + llm_response = LLMResponse( + prompt=[], + response="response", + reasoning_response="reasoning", + tool=ToolCall(id="call_1", name="tool_1", arguments={}), + ) + agent.step = MagicMock(return_value=llm_response) + + # Pass None (default) + agent.run(mock_env, replay_actions=None) + + # step() should be called + assert agent.step.call_count == 1 + + def test_replay_actions_order_preserved(self, mock_env, mock_env_info, mock_llm): + """Test that replay actions are executed in the correct order.""" + agent = BaseAgent(agent_args={"max_steps": 10}) + agent.llm = mock_llm + + # Create 5 replay actions with distinct IDs + replay_actions = [ + LLMResponse( + prompt=[], + response=f"response {i}", + reasoning_response=f"reasoning {i}", + tool=ToolCall( + id=f"action_{i}", name=f"tool_{i}", arguments={"order": i} + ), + ) + for i in range(5) + ] + + step_count = {"count": 0} + + def mock_env_step(tool, response, reasoning): + step_count["count"] += 1 + info = MagicMock() + info.instructions = "Test" + info.tools = [] + info.resolved = step_count["count"] >= 5 + info.terminated = step_count["count"] >= 5 + info.score = step_count["count"] + info.max_score = 10 + info.step_observation = MagicMock() + info.step_observation.observation = f"observation {step_count['count']}" + info.action_tool_call = tool + return info + + mock_env.reset.return_value = mock_env_info + mock_env.step.side_effect = mock_env_step + mock_env.info = mock_env_info + + agent.run(mock_env, replay_actions=replay_actions) + + # Verify all 5 actions were executed in order + assert mock_env.step.call_count == 5 + calls = mock_env.step.call_args_list + for i, call in enumerate(calls): + assert call[0][0].id == f"action_{i}" + assert call[0][0].arguments == {"order": i} diff --git a/tests/agents/test_utils.py b/tests/agents/test_utils.py index d44639b9..bbf849e8 100644 --- a/tests/agents/test_utils.py +++ b/tests/agents/test_utils.py @@ -1,7 +1,11 @@ +import json import logging -from unittest.mock import patch +from unittest.mock import MagicMock -from debug_gym.agents.utils import load_config +import pytest + +from debug_gym.agents.utils import load_config, load_trajectory, save_trajectory +from debug_gym.llms.base import LLMResponse def test_load_config(): @@ -73,3 +77,381 @@ def test_load_config(): assert _config == expected_config assert _args.debug is True assert _args.logging_level == logging.INFO + + +@pytest.fixture +def mock_logger(): + """Create a mock logger for testing.""" + logger = MagicMock() + return logger + + +@pytest.fixture +def mock_agent(): + """Create a mock agent with a trajectory for testing.""" + agent = MagicMock() + agent.build_trajectory.return_value = { + "log": [ + {"step_id": 0, "action": None, "content": "initial state"}, + { + "step_id": 1, + "action": { + "id": "call_1", + "name": "view", + "arguments": {"file": "test.py"}, + }, + "content": "viewed file", + "reasoning": "need to see the code", + "prompt_response_pairs": [ + { + "prompt": [{"role": "user", "content": "test"}], + "response": "viewed file", + "reasoning_response": "need to see the code", + "token_usage": {"prompt": 100, "response": 50}, + } + ], + }, + ] + } + return agent + + +class TestSaveTrajectory: + """Tests for the save_trajectory function.""" + + def test_save_trajectory_creates_directory(self, mock_agent, mock_logger, tmp_path): + """Test that save_trajectory creates the directory if it doesn't exist.""" + problem_path = tmp_path / "new_dir" / "problem1" + save_trajectory(mock_agent, problem_path, mock_logger) + + assert problem_path.exists() + assert (problem_path / "trajectory.json").exists() + + def test_save_trajectory_writes_json(self, mock_agent, mock_logger, tmp_path): + """Test that save_trajectory writes valid JSON to disk.""" + problem_path = tmp_path / "problem1" + save_trajectory(mock_agent, problem_path, mock_logger) + + json_file = problem_path / "trajectory.json" + with open(json_file) as f: + data = json.load(f) + + assert "log" in data + assert len(data["log"]) == 2 + + def test_save_trajectory_raises_on_error(self, mock_agent, mock_logger, tmp_path): + """Test that save_trajectory raises and logs error on failure.""" + mock_agent.build_trajectory.side_effect = RuntimeError("test error") + problem_path = tmp_path / "problem1" + + with pytest.raises(RuntimeError): + save_trajectory(mock_agent, problem_path, mock_logger) + + mock_logger.error.assert_called_once() + + def test_save_trajectory_overwrites_existing( + self, mock_agent, mock_logger, tmp_path + ): + """Test that save_trajectory overwrites an existing trajectory file.""" + problem_path = tmp_path / "problem1" + problem_path.mkdir(parents=True) + + # Write initial file + json_file = problem_path / "trajectory.json" + with open(json_file, "w") as f: + json.dump({"log": [{"step_id": 0}]}, f) + + # Save new trajectory + save_trajectory(mock_agent, problem_path, mock_logger) + + with open(json_file) as f: + data = json.load(f) + + # Should have the new trajectory with 2 steps + assert len(data["log"]) == 2 + + +class TestLoadTrajectory: + """Tests for the load_trajectory function.""" + + def test_load_trajectory_no_file(self, tmp_path, mock_logger): + """Test that load_trajectory returns None when trajectory file doesn't exist.""" + result = load_trajectory(tmp_path, mock_logger) + assert result is None + + def test_load_trajectory_empty_log(self, tmp_path, mock_logger): + """Test that load_trajectory returns None when log is empty.""" + trajectory = {"log": []} + trajectory_file = tmp_path / "trajectory.json" + with open(trajectory_file, "w") as f: + json.dump(trajectory, f) + + result = load_trajectory(tmp_path, mock_logger) + assert result is None + + def test_load_trajectory_no_log_key(self, tmp_path, mock_logger): + """Test that load_trajectory returns None when log key is missing.""" + trajectory = {"other_key": "value"} + trajectory_file = tmp_path / "trajectory.json" + with open(trajectory_file, "w") as f: + json.dump(trajectory, f) + + result = load_trajectory(tmp_path, mock_logger) + assert result is None + + def test_load_trajectory_skips_step_zero(self, tmp_path, mock_logger): + """Test that load_trajectory skips step 0 (initial state).""" + trajectory = { + "log": [ + {"step_id": 0, "action": None}, + { + "step_id": 1, + "action": { + "id": "call_1", + "name": "test_tool", + "arguments": {"arg1": "value1"}, + }, + "prompt_response_pairs": [ + { + "prompt": [{"role": "user", "content": "test"}], + "response": "test response", + "reasoning_response": None, + "token_usage": {"prompt": 10, "response": 20}, + } + ], + }, + ] + } + trajectory_file = tmp_path / "trajectory.json" + with open(trajectory_file, "w") as f: + json.dump(trajectory, f) + + result = load_trajectory(tmp_path, mock_logger) + assert result is not None + assert len(result) == 1 + assert result[0].tool.name == "test_tool" + + def test_load_trajectory_with_prompt_response_pairs(self, tmp_path, mock_logger): + """Test load_trajectory with full prompt_response_pairs data.""" + trajectory = { + "log": [ + { + "step_id": 1, + "action": { + "id": "call_123", + "name": "edit_file", + "arguments": {"file": "test.py", "content": "print('hello')"}, + }, + "prompt_response_pairs": [ + { + "prompt": [ + { + "role": "system", + "content": "You are a helpful assistant", + }, + {"role": "user", "content": "Edit the file"}, + ], + "response": "I'll edit the file for you", + "reasoning_response": "Thinking about the edit...", + "token_usage": {"prompt": 100, "response": 50}, + } + ], + } + ] + } + trajectory_file = tmp_path / "trajectory.json" + with open(trajectory_file, "w") as f: + json.dump(trajectory, f) + + result = load_trajectory(tmp_path, mock_logger) + assert result is not None + assert len(result) == 1 + + llm_response = result[0] + assert isinstance(llm_response, LLMResponse) + assert llm_response.response == "I'll edit the file for you" + assert llm_response.reasoning_response == "Thinking about the edit..." + assert llm_response.token_usage.prompt == 100 + assert llm_response.token_usage.response == 50 + assert llm_response.tool.id == "call_123" + assert llm_response.tool.name == "edit_file" + assert llm_response.tool.arguments == { + "file": "test.py", + "content": "print('hello')", + } + assert len(llm_response.prompt) == 2 + + def test_load_trajectory_fallback_without_prompt_response_pairs( + self, tmp_path, mock_logger + ): + """Test load_trajectory fallback when prompt_response_pairs is missing.""" + trajectory = { + "log": [ + { + "step_id": 1, + "action": { + "id": "call_456", + "name": "run_command", + "arguments": {"cmd": "ls"}, + }, + "content": "Running the command", + "reasoning": "Need to list files", + } + ] + } + trajectory_file = tmp_path / "trajectory.json" + with open(trajectory_file, "w") as f: + json.dump(trajectory, f) + + result = load_trajectory(tmp_path, mock_logger) + assert result is not None + assert len(result) == 1 + + llm_response = result[0] + assert llm_response.response == "Running the command" + assert llm_response.reasoning_response == "Need to list files" + assert llm_response.prompt == [] + assert llm_response.tool.name == "run_command" + + def test_load_trajectory_multiple_steps(self, tmp_path, mock_logger): + """Test load_trajectory with multiple steps.""" + trajectory = { + "log": [ + {"step_id": 0, "action": None}, + { + "step_id": 1, + "action": { + "id": "call_1", + "name": "tool_1", + "arguments": {}, + }, + "prompt_response_pairs": [ + { + "prompt": [], + "response": "response 1", + "token_usage": {"prompt": 10, "response": 5}, + } + ], + }, + { + "step_id": 2, + "action": { + "id": "call_2", + "name": "tool_2", + "arguments": {"key": "value"}, + }, + "prompt_response_pairs": [ + { + "prompt": [], + "response": "response 2", + "token_usage": {"prompt": 20, "response": 10}, + } + ], + }, + { + "step_id": 3, + "action": { + "id": "call_3", + "name": "tool_3", + "arguments": {}, + }, + "prompt_response_pairs": [ + { + "prompt": [], + "response": "response 3", + "token_usage": {"prompt": 30, "response": 15}, + } + ], + }, + ] + } + trajectory_file = tmp_path / "trajectory.json" + with open(trajectory_file, "w") as f: + json.dump(trajectory, f) + + result = load_trajectory(tmp_path, mock_logger) + assert result is not None + assert len(result) == 3 + assert result[0].tool.name == "tool_1" + assert result[1].tool.name == "tool_2" + assert result[2].tool.name == "tool_3" + mock_logger.info.assert_called_once() + + def test_load_trajectory_invalid_json(self, tmp_path, mock_logger): + """Test that load_trajectory handles invalid JSON gracefully.""" + trajectory_file = tmp_path / "trajectory.json" + with open(trajectory_file, "w") as f: + f.write("not valid json {{{") + + result = load_trajectory(tmp_path, mock_logger) + assert result is None + mock_logger.warning.assert_called_once() + + def test_load_trajectory_skips_steps_without_action(self, tmp_path, mock_logger): + """Test that load_trajectory skips steps with no action data.""" + trajectory = { + "log": [ + { + "step_id": 1, + "action": {}, # Empty action + }, + { + "step_id": 2, + # Missing action key entirely + }, + { + "step_id": 3, + "action": { + "id": "call_valid", + "name": "valid_tool", + "arguments": {}, + }, + "prompt_response_pairs": [ + { + "prompt": [], + "response": "valid response", + "token_usage": {"prompt": 5, "response": 5}, + } + ], + }, + ] + } + trajectory_file = tmp_path / "trajectory.json" + with open(trajectory_file, "w") as f: + json.dump(trajectory, f) + + result = load_trajectory(tmp_path, mock_logger) + assert result is not None + assert len(result) == 1 + assert result[0].tool.name == "valid_tool" + + def test_load_trajectory_missing_optional_fields(self, tmp_path, mock_logger): + """Test load_trajectory with missing optional fields in action.""" + trajectory = { + "log": [ + { + "step_id": 1, + "action": { + "name": "minimal_tool", + # Missing id and arguments + }, + "prompt_response_pairs": [ + { + "prompt": [], + "response": "response", + "token_usage": {"prompt": 1, "response": 1}, + } + ], + } + ] + } + trajectory_file = tmp_path / "trajectory.json" + with open(trajectory_file, "w") as f: + json.dump(trajectory, f) + + result = load_trajectory(tmp_path, mock_logger) + assert result is not None + assert len(result) == 1 + assert result[0].tool.id == "" + assert result[0].tool.name == "minimal_tool" + assert result[0].tool.arguments == {} diff --git a/tests/gym/envs/test_unrecoverable_terminal.py b/tests/gym/envs/test_unrecoverable_terminal.py index 459fdfa4..d2a9e800 100644 --- a/tests/gym/envs/test_unrecoverable_terminal.py +++ b/tests/gym/envs/test_unrecoverable_terminal.py @@ -73,7 +73,10 @@ def fatal_env(): def test_env_terminates_after_unrecoverable_terminal_error(fatal_env): tool_call = ToolCall(id="bash-1", name="bash", arguments={"command": "ls"}) - info = fatal_env.step(tool_call) + with pytest.raises(UnrecoverableTerminalError) as exc_info: + fatal_env.step(tool_call) + + info = exc_info.value.env_info assert info.terminated is True assert fatal_env.terminated is True