diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 16df0e87..4c44ee12 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: args: ["--profile", "black", "--filter-files"] - repo: https://github.com/psf/black - rev: 24.4.2 + rev: 26.1.0 hooks: - id: black args: ["--line-length", "88", "--force-exclude", "data/*"] \ No newline at end of file diff --git a/debug_gym/gym/terminals/__init__.py b/debug_gym/gym/terminals/__init__.py index 1f34e1d8..547e745a 100644 --- a/debug_gym/gym/terminals/__init__.py +++ b/debug_gym/gym/terminals/__init__.py @@ -41,10 +41,11 @@ def select_terminal( if uuid is not None: extra_labels = {**extra_labels, "uuid": uuid} - if terminal_class is KubernetesTerminal and extra_labels: + # Pass extra_labels to terminals that support it (Kubernetes and Docker) + if terminal_class in (KubernetesTerminal, DockerTerminal) and extra_labels: config["extra_labels"] = extra_labels - if terminal_class is not KubernetesTerminal: + if terminal_class not in (KubernetesTerminal, DockerTerminal): config.pop("extra_labels", None) return terminal_class( diff --git a/debug_gym/gym/terminals/docker.py b/debug_gym/gym/terminals/docker.py index c3d0a6b8..e91e2c47 100644 --- a/debug_gym/gym/terminals/docker.py +++ b/debug_gym/gym/terminals/docker.py @@ -31,6 +31,7 @@ def __init__( registry: str = "", setup_commands: list[str] | None = None, command_timeout: int = 300, + extra_labels: dict[str, str] | None = None, **kwargs, ): """ @@ -48,6 +49,8 @@ def __init__( terminal_config: type: docker command_timeout: 60 + extra_labels: Additional labels to add to the container (e.g., {"run-id": "my-run"}). + Useful for identifying containers during cleanup. **kwargs: Additional arguments (ignored with debug log). """ super().__init__( @@ -61,6 +64,7 @@ def __init__( self.registry = registry.rstrip("/") + "/" if registry else "" self.setup_commands = setup_commands or [] self.command_timeout = command_timeout + self.extra_labels = extra_labels or {} self._docker_client = None # Lazily initialized self._container = None @@ -238,12 +242,19 @@ def setup_container(self) -> docker.models.containers.Container: # Generate a unique container name container_name = f"debug_gym_{uuid.uuid4()}" + + # Build labels: always include app=debug-gym for identification + labels = {"app": "debug-gym"} + if self.extra_labels: + labels.update(self.extra_labels) + container = self.docker_client.containers.run( name=container_name, image=f"{self.registry}{self.base_image}", command="sleep infinity", # Keep the container running working_dir=self.working_dir, environment=self.env_vars, + labels=labels, detach=True, auto_remove=True, remove=True, diff --git a/debug_gym/llms/base.py b/debug_gym/llms/base.py index c2de0ebe..97845f3b 100644 --- a/debug_gym/llms/base.py +++ b/debug_gym/llms/base.py @@ -211,10 +211,8 @@ def __init__( # Runtime generation kwargs from experiment config (temperature, max_tokens, etc.) self.runtime_generate_kwargs = runtime_generate_kwargs or {} - self.logger.debug( - f"Using {self.model_name} with max context length of { - self.context_length:,} tokens." - ) + self.logger.debug(f"Using {self.model_name} with max context length of { + self.context_length:,} tokens.") @classmethod def instantiate( diff --git a/tests/gym/terminals/test_terminal.py b/tests/gym/terminals/test_terminal.py index eb2ac54c..fa452c64 100644 --- a/tests/gym/terminals/test_terminal.py +++ b/tests/gym/terminals/test_terminal.py @@ -188,3 +188,36 @@ def __init__(self, **kwargs): "extra_labels": {"foo": "bar"}, "pod_spec_kwargs": {"tolerations": []}, } + + +def test_select_terminal_docker_extra_labels(monkeypatch): + """Test that extra_labels are passed to DockerTerminal.""" + captured = {} + + class DummyDocker: + def __init__(self, **kwargs): + captured.update(kwargs) + + monkeypatch.setattr( + "debug_gym.gym.terminals.DockerTerminal", + DummyDocker, + ) + + config = { + "type": "docker", + "base_image": "ubuntu:latest", + "extra_labels": {"run-id": "my-run"}, + } + + terminal = select_terminal(config, uuid="1234") + + assert isinstance(terminal, DummyDocker) + assert captured["base_image"] == "ubuntu:latest" + assert captured["extra_labels"] == {"run-id": "my-run", "uuid": "1234"} + assert "logger" in captured + # Original config should not be modified + assert config == { + "type": "docker", + "base_image": "ubuntu:latest", + "extra_labels": {"run-id": "my-run"}, + } diff --git a/tests/gym/tools/test_grep.py b/tests/gym/tools/test_grep.py index b594bd6f..9107ba94 100644 --- a/tests/gym/tools/test_grep.py +++ b/tests/gym/tools/test_grep.py @@ -19,8 +19,7 @@ def _setup_grep_test_repo(base_dir): # Python files with various content with (working_dir / "main.py").open("w") as f: - f.write( - """#!/usr/bin/env python3 + f.write("""#!/usr/bin/env python3 import os import sys @@ -38,12 +37,10 @@ def method_with_bug(self): if __name__ == "__main__": hello_world() -""" - ) +""") with (working_dir / "src" / "utils.py").open("w") as f: - f.write( - """import re + f.write("""import re import json def validate_email(email): @@ -61,12 +58,10 @@ def __init__(self): def validate(self, email): return re.match(self.pattern, email) is not None -""" - ) +""") with (working_dir / "tests" / "test_utils.py").open("w") as f: - f.write( - """import pytest + f.write("""import pytest from src.utils import validate_email, EmailValidator def test_validate_email(): @@ -82,33 +77,27 @@ def test_email_validator_class(): def test_broken_function(): # This test needs to be fixed assert False # This should pass -""" - ) +""") # Configuration files with (working_dir / "config.json").open("w") as f: - f.write( - """{ + f.write("""{ "name": "test_project", "version": "1.0.0", "debug": true, "database_url": "sqlite:///test.db" -}""" - ) +}""") with (working_dir / "requirements.txt").open("w") as f: - f.write( - """pytest>=6.0.0 + f.write("""pytest>=6.0.0 requests>=2.25.0 flask>=2.0.0 sqlalchemy>=1.4.0 -""" - ) +""") # Documentation with (working_dir / "README.md").open("w") as f: - f.write( - """# Test Project + f.write("""# Test Project This is a test project for grep functionality. @@ -126,12 +115,10 @@ def test_broken_function(): ```bash pip install -r requirements.txt ``` -""" - ) +""") with (working_dir / "docs" / "api.md").open("w") as f: - f.write( - """# API Documentation + f.write("""# API Documentation ## EmailValidator Class @@ -153,8 +140,7 @@ def test_broken_function(): validator = EmailValidator() result = validator.validate("user@example.com") ``` -""" - ) +""") # Binary file (should be ignored) with (working_dir / "binary.bin").open("wb") as f: @@ -162,19 +148,16 @@ def test_broken_function(): # Log file with (working_dir / "app.log").open("w") as f: - f.write( - """2024-01-01 10:00:00 INFO Starting application + f.write("""2024-01-01 10:00:00 INFO Starting application 2024-01-01 10:00:01 DEBUG Loading configuration 2024-01-01 10:00:02 ERROR Failed to connect to database 2024-01-01 10:00:03 WARNING Retrying connection 2024-01-01 10:00:04 INFO Application started successfully -""" - ) +""") # Hidden files with (working_dir / ".gitignore").open("w") as f: - f.write( - """__pycache__/ + f.write("""__pycache__/ *.pyc *.pyo *.pyd @@ -182,8 +165,7 @@ def test_broken_function(): .venv venv/ env/ -""" - ) +""") return working_dir diff --git a/tests/gym/tools/test_pdb.py b/tests/gym/tools/test_pdb.py index 25d1fbfb..72c57163 100644 --- a/tests/gym/tools/test_pdb.py +++ b/tests/gym/tools/test_pdb.py @@ -852,8 +852,7 @@ def test_pdb_changing_entrypoint(tmp_path, setup_pdb_repo_env): # Create a simple Python script to debug with (wd / "simple_script.py").open("w") as f: - f.write( - """ + f.write(""" def main(): x = 42 print(f"Value is {x}") @@ -861,8 +860,7 @@ def main(): if __name__ == "__main__": main() -""" - ) +""") # Use entrypoint to debug the simple script instead of pytest script_entrypoint = "python -m pdb simple_script.py"