Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ repos:
args: ["--profile", "black", "--filter-files"]

- repo: https://github.com/psf/black
rev: 24.4.2
rev: 26.1.0
hooks:
- id: black
args: ["--line-length", "88", "--force-exclude", "data/*"]
5 changes: 3 additions & 2 deletions debug_gym/gym/terminals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ def select_terminal(
if uuid is not None:
extra_labels = {**extra_labels, "uuid": uuid}

if terminal_class is KubernetesTerminal and extra_labels:
# Pass extra_labels to terminals that support it (Kubernetes and Docker)
if terminal_class in (KubernetesTerminal, DockerTerminal) and extra_labels:
config["extra_labels"] = extra_labels

if terminal_class is not KubernetesTerminal:
if terminal_class not in (KubernetesTerminal, DockerTerminal):
config.pop("extra_labels", None)

return terminal_class(
Expand Down
11 changes: 11 additions & 0 deletions debug_gym/gym/terminals/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
registry: str = "",
setup_commands: list[str] | None = None,
command_timeout: int = 300,
extra_labels: dict[str, str] | None = None,
**kwargs,
):
"""
Expand All @@ -48,6 +49,8 @@ def __init__(
terminal_config:
type: docker
command_timeout: 60
extra_labels: Additional labels to add to the container (e.g., {"run-id": "my-run"}).
Useful for identifying containers during cleanup.
**kwargs: Additional arguments (ignored with debug log).
"""
super().__init__(
Expand All @@ -61,6 +64,7 @@ def __init__(
self.registry = registry.rstrip("/") + "/" if registry else ""
self.setup_commands = setup_commands or []
self.command_timeout = command_timeout
self.extra_labels = extra_labels or {}
self._docker_client = None # Lazily initialized
self._container = None

Expand Down Expand Up @@ -238,12 +242,19 @@ def setup_container(self) -> docker.models.containers.Container:

# Generate a unique container name
container_name = f"debug_gym_{uuid.uuid4()}"

# Build labels: always include app=debug-gym for identification
labels = {"app": "debug-gym"}
if self.extra_labels:
labels.update(self.extra_labels)

container = self.docker_client.containers.run(
name=container_name,
image=f"{self.registry}{self.base_image}",
command="sleep infinity", # Keep the container running
working_dir=self.working_dir,
environment=self.env_vars,
labels=labels,
detach=True,
auto_remove=True,
remove=True,
Expand Down
6 changes: 2 additions & 4 deletions debug_gym/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,8 @@ def __init__(
# Runtime generation kwargs from experiment config (temperature, max_tokens, etc.)
self.runtime_generate_kwargs = runtime_generate_kwargs or {}

self.logger.debug(
f"Using {self.model_name} with max context length of {
self.context_length:,} tokens."
)
self.logger.debug(f"Using {self.model_name} with max context length of {
self.context_length:,} tokens.")

@classmethod
def instantiate(
Expand Down
33 changes: 33 additions & 0 deletions tests/gym/terminals/test_terminal.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,3 +188,36 @@ def __init__(self, **kwargs):
"extra_labels": {"foo": "bar"},
"pod_spec_kwargs": {"tolerations": []},
}


def test_select_terminal_docker_extra_labels(monkeypatch):
"""Test that extra_labels are passed to DockerTerminal."""
captured = {}

class DummyDocker:
def __init__(self, **kwargs):
captured.update(kwargs)

monkeypatch.setattr(
"debug_gym.gym.terminals.DockerTerminal",
DummyDocker,
)

config = {
"type": "docker",
"base_image": "ubuntu:latest",
"extra_labels": {"run-id": "my-run"},
}

terminal = select_terminal(config, uuid="1234")

assert isinstance(terminal, DummyDocker)
assert captured["base_image"] == "ubuntu:latest"
assert captured["extra_labels"] == {"run-id": "my-run", "uuid": "1234"}
assert "logger" in captured
# Original config should not be modified
assert config == {
"type": "docker",
"base_image": "ubuntu:latest",
"extra_labels": {"run-id": "my-run"},
}
54 changes: 18 additions & 36 deletions tests/gym/tools/test_grep.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ def _setup_grep_test_repo(base_dir):

# Python files with various content
with (working_dir / "main.py").open("w") as f:
f.write(
"""#!/usr/bin/env python3
f.write("""#!/usr/bin/env python3
import os
import sys

Expand All @@ -38,12 +37,10 @@ def method_with_bug(self):

if __name__ == "__main__":
hello_world()
"""
)
""")

with (working_dir / "src" / "utils.py").open("w") as f:
f.write(
"""import re
f.write("""import re
import json

def validate_email(email):
Expand All @@ -61,12 +58,10 @@ def __init__(self):

def validate(self, email):
return re.match(self.pattern, email) is not None
"""
)
""")

with (working_dir / "tests" / "test_utils.py").open("w") as f:
f.write(
"""import pytest
f.write("""import pytest
from src.utils import validate_email, EmailValidator

def test_validate_email():
Expand All @@ -82,33 +77,27 @@ def test_email_validator_class():
def test_broken_function():
# This test needs to be fixed
assert False # This should pass
"""
)
""")

# Configuration files
with (working_dir / "config.json").open("w") as f:
f.write(
"""{
f.write("""{
"name": "test_project",
"version": "1.0.0",
"debug": true,
"database_url": "sqlite:///test.db"
}"""
)
}""")

with (working_dir / "requirements.txt").open("w") as f:
f.write(
"""pytest>=6.0.0
f.write("""pytest>=6.0.0
requests>=2.25.0
flask>=2.0.0
sqlalchemy>=1.4.0
"""
)
""")

# Documentation
with (working_dir / "README.md").open("w") as f:
f.write(
"""# Test Project
f.write("""# Test Project

This is a test project for grep functionality.

Expand All @@ -126,12 +115,10 @@ def test_broken_function():
```bash
pip install -r requirements.txt
```
"""
)
""")

with (working_dir / "docs" / "api.md").open("w") as f:
f.write(
"""# API Documentation
f.write("""# API Documentation

## EmailValidator Class

Expand All @@ -153,37 +140,32 @@ def test_broken_function():
validator = EmailValidator()
result = validator.validate("[email protected]")
```
"""
)
""")

# Binary file (should be ignored)
with (working_dir / "binary.bin").open("wb") as f:
f.write(b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09")

# Log file
with (working_dir / "app.log").open("w") as f:
f.write(
"""2024-01-01 10:00:00 INFO Starting application
f.write("""2024-01-01 10:00:00 INFO Starting application
2024-01-01 10:00:01 DEBUG Loading configuration
2024-01-01 10:00:02 ERROR Failed to connect to database
2024-01-01 10:00:03 WARNING Retrying connection
2024-01-01 10:00:04 INFO Application started successfully
"""
)
""")

# Hidden files
with (working_dir / ".gitignore").open("w") as f:
f.write(
"""__pycache__/
f.write("""__pycache__/
*.pyc
*.pyo
*.pyd
.env
.venv
venv/
env/
"""
)
""")

return working_dir

Expand Down
6 changes: 2 additions & 4 deletions tests/gym/tools/test_pdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,17 +852,15 @@ def test_pdb_changing_entrypoint(tmp_path, setup_pdb_repo_env):

# Create a simple Python script to debug
with (wd / "simple_script.py").open("w") as f:
f.write(
"""
f.write("""
def main():
x = 42
print(f"Value is {x}")
return x

if __name__ == "__main__":
main()
"""
)
""")

# Use entrypoint to debug the simple script instead of pytest
script_entrypoint = "python -m pdb simple_script.py"
Expand Down
Loading