Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 101 additions & 11 deletions rlm/core/rlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
RLMIteration,
RLMMetadata,
)
from rlm.environments import BaseEnv, get_environment
from rlm.environments import BaseEnv, SupportsPersistence, get_environment
from rlm.logger import RLMLogger, VerbosePrinter
from rlm.utils.parsing import (
find_code_blocks,
Expand Down Expand Up @@ -51,6 +51,7 @@ def __init__(
other_backend_kwargs: list[dict[str, Any]] | None = None,
logger: RLMLogger | None = None,
verbose: bool = False,
persistent: bool = False,
):
"""
Args:
Expand All @@ -66,6 +67,7 @@ def __init__(
other_backend_kwargs: The kwargs to pass to the other client backends (ordered to match other_backends).
logger: The logger to use for the RLM.
verbose: Whether to print verbose output in rich to console.
persistent: If True, reuse the environment across completion() calls for multi-turn conversations.
"""
# Store config for spawning per-completion
self.backend = backend
Expand All @@ -84,6 +86,14 @@ def __init__(
self.logger = logger
self.verbose = VerbosePrinter(enabled=verbose)

# Persistence support
self.persistent = persistent
self._persistent_env: SupportsPersistence | None = None

# Validate persistence support at initialization
if self.persistent:
self._validate_persistent_environment_support()

# Log metadata if logger is provided
if self.logger or verbose:
metadata = RLMMetadata(
Expand All @@ -108,7 +118,9 @@ def __init__(
def _spawn_completion_context(self, prompt: str | dict[str, Any]):
"""
Spawn an LM handler and environment for a single completion call.
Cleans up both when the context exits.

When persistent=True, the environment is reused across calls.
When persistent=False (default), creates fresh environment each call.
"""
# Create client and wrap in handler
client: BaseLM = get_client(self.backend, self.backend_kwargs)
Expand All @@ -122,20 +134,32 @@ def _spawn_completion_context(self, prompt: str | dict[str, Any]):

lm_handler.start()

# Pass handler address to environment so it can make llm_query() calls
env_kwargs = self.environment_kwargs.copy()
env_kwargs["lm_handler_address"] = (lm_handler.host, lm_handler.port)
env_kwargs["context_payload"] = prompt
# Environment: reuse if persistent, otherwise create fresh
if self.persistent and self._persistent_env is not None:
environment = self._persistent_env
# Defensive check: ensure environment supports persistence methods
if not self._env_supports_persistence(environment):
raise RuntimeError(
f"Persistent environment of type '{type(environment).__name__}' does not "
f"implement required methods (update_handler_address, add_context, get_context_count). "
f"This should have been caught at initialization."
)
environment.update_handler_address((lm_handler.host, lm_handler.port))
environment.add_context(prompt)
else:
env_kwargs = self.environment_kwargs.copy()
env_kwargs["lm_handler_address"] = (lm_handler.host, lm_handler.port)
env_kwargs["context_payload"] = prompt
environment: BaseEnv = get_environment(self.environment_type, env_kwargs)

# Initialize the environment
environment: BaseEnv = get_environment(self.environment_type, env_kwargs)
if self.persistent:
self._persistent_env = environment

try:
yield lm_handler, environment
finally:
# Cleanup
lm_handler.stop()
if hasattr(environment, "cleanup"):
if not self.persistent and hasattr(environment, "cleanup"):
environment.cleanup()

def _setup_prompt(self, prompt: str | dict[str, Any]) -> list[dict[str, Any]]:
Expand Down Expand Up @@ -177,7 +201,19 @@ def completion(

for i in range(self.max_iterations):
# Current prompt = message history + additional prompt suffix
current_prompt = message_history + [build_user_prompt(root_prompt, i)]
context_count = (
environment.get_context_count()
if isinstance(environment, SupportsPersistence)
else 1
)
history_count = (
environment.get_history_count()
if isinstance(environment, SupportsPersistence)
else 0
)
current_prompt = message_history + [
build_user_prompt(root_prompt, i, context_count, history_count)
]

iteration: RLMIteration = self._completion_turn(
prompt=current_prompt,
Expand All @@ -201,6 +237,11 @@ def completion(
usage = lm_handler.get_usage_summary()
self.verbose.print_final_answer(final_answer)
self.verbose.print_summary(i + 1, time_end - time_start, usage.to_dict())

# Store message history in persistent environment
if self.persistent and isinstance(environment, SupportsPersistence):
environment.add_history(message_history)

return RLMChatCompletion(
root_model=self.backend_kwargs.get("model_name", "unknown")
if self.backend_kwargs
Expand All @@ -223,6 +264,11 @@ def completion(
usage = lm_handler.get_usage_summary()
self.verbose.print_final_answer(final_answer)
self.verbose.print_summary(self.max_iterations, time_end - time_start, usage.to_dict())

# Store message history in persistent environment
if self.persistent and isinstance(environment, SupportsPersistence):
environment.add_history(message_history)

return RLMChatCompletion(
root_model=self.backend_kwargs.get("model_name", "unknown")
if self.backend_kwargs
Expand Down Expand Up @@ -292,3 +338,47 @@ def _fallback_answer(self, message: str | dict[str, Any]) -> str:
client: BaseLM = get_client(self.backend, self.backend_kwargs)
response = client.completion(message)
return response

def _validate_persistent_environment_support(self) -> None:
"""
Validate that the configured environment type supports persistent mode.

Persistent mode requires environments to implement:
- update_handler_address(address): Update LM handler address between calls
- add_context(payload, index): Add new context for multi-turn conversations
- get_context_count(): Return the number of loaded contexts

Currently only 'local' (LocalREPL) supports these methods.

Raises:
ValueError: If the environment type does not support persistent mode.
"""
# Known environments that support persistence
persistent_supported_environments = {"local"}

if self.environment_type not in persistent_supported_environments:
raise ValueError(
f"persistent=True is not supported for environment type '{self.environment_type}'. "
f"Persistent mode requires environments that implement update_handler_address(), "
f"add_context(), and get_context_count(). "
f"Supported environments: {sorted(persistent_supported_environments)}"
)

@staticmethod
def _env_supports_persistence(env: BaseEnv) -> bool:
"""Check if an environment instance supports persistent mode methods."""
return isinstance(env, SupportsPersistence)

def close(self) -> None:
"""Clean up persistent environment. Call when done with multi-turn conversations."""
if self._persistent_env is not None:
if hasattr(self._persistent_env, "cleanup"):
self._persistent_env.cleanup()
self._persistent_env = None

def __enter__(self) -> "RLM":
return self

def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
self.close()
return False
4 changes: 3 additions & 1 deletion rlm/environments/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Any, Literal

from rlm.environments.base_env import BaseEnv
from rlm.environments.base_env import BaseEnv, SupportsPersistence
from rlm.environments.local_repl import LocalREPL

__all__ = ["BaseEnv", "LocalREPL", "SupportsPersistence", "get_environment"]


def get_environment(
environment: Literal["local", "modal", "docker"],
Expand Down
121 changes: 116 additions & 5 deletions rlm/environments/base_env.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Protocol, runtime_checkable

from rlm.core.types import REPLResult

Expand All @@ -9,7 +10,8 @@ class BaseEnv(ABC):
where isolated environments are on a separate machine from the LM.
"""

def __init__(self, **kwargs):
def __init__(self, persistent: bool = False, **kwargs):
self.persistent = persistent
self.kwargs = kwargs

@abstractmethod
Expand All @@ -31,8 +33,8 @@ class IsolatedEnv(BaseEnv, ABC):
guaranteeing complete isolation from the LM process.
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
def __init__(self, persistent: bool = False, **kwargs):
super().__init__(persistent=persistent, **kwargs)

@abstractmethod
def setup(self):
Expand All @@ -54,8 +56,8 @@ class NonIsolatedEnv(BaseEnv, ABC):
as a subprocess.
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
def __init__(self, persistent: bool = False, **kwargs):
super().__init__(persistent=persistent, **kwargs)

@abstractmethod
def setup(self):
Expand All @@ -68,3 +70,112 @@ def load_context(self, context_payload: dict | list | str):
@abstractmethod
def execute_code(self, code: str) -> REPLResult:
raise NotImplementedError


@runtime_checkable
class SupportsPersistence(Protocol):
"""Protocol for environments that support persistent multi-turn sessions.

CHECKING SUPPORT:
Use isinstance(env, SupportsPersistence) to check if an environment
supports persistence capabilities.

IMPLEMENTING THIS PROTOCOL:
To add persistence to your environment, implement these 5 methods.
See tests/test_local_repl_persistent.py for expected behavior.

VERSIONING BEHAVIOR:
Contexts and histories are versioned with numeric suffixes:
- First context -> context_0, context_1, context_2, ...
- First history -> history_0, history_1, history_2, ...

ALIASING BEHAVIOR:
The unversioned names always point to index 0:
- context -> context_0 (first context)
- history -> history_0 (first history)

EXAMPLE IMPLEMENTATION:
See rlm/environments/local_repl.py for a complete reference.

TESTS:
- Unit tests: tests/test_local_repl_persistent.py
- Integration tests: tests/test_multi_turn_integration.py

Run: uv run pytest tests/test_local_repl_persistent.py -v
"""

def update_handler_address(self, address: tuple[str, int]) -> None:
"""Update the LM handler address for nested LLM calls.

Called by RLM when the handler address changes between completions.
Store the address so llm_query() calls from executed code can reach
the LM handler.

Args:
address: (host, port) tuple for the LM handler server.
"""
...

def add_context(
self, context_payload: dict | list | str, context_index: int | None = None
) -> int:
"""Add a context payload, making it available as context_N in code.

Versioning:
- context_index=None: auto-increment (0, 1, 2, ...)
- context_index=N: use specific index N

Storage:
Must store so executed code can access:
- context_0, context_1, etc. (versioned)
- context (alias to context_0)

Args:
context_payload: The context data (string, dict, or list).
context_index: Optional specific index, or None to auto-increment.

Returns:
The index used (for auto-increment, returns the assigned index).
"""
...

def get_context_count(self) -> int:
"""Return the number of contexts added so far.

Used by RLM to inform the model how many contexts are available.
"""
...

def add_history(
self, message_history: list[dict[str, Any]], history_index: int | None = None
) -> int:
"""Add a message history, making it available as history_N in code.

Versioning:
- history_index=None: auto-increment (0, 1, 2, ...)
- history_index=N: use specific index N

Storage:
Must store so executed code can access:
- history_0, history_1, etc. (versioned)
- history (alias to history_0)

IMPORTANT: Store a deep copy, not a reference. The caller may
modify the list after calling this method.

Args:
message_history: List of message dicts (role, content).
history_index: Optional specific index, or None to auto-increment.

Returns:
The index used.
"""
...

def get_history_count(self) -> int:
"""Return the number of histories added so far.

Used by RLM to inform the model how many conversation histories
are available.
"""
...
13 changes: 9 additions & 4 deletions rlm/environments/docker_repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,14 @@ def __init__(
lm_handler_address: tuple[str, int] | None = None,
context_payload: dict | list | str | None = None,
setup_code: str | None = None,
persistent: bool = False,
**kwargs,
):
super().__init__(**kwargs)
if persistent:
raise NotImplementedError(
"Persistent REPLs are currently not supported for environment: DockerREPL"
)
super().__init__(persistent=persistent, **kwargs)

self.image = image
self.lm_handler_address = lm_handler_address
Expand Down Expand Up @@ -292,13 +297,13 @@ def execute_code(self, code: str) -> REPLResult:
)

def cleanup(self):
if self.container_id:
if hasattr(self, "container_id") and self.container_id:
subprocess.run(["docker", "stop", self.container_id], capture_output=True)
self.container_id = None
if self.proxy_server:
if hasattr(self, "proxy_server") and self.proxy_server:
self.proxy_server.shutdown()
self.proxy_server = None
if os.path.exists(self.temp_dir):
if hasattr(self, "temp_dir") and os.path.exists(self.temp_dir):
import shutil

shutil.rmtree(self.temp_dir, ignore_errors=True)
Expand Down
Loading