Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/docker_repl_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ class MockLM(BaseLM):
def __init__(self):
super().__init__(model_name="mock")

def completion(self, prompt):
def completion(self, prompt, model=None, **kwargs):
return f"Mock: {str(prompt)[:50]}"

async def acompletion(self, prompt):
return self.completion(prompt)
async def acompletion(self, prompt, model=None, **kwargs):
return self.completion(prompt, model=model, **kwargs)

def get_usage_summary(self):
return UsageSummary({"mock": ModelUsageSummary(1, 10, 10)})
Expand Down
4 changes: 2 additions & 2 deletions examples/modal_repl_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ class MockLM(BaseLM):
def __init__(self):
super().__init__(model_name="mock-model")

def completion(self, prompt):
def completion(self, prompt, model=None, **kwargs):
return f"Mock response to: {prompt[:50]}"

async def acompletion(self, prompt):
async def acompletion(self, prompt, model=None, **kwargs):
return self.completion(prompt)

def get_usage_summary(self):
Expand Down
4 changes: 2 additions & 2 deletions rlm/clients/base_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ def __init__(self, model_name: str, **kwargs):
self.kwargs = kwargs

@abstractmethod
def completion(self, prompt: str | dict[str, Any]) -> str:
def completion(self, prompt: str | dict[str, Any], model: str | None = None, **kwargs) -> str:
raise NotImplementedError

@abstractmethod
async def acompletion(self, prompt: str | dict[str, Any]) -> str:
async def acompletion(self, prompt: str | dict[str, Any], model: str | None = None, **kwargs) -> str:
raise NotImplementedError

