diff --git a/debug_gym/gym/tools/agent.py b/debug_gym/gym/tools/agent.py new file mode 100644 index 000000000..a129b1b65 --- /dev/null +++ b/debug_gym/gym/tools/agent.py @@ -0,0 +1,81 @@ +import logging + +from debug_gym.agents.history_tracker import build_history_prompt +from debug_gym.gym.entities import Observation +from debug_gym.gym.tools.tool import EnvironmentTool +from debug_gym.gym.tools.toolbox import Toolbox +from debug_gym.gym.utils import is_subdirectory, show_line_number +from debug_gym.llms.base import LLM + + +@Toolbox.register() +class AgentTool(EnvironmentTool): + name: str = "agent" + examples = [ + """agent(query="") Send a query to this specialized agent.\n""" + ] + description = ( + "This tool allows to call a large language model specialized agent, and therefore to query for information and exploit capabilities in the model.\n" + + "\n".join(examples) + ) + arguments = { + "query": { + "type": ["string"], + "description": "The request to the agent.", + }, + } + + def __init__( + self, history, llm_name=None, llm_config=None, llm_config_file_path=None + ): + self._history = history + self.logger = logging.getLogger("agent_logger") + self._llm = LLM.instantiate( + llm_name=llm_name, + llm_config=llm_config, + llm_config_file_path=llm_config_file_path, + logger=self.logger, + ) + self.llm_config = self._llm.config + super().__init__() + + def use( + self, + environment, + query: str, + ) -> Observation: + messages = build_history_prompt( + self._history.filter_out(actions=[None]), + self._llm, + False, + ) + messages.append( + { + "role": "user", + "content": query, + } + ) + llm_response = self._llm.generate(messages, [], tool_choice="none") + + return Observation( + self.name, + llm_response.response, + ) + + def __copy__(self): + cls = self.__class__ + result = cls.__new__(cls) + result.__dict__.update(self.__dict__) + return result + + def __deepcopy__(self, memo): + cls = self.__class__ + result = cls.__new__(cls) + + result.logger = logging.getLogger("agent_logger") + result._llm = LLM.instantiate( + llm_config=result.llm_config, + logger=result.logger, + ) + result._history = self._history.copy() + return result diff --git a/debug_gym/llms/base.py b/debug_gym/llms/base.py index 2682db3d9..2e433df6f 100644 --- a/debug_gym/llms/base.py +++ b/debug_gym/llms/base.py @@ -206,29 +206,21 @@ def __init__( ) @classmethod - def instantiate( + def instantiate_from_config( cls, - llm_name: str, - llm_config_file_path: str | None = None, + llm_config: LLMConfig, logger: DebugGymLogger | None = None, ) -> "LLM": """Creates an instance of the appropriate LLM class based on the configuration. Args: - llm_name: Name of the LLM model to instantiate. - llm_config_file_path: Optional path to the LLM configuration file. + llm_config: LLMConfig object containing the configuration. logger: Optional DebugGymLogger for logging. Returns: An instance of the appropriate LLM class. """ logger = logger or DebugGymLogger("debug-gym") - if llm_name == "human": - from debug_gym.llms import Human - - return Human(llm_name, logger=logger) - - llm_config = LLMConfigRegistry.from_file(llm_config_file_path)[llm_name] tags = llm_config.tags if "copilot openai" in tags: @@ -263,9 +255,40 @@ def instantiate( from debug_gym.llms import OpenAILLM klass = OpenAILLM - llm = klass(llm_name, logger=logger, llm_config=llm_config) + + llm = klass(llm_config.model, logger=logger, llm_config=llm_config) return llm + @classmethod + def instantiate( + cls, + llm_name: str | None = None, + llm_config: LLMConfig | None = None, + llm_config_file_path: str | None = None, + logger: DebugGymLogger | None = None, + ) -> "LLM": + """Creates an instance of the appropriate LLM class based on the configuration. + + Args: + llm_name: Name of the LLM model to instantiate. + llm_config_file_path: Optional path to the LLM configuration file. + logger: Optional DebugGymLogger for logging. + + Returns: + An instance of the appropriate LLM class. + """ + logger = logger or DebugGymLogger("debug-gym") + + if llm_name == "human": + from debug_gym.llms import Human + + return Human(llm_name, logger=logger) + + if llm_config is None: + llm_config = LLMConfigRegistry.from_file(llm_config_file_path)[llm_name] + + return cls.instantiate_from_config(llm_config, logger) + @abstractmethod def generate(self, messages, tools, **kwargs) -> LLMResponse: """Generate a response given some messages and return it as an LLMResponse object. diff --git a/debug_gym/llms/openai.py b/debug_gym/llms/openai.py index 0a2d2aa8b..0a60e6604 100644 --- a/debug_gym/llms/openai.py +++ b/debug_gym/llms/openai.py @@ -209,7 +209,7 @@ def generate(self, messages, tools, **kwargs) -> LLMResponse: model=self.config.model, messages=messages, tools=self.define_tools(tools), - tool_choice="auto", + tool_choice=kwargs.pop("tool_choice", "auto"), **kwargs, ) except openai.BadRequestError as e: