Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion debug_gym/agents/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import yaml

from debug_gym.agents.base_agent import AGENT_REGISTRY
from debug_gym.logger import DebugGymLogger


Expand Down
14 changes: 3 additions & 11 deletions debug_gym/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,17 +222,16 @@ def instantiate(
name: str | None = None,
llm_config_file_path: str | None = None,
logger: DebugGymLogger | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
**runtime_generate_kwargs,
) -> "LLM":
"""Creates an instance of the appropriate LLM class based on the configuration.

Args:
name: Name of the LLM model (e.g., "gpt-4o", "claude-3.7").
llm_config_file_path: Optional path to the LLM configuration file.
logger: Optional DebugGymLogger for logging.
temperature: Optional temperature for generation.
max_tokens: Optional max tokens for generation.
**runtime_generate_kwargs: Additional generation kwargs to pass to the LLM.
e.g. temperature, max_tokens, tool_choice ("auto", "required", "none")

Returns:
An instance of the appropriate LLM class.
Expand All @@ -243,13 +242,6 @@ def instantiate(
if not name:
return None

# Build runtime generation kwargs from explicit args
runtime_generate_kwargs = {}
if temperature is not None:
runtime_generate_kwargs["temperature"] = temperature
if max_tokens is not None:
runtime_generate_kwargs["max_tokens"] = max_tokens

if name == "human":
from debug_gym.llms import Human

Expand Down
10 changes: 10 additions & 0 deletions debug_gym/llms/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from transformers import AutoTokenizer

from debug_gym.llms.base import LLMResponse
from debug_gym.llms.openai import OpenAILLM


Expand Down Expand Up @@ -71,3 +72,12 @@ def tokenize(self, messages: list[dict]) -> list[list[str]]:
tokens = tokenizer.tokenize(content)
result.append(tokens)
return result

def generate(
self, messages, tools, tool_choice="required", **kwargs
) -> LLMResponse:
"""Override default tool_choice parameter to "required"."""
llm_response = super().generate(
messages, tools, tool_choice=tool_choice, **kwargs
)
return llm_response
4 changes: 2 additions & 2 deletions debug_gym/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def convert_observation_to_message(
"content": filter_non_utf8(observation),
}

def generate(self, messages, tools, **kwargs) -> LLMResponse:
def generate(self, messages, tools, tool_choice="auto", **kwargs) -> LLMResponse:
# set max tokens if not provided
kwargs["max_tokens"] = kwargs.get("max_tokens", NOT_GIVEN)
api_call = retry_on_exception(
Expand All @@ -296,7 +296,7 @@ def generate(self, messages, tools, **kwargs) -> LLMResponse:
model=self.config.model,
messages=messages,
tools=self.define_tools(tools),
tool_choice="auto",
tool_choice=tool_choice,
**kwargs,
)
else:
Expand Down
4 changes: 2 additions & 2 deletions tests/llms/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ def test_instantiate_llm(mock_open, logger_mock):
assert llm.runtime_generate_kwargs == {"temperature": 0.5, "max_tokens": 1000}

# Test with **kwargs unpacking (like config)
llm_config = {"name": "gpt-4o-mini", "temperature": 0.7}
llm_config = {"name": "gpt-4o-mini", "temperature": 0.7, "tool_choice": "auto"}
llm = LLM.instantiate(**llm_config, logger=logger_mock)
assert isinstance(llm, OpenAILLM)
assert llm.runtime_generate_kwargs == {"temperature": 0.7}
assert llm.runtime_generate_kwargs == {"temperature": 0.7, "tool_choice": "auto"}


class Tool1(EnvironmentTool):
Expand Down
Loading