@abstractmethod
Expand Down
28 changes: 19 additions & 9 deletions rlm/clients/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def __init__(
self.model_output_tokens: dict[str, int] = defaultdict(int)
self.model_total_tokens: dict[str, int] = defaultdict(int)

def completion(self, prompt: str | list[dict[str, Any]], model: str | None = None) -> str:
def completion(
self, prompt: str | list[dict[str, Any]], model: str | None = None, **kwargs
) -> str:
if isinstance(prompt, str):
messages = [{"role": "user", "content": prompt}]
elif isinstance(prompt, list) and all(isinstance(item, dict) for item in prompt):
Expand All @@ -58,12 +60,13 @@ def completion(self, prompt: str | list[dict[str, Any]], model: str | None = Non
if not model:
raise ValueError("Model name is required for OpenAI client.")

response = self.client.chat.completions.create(model=model, messages=messages)
api_kwargs = {"model": model, "messages": messages, **kwargs}
response = self.client.chat.completions.create(**api_kwargs)
self._track_cost(response, model)
return response.choices[0].message.content

async def acompletion(
self, prompt: str | list[dict[str, Any]], model: str | None = None
self, prompt: str | list[dict[str, Any]], model: str | None = None, **kwargs
) -> str:
if isinstance(prompt, str):
messages = [{"role": "user", "content": prompt}]
Expand All @@ -76,19 +79,26 @@ async def acompletion(
if not model:
raise ValueError("Model name is required for OpenAI client.")

response = await self.async_client.chat.completions.create(model=model, messages=messages)
api_kwargs = {"model": model, "messages": messages, **kwargs}
response = await self.async_client.chat.completions.create(**api_kwargs)
self._track_cost(response, model)
return response.choices[0].message.content

def _track_cost(self, response: openai.ChatCompletion, model: str):
self.model_call_counts[model] += 1
self.model_input_tokens[model] += response.usage.prompt_tokens
self.model_output_tokens[model] += response.usage.completion_tokens
self.model_total_tokens[model] += response.usage.total_tokens

usage = getattr(response, "usage", None)
prompt_tokens = getattr(usage, "prompt_tokens", 0) if usage else 0
completion_tokens = getattr(usage, "completion_tokens", 0) if usage else 0
total_tokens = getattr(usage, "total_tokens", 0) if usage else 0

self.model_input_tokens[model] += prompt_tokens
self.model_output_tokens[model] += completion_tokens
self.model_total_tokens[model] += total_tokens

# Track last call for handler to read
self.last_prompt_tokens = response.usage.prompt_tokens
self.last_completion_tokens = response.usage.completion_tokens
self.last_prompt_tokens = prompt_tokens
self.last_completion_tokens = completion_tokens

def get_usage_summary(self) -> UsageSummary:
model_summaries = {}
Expand Down
8 changes: 7 additions & 1 deletion rlm/core/comms_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class LMRequest:
prompt: str | dict[str, Any] | None = None
prompts: list[str | dict[str, Any]] | None = None
model: str | None = None
kwargs: dict[str, Any] | None = None

@property
def is_batched(self) -> bool:
Expand All @@ -43,6 +44,8 @@ def to_dict(self) -> dict:
d["prompts"] = self.prompts
if self.model is not None:
d["model"] = self.model
if self.kwargs is not None:
d["kwargs"] = self.kwargs
return d

@classmethod
Expand All @@ -52,6 +55,7 @@ def from_dict(cls, data: dict) -> "LMRequest":
prompt=data.get("prompt"),
prompts=data.get("prompts"),
model=data.get("model"),
kwargs=data.get("kwargs"),
)


Expand Down Expand Up @@ -220,6 +224,7 @@ def send_lm_request_batched(
address: tuple[str, int],
prompts: list[str | dict[str, Any]],
model: str | None = None,
kwargs: dict[str, Any] | None = None,
timeout: int = 300,
) -> list[LMResponse]:
"""Send a batched LM request and return a list of typed responses.
Expand All @@ -228,13 +233,14 @@ def send_lm_request_batched(
address: (host, port) tuple of LM Handler server.
prompts: List of prompts to send.
model: Optional model name to use.
kwargs: Optional kwargs to pass to completion calls.
timeout: Socket timeout in seconds.

Returns:
List of LMResponse objects, one per prompt, in the same order.
"""
try:
request = LMRequest(prompts=prompts, model=model)
request = LMRequest(prompts=prompts, model=model, kwargs=kwargs)
response_data = socket_request(address, request.to_dict(), timeout)
response = LMResponse.from_dict(response_data)

Expand Down
13 changes: 9 additions & 4 deletions rlm/core/lm_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def _handle_single(self, request: LMRequest, handler: "LMHandler") -> LMResponse
client = handler.get_client(request.model)

start_time = time.perf_counter()
content = client.completion(request.prompt)
kwargs = request.kwargs or {}
content = client.completion(request.prompt, model=request.model, **kwargs)
end_time = time.perf_counter()

usage_summary = client.get_last_usage()
Expand All @@ -67,9 +68,13 @@ def _handle_batched(self, request: LMRequest, handler: "LMHandler") -> LMRespons
client = handler.get_client(request.model)

start_time = time.perf_counter()
kwargs = request.kwargs or {}

async def run_all():
tasks = [client.acompletion(prompt) for prompt in request.prompts]
tasks = [
client.acompletion(prompt, model=request.model, **kwargs)
for prompt in request.prompts
]
return await asyncio.gather(*tasks)

results = asyncio.run(run_all())
Expand Down Expand Up @@ -164,9 +169,9 @@ def stop(self):
self._server = None
self._thread = None

def completion(self, prompt: str, model: str | None = None) -> str:
def completion(self, prompt: str, model: str | None = None, **kwargs) -> str:
"""Direct completion call (for main process use)."""
return self.get_client(model).completion(prompt)
return self.get_client(model).completion(prompt, model=model, **kwargs)

def __enter__(self):
self.start()
Expand Down
21 changes: 13 additions & 8 deletions rlm/core/rlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _setup_prompt(self, prompt: str | dict[str, Any]) -> list[dict[str, Any]]:
return message_history

def completion(
self, prompt: str | dict[str, Any], root_prompt: str | None = None
self, prompt: str | dict[str, Any], root_prompt: str | None = None, **kwargs
) -> RLMChatCompletion:
"""
Recursive Language Model completion call. This is the main entry point for querying an RLM, and
Expand All @@ -163,14 +163,15 @@ def completion(
prompt: A single string or dictionary of messages to pass as context to the model.
root_prompt: We allow the RLM's root LM to see a (small) prompt that the user specifies. A common example of this
is if the user is asking the RLM to answer a question, we can pass the question as the root prompt.
**kwargs: Optional kwargs to pass to root LM completion calls (e.g., max_tokens, temperature).
Returns:
A final answer as a string.
"""
time_start = time.perf_counter()

# If we're at max depth, the RLM is an LM, so we fallback to the regular LM.
if self.depth >= self.max_depth:
return self._fallback_answer(prompt)
return self._fallback_answer(prompt, **kwargs)

with self._spawn_completion_context(prompt) as (lm_handler, environment):
message_history = self._setup_prompt(prompt)
Expand All @@ -183,6 +184,7 @@ def completion(
prompt=current_prompt,
lm_handler=lm_handler,
environment=environment,
**kwargs,
)

# Check if RLM is done and has a final answer.
Expand Down Expand Up @@ -219,7 +221,7 @@ def completion(

# Default behavior: we run out of iterations, provide one final answer
time_end = time.perf_counter()
final_answer = self._default_answer(message_history, lm_handler)
final_answer = self._default_answer(message_history, lm_handler, **kwargs)
usage = lm_handler.get_usage_summary()
self.verbose.print_final_answer(final_answer)
self.verbose.print_summary(self.max_iterations, time_end - time_start, usage.to_dict())
Expand All @@ -238,13 +240,14 @@ def _completion_turn(
prompt: str | dict[str, Any],
lm_handler: LMHandler,
environment: BaseEnv,
**kwargs,
) -> RLMIteration:
"""
Perform a single iteration of the RLM, including prompting the model
and code execution + tool execution.
"""
iter_start = time.perf_counter()
response = lm_handler.completion(prompt)
response = lm_handler.completion(prompt, **kwargs)
code_block_strs = find_code_blocks(response)
code_blocks = []

Expand All @@ -260,7 +263,9 @@ def _completion_turn(
iteration_time=iteration_time,
)

def _default_answer(self, message_history: list[dict[str, Any]], lm_handler: LMHandler) -> str:
def _default_answer(
self, message_history: list[dict[str, Any]], lm_handler: LMHandler, **kwargs
) -> str:
"""
Default behavior if the RLM runs out of iterations and does not find a final answer.
It will take the message history, and try to generate a final answer from it.
Expand All @@ -271,7 +276,7 @@ def _default_answer(self, message_history: list[dict[str, Any]], lm_handler: LMH
"content": "Please provide a final answer to the user's question based on the information provided.",
}
]
response = lm_handler.completion(current_prompt)
response = lm_handler.completion(current_prompt, **kwargs)

if self.logger:
self.logger.log(
Expand All @@ -285,10 +290,10 @@ def _default_answer(self, message_history: list[dict[str, Any]], lm_handler: LMH

return response

def _fallback_answer(self, message: str | dict[str, Any]) -> str:
def _fallback_answer(self, message: str | dict[str, Any], **kwargs) -> str:
"""
Fallback behavior if the RLM is actually at max depth, and should be treated as an LM.
"""
client: BaseLM = get_client(self.backend, self.backend_kwargs)
response = client.completion(message)
response = client.completion(message, **kwargs)
return response
23 changes: 17 additions & 6 deletions rlm/environments/docker_repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def _handle_single(self, body: dict) -> dict:
if not self.lm_handler_address:
return {"error": "No LM handler configured"}

request = LMRequest(prompt=body.get("prompt"), model=body.get("model"))
request = LMRequest(
prompt=body.get("prompt"), model=body.get("model"), kwargs=body.get("kwargs")
)
response = send_lm_request(self.lm_handler_address, request)

if not response.success:
Expand All @@ -72,7 +74,10 @@ def _handle_batched(self, body: dict) -> dict:

prompts = body.get("prompts", [])
responses = send_lm_request_batched(
self.lm_handler_address, prompts, model=body.get("model")
self.lm_handler_address,
prompts,
model=body.get("model"),
kwargs=body.get("kwargs"),
)

results = []
Expand Down Expand Up @@ -102,17 +107,23 @@ def _build_exec_script(code: str, proxy_port: int) -> str:
PROXY = "http://host.docker.internal:{proxy_port}"
STATE = "/workspace/state.dill"

def llm_query(prompt, model=None):
def llm_query(prompt, model=None, **kwargs):
try:
r = requests.post(f"{{PROXY}}/llm_query", json={{"prompt": prompt, "model": model}}, timeout=300)
payload = {{"prompt": prompt, "model": model}}
if kwargs:
payload["kwargs"] = kwargs
r = requests.post(f"{{PROXY}}/llm_query", json=payload, timeout=300)
d = r.json()
return d.get("response") or f"Error: {{d.get('error')}}"
except Exception as e:
return f"Error: {{e}}"

def llm_query_batched(prompts, model=None):
def llm_query_batched(prompts, model=None, **kwargs):
try:
r = requests.post(f"{{PROXY}}/llm_query_batched", json={{"prompts": prompts, "model": model}}, timeout=300)
payload = {{"prompts": prompts, "model": model}}
if kwargs:
payload["kwargs"] = kwargs
r = requests.post(f"{{PROXY}}/llm_query_batched", json=payload, timeout=300)
d = r.json()
return d.get("responses") or [f"Error: {{d.get('error')}}"] * len(prompts)
except Exception as e:
Expand Down
10 changes: 6 additions & 4 deletions rlm/environments/local_repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,18 +166,19 @@ def _final_var(self, variable_name: str) -> str:
return str(self.locals[variable_name])
return f"Error: Variable '{variable_name}' not found"

def _llm_query(self, prompt: str, model: str | None = None) -> str:
def _llm_query(self, prompt: str, model: str | None = None, **kwargs) -> str:
"""Query the LM via socket connection to the handler.

Args:
prompt: The prompt to send to the LM.
model: Optional model name to use (if handler has multiple clients).
**kwargs: Optional kwargs to pass to the completion call (e.g., max_tokens, temperature).
"""
if not self.lm_handler_address:
return "Error: No LM handler configured"

try:
request = LMRequest(prompt=prompt, model=model)
request = LMRequest(prompt=prompt, model=model, kwargs=kwargs if kwargs else None)
response = send_lm_request(self.lm_handler_address, request)

if not response.success:
Expand All @@ -192,12 +193,13 @@ def _llm_query(self, prompt: str, model: str | None = None) -> str:
except Exception as e:
return f"Error: LM query failed - {e}"

def _llm_query_batched(self, prompts: list[str], model: str | None = None) -> list[str]:
def _llm_query_batched(self, prompts: list[str], model: str | None = None, **kwargs) -> list[str]:
"""Query the LM with multiple prompts concurrently.

Args:
prompts: List of prompts to send to the LM.
model: Optional model name to use (if handler has multiple clients).
**kwargs: Optional kwargs to pass to the completion calls (e.g., max_tokens, temperature).

Returns:
List of responses in the same order as input prompts.
Expand All @@ -206,7 +208,7 @@ def _llm_query_batched(self, prompts: list[str], model: str | None = None) -> li
return ["Error: No LM handler configured"] * len(prompts)

try:
responses = send_lm_request_batched(self.lm_handler_address, prompts, model=model)
responses = send_lm_request_batched(self.lm_handler_address, prompts, model=model, kwargs=kwargs if kwargs else None)

results = []
for response in responses:
Expand Down
Loading