diff --git a/examples/docker_repl_example.py b/examples/docker_repl_example.py index cb00adf..f8c1c5d 100644 --- a/examples/docker_repl_example.py +++ b/examples/docker_repl_example.py @@ -18,11 +18,11 @@ class MockLM(BaseLM): def __init__(self): super().__init__(model_name="mock") - def completion(self, prompt): + def completion(self, prompt, model=None, **kwargs): return f"Mock: {str(prompt)[:50]}" - async def acompletion(self, prompt): - return self.completion(prompt) + async def acompletion(self, prompt, model=None, **kwargs): + return self.completion(prompt, model=model, **kwargs) def get_usage_summary(self): return UsageSummary({"mock": ModelUsageSummary(1, 10, 10)}) diff --git a/examples/modal_repl_example.py b/examples/modal_repl_example.py index 2b28da1..830a129 100644 --- a/examples/modal_repl_example.py +++ b/examples/modal_repl_example.py @@ -16,10 +16,10 @@ class MockLM(BaseLM): def __init__(self): super().__init__(model_name="mock-model") - def completion(self, prompt): + def completion(self, prompt, model=None, **kwargs): return f"Mock response to: {prompt[:50]}" - async def acompletion(self, prompt): + async def acompletion(self, prompt, model=None, **kwargs): return self.completion(prompt) def get_usage_summary(self): diff --git a/rlm/clients/base_lm.py b/rlm/clients/base_lm.py index d742650..20c8c2b 100644 --- a/rlm/clients/base_lm.py +++ b/rlm/clients/base_lm.py @@ -15,11 +15,11 @@ def __init__(self, model_name: str, **kwargs): self.kwargs = kwargs @abstractmethod - def completion(self, prompt: str | dict[str, Any]) -> str: + def completion(self, prompt: str | dict[str, Any], model: str | None = None, **kwargs) -> str: raise NotImplementedError @abstractmethod - async def acompletion(self, prompt: str | dict[str, Any]) -> str: + async def acompletion(self, prompt: str | dict[str, Any], model: str | None = None, **kwargs) -> str: raise NotImplementedError @abstractmethod diff --git a/rlm/clients/openai.py b/rlm/clients/openai.py index d5fa999..eb10a0a 100644 --- a/rlm/clients/openai.py +++ b/rlm/clients/openai.py @@ -46,7 +46,9 @@ def __init__( self.model_output_tokens: dict[str, int] = defaultdict(int) self.model_total_tokens: dict[str, int] = defaultdict(int) - def completion(self, prompt: str | list[dict[str, Any]], model: str | None = None) -> str: + def completion( + self, prompt: str | list[dict[str, Any]], model: str | None = None, **kwargs + ) -> str: if isinstance(prompt, str): messages = [{"role": "user", "content": prompt}] elif isinstance(prompt, list) and all(isinstance(item, dict) for item in prompt): @@ -58,12 +60,13 @@ def completion(self, prompt: str | list[dict[str, Any]], model: str | None = Non if not model: raise ValueError("Model name is required for OpenAI client.") - response = self.client.chat.completions.create(model=model, messages=messages) + api_kwargs = {"model": model, "messages": messages, **kwargs} + response = self.client.chat.completions.create(**api_kwargs) self._track_cost(response, model) return response.choices[0].message.content async def acompletion( - self, prompt: str | list[dict[str, Any]], model: str | None = None + self, prompt: str | list[dict[str, Any]], model: str | None = None, **kwargs ) -> str: if isinstance(prompt, str): messages = [{"role": "user", "content": prompt}] @@ -76,19 +79,26 @@ async def acompletion( if not model: raise ValueError("Model name is required for OpenAI client.") - response = await self.async_client.chat.completions.create(model=model, messages=messages) + api_kwargs = {"model": model, "messages": messages, **kwargs} + response = await self.async_client.chat.completions.create(**api_kwargs) self._track_cost(response, model) return response.choices[0].message.content def _track_cost(self, response: openai.ChatCompletion, model: str): self.model_call_counts[model] += 1 - self.model_input_tokens[model] += response.usage.prompt_tokens - self.model_output_tokens[model] += response.usage.completion_tokens - self.model_total_tokens[model] += response.usage.total_tokens + + usage = getattr(response, "usage", None) + prompt_tokens = getattr(usage, "prompt_tokens", 0) if usage else 0 + completion_tokens = getattr(usage, "completion_tokens", 0) if usage else 0 + total_tokens = getattr(usage, "total_tokens", 0) if usage else 0 + + self.model_input_tokens[model] += prompt_tokens + self.model_output_tokens[model] += completion_tokens + self.model_total_tokens[model] += total_tokens # Track last call for handler to read - self.last_prompt_tokens = response.usage.prompt_tokens - self.last_completion_tokens = response.usage.completion_tokens + self.last_prompt_tokens = prompt_tokens + self.last_completion_tokens = completion_tokens def get_usage_summary(self) -> UsageSummary: model_summaries = {} diff --git a/rlm/core/comms_utils.py b/rlm/core/comms_utils.py index 07e92ae..c9cfbc1 100644 --- a/rlm/core/comms_utils.py +++ b/rlm/core/comms_utils.py @@ -28,6 +28,7 @@ class LMRequest: prompt: str | dict[str, Any] | None = None prompts: list[str | dict[str, Any]] | None = None model: str | None = None + kwargs: dict[str, Any] | None = None @property def is_batched(self) -> bool: @@ -43,6 +44,8 @@ def to_dict(self) -> dict: d["prompts"] = self.prompts if self.model is not None: d["model"] = self.model + if self.kwargs is not None: + d["kwargs"] = self.kwargs return d @classmethod @@ -52,6 +55,7 @@ def from_dict(cls, data: dict) -> "LMRequest": prompt=data.get("prompt"), prompts=data.get("prompts"), model=data.get("model"), + kwargs=data.get("kwargs"), ) @@ -220,6 +224,7 @@ def send_lm_request_batched( address: tuple[str, int], prompts: list[str | dict[str, Any]], model: str | None = None, + kwargs: dict[str, Any] | None = None, timeout: int = 300, ) -> list[LMResponse]: """Send a batched LM request and return a list of typed responses. @@ -228,13 +233,14 @@ def send_lm_request_batched( address: (host, port) tuple of LM Handler server. prompts: List of prompts to send. model: Optional model name to use. + kwargs: Optional kwargs to pass to completion calls. timeout: Socket timeout in seconds. Returns: List of LMResponse objects, one per prompt, in the same order. """ try: - request = LMRequest(prompts=prompts, model=model) + request = LMRequest(prompts=prompts, model=model, kwargs=kwargs) response_data = socket_request(address, request.to_dict(), timeout) response = LMResponse.from_dict(response_data) diff --git a/rlm/core/lm_handler.py b/rlm/core/lm_handler.py index 36d27b3..2dfe68e 100644 --- a/rlm/core/lm_handler.py +++ b/rlm/core/lm_handler.py @@ -48,7 +48,8 @@ def _handle_single(self, request: LMRequest, handler: "LMHandler") -> LMResponse client = handler.get_client(request.model) start_time = time.perf_counter() - content = client.completion(request.prompt) + kwargs = request.kwargs or {} + content = client.completion(request.prompt, model=request.model, **kwargs) end_time = time.perf_counter() usage_summary = client.get_last_usage() @@ -67,9 +68,13 @@ def _handle_batched(self, request: LMRequest, handler: "LMHandler") -> LMRespons client = handler.get_client(request.model) start_time = time.perf_counter() + kwargs = request.kwargs or {} async def run_all(): - tasks = [client.acompletion(prompt) for prompt in request.prompts] + tasks = [ + client.acompletion(prompt, model=request.model, **kwargs) + for prompt in request.prompts + ] return await asyncio.gather(*tasks) results = asyncio.run(run_all()) @@ -164,9 +169,9 @@ def stop(self): self._server = None self._thread = None - def completion(self, prompt: str, model: str | None = None) -> str: + def completion(self, prompt: str, model: str | None = None, **kwargs) -> str: """Direct completion call (for main process use).""" - return self.get_client(model).completion(prompt) + return self.get_client(model).completion(prompt, model=model, **kwargs) def __enter__(self): self.start() diff --git a/rlm/core/rlm.py b/rlm/core/rlm.py index 80a3792..20690b8 100644 --- a/rlm/core/rlm.py +++ b/rlm/core/rlm.py @@ -151,7 +151,7 @@ def _setup_prompt(self, prompt: str | dict[str, Any]) -> list[dict[str, Any]]: return message_history def completion( - self, prompt: str | dict[str, Any], root_prompt: str | None = None + self, prompt: str | dict[str, Any], root_prompt: str | None = None, **kwargs ) -> RLMChatCompletion: """ Recursive Language Model completion call. This is the main entry point for querying an RLM, and @@ -163,6 +163,7 @@ def completion( prompt: A single string or dictionary of messages to pass as context to the model. root_prompt: We allow the RLM's root LM to see a (small) prompt that the user specifies. A common example of this is if the user is asking the RLM to answer a question, we can pass the question as the root prompt. + **kwargs: Optional kwargs to pass to root LM completion calls (e.g., max_tokens, temperature). Returns: A final answer as a string. """ @@ -170,7 +171,7 @@ def completion( # If we're at max depth, the RLM is an LM, so we fallback to the regular LM. if self.depth >= self.max_depth: - return self._fallback_answer(prompt) + return self._fallback_answer(prompt, **kwargs) with self._spawn_completion_context(prompt) as (lm_handler, environment): message_history = self._setup_prompt(prompt) @@ -183,6 +184,7 @@ def completion( prompt=current_prompt, lm_handler=lm_handler, environment=environment, + **kwargs, ) # Check if RLM is done and has a final answer. @@ -219,7 +221,7 @@ def completion( # Default behavior: we run out of iterations, provide one final answer time_end = time.perf_counter() - final_answer = self._default_answer(message_history, lm_handler) + final_answer = self._default_answer(message_history, lm_handler, **kwargs) usage = lm_handler.get_usage_summary() self.verbose.print_final_answer(final_answer) self.verbose.print_summary(self.max_iterations, time_end - time_start, usage.to_dict()) @@ -238,13 +240,14 @@ def _completion_turn( prompt: str | dict[str, Any], lm_handler: LMHandler, environment: BaseEnv, + **kwargs, ) -> RLMIteration: """ Perform a single iteration of the RLM, including prompting the model and code execution + tool execution. """ iter_start = time.perf_counter() - response = lm_handler.completion(prompt) + response = lm_handler.completion(prompt, **kwargs) code_block_strs = find_code_blocks(response) code_blocks = [] @@ -260,7 +263,9 @@ def _completion_turn( iteration_time=iteration_time, ) - def _default_answer(self, message_history: list[dict[str, Any]], lm_handler: LMHandler) -> str: + def _default_answer( + self, message_history: list[dict[str, Any]], lm_handler: LMHandler, **kwargs + ) -> str: """ Default behavior if the RLM runs out of iterations and does not find a final answer. It will take the message history, and try to generate a final answer from it. @@ -271,7 +276,7 @@ def _default_answer(self, message_history: list[dict[str, Any]], lm_handler: LMH "content": "Please provide a final answer to the user's question based on the information provided.", } ] - response = lm_handler.completion(current_prompt) + response = lm_handler.completion(current_prompt, **kwargs) if self.logger: self.logger.log( @@ -285,10 +290,10 @@ def _default_answer(self, message_history: list[dict[str, Any]], lm_handler: LMH return response - def _fallback_answer(self, message: str | dict[str, Any]) -> str: + def _fallback_answer(self, message: str | dict[str, Any], **kwargs) -> str: """ Fallback behavior if the RLM is actually at max depth, and should be treated as an LM. """ client: BaseLM = get_client(self.backend, self.backend_kwargs) - response = client.completion(message) + response = client.completion(message, **kwargs) return response diff --git a/rlm/environments/docker_repl.py b/rlm/environments/docker_repl.py index 6dd8c00..3299bb5 100644 --- a/rlm/environments/docker_repl.py +++ b/rlm/environments/docker_repl.py @@ -55,7 +55,9 @@ def _handle_single(self, body: dict) -> dict: if not self.lm_handler_address: return {"error": "No LM handler configured"} - request = LMRequest(prompt=body.get("prompt"), model=body.get("model")) + request = LMRequest( + prompt=body.get("prompt"), model=body.get("model"), kwargs=body.get("kwargs") + ) response = send_lm_request(self.lm_handler_address, request) if not response.success: @@ -72,7 +74,10 @@ def _handle_batched(self, body: dict) -> dict: prompts = body.get("prompts", []) responses = send_lm_request_batched( - self.lm_handler_address, prompts, model=body.get("model") + self.lm_handler_address, + prompts, + model=body.get("model"), + kwargs=body.get("kwargs"), ) results = [] @@ -102,17 +107,23 @@ def _build_exec_script(code: str, proxy_port: int) -> str: PROXY = "http://host.docker.internal:{proxy_port}" STATE = "/workspace/state.dill" -def llm_query(prompt, model=None): +def llm_query(prompt, model=None, **kwargs): try: - r = requests.post(f"{{PROXY}}/llm_query", json={{"prompt": prompt, "model": model}}, timeout=300) + payload = {{"prompt": prompt, "model": model}} + if kwargs: + payload["kwargs"] = kwargs + r = requests.post(f"{{PROXY}}/llm_query", json=payload, timeout=300) d = r.json() return d.get("response") or f"Error: {{d.get('error')}}" except Exception as e: return f"Error: {{e}}" -def llm_query_batched(prompts, model=None): +def llm_query_batched(prompts, model=None, **kwargs): try: - r = requests.post(f"{{PROXY}}/llm_query_batched", json={{"prompts": prompts, "model": model}}, timeout=300) + payload = {{"prompts": prompts, "model": model}} + if kwargs: + payload["kwargs"] = kwargs + r = requests.post(f"{{PROXY}}/llm_query_batched", json=payload, timeout=300) d = r.json() return d.get("responses") or [f"Error: {{d.get('error')}}"] * len(prompts) except Exception as e: diff --git a/rlm/environments/local_repl.py b/rlm/environments/local_repl.py index b818380..60f3bd3 100644 --- a/rlm/environments/local_repl.py +++ b/rlm/environments/local_repl.py @@ -166,18 +166,19 @@ def _final_var(self, variable_name: str) -> str: return str(self.locals[variable_name]) return f"Error: Variable '{variable_name}' not found" - def _llm_query(self, prompt: str, model: str | None = None) -> str: + def _llm_query(self, prompt: str, model: str | None = None, **kwargs) -> str: """Query the LM via socket connection to the handler. Args: prompt: The prompt to send to the LM. model: Optional model name to use (if handler has multiple clients). + **kwargs: Optional kwargs to pass to the completion call (e.g., max_tokens, temperature). """ if not self.lm_handler_address: return "Error: No LM handler configured" try: - request = LMRequest(prompt=prompt, model=model) + request = LMRequest(prompt=prompt, model=model, kwargs=kwargs if kwargs else None) response = send_lm_request(self.lm_handler_address, request) if not response.success: @@ -192,12 +193,13 @@ def _llm_query(self, prompt: str, model: str | None = None) -> str: except Exception as e: return f"Error: LM query failed - {e}" - def _llm_query_batched(self, prompts: list[str], model: str | None = None) -> list[str]: + def _llm_query_batched(self, prompts: list[str], model: str | None = None, **kwargs) -> list[str]: """Query the LM with multiple prompts concurrently. Args: prompts: List of prompts to send to the LM. model: Optional model name to use (if handler has multiple clients). + **kwargs: Optional kwargs to pass to the completion calls (e.g., max_tokens, temperature). Returns: List of responses in the same order as input prompts. @@ -206,7 +208,7 @@ def _llm_query_batched(self, prompts: list[str], model: str | None = None) -> li return ["Error: No LM handler configured"] * len(prompts) try: - responses = send_lm_request_batched(self.lm_handler_address, prompts, model=model) + responses = send_lm_request_batched(self.lm_handler_address, prompts, model=model, kwargs=kwargs if kwargs else None) results = [] for response in responses: diff --git a/rlm/environments/modal_repl.py b/rlm/environments/modal_repl.py index 2acfed1..e4d02f4 100644 --- a/rlm/environments/modal_repl.py +++ b/rlm/environments/modal_repl.py @@ -166,12 +166,15 @@ def _build_exec_script(code: str, broker_port: int = 8080) -> str: BROKER_URL = "http://127.0.0.1:{broker_port}" -def llm_query(prompt, model=None): +def llm_query(prompt, model=None, **kwargs): """Query the LM via the broker.""" try: + payload = {{"type": "single", "prompt": prompt, "model": model}} + if kwargs: + payload["kwargs"] = kwargs response = requests.post( f"{{BROKER_URL}}/enqueue", - json={{"type": "single", "prompt": prompt, "model": model}}, + json=payload, timeout=300, ) data = response.json() @@ -182,12 +185,15 @@ def llm_query(prompt, model=None): return f"Error: LM query failed - {{e}}" -def llm_query_batched(prompts, model=None): +def llm_query_batched(prompts, model=None, **kwargs): """Query the LM with multiple prompts.""" try: + payload = {{"type": "batched", "prompts": prompts, "model": model}} + if kwargs: + payload["kwargs"] = kwargs response = requests.post( f"{{BROKER_URL}}/enqueue", - json={{"type": "batched", "prompts": prompts, "model": model}}, + json=payload, timeout=300, ) data = response.json() @@ -405,10 +411,11 @@ def _handle_llm_request(self, req_data: dict) -> dict: """Handle an LLM request from the sandbox.""" req_type = req_data.get("type") model = req_data.get("model") + kwargs = req_data.get("kwargs") if req_type == "single": prompt = req_data.get("prompt") - request = LMRequest(prompt=prompt, model=model) + request = LMRequest(prompt=prompt, model=model, kwargs=kwargs) response = send_lm_request(self.lm_handler_address, request) if not response.success: @@ -422,7 +429,9 @@ def _handle_llm_request(self, req_data: dict) -> dict: elif req_type == "batched": prompts = req_data.get("prompts", []) - responses = send_lm_request_batched(self.lm_handler_address, prompts, model=model) + responses = send_lm_request_batched( + self.lm_handler_address, prompts, model=model, kwargs=kwargs + ) results = [] for resp in responses: diff --git a/tests/clients/_openai.py b/tests/clients/_openai.py new file mode 100644 index 0000000..f5e1ba6 --- /dev/null +++ b/tests/clients/_openai.py @@ -0,0 +1,63 @@ +"""Basic tests for OpenAIClient.""" + +import os + +from dotenv import load_dotenv + +from rlm.clients.openai import OpenAIClient + +load_dotenv() + + +def test_openai_completion_with_kwargs(): + """Test that kwargs are passed through to OpenAI API.""" + api_key = os.environ.get("OPENAI_API_KEY") + base_url = os.environ.get("OPENAI_BASE_URL") + model_name = "openai/gpt-5.2" + + if not api_key: + print("Skipping test: OPENAI_API_KEY not set") + return + + client = OpenAIClient(api_key=api_key, model_name=model_name, base_url=base_url) + prompt = "What is the capital of France?" + + try: + # Test with kwargs + sample_kwargs = {"reasoning_effort": "high", "extra_body": {"usage": {"include": True}}} + result = client.completion(prompt, **sample_kwargs) + print(f"OpenAI response with {sample_kwargs=}:\n{result}") + assert result is not None + assert len(result) > 0 + except Exception as e: + print(f"OpenAIClient error: {e}") + raise + + +def test_openai_completion_without_kwargs(): + """Test that completion works without kwargs.""" + api_key = os.environ.get("OPENAI_API_KEY") + base_url = os.environ.get("OPENAI_BASE_URL") + model_name = "openai/gpt-4.1-mini" + + if not api_key: + print("Skipping test: OPENAI_API_KEY not set") + return + + client = OpenAIClient(api_key=api_key, model_name=model_name, base_url=base_url) + prompt = "What is the capital of France?" + + try: + # Test without kwargs + result = client.completion(prompt) + print(f"OpenAI response without kwargs:\n{result}") + assert result is not None + assert len(result) > 0 + except Exception as e: + print(f"OpenAIClient error: {e}") + raise + + +if __name__ == "__main__": + test_openai_completion_with_kwargs() + test_openai_completion_without_kwargs() diff --git a/tests/mock_lm.py b/tests/mock_lm.py index e4793b1..3719a9c 100644 --- a/tests/mock_lm.py +++ b/tests/mock_lm.py @@ -8,10 +8,10 @@ class MockLM(BaseLM): def __init__(self): super().__init__(model_name="mock-model") - def completion(self, prompt): + def completion(self, prompt, model=None, **kwargs): return f"Mock response to: {prompt[:50]}" - async def acompletion(self, prompt): + async def acompletion(self, prompt, model=None, **kwargs): return self.completion(prompt) def get_usage_summary(self):