Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pycqa/isort
rev: 5.13.2
rev: 7.0.0
hooks:
- id: isort
args: ["--profile", "black", "--filter-files"]
Expand Down
135 changes: 75 additions & 60 deletions debug_gym/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import os
import uuid
from dataclasses import MISSING, asdict, dataclass, field, fields
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List

from jinja2 import Environment, FileSystemLoader, Template

from debug_gym.agents.history_tracker import HistoryTracker
from debug_gym.gym.envs.env import EnvInfo, RepoEnv
from debug_gym.gym.terminals.terminal import UnrecoverableTerminalError
from debug_gym.gym.utils import filter_non_utf8
from debug_gym.llms.base import LLM, LLMResponse
from debug_gym.llms.utils import trim
Expand Down Expand Up @@ -293,11 +294,17 @@ def step(self, info: EnvInfo) -> LLMResponse | List[LLMResponse]:
return self.llm(messages, info.tools)

def execute_action(self, llm_response: LLMResponse | List[LLMResponse]) -> EnvInfo:
next_info = self.env.step(
llm_response.tool,
llm_response.response,
llm_response.reasoning_response,
)
try:
next_info = self.env.step(
llm_response.tool,
llm_response.response,
llm_response.reasoning_response,
)
except UnrecoverableTerminalError as e:
# Record the failed step in history if env_info is available
if e.env_info is not None:
self.history.step(e.env_info, llm_response)
raise
self.history.step(next_info, llm_response)
return next_info

Expand All @@ -324,88 +331,96 @@ def run(
env: RepoEnv,
debug: bool = False,
reset_env: bool = True,
replay_actions: list = None,
) -> Dict[str, Any]:
"""Run the agent loop until termination or max steps.

Args:
env: The environment to interact with.
debug: Whether to drop into debugger after each LLM call.
reset_env: Whether to reset the environment (default True).
replay_actions: List of LLMResponse objects to replay before continuing
with new LLM calls. Used for retry after terminal failures.

Returns:
The trajectory as a JSON-serializable dict.
"""
info = None
step = 0
replay_actions = replay_actions or []
replay_index = 0

# assign the env
self.env = env

try:
if reset_env:
info = env.reset()
else:
info = env.info

self.init(info)

if info.resolved:
self.logger.report_progress(
problem_id=env.task_name,
step=0,
total_steps=self.args.max_steps,
score=info.score,
max_score=info.max_score,
status="resolved",
)
return self.build_trajectory()
if reset_env:
info = env.reset()
else:
info = env.info

highscore = info.score
should_stop = False
step = 1
self.init(info)

while not should_stop:
self.logger.info(f"\n{'='*20} STEP {step} {'='*20}\n")
if info.resolved:
self.logger.report_progress(
problem_id=env.task_name,
step=0,
total_steps=self.args.max_steps,
score=info.score,
max_score=info.max_score,
status="resolved",
)
return self.build_trajectory()

highscore = info.score
should_stop = False
step = 1

while not should_stop:
self.logger.info(f"\n{'='*20} STEP {step} {'='*20}\n")

# Check if we should replay a previous action or generate a new one
if replay_index < len(replay_actions):
agent_response = replay_actions[replay_index]
replay_index += 1
# Log replay details similar to replay.py
self.logger.info(
f"[REPLAY] Replaying step {replay_index} from previous attempt:\n"
f" Tool: {agent_response.tool.name}\n"
f" Args: {agent_response.tool.arguments}\n"
f" Reasoning: {agent_response.reasoning_response[:100] if agent_response.reasoning_response else None}...\n"
f" Content: {agent_response.response[:100] if agent_response.response else None}..."
)
else:
agent_response = self.step(info)
info = self.execute_action(agent_response)

if debug:
breakpoint()
info = self.execute_action(agent_response)

should_stop, reason = self.should_stop(step + 1, info)
status = (
"resolved"
if info.resolved
else ("unresolved" if should_stop else "running")
)
if debug:
breakpoint()

should_stop, reason = self.should_stop(step + 1, info)
status = (
"resolved"
if info.resolved
else ("unresolved" if should_stop else "running")
)

highscore = max(highscore, info.score)
msg = f"[{env.task_name[:10]:<10}] Step {step} | Score: {info.score}/{info.max_score or '-'} [Best: {highscore}]"
if should_stop:
msg += f" | Stopping Reason: {reason}"
self.logger.info(msg)
step += 1

highscore = max(highscore, info.score)
msg = f"[{env.task_name[:10]:<10}] Step {step} | Score: {info.score}/{info.max_score or '-'} [Best: {highscore}]"
if should_stop:
msg += f" | Stopping Reason: {reason}"
self.logger.info(msg)
step += 1

self.logger.report_progress(
problem_id=env.task_name,
step=step,
total_steps=self.args.max_steps,
score=info.score,
max_score=info.max_score,
status=status,
)
return self.build_trajectory()
except Exception as e:
self.logger.report_progress(
problem_id=env.task_name,
step=step,
total_steps=step,
score=getattr(info, "score", 0),
max_score=getattr(info, "max_score", None),
status="error",
total_steps=self.args.max_steps,
score=info.score,
max_score=info.max_score,
status=status,
)
raise e
return self.build_trajectory()


def create_agent(config: Dict[str, Any], **kwargs) -> BaseAgent:
Expand Down
113 changes: 102 additions & 11 deletions debug_gym/agents/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import yaml

from debug_gym.gym.tools.tool import ToolCall
from debug_gym.llms.base import LLMResponse
from debug_gym.logger import DebugGymLogger


Expand Down Expand Up @@ -90,6 +92,15 @@ def load_config(args=None):
default=20,
help="Maximum number of tasks to display in the progress bar.",
)
parser.add_argument(
"--max-retries",
type=int,
default=3,
help="Maximum number of retries for terminal or timeout failures"
" (e.g., spot instance eviction, container deletion). "
"Steps are replayed, including the step that failed, "
"after which new steps are generated. Default: 3.",
)
parser.add_argument(
"-p",
"--params",
Expand Down Expand Up @@ -127,20 +138,100 @@ def load_config(args=None):

def save_patch(env, problem_path: Path, logger: DebugGymLogger):
"""Persist the current environment patch to disk."""
problem_path.mkdir(parents=True, exist_ok=True)
patch_path = problem_path / "debug_gym.patch"
with open(patch_path, "w") as f:
f.write(env.patch)

logger.debug(f"Patch saved in {patch_path}")
try:
problem_path.mkdir(parents=True, exist_ok=True)
patch_path = problem_path / "debug_gym.patch"
with open(patch_path, "w") as f:
f.write(env.patch)
logger.debug(f"Patch saved in {patch_path}")
except Exception as patch_error:
# Terminal may be unavailable (e.g., pod died), log and continue
logger.warning(f"Could not save patch: {patch_error!r}")


def save_trajectory(agent, problem_path: Path, logger: DebugGymLogger):
"""Persist the agent trajectory to disk."""
problem_path.mkdir(parents=True, exist_ok=True)
trajectory = agent.build_trajectory()
try:
problem_path.mkdir(parents=True, exist_ok=True)
trajectory = agent.build_trajectory()
json_file = problem_path / "trajectory.json"
with open(json_file, "w") as f:
json.dump(trajectory, f, indent=4)
logger.debug(f"Trajectory saved in {json_file}")
except Exception as save_error:
logger.error(f"Could not save trajectory for replay: {save_error!r}")
raise


def load_trajectory(
problem_path: Path, logger: DebugGymLogger
) -> list[LLMResponse] | None:
"""Load a previous trajectory and reconstruct LLMResponse objects for replay.

Follows the same approach as replay.py for accurate reconstruction of LLMResponse
objects, including token counts and original prompt data from prompt_response_pairs.

Returns a list of LLMResponse objects that can be passed to agent.execute_action(),
or None if no trajectory exists.
"""
json_file = problem_path / "trajectory.json"
with open(json_file, "w") as f:
json.dump(trajectory, f, indent=4)
if not json_file.exists():
return None

try:
with open(json_file, "r") as f:
trajectory = json.load(f)
except (json.JSONDecodeError, IOError) as e:
logger.warning(f"Failed to load trajectory from {json_file}: {e}")
return None

log = trajectory.get("log", [])
if not log:
return None

llm_responses = []
for step in log:
# Skip step 0 (initial state with no action)
if step.get("step_id") == 0 or step.get("action") is None:
continue

# Reconstruct ToolCall from saved action
action_data = step.get("action", {})
if not action_data:
continue

tool_call = ToolCall(
id=action_data.get("id", ""),
name=action_data.get("name", ""),
arguments=action_data.get("arguments", {}),
)

# Extract data from prompt_response_pairs if available (like replay.py does)
prompt_response_pairs = step.get("prompt_response_pairs", [])
if prompt_response_pairs and len(prompt_response_pairs) > 0:
prompt_response = prompt_response_pairs[0]
token_usage = prompt_response.get("token_usage", {})
llm_response = LLMResponse(
prompt=prompt_response.get("prompt", []),
response=prompt_response.get("response"),
reasoning_response=prompt_response.get("reasoning_response"),
tool=tool_call,
prompt_token_count=token_usage.get("prompt", 0),
response_token_count=token_usage.get("response", 0),
)
else:
# Fallback to step-level data if prompt_response_pairs not available
llm_response = LLMResponse(
prompt=[],
response=step.get("content"),
reasoning_response=step.get("reasoning"),
tool=tool_call,
)
llm_responses.append(llm_response)

if llm_responses:
logger.info(
f"Loaded {len(llm_responses)} steps from previous trajectory for replay"
)

logger.debug(f"Trajectory saved in {json_file}")
return llm_responses if llm_responses else None
Loading
Loading