From 83893905559f6d7178558c4d7130661575bf79a3 Mon Sep 17 00:00:00 2001 From: Zachary BENSALEM Date: Tue, 2 Jun 2026 06:47:20 +0200 Subject: [PATCH 1/7] feat: expand RLM browser and trace workflows Add browser-aware document/runtime tooling, Daytona broker and sandbox slot recovery, richer MLflow trace context, and frontend rendering for trajectory/tool state. Cover the review follow-ups around SSRF guarding, broker cooldown retry, browser skill selection, scoped trajectory de-duplication, and safe sandbox reconciliation. --- .gitignore | 2 + docs/how-to-guides/dspy-integration.md | 10 + docs/how-to-guides/mlflow-workflows.md | 26 ++ .../reference/frontend-backend-integration.md | 31 ++ scripts/mlflow_cli.py | 62 +++ scripts/validate_rlm_e2e_trace.py | 16 +- src/fleet_rlm/api/routers/ws/session.py | 9 + src/fleet_rlm/api/routers/ws/stream.py | 81 +++- .../api/runtime_services/chat_persistence.py | 32 +- .../api/runtime_services/chat_runtime.py | 3 + .../api/runtime_services/diagnostics.py | 62 ++- .../api/runtime_services/sandboxes.py | 7 + .../api/runtime_services/session_service.py | 9 +- src/fleet_rlm/api/runtime_services/volumes.py | 3 +- src/fleet_rlm/integrations/daytona/bridge.py | 254 +++++++----- .../integrations/daytona/bridge_callbacks.py | 5 +- .../integrations/daytona/concurrency.py | 53 ++- .../integrations/daytona/file_browser.py | 11 +- .../integrations/daytona/interpreter.py | 1 + src/fleet_rlm/integrations/daytona/runtime.py | 74 +++- .../integrations/daytona/sandbox_executor.py | 99 ++++- .../integrations/daytona/snapshots.py | 166 ++++++++ .../observability/auto_assessment.py | 64 ++- .../observability/mlflow_context.py | 298 +++++++++++++- .../observability/mlflow_runtime.py | 45 +- .../observability/mlflow_traces.py | 90 +++- src/fleet_rlm/runtime/agent/runtime.py | 76 +++- src/fleet_rlm/runtime/agent/signatures.py | 21 +- .../runtime/execution/streaming_events.py | 37 +- src/fleet_rlm/runtime/factory.py | 6 + src/fleet_rlm/runtime/modules/escalating.py | 251 +++++++++++- src/fleet_rlm/runtime/modules/factory.py | 34 +- .../runtime/modules/skill_selection.py | 24 ++ .../runtime/modules/variable_mode.py | 44 +- src/fleet_rlm/runtime/tools/binding.py | 55 ++- src/fleet_rlm/runtime/tools/browser_tools.py | 41 ++ src/fleet_rlm/runtime/tools/document_tools.py | 76 +++- src/fleet_rlm/runtime/tools/registry.py | 1 + .../agent-elements/input/model-picker.tsx | 145 +++++++ .../agent-elements/tools/subagent-tool.tsx | 39 ++ .../agent-elements/tools/tool-renderer.tsx | 18 +- ...space-message-list.agent-elements.test.tsx | 113 ++++- .../transcript/workspace-message-list.tsx | 57 ++- .../workspace-agent-input-bar.tsx | 107 ++++- .../workspace/screen/workspace-screen.tsx | 30 +- .../rlm-api/__tests__/ws-frame-parser.test.ts | 25 ++ .../backend-chat-event-adapter.test.ts | 385 +++++++++++++++--- .../workspace/backend-chat-event-adapter.ts | 238 +++++++++-- .../backend-chat-event-tool-parts.ts | 9 + .../backend-chat-event-trajectory.ts | 41 +- tests/unit/api/test_chat_persistence.py | 32 ++ tests/unit/api/test_chat_runtime.py | 33 ++ tests/unit/api/test_events.py | 30 ++ tests/unit/api/test_runtime_diagnostics.py | 131 ++++++ tests/unit/api/test_sandboxes.py | 41 ++ tests/unit/api/test_session_service.py | 33 ++ tests/unit/api/test_volume_services.py | 22 + tests/unit/cli/test_mlflow_cli.py | 96 ++++- .../integrations/test_daytona_concurrency.py | 45 ++ .../unit/integrations/test_daytona_runtime.py | 209 ++++++++++ .../test_daytona_sandbox_executor.py | 198 +++++++++ .../unit/integrations/test_mlflow_context.py | 360 ++++++++++++++++ tests/unit/integrations/test_mlflow_traces.py | 91 +++++ tests/unit/integrations/test_observability.py | 89 +++- tests/unit/runtime/test_escalating_module.py | 258 +++++++++++- tests/unit/runtime/test_modules.py | 80 ++++ tests/unit/runtime/test_skill_selection.py | 14 + tests/unit/runtime/test_tools.py | 144 +++++++ 68 files changed, 4938 insertions(+), 354 deletions(-) create mode 100644 src/fleet_rlm/runtime/tools/browser_tools.py create mode 100644 src/frontend/src/components/agent-elements/input/model-picker.tsx create mode 100644 src/frontend/src/components/agent-elements/tools/subagent-tool.tsx create mode 100644 tests/unit/api/test_chat_runtime.py create mode 100644 tests/unit/api/test_session_service.py create mode 100644 tests/unit/api/test_volume_services.py create mode 100644 tests/unit/integrations/test_daytona_sandbox_executor.py create mode 100644 tests/unit/integrations/test_mlflow_context.py create mode 100644 tests/unit/integrations/test_mlflow_traces.py diff --git a/.gitignore b/.gitignore index 0159f3428..56f9834f4 100644 --- a/.gitignore +++ b/.gitignore @@ -54,6 +54,8 @@ docs/superpowers/ docs/plans/ docs/archive/ +qa-results/ + plan/ # Frontend diff --git a/docs/how-to-guides/dspy-integration.md b/docs/how-to-guides/dspy-integration.md index 3c8eac9f5..6fc4ed6f2 100644 --- a/docs/how-to-guides/dspy-integration.md +++ b/docs/how-to-guides/dspy-integration.md @@ -468,6 +468,16 @@ When MLflow is enabled, RLM execution automatically captures: - Reasoning trajectories - Timing and token usage +For variable-mode `dspy.RLM` runs, Fleet records each REPL/code trajectory step +as a child MLflow `TOOL` span named `repl_execute`, with bounded inputs and +outputs. This keeps the MLflow trace tree aligned with the compact trace rows +rendered in the chat surface. + +Fleet also records an `rlm_available_tools` `LLM` span that advertises the +`repl_execute` schema through `mlflow.chat.tools`. MLflow tool-call judges use +that schema as the available-tool set and then evaluate the concrete +`repl_execute` `TOOL` spans as the calls that actually ran. + ### Optimize with MIPROv2 Use collected traces for DSPy optimization: diff --git a/docs/how-to-guides/mlflow-workflows.md b/docs/how-to-guides/mlflow-workflows.md index ac99a910b..a3b62ed1d 100644 --- a/docs/how-to-guides/mlflow-workflows.md +++ b/docs/how-to-guides/mlflow-workflows.md @@ -99,6 +99,21 @@ Inspect active scorers before debugging unexpected assessment warnings: uv run python scripts/mlflow_cli.py scorers list ``` +Stop a stale scheduled scorer without deleting its registration: + +```bash +# from repo root +uv run python scripts/mlflow_cli.py scorers stop --name "Trace Judge" +``` + +Restart a stopped scorer only after confirming its judge model and trace inputs +are correct: + +```bash +# from repo root +uv run python scripts/mlflow_cli.py scorers start --name "Trace Judge" --sample-rate 1.0 +``` + Remove a stale scorer only as an explicit maintenance action: ```bash @@ -127,6 +142,17 @@ uv run fleet web ``` As you use the app, MLflow traces are recorded in the configured experiment. +For RLM document-analysis and variable-mode runs, Fleet also materializes +trajectory code execution as MLflow `TOOL` spans named `repl_execute`. Those +spans include bounded `mlflow.spanInputs` / `mlflow.spanOutputs` payloads so the +MLflow trace tree, external scorers, and the chat transcript all describe the +same REPL actions. + +Fleet also emits a compact `rlm_available_tools` `LLM` span with +`mlflow.chat.tools` metadata for the RLM REPL. MLflow's built-in tool-call +judges read available tool schemas from `LLM` / `CHAT_MODEL` spans and read +actual calls from `TOOL` spans; keeping both in the trace prevents judges from +falling back to model-based tool extraction. ## 4. Record Human Feedback and Ground Truth diff --git a/docs/reference/frontend-backend-integration.md b/docs/reference/frontend-backend-integration.md index 9c36b50e0..2214fcc65 100644 --- a/docs/reference/frontend-backend-integration.md +++ b/docs/reference/frontend-backend-integration.md @@ -124,6 +124,26 @@ The frontend keeps the following runtime controls aligned with backend requests: - `context_paths` - `batch_concurrency` +When `execution_mode` is `auto`, prompts that combine a public HTTP(S) URL +with documentation-analysis intent (`analyze`, `summarize`, `read`, `docs`, or +`documentation`) route directly to the Daytona-backed RLM document path. The +backend fetches the document through the redirect-validating document helpers +and passes `source_url`, `document_text`, and `source_metadata` as separate +variable-mode `dspy.RLM` inputs. That keeps large documentation bodies in REPL +variables instead of folding them into the prompt text. +`execution_mode="rlm_only"` still forces RLM execution, while +`execution_mode="tools_only"` bypasses the automatic URL-to-RLM route. + +Fleet's RLM prompt envelope follows the Fast-RLM usage pattern for large +variable-mode tasks: repeat the task at the top and bottom, keep bulk data in +REPL variables, make available tools ordinary Python callables, and keep +intermediate printed output bounded. The server runtime settings feed the chat +agent's RLM wrappers directly: + +- `rlm_max_iterations` -> `dspy.RLM(max_iterations=...)` +- `rlm_max_llm_calls` -> `dspy.RLM(max_llm_calls=...)` +- `agent_max_output_chars` -> `dspy.RLM(max_output_chars=...)` + The backend enriches frames with runtime context. The frontend treats these keys as stable when present: @@ -138,6 +158,11 @@ as stable when present: - `sandbox_id` - `workspace_path` - `sandbox_transition` +- `selected_skills` +- `routing_decision` +- `source_url` +- `trajectory_index` +- `rlm_limits` ### Transcript Stream @@ -148,9 +173,15 @@ The frontend reduces frames into: - user and assistant messages - reasoning and trajectory rows - tool and sandbox cards +- selected-skill and routing status rows - HITL / clarification cards - summary rows and warnings +RLM trajectories that include `{reasoning, code, output}` are normalized into +`execution_step` frames with `step.type="repl"`. The transcript renders these +as compact expandable sandbox rows; large code/output payloads stay summarized +in chat while the workbench receives the structured step payload. + The adapter stack is: 1. `ws-frame-parser.ts` normalizes raw websocket frames. diff --git a/scripts/mlflow_cli.py b/scripts/mlflow_cli.py index 13b1a2a10..faaf64fe9 100755 --- a/scripts/mlflow_cli.py +++ b/scripts/mlflow_cli.py @@ -129,6 +129,56 @@ def _delete_scorer(delete_scorer: Any, *, name: str, experiment_id: str | None, delete_scorer(name) +def _get_scorer(mlflow: Any, *, name: str, experiment_id: str | None, version: str | None = None) -> Any: + get_scorer = getattr(getattr(mlflow, "genai", None), "get_scorer", None) + if not callable(get_scorer): + raise RuntimeError("mlflow.genai.get_scorer is not available in this MLflow version.") + + parameters = inspect.signature(get_scorer).parameters + kwargs: dict[str, Any] = {"name": name} + if "experiment_id" in parameters: + kwargs["experiment_id"] = experiment_id + if "version" in parameters and version: + kwargs["version"] = int(version) if str(version).isdigit() else version + return get_scorer(**kwargs) + + +def do_scorers_stop(args: argparse.Namespace) -> int: + mlflow, config, active_experiment_id = _configure_mlflow_tracking() + experiment_id = args.experiment_id or active_experiment_id + scorer = _get_scorer(mlflow, name=args.name, experiment_id=experiment_id) + stop_scorer = getattr(scorer, "stop", None) + if not callable(stop_scorer): + raise RuntimeError("This MLflow scorer does not expose stop().") + stop_scorer(name=args.name, experiment_id=experiment_id) + print(f"stopped_scorer={args.name}") + print(f"experiment={config.experiment}") + print(f"experiment_id={experiment_id or ''}") + return 0 + + +def do_scorers_start(args: argparse.Namespace) -> int: + mlflow, config, active_experiment_id = _configure_mlflow_tracking() + experiment_id = args.experiment_id or active_experiment_id + scorer = _get_scorer(mlflow, name=args.name, experiment_id=experiment_id) + start_scorer = getattr(scorer, "start", None) + if not callable(start_scorer): + raise RuntimeError("This MLflow scorer does not expose start().") + + from mlflow.genai.scorers import ScorerSamplingConfig + + start_scorer( + name=args.name, + experiment_id=experiment_id, + sampling_config=ScorerSamplingConfig(sample_rate=args.sample_rate, filter_string=args.filter_string), + ) + print(f"started_scorer={args.name}") + print(f"sample_rate={args.sample_rate}") + print(f"experiment={config.experiment}") + print(f"experiment_id={experiment_id or ''}") + return 0 + + def do_scorers_delete(args: argparse.Namespace) -> int: if not args.yes: print("Refusing to delete scorer without --yes.") @@ -199,6 +249,18 @@ def main() -> int: psl.add_argument("--experiment-id", default=None) psl.set_defaults(func=do_scorers_list) + pss = scorer_subparsers.add_parser("stop", help="Stop a persisted scorer schedule without deleting it") + pss.add_argument("--name", required=True) + pss.add_argument("--experiment-id", default=None) + pss.set_defaults(func=do_scorers_stop) + + psr = scorer_subparsers.add_parser("start", help="Start or resume a persisted scorer schedule") + psr.add_argument("--name", required=True) + psr.add_argument("--experiment-id", default=None) + psr.add_argument("--sample-rate", type=float, default=1.0) + psr.add_argument("--filter-string", default=None) + psr.set_defaults(func=do_scorers_start) + psd = scorer_subparsers.add_parser("delete", help="Delete a persisted scorer by name") psd.add_argument("--name", required=True) psd.add_argument("--experiment-id", default=None) diff --git a/scripts/validate_rlm_e2e_trace.py b/scripts/validate_rlm_e2e_trace.py index f43ccf742..141fcf9d6 100644 --- a/scripts/validate_rlm_e2e_trace.py +++ b/scripts/validate_rlm_e2e_trace.py @@ -125,6 +125,9 @@ async def _collect_chat_until_terminal( if payload.get("type") == "error": raise RuntimeError(f"Chat websocket error: {payload}") + if payload.get("type") == "execution_completed": + return events, payload + if payload.get("type") != "event": continue kind = payload.get("data", {}).get("kind") @@ -331,8 +334,8 @@ async def _run_validation(args: argparse.Namespace) -> ValidationResult: chat_ws_url = _make_ws_url(args.server_url, "/api/v1/ws/execution") execution_ws_url = _make_ws_url( args.server_url, - "/api/v1/ws/execution", - query=(f"workspace_id={args.workspace_id}&user_id={args.user_id}&session_id={session_id}"), + "/api/v1/ws/execution/events", + query=f"session_id={session_id}", ) async with websockets.connect( @@ -394,9 +397,12 @@ async def _run_validation(args: argparse.Namespace) -> ValidationResult: if event.get("session_id") != session_id: raise RuntimeError("Execution event session_id mismatch.") - terminal_kind = terminal_chat_payload.get("data", {}).get("kind") - if terminal_kind != "final": - raise RuntimeError(f"Terminal chat event kind is {terminal_kind!r}; expected 'final'.") + terminal_kind = terminal_chat_payload.get("data", {}).get("kind") or terminal_chat_payload.get("type") + if terminal_kind not in {"final", "execution_completed"}: + raise RuntimeError( + f"Terminal chat event kind is {terminal_kind!r}; expected 'final' or " + "'execution_completed'." + ) await _persist_artifact_via_command( chat_ws, diff --git a/src/fleet_rlm/api/routers/ws/session.py b/src/fleet_rlm/api/routers/ws/session.py index 6abbb2568..330026f4c 100644 --- a/src/fleet_rlm/api/routers/ws/session.py +++ b/src/fleet_rlm/api/routers/ws/session.py @@ -236,6 +236,15 @@ async def switch_session_if_needed( agent, manifest_path, ) + # Each turn may acquire a different sandbox (pool-based dispatch), + # so the volume on the new sandbox won't have the prior turn's + # manifest. Fall back to the local store when the volume read + # returns nothing. + if not manifest: + manifest = await _restore_manifest_from_local_store( + persistence=persistence, + sess_id=sess_id, + ) else: # No Daytona volume — attempt to restore from local store so that # session history survives process restarts between WS connections. diff --git a/src/fleet_rlm/api/routers/ws/stream.py b/src/fleet_rlm/api/routers/ws/stream.py index 040ff2cc1..9079547fc 100644 --- a/src/fleet_rlm/api/routers/ws/stream.py +++ b/src/fleet_rlm/api/routers/ws/stream.py @@ -9,6 +9,7 @@ from collections.abc import AsyncIterator, Awaitable, Callable from dataclasses import dataclass, field from datetime import datetime, timezone +from types import SimpleNamespace from typing import Any from fastapi import WebSocket, WebSocketDisconnect @@ -20,7 +21,10 @@ from fleet_rlm.integrations.observability.trace_context import ( runtime_telemetry_enabled_context, ) -from fleet_rlm.runtime.execution.streaming_events import is_terminal_stream_event_kind +from fleet_rlm.runtime.execution.streaming_events import ( + _normalize_trajectory, + is_terminal_stream_event_kind, +) from fleet_rlm.utils.logging import sanitize_for_log as _sanitize_for_log from ...dependencies import DiagnosticsDeps, SessionCacheDeps @@ -78,6 +82,36 @@ logger = logging.getLogger(__name__) +def _routing_status_text(payload: dict[str, Any]) -> str: + selected = ", ".join(str(item) for item in payload.get("selected_skills", []) or []) + route = payload.get("routing_decision", "auto") + source = payload.get("source_url") + text = f"Route: {route}" + if selected: + text += f" | skills: {selected}" + if source: + text += f" | source: {source}" + return text + + +def _build_routing_preview_event(agent: ChatAgentProtocol, msg: WSMessage) -> Any | None: + preview_routing = getattr(agent, "preview_routing", None) + if not callable(preview_routing): + return None + payload = preview_routing( + user_request=msg.content, + execution_mode=msg.execution_mode or "auto", + ) + if not isinstance(payload, dict) or not payload.get("routing_decision"): + return None + return SimpleNamespace( + kind="status", + text=_routing_status_text(payload), + payload={**payload, "phase": "routing"}, + timestamp=datetime.now(timezone.utc), + ) + + @dataclass(slots=True) class WorkspaceEvent: """Normalized event shape for websocket streaming.""" @@ -129,6 +163,37 @@ def _runtime_trace_metadata(payload: dict[str, Any] | None) -> dict[str, Any]: runtime = runtime_payload if isinstance(runtime_payload, dict) else {} metadata: dict[str, Any] = {} + for key in ( + "routing_decision", + "source_url", + "execution_mode", + "runtime_module", + ): + value = payload.get(key, runtime.get(key)) + if value not in (None, "", False): + metadata[f"fleet_rlm.{key}"] = value + + selected_skills = payload.get("selected_skills") + if isinstance(selected_skills, list): + metadata["fleet_rlm.selected_skills"] = ",".join(str(item) for item in selected_skills if str(item)) + + trajectory_steps = _normalize_trajectory(payload.get("trajectory")) + if trajectory_steps: + metadata["fleet_rlm.trajectory_steps"] = str(len(trajectory_steps)) + if any(step.get("thought") for step in trajectory_steps): + metadata["fleet_rlm.trajectory_has_reasoning"] = "true" + if any(step.get("tool_name") for step in trajectory_steps): + metadata["fleet_rlm.trajectory_has_tools"] = "true" + if any( + "repl" in str(step.get("tool_name", "")).lower() + or "code" in step + or step.get("type") == "repl" + for step in trajectory_steps + ): + metadata["fleet_rlm.trajectory_has_repl"] = "true" + if any(step.get("output") is not None or step.get("observation") is not None for step in trajectory_steps): + metadata["fleet_rlm.trajectory_has_outputs"] = "true" + for key in ( "runtime_degraded", "runtime_failure_category", @@ -1154,6 +1219,18 @@ async def run(self) -> None: ), }, ) + routing_preview_event = _build_routing_preview_event(self.agent, msg) + if routing_preview_event is not None: + await _try_send_json( + self.websocket, + { + "type": "event", + "data": build_stream_event_dict( + event=routing_preview_event, + payload=routing_preview_event.payload, + ), + }, + ) self.stream_task = asyncio.create_task( _background_execution_task( msg=msg, @@ -1173,6 +1250,8 @@ async def run(self) -> None: cancel_flag=self.session.cancel_flag, local_persist=self.local_persist, lifecycle=self.session.lifecycle, + cancel_active_run=False, + persist_on_disconnect=False, ) except Exception as exc: await handle_chat_loop_exception( diff --git a/src/fleet_rlm/api/runtime_services/chat_persistence.py b/src/fleet_rlm/api/runtime_services/chat_persistence.py index b43892a5f..4f9a9c4c1 100644 --- a/src/fleet_rlm/api/runtime_services/chat_persistence.py +++ b/src/fleet_rlm/api/runtime_services/chat_persistence.py @@ -256,11 +256,19 @@ async def handle_chat_disconnect( cancel_flag: dict[str, bool], local_persist: Callable[..., Awaitable[None]], lifecycle: ExecutionLifecycleManager | None, + cancel_active_run: bool = True, + persist_on_disconnect: bool = True, ) -> None: """Cleanly stop the active websocket loop after a client disconnect.""" - cancel_flag["cancelled"] = True + if cancel_active_run: + cancel_flag["cancelled"] = True await cancel_task(pending_receive_task) - await cancel_task(stream_task) + if cancel_active_run: + await cancel_task(stream_task) + elif stream_task is not None: + stream_task.add_done_callback(_log_background_disconnect_task_result) + if not persist_on_disconnect: + return try: await local_persist( include_volume_save=True, @@ -287,6 +295,16 @@ async def handle_chat_disconnect( await lifecycle.complete_run(RunStatus.CANCELLED) +def _log_background_disconnect_task_result(task: asyncio.Task[Any]) -> None: + """Consume detached execution task failures after the command socket closes.""" + try: + task.result() + except asyncio.CancelledError: + return + except Exception: + logger.warning("Background execution failed after websocket disconnect", exc_info=True) + + # --------------------------------------------------------------------------- # Worker request # --------------------------------------------------------------------------- @@ -1203,9 +1221,13 @@ async def _persist_session_state_impl( "Skipping Daytona volume persistence because cleanup has no active session (path=%s)", active_manifest_path, ) - elif include_volume_save and interpreter is None and persistence is not None: - # No Daytona volume available — fall back to local store so the manifest - # survives process restarts between WebSocket connections. + # Always persist to local store when persistence is available — this is the + # durable fallback that survives sandbox churn. Pool-based dispatch means + # each turn may acquire a *different* Daytona sandbox, so the volume save + # above lands on the current sandbox while the *next* turn's new sandbox + # volume starts empty. The local store is sandbox-independent and bridges + # the gap. We write it regardless of whether a volume save also happened. + if include_volume_save and persistence is not None: sess_id = str(session_record.get("session_id") or "") if sess_id: await _persist_manifest_to_local_store( diff --git a/src/fleet_rlm/api/runtime_services/chat_runtime.py b/src/fleet_rlm/api/runtime_services/chat_runtime.py index c0a45f2a6..2e9643286 100644 --- a/src/fleet_rlm/api/runtime_services/chat_runtime.py +++ b/src/fleet_rlm/api/runtime_services/chat_runtime.py @@ -293,6 +293,9 @@ async def prepare_chat_runtime( def _chat_agent_builder_kwargs(runtime: PreparedChatRuntime) -> dict[str, Any]: return { "react_max_iters": runtime.cfg.react_max_iters, + "rlm_max_iterations": runtime.cfg.rlm_max_iterations, + "rlm_max_llm_calls": runtime.cfg.rlm_max_llm_calls, + "rlm_max_output_chars": runtime.cfg.agent_max_output_chars, "planner_lm": runtime.planner_lm, "delegate_lm": runtime.delegate_lm, "repository": runtime.repository, diff --git a/src/fleet_rlm/api/runtime_services/diagnostics.py b/src/fleet_rlm/api/runtime_services/diagnostics.py index f04f913a4..7ab7aa335 100644 --- a/src/fleet_rlm/api/runtime_services/diagnostics.py +++ b/src/fleet_rlm/api/runtime_services/diagnostics.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import logging import os import time from collections.abc import Awaitable, Callable @@ -13,7 +14,11 @@ from pydantic import ValidationError from fleet_rlm.integrations.daytona import DaytonaConfigError -from fleet_rlm.integrations.daytona.concurrency import get_current_sandbox_usage +from fleet_rlm.integrations.daytona.concurrency import ( + SandboxUsageStats, + get_current_sandbox_usage, + reconcile_sandbox_slots, +) from fleet_rlm.integrations.observability.config import MlflowConfig from ..bootstrap_observability import resolve_mlflow_auto_start_enabled @@ -40,6 +45,8 @@ utc_now_iso, ) +logger = logging.getLogger(__name__) + def resolve_active_model(value: str | None, env_key: str) -> str: direct = (value or "").strip() @@ -67,6 +74,38 @@ def connectivity_result_from_cache( return None +def _status_sandbox_usage() -> SandboxUsageStats: + """Return sandbox slot diagnostics, reconciling obviously stale saturation. + + Normal slot release remains lifecycle-driven. This diagnostic recovery path + only runs when the local semaphore reports at least one active slot; in + that case a provider count is cheap and prevents the Settings/Workbench UI + from showing stale local occupancy after Daytona sandboxes have already + disappeared. + """ + usage = get_current_sandbox_usage() + if usage.active_count <= 0: + return usage + + runtime = None + try: + from fleet_rlm.integrations.daytona.runtime import DaytonaSandboxRuntime + + runtime = DaytonaSandboxRuntime() + provider_active = runtime._count_provider_fleet_sandboxes_sync() + except Exception: + logger.debug("Skipping sandbox slot reconciliation during runtime status", exc_info=True) + return usage + finally: + if runtime is not None: + with suppress(Exception): + runtime.close() + + if provider_active < usage.active_count: + return reconcile_sandbox_slots(provider_active_count=provider_active) + return usage + + def lm_preflight() -> tuple[dict[str, bool], list[str]]: has_model = bool((os.environ.get("DSPY_LM_MODEL") or "").strip()) has_api_key = bool((os.environ.get("DSPY_LLM_API_KEY") or os.environ.get("DSPY_LM_API_KEY") or "").strip()) @@ -358,7 +397,7 @@ def build_runtime_status_response( daytona_checks, daytona_guidance = daytona_preflight( sandbox_provider=config_deps.config.sandbox_provider, ) - sandbox_usage = get_current_sandbox_usage() + sandbox_usage = _status_sandbox_usage() lm_test = connectivity_result_from_cache(diagnostics=diagnostics_deps, kind="lm") daytona_test = connectivity_result_from_cache(diagnostics=diagnostics_deps, kind="daytona") @@ -404,6 +443,22 @@ def build_runtime_status_response( mlflow_enabled=mlflow_cfg.enabled, tracking_uri=mlflow_cfg.tracking_uri, ) + persisted_scorer_names: list[str] = [] + if mlflow_cfg.enabled and not mlflow_cfg.enable_auto_assessment: + try: + from fleet_rlm.integrations.observability.auto_assessment import persisted_scorer_names as _scorer_names + + persisted_scorer_names = _scorer_names(mlflow_cfg) + except Exception: + logger.debug("Failed to inspect MLflow persisted scorers for runtime status.", exc_info=True) + persisted_scorer_names = [] + if persisted_scorer_names: + deduped_guidance.append( + "MLflow has persisted scorer(s) while Fleet auto-assessment is disabled: " + f"{', '.join(persisted_scorer_names)}. These can still assess traces; inspect with " + "`uv run python scripts/mlflow_cli.py scorers list` and stop with " + "`uv run python scripts/mlflow_cli.py scorers stop --name ` if unintended." + ) return RuntimeStatusResponse( app_env=config_deps.config.app_env, @@ -430,6 +485,9 @@ def build_runtime_status_response( mlflow={ "enabled": mlflow_cfg.enabled, "auto_start_enabled": mlflow_auto_start_enabled, + "auto_assessment_enabled": mlflow_cfg.enable_auto_assessment, + "persisted_scorer_count": len(persisted_scorer_names), + "persisted_scorers": persisted_scorer_names, "startup_status": mlflow_startup_status, "startup_error": mlflow_startup_error, }, diff --git a/src/fleet_rlm/api/runtime_services/sandboxes.py b/src/fleet_rlm/api/runtime_services/sandboxes.py index e38eb2711..7a35ae79a 100644 --- a/src/fleet_rlm/api/runtime_services/sandboxes.py +++ b/src/fleet_rlm/api/runtime_services/sandboxes.py @@ -11,6 +11,7 @@ from fleet_rlm.integrations.daytona import config as _daytona_config from fleet_rlm.integrations.daytona import runtime as _daytona_runtime from fleet_rlm.integrations.daytona.async_compat import _run_sync_in_thread +from fleet_rlm.integrations.daytona.concurrency import release_sandbox_slot from fleet_rlm.integrations.observability.sanitization import redact_sensitive from fleet_rlm.utils.sandbox_ownership import ( SANDBOX_OWNER_LABEL, @@ -110,7 +111,10 @@ async def delete_sandbox( sandbox, owner_labels=owner_labels, ) + labels = _sandbox_labels(sandbox) await _management_session(sandbox).adelete() + if labels.get("managed-by") == "fleet-rlm": + release_sandbox_slot() finally: await _close_daytona_client(client) @@ -133,7 +137,10 @@ async def archive_sandbox( sandbox, owner_labels=owner_labels, ) + labels = _sandbox_labels(sandbox) await _management_session(sandbox).aarchive() + if labels.get("managed-by") == "fleet-rlm": + release_sandbox_slot() finally: await _close_daytona_client(client) diff --git a/src/fleet_rlm/api/runtime_services/session_service.py b/src/fleet_rlm/api/runtime_services/session_service.py index a5e26d157..a87243e71 100644 --- a/src/fleet_rlm/api/runtime_services/session_service.py +++ b/src/fleet_rlm/api/runtime_services/session_service.py @@ -147,7 +147,14 @@ def list_session_state( user_id = string_or_default(payload_dict.get("user_id"), "anonymous") session = payload_dict.get("session", {}) session_state = session.get("state", {}) if isinstance(session, Mapping) else {} - history = session_state.get("history", []) if isinstance(session_state, Mapping) else [] + history = [] + if isinstance(session_state, Mapping): + raw_history = session_state.get("history") + raw_turns = session_state.get("turns") + if isinstance(raw_turns, list): + history = raw_turns + elif isinstance(raw_history, list): + history = raw_history documents = session_state.get("documents", {}) if isinstance(session_state, Mapping) else {} summaries.append( SessionStateSummary( diff --git a/src/fleet_rlm/api/runtime_services/volumes.py b/src/fleet_rlm/api/runtime_services/volumes.py index 0476812b5..1c5f8c78c 100644 --- a/src/fleet_rlm/api/runtime_services/volumes.py +++ b/src/fleet_rlm/api/runtime_services/volumes.py @@ -17,6 +17,7 @@ alist_daytona_volumes, aread_daytona_volume_file_text, ) +from fleet_rlm.integrations.daytona.volumes import VFS_CANONICAL_ROOTS from fleet_rlm.utils.identity import sanitize_id as _sanitize_id from ..auth import NormalizedIdentity @@ -33,7 +34,7 @@ VolumeOperation = Callable[[str, str, int], dict[str, Any] | Awaitable[dict[str, Any]]] VolumeTreeOperation = Callable[[str, str, int, int], dict[str, Any] | Awaitable[dict[str, Any]]] -CANONICAL_VOLUME_ROOTS: tuple[str, ...] = ("/memory", "/artifacts", "/buffers", "/meta") +CANONICAL_VOLUME_ROOTS: tuple[str, ...] = tuple(sorted(VFS_CANONICAL_ROOTS)) @dataclass(frozen=True) diff --git a/src/fleet_rlm/integrations/daytona/bridge.py b/src/fleet_rlm/integrations/daytona/bridge.py index 34edc8db5..7d0434095 100644 --- a/src/fleet_rlm/integrations/daytona/bridge.py +++ b/src/fleet_rlm/integrations/daytona/bridge.py @@ -4,10 +4,6 @@ were previously in bridge_assets.py. Uses the synchronous Daytona SDK directly — no async compatibility layer. - -FUTURE: If the Daytona SDK introduces a native callback or event system for -sandbox-to-host communication, this Flask-based broker (~300 LOC) could be -replaced. Monitor the SDK roadmap for webhook/callback infrastructure. """ from __future__ import annotations @@ -15,6 +11,7 @@ import inspect import json import keyword +import logging import time import urllib.error import urllib.request @@ -28,6 +25,8 @@ from .async_compat import _run_sync_in_thread +logger = logging.getLogger(__name__) + # --------------------------------------------------------------------------- # Embedded broker assets # --------------------------------------------------------------------------- @@ -36,122 +35,157 @@ _BROKER_SERVER_PATH = "/home/daytona/broker_server.py" _BROKER_SESSION_COMMAND = f"cd /home/daytona && python {_BROKER_SERVER_PATH.rsplit('/', 1)[-1]}" _BROKER_SERVER_CODE = """ -\"\"\"Broker server for mediating tool calls between sandbox code and the host.\"\"\" +\"\"\"Broker server for mediating tool calls between sandbox code and the host. + +Uses only Python stdlib — no third-party dependencies required. +\"\"\" import json import threading import time import uuid +from http.server import BaseHTTPRequestHandler, HTTPServer +from socketserver import ThreadingMixIn +from urllib.parse import parse_qs, urlparse -from flask import Flask, jsonify, request - -app = Flask(__name__) _lock = threading.Lock() -_pending_requests: dict[str, dict[str, object]] = {} -_results: dict[str, object] = {} - - -@app.route("/health", methods=["GET"]) -def health(): - return jsonify({"status": "ok"}) - - -@app.route("/tool_call", methods=["POST"]) -def tool_call(): - data = request.json or {} - call_id = str(data.get("id") or uuid.uuid4()) - tool_name = str(data.get("tool_name") or "") - args = data.get("args", []) - kwargs = data.get("kwargs", {}) - - with _lock: - _pending_requests[call_id] = { - "tool_name": tool_name, - "args": args if isinstance(args, list) else [], - "kwargs": kwargs if isinstance(kwargs, dict) else {}, - "claimed": False, - "claimed_at": None, - "lease_token": None, - } - - timeout = __DAYTONA_TOOL_CALL_TIMEOUT_S__ - interval = 0.05 - elapsed = 0.0 - while elapsed < timeout: - with _lock: - if call_id in _results: - result = _results.pop(call_id) - _pending_requests.pop(call_id, None) - return jsonify({"result": result}) - time.sleep(interval) - elapsed += interval - - with _lock: - _pending_requests.pop(call_id, None) - return jsonify({"error": "Tool call timeout"}), 504 - - -@app.route("/pending", methods=["GET"]) -def get_pending(): - try: - max_items = int(request.args.get("max", "1")) - except ValueError: - max_items = 1 - max_items = max(1, max_items) - - try: - lease_seconds = float(request.args.get("lease_seconds", "60")) - except ValueError: - lease_seconds = 60.0 - lease_seconds = max(1.0, lease_seconds) - - requests_out = [] - with _lock: - now = time.time() - for call_id, payload in _pending_requests.items(): - if len(requests_out) >= max_items: - break - if call_id in _results: - continue - claimed_at = payload.get("claimed_at") - if payload.get("claimed") and isinstance(claimed_at, (int, float)): - if now - claimed_at < lease_seconds: - continue - claim_token = str(uuid.uuid4()) - payload["claimed"] = True - payload["claimed_at"] = now - payload["lease_token"] = claim_token - requests_out.append( - { - "id": call_id, - "tool_name": payload["tool_name"], - "args": payload["args"], - "kwargs": payload["kwargs"], - "claim_token": claim_token, +_pending_requests: dict = {} +_results: dict = {} + + +def _read_json(handler): + length = int(handler.headers.get("Content-Length", 0)) + return json.loads(handler.rfile.read(length).decode("utf-8")) if length else {} + + +def _send_json(handler, data, status=200): + body = json.dumps(data).encode("utf-8") + handler.send_response(status) + handler.send_header("Content-Type", "application/json") + handler.send_header("Content-Length", str(len(body))) + handler.end_headers() + handler.wfile.write(body) + + +class _BrokerHandler(BaseHTTPRequestHandler): + def log_message(self, format, *args): + pass # suppress default access log + + def do_GET(self): + parsed = urlparse(self.path) + path = parsed.path + + if path == "/health": + _send_json(self, {"status": "ok"}) + + elif path == "/pending": + qs = parse_qs(parsed.query) + try: + max_items = int(qs.get("max", ["1"])[0]) + except ValueError: + max_items = 1 + max_items = max(1, max_items) + try: + lease_seconds = float(qs.get("lease_seconds", ["60"])[0]) + except ValueError: + lease_seconds = 60.0 + lease_seconds = max(1.0, lease_seconds) + + requests_out = [] + with _lock: + now = time.time() + for call_id, payload in _pending_requests.items(): + if len(requests_out) >= max_items: + break + if call_id in _results: + continue + claimed_at = payload.get("claimed_at") + if payload.get("claimed") and isinstance(claimed_at, (int, float)): + if now - claimed_at < lease_seconds: + continue + claim_token = str(uuid.uuid4()) + payload["claimed"] = True + payload["claimed_at"] = now + payload["lease_token"] = claim_token + requests_out.append({ + "id": call_id, + "tool_name": payload["tool_name"], + "args": payload["args"], + "kwargs": payload["kwargs"], + "claim_token": claim_token, + }) + _send_json(self, {"requests": requests_out}) + + else: + _send_json(self, {"error": "not found"}, 404) + + def do_POST(self): + parsed = urlparse(self.path) + path = parsed.path + + if path == "/tool_call": + data = _read_json(self) + call_id = str(data.get("id") or uuid.uuid4()) + tool_name = str(data.get("tool_name") or "") + args = data.get("args", []) + kwargs = data.get("kwargs", {}) + + with _lock: + _pending_requests[call_id] = { + "tool_name": tool_name, + "args": args if isinstance(args, list) else [], + "kwargs": kwargs if isinstance(kwargs, dict) else {}, + "claimed": False, + "claimed_at": None, + "lease_token": None, } - ) - return jsonify({"requests": requests_out}) + timeout = __DAYTONA_TOOL_CALL_TIMEOUT_S__ + interval = 0.05 + elapsed = 0.0 + while elapsed < timeout: + with _lock: + if call_id in _results: + result = _results.pop(call_id) + _pending_requests.pop(call_id, None) + _send_json(self, {"result": result}) + return + time.sleep(interval) + elapsed += interval + + with _lock: + _pending_requests.pop(call_id, None) + _send_json(self, {"error": "Tool call timeout"}, 504) + + elif path.startswith("/result/"): + call_id = path[len("/result/"):] + data = _read_json(self) + result = data.get("result") + claim_token = str(data.get("claim_token") or "") + with _lock: + req = _pending_requests.get(call_id) + if req is None: + _send_json(self, {"error": "Unknown or expired call_id"}, 404) + return + expected_token = req.get("lease_token") + if not expected_token or claim_token != expected_token: + _send_json(self, {"error": "Stale or invalid claim token"}, 409) + return + _results[call_id] = result + req["lease_token"] = None + _send_json(self, {"status": "ok"}) -@app.route("/result/", methods=["POST"]) -def post_result(call_id: str): - data = request.json or {} - result = data.get("result") - claim_token = data.get("claim_token") - with _lock: - req = _pending_requests.get(call_id) - if req is None: - return jsonify({"error": "Unknown or expired call_id"}), 404 - expected_token = req.get("lease_token") - if not expected_token or claim_token != expected_token: - return jsonify({"error": "Stale or invalid claim token"}), 409 - _results[call_id] = result - req["lease_token"] = None - return jsonify({"status": "ok"}) + else: + _send_json(self, {"error": "not found"}, 404) + + +class _ThreadedHTTPServer(ThreadingMixIn, HTTPServer): + daemon_threads = True if __name__ == "__main__": - app.run(host="0.0.0.0", port=3000, threaded=True) + server = _ThreadedHTTPServer(("0.0.0.0", 3000), _BrokerHandler) + server.serve_forever() """.strip() # Default broker tool-call polling timeout (used as fallback when no instance @@ -350,6 +384,12 @@ def ensure_started(self) -> None: return except Exception as exc: last_error = exc + logger.warning( + "Broker start attempt %d/%d failed: %s", + attempt + 1, + self.broker_start_retries + 1, + exc, + ) # Clean up the failed session attempt and reset state so # the next attempt (or the next ensure_started call) starts # from scratch instead of caching a broken broker. diff --git a/src/fleet_rlm/integrations/daytona/bridge_callbacks.py b/src/fleet_rlm/integrations/daytona/bridge_callbacks.py index 999f787d4..11fffac3e 100644 --- a/src/fleet_rlm/integrations/daytona/bridge_callbacks.py +++ b/src/fleet_rlm/integrations/daytona/bridge_callbacks.py @@ -39,9 +39,10 @@ def bridge_tools( ) -> dict[str, Callable[..., Any]]: """Return host callbacks exposed to sandbox bridge wrappers.""" tools = {name: tool for name, tool in interpreter._tools.items() if name not in native_tool_names} - if "llm_query" not in tools: + semantic_callbacks_enabled = bool(getattr(interpreter, "semantic_callbacks_enabled", True)) + if semantic_callbacks_enabled and "llm_query" not in tools: tools["llm_query"] = interpreter.llm_query - if "llm_query_batched" not in tools: + if semantic_callbacks_enabled and "llm_query_batched" not in tools: tools["llm_query_batched"] = interpreter.llm_query_batched if "sub_rlm" not in tools and hasattr(interpreter, "sub_rlm"): tools["sub_rlm"] = interpreter.sub_rlm diff --git a/src/fleet_rlm/integrations/daytona/concurrency.py b/src/fleet_rlm/integrations/daytona/concurrency.py index e4d760e85..4c36886bc 100644 --- a/src/fleet_rlm/integrations/daytona/concurrency.py +++ b/src/fleet_rlm/integrations/daytona/concurrency.py @@ -73,19 +73,35 @@ class SandboxUsageStats(BaseModel): # Module-level semaphore state # --------------------------------------------------------------------------- -_GLOBAL_SEMAPHORE: asyncio.BoundedSemaphore | None = None +class _FleetSandboxSemaphore(asyncio.Semaphore): + """Semaphore with a configurable release bound for reconciled state.""" + + def __init__(self, *, value: int, bound: int) -> None: + super().__init__(value) + self._fleet_bound = bound + + def release(self) -> None: + if self._value >= self._fleet_bound: + raise ValueError("BoundedSemaphore released too many times") + super().release() + + +_GLOBAL_SEMAPHORE: asyncio.Semaphore | None = None _SEMAPHORE_LOCK = threading.Lock() _INITIALIZED_CONFIG: ConcurrencyConfig | None = None -async def _get_global_semaphore() -> asyncio.BoundedSemaphore: +async def _get_global_semaphore() -> asyncio.Semaphore: """Get or initialize the global sandbox semaphore lazily.""" global _GLOBAL_SEMAPHORE, _INITIALIZED_CONFIG if _GLOBAL_SEMAPHORE is None: with _SEMAPHORE_LOCK: if _GLOBAL_SEMAPHORE is None: config = ConcurrencyConfig.from_env() - _GLOBAL_SEMAPHORE = asyncio.BoundedSemaphore(config.max_sandboxes) + _GLOBAL_SEMAPHORE = _FleetSandboxSemaphore( + value=config.max_sandboxes, + bound=config.max_sandboxes, + ) _INITIALIZED_CONFIG = config logger.info( "Initialized global sandbox semaphore with limit=%d", @@ -158,6 +174,37 @@ def release_sandbox_slot_for(sandbox: Any) -> None: _set_sandbox_attr(sandbox, "_fleet_slot_released", True) +def reconcile_sandbox_slots(*, provider_active_count: int) -> SandboxUsageStats: + """Reset local slot accounting from provider-visible Fleet sandbox count. + + This is intentionally a recovery tool, not the normal release path. It is + used after slot acquisition times out and the Daytona provider reports fewer + Fleet-managed sandboxes than the in-process semaphore believes are active. + Waiting acquirers should retry after reconciliation because this replaces + the process-local semaphore instead of mutating its internal counters. + """ + global _GLOBAL_SEMAPHORE, _INITIALIZED_CONFIG + with _SEMAPHORE_LOCK: + if _INITIALIZED_CONFIG is None: + _INITIALIZED_CONFIG = ConcurrencyConfig.from_env() + limit = _INITIALIZED_CONFIG.max_sandboxes + clamped_active = max(0, min(int(provider_active_count), limit)) + available = max(0, limit - clamped_active) + _GLOBAL_SEMAPHORE = _FleetSandboxSemaphore(value=available, bound=limit) + logger.warning( + "Reconciled Fleet sandbox slots from provider state " + "(provider_active=%d, limit=%d, available=%d)", + clamped_active, + limit, + available, + ) + return SandboxUsageStats( + limit=limit, + available_slots=available, + active_count=clamped_active, + ) + + def _set_sandbox_attr(sandbox: Any, name: str, value: Any) -> None: """Set SDK object attributes, bypassing validated assignment when needed.""" try: diff --git a/src/fleet_rlm/integrations/daytona/file_browser.py b/src/fleet_rlm/integrations/daytona/file_browser.py index 6d781a7d5..9b07e2be4 100644 --- a/src/fleet_rlm/integrations/daytona/file_browser.py +++ b/src/fleet_rlm/integrations/daytona/file_browser.py @@ -51,6 +51,13 @@ def _check_vfs_root_allowed(display_path: str) -> None: ) +def _is_allowed_root_child(parent_display_path: str, child_name: str) -> bool: + """Return whether a direct child should be visible from the VFS root.""" + if PurePosixPath(parent_display_path) != PurePosixPath("/"): + return True + return str(PurePosixPath("/") / child_name) in VFS_CANONICAL_ROOTS + + @dataclass(frozen=True) class _ResolvedDaytonaPath: display_path: str @@ -135,6 +142,8 @@ def _walk( name = entry_name(getattr(entry, "name", "") or getattr(entry, "path", "")) if not name: continue + if not _is_allowed_root_child(location.display_path, name): + continue child = _child_daytona_path(location, name) is_dir = bool(getattr(entry, "is_dir", False)) @@ -186,7 +195,7 @@ def _walk( return { "volume_name": volume_name, "root_path": root.display_path, - "allowed_roots": ["/memory", "/artifacts", "/buffers", "/meta"], + "allowed_roots": sorted(VFS_CANONICAL_ROOTS), "nodes": [root_node], "total_files": counters["files"], "total_dirs": counters["dirs"], diff --git a/src/fleet_rlm/integrations/daytona/interpreter.py b/src/fleet_rlm/integrations/daytona/interpreter.py index 74b5b0b77..6fbab4d27 100644 --- a/src/fleet_rlm/integrations/daytona/interpreter.py +++ b/src/fleet_rlm/integrations/daytona/interpreter.py @@ -137,6 +137,7 @@ def __init__( self.broker_health_timeout = max(1.0, float(broker_health_timeout)) self.broker_tool_call_timeout = max(1.0, float(broker_tool_call_timeout)) self.broker_start_retries = max(0, int(broker_start_retries)) + self.semantic_callbacks_enabled = True self.delegate_adapter = delegate_adapter self.child_isolation_metadata: dict[str, Any] | None = None diff --git a/src/fleet_rlm/integrations/daytona/runtime.py b/src/fleet_rlm/integrations/daytona/runtime.py index fcc4d87be..8a44f26d2 100644 --- a/src/fleet_rlm/integrations/daytona/runtime.py +++ b/src/fleet_rlm/integrations/daytona/runtime.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import inspect import logging import re import threading @@ -14,6 +15,7 @@ acquire_sandbox_slot, attach_slot_release_handler, get_current_sandbox_usage, + reconcile_sandbox_slots, release_sandbox_slot, release_sandbox_slot_for, ) @@ -86,6 +88,26 @@ _GENERATED_SANDBOX_NAME_RE = re.compile(r"^fleet-rlm-\d{8}-\d{6}(?:-[0-9a-f]{8})?$") _SANDBOX_NAME_CONFLICT_RETRIES = 3 +_PROVIDER_ACTIVE_STATES = frozenset({"pending", "creating", "starting", "started", "running"}) + +_BROWSER_SKILL_INDICATORS = frozenset({"browser-interaction", "browser_interaction", "browser", "playwright"}) + + +def resolve_snapshot_for_skills(selected_skills: list[str] | None = None) -> str: + """Return the appropriate snapshot name based on selected skills. + + When any selected skill indicates browser interaction, returns the + browser-capable snapshot. Otherwise returns the default base snapshot. + """ + from .snapshots import BROWSER_SNAPSHOT_NAME + + if not selected_skills: + return DEFAULT_SNAPSHOT_NAME + for skill in selected_skills: + if any(indicator in skill.lower() for indicator in _BROWSER_SKILL_INDICATORS): + return BROWSER_SNAPSHOT_NAME + return DEFAULT_SNAPSHOT_NAME + # --------------------------------------------------------------------------- # DaytonaSandboxRuntime @@ -297,6 +319,55 @@ async def acreate_sandbox_from_spec(self, spec: SandboxSpec) -> Any: """ return await _run_sync_in_thread(self._create_sandbox_from_spec_impl, spec) + def _count_provider_fleet_sandboxes_sync(self) -> int: + """Return provider-visible Fleet-managed sandboxes for slot reconciliation.""" + client = self._get_client() + signature = inspect.signature(client.list) + kwargs: dict[str, Any] = {} + if "labels" in signature.parameters: + kwargs["labels"] = self.DEFAULT_LABELS + result = client.list(**kwargs) + raw_items = getattr(result, "items", result) if result else [] + count = 0 + for sandbox in raw_items: + labels = getattr(sandbox, "labels", None) or {} + if not isinstance(labels, dict): + labels = {} + normalized = {str(key): str(value) for key, value in labels.items()} + if not all(normalized.get(key) == value for key, value in self.DEFAULT_LABELS.items()): + continue + raw_state = getattr(sandbox, "state", None) + state = str(getattr(raw_state, "value", raw_state) or "").lower() + if state and state not in _PROVIDER_ACTIVE_STATES: + continue + count += 1 + return count + + async def _count_provider_fleet_sandboxes(self) -> int: + return await _run_sync_in_thread(self._count_provider_fleet_sandboxes_sync) + + async def _acquire_slot_with_reconciliation(self) -> None: + try: + await acquire_sandbox_slot(timeout=60.0) + return + except asyncio.TimeoutError: + before = get_current_sandbox_usage() + + try: + provider_active = await self._count_provider_fleet_sandboxes() + except Exception: + logger.warning("Failed to reconcile sandbox slots from Daytona provider", exc_info=True) + raise asyncio.TimeoutError from None + + if provider_active >= before.active_count: + raise asyncio.TimeoutError + + reconcile_sandbox_slots(provider_active_count=provider_active) + try: + await acquire_sandbox_slot(timeout=1.0) + except asyncio.TimeoutError: + raise asyncio.TimeoutError from None + def create_sandbox( self, volume_name: str | None = None, @@ -327,7 +398,7 @@ async def acreate_sandbox( """ slot_acquired = False try: - await acquire_sandbox_slot(timeout=60.0) + await self._acquire_slot_with_reconciliation() slot_acquired = True except asyncio.TimeoutError as exc: usage = get_current_sandbox_usage() @@ -625,4 +696,5 @@ async def areconcile_workspace_session( "aget_snapshot", "alist_snapshots", "aresolve_snapshot", + "resolve_snapshot_for_skills", ] diff --git a/src/fleet_rlm/integrations/daytona/sandbox_executor.py b/src/fleet_rlm/integrations/daytona/sandbox_executor.py index 29856b660..061ca9970 100644 --- a/src/fleet_rlm/integrations/daytona/sandbox_executor.py +++ b/src/fleet_rlm/integrations/daytona/sandbox_executor.py @@ -37,6 +37,9 @@ ) from .session_runtime import DaytonaSandboxSession +_BROKER_START_FAILURE_COOLDOWN_SECONDS = 300.0 +_BROKER_START_FAILURES: dict[str, tuple[float, str]] = {} + def _generic_submit_code() -> str: return """ @@ -375,6 +378,7 @@ class DaytonaExecutionOwner(SupportsExecutionEventCallback, Protocol): _bridge: DaytonaToolBridge | None _bridge_sandbox_id: str | None _bridge_context_id: str | None + _bridge_start_error: str | None _bridge_tools: Callable[..., Any] _invoke_tool: Callable[..., Any] _reject_unsupported_recursive_callbacks: Callable[..., None] @@ -763,6 +767,73 @@ def structured_execution_error(*, reason: str, error: str) -> DaytonaExecutionRe return DaytonaExecutionResponse(error=payload) +def _broker_failure_key(session: DaytonaSandboxSession) -> str: + return str(getattr(session, "sandbox_id", "") or getattr(session, "id", "") or id(session)) + + +def _cached_broker_start_error(session: DaytonaSandboxSession, *, now: float | None = None) -> str | None: + key = _broker_failure_key(session) + cached = _BROKER_START_FAILURES.get(key) + if cached is None: + return None + timestamp, error = cached + current = time.time() if now is None else now + if current - timestamp > _BROKER_START_FAILURE_COOLDOWN_SECONDS: + _BROKER_START_FAILURES.pop(key, None) + return None + return error + + +def _remember_broker_start_error(session: DaytonaSandboxSession, error: str) -> None: + _BROKER_START_FAILURES[_broker_failure_key(session)] = (time.time(), error) + + +def _clear_broker_start_error(session: DaytonaSandboxSession) -> None: + _BROKER_START_FAILURES.pop(_broker_failure_key(session), None) + + +def _owner_broker_start_error( + owner: DaytonaExecutionOwner, + session: DaytonaSandboxSession, +) -> str | None: + """Return an active broker-start failure and clear stale owner state.""" + cached = _cached_broker_start_error(session) + if cached: + return cached + if getattr(owner, "_bridge_start_error", None): + owner._bridge_start_error = None + return None + + +def _inject_broker_failure_stubs( + session: DaytonaSandboxSession, + context: Any, + tools: dict[str, Any], + *, + error: str, +) -> None: + """Inject stub functions for each bridged tool so the REPL agent gets an + informative RuntimeError instead of a bare NameError when the broker failed. + Best-effort: any injection error is silently suppressed. + """ + if not tools: + return + short_error = error[:200].replace("'", "\\'") + lines = [ + f"def {name}(*_a, **_kw):" + f" raise RuntimeError('Tool {name!r} unavailable: broker failed to start. {short_error}')" + for name in tools + if name.isidentifier() + ] + if not lines: + return + stub_code = "\n".join(lines) + try: + session.sandbox.code_interpreter.run_code(stub_code, context=context) + except Exception: + pass # best-effort + + def run_prepared_execution( owner: DaytonaExecutionOwner, *, @@ -774,11 +845,26 @@ def run_prepared_execution( ) -> DaytonaBridgeExecution: tools = callbacks.bridge_tools() if callbacks.requires_bridge(code, tools): - bridge = callbacks.ensure_bridge( - session=session, - context=context, - tools=tools, - ) + bridge_start_error = _owner_broker_start_error(owner, session) + if bridge_start_error: + raise CodeInterpreterError( + "Broker callbacks are unavailable after a previous startup failure in this session; " + f"llm_query should not be retried. Previous failure: {bridge_start_error}" + ) + try: + bridge = callbacks.ensure_bridge( + session=session, + context=context, + tools=tools, + ) + except Exception as exc: + if "Broker server failed to start" in str(exc): + owner._bridge_start_error = str(exc) + _remember_broker_start_error(session, str(exc)) + _inject_broker_failure_stubs(session, context, tools, error=str(exc)) + raise + owner._bridge_start_error = None + _clear_broker_start_error(session) return bridge.execute_tool_call( code=code, timeout=int(owner.execute_timeout or owner.timeout), @@ -1103,6 +1189,7 @@ def __init__( self._bridge: DaytonaToolBridge | None = None self._bridge_sandbox_id: str | None = None self._bridge_context_id: str | None = None + self._bridge_start_error: str | None = None self._setup_context_id: str | None = None self._setup_workspace_path: str | None = None self._submit_signature_key: tuple[tuple[str, str], ...] | None = None @@ -1126,6 +1213,7 @@ def soft_reset(self) -> None: self._setup_context_id = None self._setup_workspace_path = None self._submit_signature_key = None + self._bridge_start_error = None if self._bridge is not None: self._bridge._injected_tools.clear() self._bridge_context_id = None @@ -1148,6 +1236,7 @@ def close_bridge(self) -> None: self._bridge = None self._bridge_sandbox_id = None self._bridge_context_id = None + self._bridge_start_error = None if bridge is not None: bridge.close() diff --git a/src/fleet_rlm/integrations/daytona/snapshots.py b/src/fleet_rlm/integrations/daytona/snapshots.py index 1c4824d10..97c6a4128 100644 --- a/src/fleet_rlm/integrations/daytona/snapshots.py +++ b/src/fleet_rlm/integrations/daytona/snapshots.py @@ -42,6 +42,20 @@ DEFAULT_SNAPSHOT_BASE_IMAGE = "python:3.12-slim" _VALID_PACKAGE_SPEC_PATTERN = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._\-\[\],<>=!~]*$") +BROWSER_SNAPSHOT_NAME = "fleet-rlm-browser" +BROWSER_SNAPSHOT_PACKAGES: list[str] = [ + *DEFAULT_SNAPSHOT_PACKAGES, + "playwright", +] + +_CHROMIUM_SYSTEM_DEPS = ( + "libx11-6 libxrandr2 libxext6 libxrender1 libxfixes3 libxss1 " + "libxtst6 libxi6 libnss3 libnspr4 libatk1.0-0 libatk-bridge2.0-0 " + "libcups2 libdrm2 libgbm1 libpango-1.0-0 libcairo2 libasound2 " + "libatspi2.0-0 libdbus-1-3 fonts-liberation" +) +_VNC_DESKTOP_DEPS = "xvfb xfce4 xfce4-terminal x11vnc novnc dbus-x11" + def _experimental_call( sandbox: DaytonaSandbox, @@ -96,6 +110,151 @@ def build_base_snapshot_image( return image +def build_browser_snapshot_image( + *, + base_image: str = DEFAULT_SNAPSHOT_BASE_IMAGE, + packages: list[str] | None = None, + include_vnc: bool = True, +) -> Any: + """Build a Daytona image with Playwright, Chromium, and optional VNC/desktop.""" + try: + from daytona import Image as DaytonaImage + except ImportError as exc: # pragma: no cover - environment specific + raise _daytona_import_error(exc) from exc + + packages_to_install = packages if packages is not None else BROWSER_SNAPSHOT_PACKAGES + for package in packages_to_install: + if not package or not _VALID_PACKAGE_SPEC_PATTERN.fullmatch(package): + msg = f"Invalid package spec for browser snapshot image install: {package!r}" + raise ValueError(msg) + + system_deps = _CHROMIUM_SYSTEM_DEPS + if include_vnc: + system_deps += f" {_VNC_DESKTOP_DEPS}" + + image = DaytonaImage.base(base_image).run_commands( + f"apt-get update && apt-get install -y --no-install-recommends {system_deps} && rm -rf /var/lib/apt/lists/*" + ) + image = image.run_commands("pip install uv") + if packages_to_install: + install_command = shlex.join(["uv", "pip", "install", "--system", *packages_to_install]) + image = image.run_commands(install_command) + image = image.run_commands("playwright install chromium") + return image + + +def create_browser_snapshot( + name: str = BROWSER_SNAPSHOT_NAME, + *, + base_image: str = DEFAULT_SNAPSHOT_BASE_IMAGE, + packages: list[str] | None = None, + include_vnc: bool = True, + config: ResolvedDaytonaConfig | None = None, + on_logs: Any | None = None, +) -> dict[str, Any]: + """Create a browser-capable Daytona snapshot with Playwright and Chromium.""" + try: + from daytona.common.snapshot import CreateSnapshotParams + except ImportError as exc: # pragma: no cover - environment specific + raise _daytona_import_error(exc) from exc + + image = build_browser_snapshot_image(base_image=base_image, packages=packages, include_vnc=include_vnc) + params = CreateSnapshotParams(name=name, image=image) + cfg = config or resolve_daytona_config() + client = _build_daytona_client(cfg) + try: + snapshot = client.snapshot.create(params, on_logs=on_logs, timeout=0) + logger.info("Browser snapshot '%s' created (id=%s)", snapshot.name, snapshot.id) + return _snapshot_summary(snapshot) + finally: + with suppress(Exception): + client.close() + + +async def acreate_browser_snapshot( + name: str = BROWSER_SNAPSHOT_NAME, + *, + base_image: str = DEFAULT_SNAPSHOT_BASE_IMAGE, + packages: list[str] | None = None, + include_vnc: bool = True, + config: ResolvedDaytonaConfig | None = None, + on_logs: Any | None = None, +) -> dict[str, Any]: + """Async wrapper — runs blocking SDK call in a thread.""" + return await _run_sync_in_thread( + create_browser_snapshot, + name, + base_image=base_image, + packages=packages, + include_vnc=include_vnc, + config=config, + on_logs=on_logs, + ) + + +def bootstrap_browser_snapshot( + name: str = BROWSER_SNAPSHOT_NAME, + *, + base_image: str = DEFAULT_SNAPSHOT_BASE_IMAGE, + refresh: bool = False, + include_vnc: bool = True, + config: ResolvedDaytonaConfig | None = None, + on_logs: Any | None = None, +) -> dict[str, Any]: + """Ensure the reusable Fleet browser Daytona snapshot exists.""" + cfg = config or resolve_daytona_config() + existing = get_snapshot(name, config=cfg) + if existing is not None and not refresh: + return {**existing, "created": False, "refreshed": False} + if existing is not None: + previous_snapshot_id = str(existing.get("id") or "") or None + delete_snapshot(previous_snapshot_id or str(existing.get("name") or name), config=cfg) + replacement = _wait_for_snapshot_refresh_target( + name, + previous_snapshot_id=previous_snapshot_id, + config=cfg, + ) + if replacement is not None: + return {**replacement, "created": False, "refreshed": True} + + try: + created = create_browser_snapshot( + name=name, + base_image=base_image, + include_vnc=include_vnc, + config=cfg, + on_logs=on_logs, + ) + except Exception as exc: + if _snapshot_create_conflict(exc): + replacement = get_snapshot(name, config=cfg) + if replacement is not None: + return {**replacement, "created": False, "refreshed": existing is not None} + raise + return {**created, "created": True, "refreshed": existing is not None} + + +async def abootstrap_browser_snapshot( + name: str = BROWSER_SNAPSHOT_NAME, + *, + base_image: str = DEFAULT_SNAPSHOT_BASE_IMAGE, + refresh: bool = False, + include_vnc: bool = True, + config: ResolvedDaytonaConfig | None = None, + on_logs: Any | None = None, +) -> dict[str, Any]: + """Async wrapper — runs blocking SDK call in a thread.""" + return await _run_sync_in_thread( + bootstrap_browser_snapshot, + name, + base_image=base_image, + refresh=refresh, + include_vnc=include_vnc, + config=config, + on_logs=on_logs, + ) + + def list_snapshots( config: ResolvedDaytonaConfig | None = None, ) -> list[dict[str, Any]]: @@ -412,10 +571,14 @@ async def acreate_sandbox_snapshot( __all__ = [ + "BROWSER_SNAPSHOT_NAME", + "BROWSER_SNAPSHOT_PACKAGES", "DEFAULT_SNAPSHOT_BASE_IMAGE", "DEFAULT_SNAPSHOT_NAME", "DEFAULT_SNAPSHOT_PACKAGES", + "abootstrap_browser_snapshot", "abootstrap_snapshot", + "acreate_browser_snapshot", "acreate_sandbox_snapshot", "acreate_snapshot", "adelete_snapshot", @@ -423,7 +586,10 @@ async def acreate_sandbox_snapshot( "alist_snapshots", "aresolve_sandbox_spec_snapshot", "aresolve_snapshot", + "bootstrap_browser_snapshot", "build_base_snapshot_image", + "build_browser_snapshot_image", + "create_browser_snapshot", "fallback_to_declarative_image", "resolve_default_snapshot", ] diff --git a/src/fleet_rlm/integrations/observability/auto_assessment.py b/src/fleet_rlm/integrations/observability/auto_assessment.py index 09d99d753..f70be47a7 100644 --- a/src/fleet_rlm/integrations/observability/auto_assessment.py +++ b/src/fleet_rlm/integrations/observability/auto_assessment.py @@ -7,6 +7,7 @@ from __future__ import annotations import logging +import time from collections.abc import Mapping from typing import Any @@ -23,6 +24,8 @@ pass _SCORER_REGISTRY: dict[tuple[str, str | None], Any] = {} +_PERSISTED_SCORER_CACHE: tuple[float, tuple[str, bool], list[str]] | None = None +_PERSISTED_SCORER_CACHE_SECONDS = 30.0 def _scorer_display_name(scorer: Any) -> str: @@ -41,6 +44,31 @@ def _scorer_display_name(scorer: Any) -> str: ) +def _scorer_is_active(scorer: Any) -> bool: + """Return whether a persisted scorer is actively scheduled to evaluate traces.""" + if isinstance(scorer, Mapping): + sample_rate = scorer.get("sample_rate") + status = scorer.get("status") + else: + sample_rate = getattr(scorer, "sample_rate", None) + status = getattr(scorer, "status", None) + + if sample_rate is not None: + try: + return float(sample_rate) > 0 + except (TypeError, ValueError): + pass + + normalized_status = str(status or "").lower() + if normalized_status: + if "stopped" in normalized_status: + return False + if "started" in normalized_status or "active" in normalized_status: + return True + + return True + + def _active_experiment_id(mlflow: Any, config: MlflowConfig) -> str | None: get_experiment_by_name = getattr(mlflow, "get_experiment_by_name", None) if not callable(get_experiment_by_name) or not config.experiment: @@ -54,35 +82,55 @@ def _active_experiment_id(mlflow: Any, config: MlflowConfig) -> str | None: return str(experiment_id) if experiment_id is not None else None -def warn_if_persisted_scorers_active(config: MlflowConfig, *, mlflow: Any | None = None) -> int: - """Warn when MLflow has persisted scorers but Fleet auto-assessment is disabled.""" - if config.enable_auto_assessment: - return 0 +def persisted_scorer_names( + config: MlflowConfig, + *, + mlflow: Any | None = None, + cache_seconds: float = _PERSISTED_SCORER_CACHE_SECONDS, +) -> list[str]: + """Return active persisted MLflow GenAI scorer names for diagnostics.""" + global _PERSISTED_SCORER_CACHE + cache_key = (config.experiment, config.enable_auto_assessment) + now = time.monotonic() + if cache_seconds > 0 and _PERSISTED_SCORER_CACHE is not None: + cached_at, cached_key, cached_names = _PERSISTED_SCORER_CACHE + if cached_key == cache_key and now - cached_at <= cache_seconds: + return list(cached_names) if mlflow is None: try: import mlflow as mlflow_module mlflow = mlflow_module except ImportError: - return 0 + return [] genai = getattr(mlflow, "genai", None) list_scorers = getattr(genai, "list_scorers", None) if not callable(list_scorers): - return 0 + return [] experiment_id = _active_experiment_id(mlflow, config) try: scorers = list_scorers(experiment_id=experiment_id) except Exception: logger.debug("Failed to list MLflow scorers for diagnostics.", exc_info=True) + return [] + scorer_names = [_scorer_display_name(scorer) for scorer in scorers if _scorer_is_active(scorer)] + if cache_seconds > 0: + _PERSISTED_SCORER_CACHE = (now, cache_key, list(scorer_names)) + return scorer_names + + +def warn_if_persisted_scorers_active(config: MlflowConfig, *, mlflow: Any | None = None) -> int: + """Warn when MLflow has persisted scorers but Fleet auto-assessment is disabled.""" + if config.enable_auto_assessment: return 0 - scorer_names = [_scorer_display_name(scorer) for scorer in scorers] + scorer_names = persisted_scorer_names(config, mlflow=mlflow, cache_seconds=0) if not scorer_names: return 0 logger.warning( "MLflow has persisted scorer(s) for experiment %s while Fleet auto-assessment is disabled: %s. " "These scorers can continue to assess traces independently of FLEET_RLM_ENABLE_AUTO_ASSESSMENT. " "Use `uv run python scripts/mlflow_cli.py scorers list` and " - "`uv run python scripts/mlflow_cli.py scorers delete --name --yes` to inspect or remove them.", + "`uv run python scripts/mlflow_cli.py scorers stop --name ` to inspect or stop them.", config.experiment, ", ".join(scorer_names), ) diff --git a/src/fleet_rlm/integrations/observability/mlflow_context.py b/src/fleet_rlm/integrations/observability/mlflow_context.py index 888721717..cfb29dc07 100644 --- a/src/fleet_rlm/integrations/observability/mlflow_context.py +++ b/src/fleet_rlm/integrations/observability/mlflow_context.py @@ -3,6 +3,7 @@ from __future__ import annotations import contextvars +import json import os import uuid from contextlib import contextmanager @@ -27,6 +28,9 @@ class MlflowTraceRequestContext: metadata: dict[str, str] = field(default_factory=dict) total_input_tokens: int = 0 total_output_tokens: int = 0 + final_response_preview: str | None = None + final_trace_metadata: dict[str, Any] = field(default_factory=dict) + emitted_trace_tags: dict[str, str] = field(default_factory=dict) _CURRENT_REQUEST_CONTEXT: contextvars.ContextVar[MlflowTraceRequestContext | None] = contextvars.ContextVar[ @@ -41,6 +45,24 @@ class MlflowTraceRequestContext: ) _TRACE_ID_LOCK = Lock() _TRACE_IDS_BY_CLIENT_REQUEST_ID: dict[str, str] = {} +_TRAJECTORY_VALUE_LIMIT = 8_000 +_RLM_REPL_TOOL_SCHEMA: dict[str, Any] = { + "type": "function", + "function": { + "name": "repl_execute", + "description": "Execute Python code in the Daytona-backed RLM REPL to inspect variables and produce observations.", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "Python code to execute in the sandboxed RLM REPL.", + } + }, + "required": ["code"], + }, + }, +} def _runtime_module(): @@ -74,6 +96,12 @@ def mlflow_request_context(context: MlflowTraceRequestContext): finalize_current_mlflow_trace(state=trace_state) capture_last_active_trace_id() _runtime_module().flush_mlflow_traces() + if context.final_response_preview or context.final_trace_metadata: + update_current_mlflow_trace( + response_preview=context.final_response_preview, + trace_metadata=context.final_trace_metadata, + ) + _runtime_module().flush_mlflow_traces() with _TRACE_ID_LOCK: _TRACE_IDS_BY_CLIENT_REQUEST_ID.pop(context.client_request_id, None) _CURRENT_TRACE_ID.reset(trace_token) @@ -97,6 +125,174 @@ def _trim_preview(value: str | None, *, limit: int = 512) -> str | None: return candidate[: limit - 3].rstrip() + "..." +def _bounded_value(value: Any, *, limit: int = _TRAJECTORY_VALUE_LIMIT) -> Any: + if value is None: + return None + if isinstance(value, str): + if len(value) <= limit: + return value + return value[: limit - 3].rstrip() + "..." + try: + serialized = json.dumps(value, ensure_ascii=False, default=str) + except TypeError: + return _bounded_value(str(value), limit=limit) + if len(serialized) <= limit: + return value + return serialized[: limit - 3].rstrip() + "..." + + +def _flat_trajectory_indices(raw: dict[str, Any]) -> list[int]: + indices: set[int] = set() + for key in raw: + if "_" not in key: + continue + _, suffix = key.rsplit("_", 1) + if suffix.isdigit(): + indices.add(int(suffix)) + return sorted(indices) + + +def _coerce_trajectory_steps(raw: Any) -> list[dict[str, Any]]: + if isinstance(raw, list): + return [dict(step, index=step.get("index", index)) for index, step in enumerate(raw) if isinstance(step, dict)] + + if not isinstance(raw, dict): + return [] + + for key in ("trajectory", "steps"): + nested = raw.get(key) + if isinstance(nested, list): + return [ + dict(step, index=step.get("index", index)) + for index, step in enumerate(nested) + if isinstance(step, dict) + ] + + steps: list[dict[str, Any]] = [] + for index in _flat_trajectory_indices(raw): + step = { + "index": index, + "thought": raw.get(f"thought_{index}") or raw.get(f"reasoning_{index}"), + "tool_name": raw.get(f"tool_name_{index}"), + "tool_args": raw.get(f"tool_args_{index}"), + "observation": raw.get(f"observation_{index}"), + "code": raw.get(f"code_{index}"), + "output": raw.get(f"output_{index}"), + } + if any(value is not None for key, value in step.items() if key != "index"): + steps.append(step) + return steps + + +def _trajectory_span_name(step: dict[str, Any]) -> str | None: + tool_name = str(step.get("tool_name") or step.get("type") or "").strip() + if step.get("code") is not None: + return "repl_execute" + if tool_name: + normalized = tool_name.lower() + if "repl" in normalized or "sandbox" in normalized: + return "repl_execute" + return f"rlm_tool:{tool_name}" + if step.get("output") is not None or step.get("observation") is not None: + return "rlm_observation" + return None + + +def _trajectory_step_failed(step: dict[str, Any]) -> bool: + values = [step.get("output"), step.get("observation")] + for value in values: + if isinstance(value, str): + normalized = value.strip().lower() + if normalized.startswith("[error]") or normalized.startswith("execution error"): + return True + if isinstance(value, dict): + status = str(value.get("status") or "").strip().lower() + if status in {"error", "failed", "failure"}: + return True + return False + + +def _record_rlm_available_tools_span(mlflow: Any, start_span: Any) -> bool: + """Expose the RLM REPL tool schema in the shape MLflow judges inspect.""" + tools = [_RLM_REPL_TOOL_SCHEMA] + attributes = { + "mlflow.chat.tools": json.dumps(tools, ensure_ascii=False), + "fleet_rlm.synthetic_span": "available_tools", + "fleet_rlm.available_tools": "repl_execute", + } + try: + with start_span(name="rlm_available_tools", span_type="LLM", attributes=attributes) as span: + span.set_inputs({"tools": tools}) + span.set_outputs({"available_tools": ["repl_execute"]}) + return True + except Exception: + _runtime_module().logger.debug("MLflow RLM available-tools span recording skipped.", exc_info=True) + return False + + +def record_rlm_trajectory_spans(trajectory: Any) -> int: + """Materialize RLM trajectory tool/REPL steps as child MLflow spans.""" + steps = _coerce_trajectory_steps(trajectory) + if not steps: + return 0 + + runtime = _runtime_module() + mlflow = runtime._import_mlflow() + if mlflow is None or not _has_active_mlflow_trace(mlflow): + return 0 + + start_span = getattr(mlflow, "start_span", None) + if not callable(start_span): + return 0 + + recorded = 0 + if any(_trajectory_span_name(step) is not None for step in steps): + _record_rlm_available_tools_span(mlflow, start_span) + + for step in steps: + span_name = _trajectory_span_name(step) + if span_name is None: + continue + + index = step.get("index") + tool_name = step.get("tool_name") or ("repl_execute" if step.get("code") is not None else span_name) + inputs = { + "tool_name": tool_name, + "tool_args": _bounded_value(step.get("tool_args") or step.get("input")), + "code": _bounded_value(step.get("code")), + } + outputs = { + "observation": _bounded_value(step.get("observation")), + "output": _bounded_value(step.get("output")), + } + attributes = { + "fleet_rlm.trajectory_index": str(index) if index is not None else "", + "fleet_rlm.trajectory_tool_name": str(tool_name), + "fleet_rlm.trajectory_has_code": str(step.get("code") is not None).lower(), + "fleet_rlm.trajectory_has_output": str( + step.get("output") is not None or step.get("observation") is not None + ).lower(), + } + failed = _trajectory_step_failed(step) + if failed: + attributes["fleet_rlm.trajectory_error"] = "true" + thought = step.get("thought") or step.get("reasoning") + if thought: + attributes["fleet_rlm.trajectory_reasoning_preview"] = str(_bounded_value(thought, limit=1_000)) + + try: + with start_span(name=span_name, span_type="TOOL", attributes=attributes) as span: + span.set_inputs({key: value for key, value in inputs.items() if value is not None}) + span.set_outputs({key: value for key, value in outputs.items() if value is not None}) + if failed: + span.set_status("ERROR") + recorded += 1 + except Exception: + runtime.logger.debug("MLflow RLM trajectory span recording skipped.", exc_info=True) + + return recorded + + def _trace_metadata_from_context( context: MlflowTraceRequestContext, ) -> dict[str, str]: @@ -112,6 +308,85 @@ def _trace_metadata_from_context( return metadata +def _trace_tags_from_context( + context: MlflowTraceRequestContext, + *, + metadata: dict[str, str], +) -> dict[str, str]: + """Return queryable Fleet tags for the active trace. + + MLflow trace metadata is useful for immutable context, but local OSS + tracking stores and scorer-generated traces do not always expose custom + metadata consistently in trace search results. Mirroring Fleet-owned + correlation fields into tags keeps trace triage and UI lookups reliable. + """ + tags: dict[str, str] = { + "fleet_rlm.trace_kind": "application", + "fleet_rlm.client_request_id": context.client_request_id, + } + if context.session_id: + tags["fleet_rlm.session_id"] = context.session_id + if context.user_id: + tags["fleet_rlm.user_id"] = context.user_id + if context.app_env: + tags["fleet_rlm.app_env"] = context.app_env + + for key, value in metadata.items(): + if key.startswith("fleet_rlm."): + tags[key] = value + return tags + + +def _apply_trace_tags_by_id(mlflow: Any, trace_id: str | None, tags: dict[str, str]) -> None: + """Best-effort tag update for traces that are no longer current/active.""" + if not trace_id or not tags: + return + set_trace_tag = getattr(mlflow, "set_trace_tag", None) + if not callable(set_trace_tag): + return + for key, value in tags.items(): + try: + set_trace_tag(str(trace_id), str(key), str(value)) + except Exception: + _runtime_module().logger.debug("MLflow trace tag update skipped.", exc_info=True) + + +def _new_active_trace_tags( + context: MlflowTraceRequestContext, + tags: dict[str, str], +) -> dict[str, str]: + """Return tags not yet sent through update_current_trace for this request.""" + new_tags = {key: value for key, value in tags.items() if key not in context.emitted_trace_tags} + context.emitted_trace_tags.update(new_tags) + return new_tags + + +def _resolve_trace_id_by_client_request_id(context: MlflowTraceRequestContext) -> str | None: + """Resolve a completed trace when MLflow no longer exposes it as active.""" + try: + from .mlflow_traces import resolve_trace_by_client_request_id + + trace = resolve_trace_by_client_request_id( + context.client_request_id, + config=_runtime_module().get_mlflow_config(), + max_results=25, + ) + except Exception: + _runtime_module().logger.debug("MLflow trace lookup by client request id skipped.", exc_info=True) + return None + + info = getattr(trace, "info", None) + trace_id = getattr(info, "trace_id", None) or getattr(info, "request_id", None) + if not trace_id: + return None + + context.resolved_trace_id = str(trace_id) + _CURRENT_TRACE_ID.set(str(trace_id)) + with _TRACE_ID_LOCK: + _TRACE_IDS_BY_CLIENT_REQUEST_ID[context.client_request_id] = str(trace_id) + return str(trace_id) + + def _has_active_mlflow_trace(mlflow: Any) -> bool: get_current_active_span = getattr(mlflow, "get_current_active_span", None) if callable(get_current_active_span): @@ -140,13 +415,25 @@ def update_current_mlflow_trace( context = current_request_context() if context is None: return None + if response_preview is not None: + context.final_response_preview = response_preview + if trace_metadata: + context.final_trace_metadata.update(trace_metadata) runtime = _runtime_module() mlflow = runtime._import_mlflow() if mlflow is None: return None if not _has_active_mlflow_trace(mlflow): - return capture_last_active_trace_id() + trace_id = capture_last_active_trace_id() + if trace_id is None: + trace_id = _resolve_trace_id_by_client_request_id(context) + metadata = _trace_metadata_from_context(context) + if trace_metadata: + metadata.update(trace_metadata) + tags = _trace_tags_from_context(context, metadata=metadata) + _apply_trace_tags_by_id(mlflow, trace_id, tags) + return trace_id try: config = runtime.get_mlflow_config() @@ -154,9 +441,12 @@ def update_current_mlflow_trace( metadata = _trace_metadata_from_context(context) if trace_metadata: metadata.update(trace_metadata) + tags = _trace_tags_from_context(context, metadata=metadata) + tags = _new_active_trace_tags(context, tags) mlflow.update_current_trace( client_request_id=context.client_request_id, metadata=metadata, + tags=tags if tags else None, request_preview=_trim_preview(context.request_preview), response_preview=_trim_preview(response_preview), model_id=model_id, @@ -178,7 +468,11 @@ def finalize_current_mlflow_trace(*, state: str) -> str | None: if mlflow is None: return None if not _has_active_mlflow_trace(mlflow): - return capture_last_active_trace_id() + trace_id = capture_last_active_trace_id() + if context is not None: + tags = _trace_tags_from_context(context, metadata=_trace_metadata_from_context(context)) + _apply_trace_tags_by_id(mlflow, trace_id, tags) + return trace_id try: tags: dict[str, str] = {} diff --git a/src/fleet_rlm/integrations/observability/mlflow_runtime.py b/src/fleet_rlm/integrations/observability/mlflow_runtime.py index 0a4cc1e7b..0bd0c3536 100644 --- a/src/fleet_rlm/integrations/observability/mlflow_runtime.py +++ b/src/fleet_rlm/integrations/observability/mlflow_runtime.py @@ -7,6 +7,7 @@ import logging import os import re +import sys from threading import Lock from typing import Any from urllib.parse import urlsplit, urlunsplit @@ -31,6 +32,7 @@ logger = logging.getLogger(__name__) _CLIENT_LOCK = Lock() +_MLFLOW_IMPORT_LOCK = Lock() _INIT_IDENTITY: tuple[Any, ...] | None = None _LAST_INIT_WAS_AUTH_FAILURE = False _ACTIVE_CONFIG: MlflowConfig | None = None @@ -70,12 +72,36 @@ def _mlflow_tracking_auth_identity() -> tuple[Any, ...]: ) +def _clear_partial_mlflow_import() -> None: + """Remove partially initialized MLflow modules after a failed import.""" + + for module_name in list(sys.modules): + if module_name == "mlflow" or module_name.startswith("mlflow."): + sys.modules.pop(module_name, None) + + def _import_mlflow() -> Any | None: - try: - import mlflow - except ImportError: + with _MLFLOW_IMPORT_LOCK: + for attempt in range(2): + try: + import mlflow + + # MLflow circular-import failures can leave a top-level module + # object without expected attributes. Touch a stable attribute so + # the helper either returns a usable module or clears the partial + # import before the request continues. + _ = mlflow.version.VERSION + except ImportError: + return None + except Exception: + _clear_partial_mlflow_import() + if attempt == 0: + continue + logger.warning("Failed to import MLflow after clearing a partial import.", exc_info=True) + return None + else: + return mlflow return None - return mlflow def _sanitize_log_field(value: object) -> str: @@ -486,10 +512,15 @@ def on_lm_end( # closing tag interrupts field parsing and forces an expensive JSONAdapter # retry (adds ~8 s per turn). _THINK_TAG_RE = re.compile(r".*?", re.DOTALL | re.IGNORECASE) +# Some models emit in one token batch and in another, so the +# paired regex above leaves orphaned closing tags after stripping complete pairs. +_ORPHAN_THINK_CLOSE_RE = re.compile(r"", re.IGNORECASE) def _strip_think_tags(text: str) -> str: - return _THINK_TAG_RE.sub("", text).lstrip("\n") + text = _THINK_TAG_RE.sub("", text) + text = _ORPHAN_THINK_CLOSE_RE.sub("", text) + return text.lstrip("\n") class ThinkTagStripCallback(BaseCallback): @@ -580,10 +611,10 @@ def log_trace_feedback( ) -def trace_to_dataset_row(trace: Any) -> dict[str, Any]: +def trace_to_dataset_row(trace: Any, *, config: MlflowConfig | None = None) -> dict[str, Any]: from .mlflow_traces import trace_to_dataset_row as _impl - return _impl(trace) + return _impl(trace, config=config) def search_annotated_trace_rows( diff --git a/src/fleet_rlm/integrations/observability/mlflow_traces.py b/src/fleet_rlm/integrations/observability/mlflow_traces.py index f33908b06..78c422340 100644 --- a/src/fleet_rlm/integrations/observability/mlflow_traces.py +++ b/src/fleet_rlm/integrations/observability/mlflow_traces.py @@ -2,6 +2,7 @@ from __future__ import annotations +import inspect import json from typing import TYPE_CHECKING, Any @@ -24,6 +25,30 @@ def _trace_experiment_ids(config: MlflowConfig) -> list[str]: return [experiment.experiment_id] +def _search_traces( + mlflow: Any, + *, + experiment_ids: list[str], + max_results: int, + return_type: str, + include_spans: bool, + filter_string: str | None = None, +) -> Any: + search_traces = getattr(mlflow, "search_traces") + parameters = inspect.signature(search_traces).parameters + kwargs: dict[str, Any] = { + "filter_string": filter_string, + "max_results": max_results, + "return_type": return_type, + "include_spans": include_spans, + } + if "locations" in parameters: + kwargs["locations"] = experiment_ids + else: + kwargs["experiment_ids"] = experiment_ids + return search_traces(**kwargs) + + def resolve_trace_by_client_request_id( client_request_id: str, *, @@ -41,7 +66,8 @@ def resolve_trace_by_client_request_id( return None try: - traces = mlflow.search_traces( + traces = _search_traces( + mlflow, experiment_ids=experiment_ids, filter_string=(f"trace.client_request_id = '{runtime._mlflow_string_literal(client_request_id)}'"), max_results=max_results, @@ -196,7 +222,45 @@ def _trace_span_types(trace: Trace) -> list[str]: return span_types -def trace_to_dataset_row(trace: Trace) -> dict[str, Any]: +def _assessment_source_field(assessment: dict[str, Any], key: str) -> str: + source = assessment.get("source") + if not isinstance(source, dict): + return "" + value = source.get(key) + return str(value or "").strip() + + +def _skip_external_persisted_scorer_feedback( + assessment: dict[str, Any], + *, + disabled_persisted_scorers: set[str], +) -> bool: + if not disabled_persisted_scorers: + return False + name = str(assessment.get("assessment_name") or "").strip() + if name not in disabled_persisted_scorers: + return False + return _assessment_source_field(assessment, "source_type") == "LLM_JUDGE" + + +def _disabled_persisted_scorer_names(config: MlflowConfig | None) -> set[str]: + resolved = config or runtime.get_mlflow_config() + if resolved.enable_auto_assessment: + return set() + try: + from .auto_assessment import persisted_scorer_names + + return set(persisted_scorer_names(resolved)) + except Exception: + runtime.logger.debug("Failed to inspect persisted MLflow scorers for trace export.", exc_info=True) + return set() + + +def trace_to_dataset_row( + trace: Trace, + *, + config: MlflowConfig | None = None, +) -> dict[str, Any]: """Convert an MLflow trace into an evaluation/export dataset row.""" payload = trace.to_dict() info = payload.get("info", {}) if isinstance(payload, dict) else {} @@ -211,7 +275,22 @@ def trace_to_dataset_row(trace: Trace) -> dict[str, Any]: expectations: dict[str, Any] = {} feedback: dict[str, Any] = {} + skipped_feedback: list[dict[str, str]] = [] + disabled_persisted_scorers = _disabled_persisted_scorer_names(config) for assessment in _trace_assessment_dicts(trace): + if _skip_external_persisted_scorer_feedback( + assessment, + disabled_persisted_scorers=disabled_persisted_scorers, + ): + skipped_feedback.append( + { + "assessment_name": str(assessment.get("assessment_name") or ""), + "source_type": _assessment_source_field(assessment, "source_type"), + "source_id": _assessment_source_field(assessment, "source_id"), + "reason": "persisted_scorer_while_fleet_auto_assessment_disabled", + } + ) + continue name = str(assessment.get("assessment_name") or "assessment") source = assessment.get("source") or {} source_id = source.get("source_id") if isinstance(source, dict) else None @@ -240,6 +319,8 @@ def trace_to_dataset_row(trace: Trace) -> dict[str, Any]: row["span_types"] = span_types if feedback: row["feedback"] = feedback + if skipped_feedback: + row["skipped_feedback"] = skipped_feedback return row @@ -259,7 +340,8 @@ def search_annotated_trace_rows( return [] try: - traces = mlflow.search_traces( + traces = _search_traces( + mlflow, experiment_ids=experiment_ids, max_results=max_results, return_type="list", @@ -274,7 +356,7 @@ def search_annotated_trace_rows( return [] rows: list[dict[str, Any]] = [] for trace in traces: - row = trace_to_dataset_row(trace) + row = trace_to_dataset_row(trace, config=resolved) if row.get("expectations") or row.get("feedback"): rows.append(row) rows.sort( diff --git a/src/fleet_rlm/runtime/agent/runtime.py b/src/fleet_rlm/runtime/agent/runtime.py index 75b34e861..fcd5defc5 100644 --- a/src/fleet_rlm/runtime/agent/runtime.py +++ b/src/fleet_rlm/runtime/agent/runtime.py @@ -95,6 +95,32 @@ def _runtime_degradation_payload(result: Any) -> dict[str, Any]: return payload +def _runtime_routing_payload(result: Any) -> dict[str, Any]: + payload: dict[str, Any] = {} + selected_skills = _prediction_value(result, "selected_skills") + if isinstance(selected_skills, list): + payload["selected_skills"] = [str(item) for item in selected_skills] + routing_decision = _prediction_value(result, "routing_decision") + if routing_decision not in (None, ""): + payload["routing_decision"] = str(routing_decision) + source_url = _prediction_value(result, "source_url") + if source_url not in (None, ""): + payload["source_url"] = str(source_url) + return payload + + +def _routing_status_text(payload: dict[str, Any]) -> str: + selected = ", ".join(payload.get("selected_skills", [])) + route = payload.get("routing_decision", "auto") + source = payload.get("source_url") + text = f"Route: {route}" + if selected: + text += f" | skills: {selected}" + if source: + text += f" | source: {source}" + return text + + def _get_streamable_react_program(program: Any) -> Any | None: react_program = getattr(program, "react", program) @@ -227,6 +253,9 @@ def __init__( *, interpreter: Any | None = None, max_iters: int = 10, + rlm_max_iterations: int | None = None, + rlm_max_llm_calls: int | None = None, + rlm_max_output_chars: int | None = None, history_max_turns: int | None = 6, extra_tools: list[Any] | None = None, repository: Any | None = None, @@ -249,6 +278,9 @@ def __init__( self.execution_mode: str = "auto" self.loaded_document_paths: list[str] = [] self.batch_concurrency: int | None = None + self.rlm_max_iterations = rlm_max_iterations if rlm_max_iterations is not None else max_iters + self.rlm_max_llm_calls = rlm_max_llm_calls if rlm_max_llm_calls is not None else 50 + self.rlm_max_output_chars = rlm_max_output_chars # Conversation summary for context compression (Phase 2) self.conversation_summary: str = "" @@ -273,7 +305,9 @@ def __init__( self.agent: Any = EscalatingFleetModule( interpreter=interpreter, tools=self.tools, - max_iterations=max_iters, + max_iterations=self.rlm_max_iterations, + max_llm_calls=self.rlm_max_llm_calls, + max_output_chars=self.rlm_max_output_chars, summary_interval=summary_interval, ) else: @@ -314,6 +348,14 @@ def _escalation_call_args(self, user_message: str) -> dict[str, Any]: "conversation_summary": self.conversation_summary, } + def preview_routing(self, *, user_request: str, execution_mode: str = "auto") -> dict[str, Any]: + """Expose deterministic route metadata before expensive turn execution.""" + preview_routing = getattr(self.agent, "preview_routing", None) + if not callable(preview_routing): + return {} + payload = preview_routing(user_request=user_request, execution_mode=execution_mode) + return payload if isinstance(payload, dict) else {} + def _runtime_observability_payload(self) -> dict[str, Any]: """Return runtime metadata shared by streamed completion events.""" return { @@ -322,6 +364,11 @@ def _runtime_observability_payload(self) -> dict[str, Any]: "escalation_enabled": self._use_escalation, "conversation_summary_available": bool(self.conversation_summary), "loaded_document_count": len(self.loaded_document_paths), + "rlm_limits": { + "max_iterations": self.rlm_max_iterations, + "max_llm_calls": self.rlm_max_llm_calls, + "max_output_chars": self.rlm_max_output_chars, + }, } def chat_turn(self, user_message: str) -> dspy.Prediction: @@ -364,6 +411,18 @@ async def _aiter_chat_turn_stream_posthoc( return yield StreamEvent(kind="status", text="Starting turn...") + preview_routing = getattr(self.agent, "preview_routing", None) + if callable(preview_routing): + routing_preview = preview_routing( + user_request=message, + execution_mode=self.execution_mode, + ) + if isinstance(routing_preview, dict) and routing_preview.get("routing_decision"): + yield StreamEvent( + kind="status", + text=_routing_status_text(routing_preview), + payload=routing_preview, + ) try: result = await asyncio.to_thread( @@ -390,6 +449,14 @@ async def _aiter_chat_turn_stream_posthoc( trajectory_raw = getattr(result, "trajectory", None) or {} trajectory = _normalize_trajectory(trajectory_raw) degradation_payload = _runtime_degradation_payload(result) + routing_payload = _runtime_routing_payload(result) + + if routing_payload.get("selected_skills") or routing_payload.get("routing_decision"): + yield StreamEvent( + kind="status", + text=_routing_status_text(routing_payload), + payload=routing_payload, + ) for step in trajectory: thought = step.get("thought") @@ -411,6 +478,9 @@ async def _aiter_chat_turn_stream_posthoc( payload={ "tool_name": tool_name, "tool_input": str(tool_args), + "tool_args": tool_args, + "step": step, + "trajectory_index": step.get("index"), }, ) @@ -422,6 +492,9 @@ async def _aiter_chat_turn_stream_posthoc( payload={ "tool_name": tool_name, "tool_output": str(observation), + "output": observation, + "step": step, + "trajectory_index": step.get("index"), }, ) if isinstance(observation, dict) and observation.get("status") == "clarification_needed": @@ -463,6 +536,7 @@ async def _aiter_chat_turn_stream_posthoc( } done_payload.update(self._runtime_observability_payload()) done_payload.update(degradation_payload) + done_payload.update(routing_payload) yield StreamEvent(kind="done", text=response, payload=done_payload) # ----------------------------------------------------------------- diff --git a/src/fleet_rlm/runtime/agent/signatures.py b/src/fleet_rlm/runtime/agent/signatures.py index 8708d5231..1f098b030 100644 --- a/src/fleet_rlm/runtime/agent/signatures.py +++ b/src/fleet_rlm/runtime/agent/signatures.py @@ -60,6 +60,12 @@ class RLMReActChatSignature(dspy.Signature): desc="Persistent memory blocks (Persona, Human, Scratchpad) that define your identity and context" ) history: dspy.History = dspy.InputField(desc="Prior chat turns using keys user_message and response") + recent_history: str = dspy.InputField( + desc=( + "Compact recent chat transcript, oldest to newest. The final listed turn is the most " + "recent prior exchange and should dominate recency-sensitive follow-up answers." + ) + ) assistant_response: str = dspy.OutputField(desc="Final assistant response to user") @@ -439,16 +445,21 @@ class RLMVariableSignature(dspy.Signature): class RLMLargeDocSignature(dspy.Signature): - """Fetch and process an oversized URL document using the REPL. + """Process an oversized URL document using the REPL. All input fields are stored as REPL variables — the LLM sees only - metadata and writes Python to stream-fetch the URL, chunk it, and call - ``sub_rlm()`` per chunk. The ``history`` variable provides session - context so the LLM can target extraction at what the user actually needs. + metadata and writes Python to inspect ``document_text`` and chunk it. The + dedicated URL path avoids automatic recursive child-sandbox delegation; use + built-in ``llm_query()`` for bounded semantic passes and ``SUBMIT(...)`` for + the final answer. The ``history`` variable provides session context so the + LLM can target extraction at what the user actually needs. """ task: str = dspy.InputField(desc="Instruction for how to process the document") - prompt: str = dspy.InputField(desc="The URL to fetch (stored as REPL variable)") + prompt: str = dspy.InputField(desc="Brief task framing and any compressed conversation context") + source_url: str = dspy.InputField(desc="Canonical fetched source URL") + document_text: str = dspy.InputField(desc="Extracted document text stored as a REPL variable") + source_metadata: dict[str, str] = dspy.InputField(desc="Source metadata such as char count and fetch status") history: dspy.History = dspy.InputField( desc="Prior chat turns for user intent context (keys: user_request, assistant_response)" ) diff --git a/src/fleet_rlm/runtime/execution/streaming_events.py b/src/fleet_rlm/runtime/execution/streaming_events.py index f8a584181..e41f973d6 100644 --- a/src/fleet_rlm/runtime/execution/streaming_events.py +++ b/src/fleet_rlm/runtime/execution/streaming_events.py @@ -225,6 +225,39 @@ def _build_flat_trajectory_step(raw: dict[str, Any], index: int) -> dict[str, An return step +def _truncate_trajectory_output(step: dict[str, Any]) -> dict[str, Any]: + if step.get("output_truncated"): + return step + output = step.get("output") + if isinstance(output, str) and len(output) > _TRAJECTORY_OUTPUT_CONTENT_CHARS: + preview, full_len = head_tail_preview(output, max_chars=_TRAJECTORY_OUTPUT_CONTENT_CHARS) + step["output"] = preview + step["observation"] = preview + step["output_truncated"] = True + step["output_length"] = full_len + return step + + +def _normalize_structured_trajectory_step(step: dict[str, Any], index: int) -> dict[str, Any]: + step_copy = dict(step) + if "index" not in step_copy: + step_copy["index"] = index + + # DSPy RLM trajectories are shaped as {reasoning, code, output}. Convert + # them into Fleet's generic tool/repl shape so websocket and frontend + # adapters can render them as visible sandbox execution rows. + if "code" in step_copy and "tool_name" not in step_copy: + step_copy["tool_name"] = "repl_execute" + step_copy.setdefault("input", step_copy.get("code")) + step_copy.setdefault("tool_args", step_copy.get("code")) + if "reasoning" in step_copy and "thought" not in step_copy: + step_copy["thought"] = step_copy.get("reasoning") + if "output" in step_copy and "observation" not in step_copy: + step_copy["observation"] = step_copy.get("output") + + return _truncate_trajectory_output(step_copy) + + def _normalize_trajectory(raw: Any | None) -> list[dict[str, Any]]: """Convert DSPy ReAct flat trajectory to structured step list.""" if not raw: @@ -241,10 +274,10 @@ def _normalize_trajectory(raw: Any | None) -> list[dict[str, Any]]: steps = [_build_flat_trajectory_step(raw, index) for index in _extract_step_indices(raw)] result: list[dict[str, Any]] = [] - for step in steps: + for index, step in enumerate(steps): if not isinstance(step, dict): continue - step_copy = dict(step) + step_copy = _normalize_structured_trajectory_step(step, index) tool_name = step_copy.get("tool_name") is_terminal = (tool_name == "finish") or (not tool_name) if is_terminal and "thought" in step_copy: diff --git a/src/fleet_rlm/runtime/factory.py b/src/fleet_rlm/runtime/factory.py index 3a6857006..cf3a265fa 100644 --- a/src/fleet_rlm/runtime/factory.py +++ b/src/fleet_rlm/runtime/factory.py @@ -33,6 +33,9 @@ def build_chat_agent( *, docs_path: Path | str | None = None, react_max_iters: int = 15, + rlm_max_iterations: int | None = None, + rlm_max_llm_calls: int | None = None, + rlm_max_output_chars: int | None = None, history_max_turns: int | None = None, extra_tools: list[Callable[..., Any]] | None = None, env_file: Path | None = None, @@ -57,6 +60,9 @@ def build_chat_agent( agent = AgentRuntime( interpreter=interpreter, max_iters=react_max_iters, + rlm_max_iterations=rlm_max_iterations, + rlm_max_llm_calls=rlm_max_llm_calls, + rlm_max_output_chars=rlm_max_output_chars, history_max_turns=history_max_turns, extra_tools=extra_tools, repository=repository, diff --git a/src/fleet_rlm/runtime/modules/escalating.py b/src/fleet_rlm/runtime/modules/escalating.py index 8ff3782ff..9e7dc5584 100644 --- a/src/fleet_rlm/runtime/modules/escalating.py +++ b/src/fleet_rlm/runtime/modules/escalating.py @@ -13,6 +13,8 @@ from __future__ import annotations import logging +import re +from dataclasses import dataclass, field from typing import Any import dspy @@ -26,12 +28,65 @@ ESCALATION_SENTINEL = "[TOOLS NEEDED]" _RLM_FALLBACK_WARNING = "RLM escalation failed; returned a lightweight fallback response." +_URL_DOCUMENT_MAX_ITERATIONS = 4 +_URL_DOCUMENT_MAX_LLM_CALLS = 8 +_URL_RE = re.compile(r"https?://[^\s)\],;]+", flags=re.IGNORECASE) +_URL_DOCUMENT_ANALYSIS_TERMS = ( + "analyze", + "analyse", + "analysis", + "summarize", + "summarise", + "summary", + "read", + "documentation", + "docs", + "document", + "page", +) +_RLM_REPL_GUIDANCE = """RLM REPL guidance: +- Keep the task visible: solve the task stated at the top and repeated at the bottom of this prompt. +- Use Python variables instead of printing large inputs. Inspect slices, lengths, keywords, and structure with code. +- Treat available tools as ordinary Python callables. Their type hints and docstrings are the contract. +- For documentation URLs, first inspect headings, links, llms.txt, sitemap entries, and section samples with Python. Do not send an entire document to one semantic callback. +- If semantic callbacks such as llm_query are unavailable, finish from Python document inspection. +- Keep intermediate output bounded; print summaries or small samples, then call SUBMIT(...) for the final answer. +- Do not print or return credentials, environment variables, or hidden configuration values. +""" + + +@dataclass(slots=True) +class _FetchedUrlDocument: + """Fetched URL document payload passed into DSPy RLM as REPL variables.""" + + source_url: str + document_text: str = "" + source_metadata: dict[str, str] = field(default_factory=dict) def _is_rlm_execution_mode(execution_mode: str) -> bool: return execution_mode in {"rlm", "rlm_only"} +def _extract_first_url(text: str) -> str | None: + match = _URL_RE.search(text) + return match.group(0).rstrip(".,;]") if match else None + + +def _is_url_document_analysis_request(text: str) -> bool: + if _extract_first_url(text) is None: + return False + lowered = text.lower() + return any(term in lowered for term in _URL_DOCUMENT_ANALYSIS_TERMS) + + +def _prediction_set(prediction: dspy.Prediction, key: str, value: Any) -> None: + try: + prediction[key] = value + except Exception: + object.__setattr__(prediction, key, value) + + def _history_value(message: Any, *keys: str) -> str: if isinstance(message, dict): for key in keys: @@ -56,6 +111,59 @@ def _format_history_turn(message: Any) -> str: return "\n".join(parts) +def _format_recent_history_context(history: dspy.History, *, max_turns: int = 4) -> str: + """Return an explicit recency-ordered history view for model inputs.""" + messages = list(getattr(history, "messages", []) or []) + if not messages: + return "" + + recent = messages[-max_turns:] + lines = [ + "Recent chat history, ordered oldest to newest.", + "The final listed turn is the most recent prior user/assistant exchange.", + ] + for index, message in enumerate(recent, start=1): + turn = _format_history_turn(message) + if not turn: + continue + marker = "most recent prior turn" if index == len(recent) else f"prior turn {index}" + lines.append(f"[{marker}]") + lines.append(turn) + return "\n".join(lines) + + +def _build_rlm_prompt_context( + *, + user_request: str, + recent_history: str, + compressed_history: str, + core_memory: str, + url_document_mode: bool, +) -> str: + """Build a Fast-RLM-style prompt envelope for variable-mode DSPy RLM.""" + sections = [ + "Task:\n" + user_request, + _RLM_REPL_GUIDANCE, + ] + if url_document_mode: + sections.append( + "URL document variables:\n" + "- source_url: canonical fetched URL string.\n" + "- document_text: extracted source text; inspect it with Python rather than printing it wholesale.\n" + "- source_metadata: fetch status, source metadata, and any bundled llms.txt/sitemap companions.\n" + "- llm_query and llm_query_batched are disabled in this URL-document path; synthesize from Python inspection.\n" + "- history: structured dspy.History for prior turns." + ) + if recent_history: + sections.append(recent_history) + if compressed_history: + sections.append("Compressed conversation context:\n" + compressed_history) + if core_memory: + sections.append("Core memory and active skill guidance:\n" + core_memory) + sections.append("Repeat task:\n" + user_request) + return "\n\n".join(section for section in sections if section.strip()) + + class EscalatingFleetModule(dspy.Module): """Unified DSPy Module that scales from lightweight chat to full RLM execution. @@ -74,6 +182,8 @@ class EscalatingFleetModule(dspy.Module): Maximum RLM iterations for the heavy path. max_llm_calls: Maximum LLM calls for the heavy path. + max_output_chars: + Maximum REPL output characters exposed back to the RLM per step. verbose: Pass ``verbose=True`` to the inner RLM for debug output. sub_lm: @@ -89,6 +199,7 @@ def __init__( tools: list[Any] | None = None, max_iterations: int = 20, max_llm_calls: int = 50, + max_output_chars: int | None = None, verbose: bool = False, sub_lm: dspy.LM | None = None, summary_interval: int = 10, @@ -99,16 +210,21 @@ def __init__( self._turn_count = 0 from fleet_rlm.runtime.modules.skill_selection import SkillSelectionModule + from fleet_rlm.runtime.tools._volume_paths import volume_root - volume_mount_path = getattr(interpreter, "volume_mount_path", None) if interpreter else None + volume_mount_path = getattr(interpreter, "volume_mount_path", None) if interpreter is not None else None + if volume_mount_path is None: + resolved = volume_root() + volume_mount_path = str(resolved) if resolved is not None else None self._skill_selector = SkillSelectionModule(volume_mount_path=volume_mount_path) self.respond = dspy.ChainOfThought(RLMReActChatSignature) self.summarize = dspy.ChainOfThought(ConversationSummarySignature) self._rlm: dspy.Module | None = None + self._url_document_rlm: dspy.Module | None = None if interpreter is not None: - from fleet_rlm.runtime.agent.signatures import RLMVariableSignature + from fleet_rlm.runtime.agent.signatures import RLMLargeDocSignature, RLMVariableSignature from fleet_rlm.runtime.modules.variable_mode import build_variable_mode_rlm self._rlm = build_variable_mode_rlm( @@ -116,10 +232,23 @@ def __init__( interpreter=interpreter, max_iterations=max_iterations, max_llm_calls=max_llm_calls, + max_output_chars=max_output_chars, verbose=verbose, sub_lm=sub_lm, extra_tools=tools or [], ) + self._url_document_rlm = build_variable_mode_rlm( + signature=RLMLargeDocSignature, + interpreter=interpreter, + max_iterations=max(1, min(max_iterations, _URL_DOCUMENT_MAX_ITERATIONS)), + max_llm_calls=max(1, min(max_llm_calls, _URL_DOCUMENT_MAX_LLM_CALLS)), + max_output_chars=max_output_chars, + verbose=verbose, + sub_lm=sub_lm, + extra_tools=[], + include_sub_tools=False, + include_llm_tools=False, + ) else: self._rlm = dspy.ChainOfThought(RLMReActChatSignature) @@ -137,6 +266,18 @@ def _should_escalate( reasoning = str(getattr(prediction, "reasoning", "") or "") return ESCALATION_SENTINEL in reasoning + def preview_routing(self, *, user_request: str, execution_mode: str = "auto") -> dict[str, Any]: + """Return deterministic routing metadata available before the full turn.""" + if execution_mode == "auto" and _is_url_document_analysis_request(user_request): + source_url = _extract_first_url(user_request) + payload: dict[str, Any] = {"routing_decision": "url_document_rlm"} + if source_url: + payload["source_url"] = source_url + return payload + if _is_rlm_execution_mode(execution_mode): + return {"routing_decision": "forced_rlm"} + return {} + def compress_history(self, history: dspy.History) -> str: """Return a compressed text summary of the given history.""" messages = list(getattr(history, "messages", []) or []) @@ -154,18 +295,20 @@ def compress_history(self, history: dspy.History) -> str: logger.warning("Conversation summary failed, returning truncated history: %s", exc) return history_text[-4000:] - def _enrich_with_skills(self, user_request: str, core_memory: str) -> str: + def _enrich_with_skills(self, user_request: str, core_memory: str) -> tuple[str, list[str]]: """Select relevant skills and append their instructions to core_memory.""" try: selection = self._skill_selector(user_request=user_request, core_memory=core_memory) skill_context = str(getattr(selection, "skill_context", "") or "") + selected = [str(item) for item in list(getattr(selection, "selected_skills", []) or [])] if skill_context: - selected = getattr(selection, "selected_skills", []) logger.debug("SkillSelection: injected %s", selected) - return f"{core_memory}\n\n[Active Skills]\n{skill_context}" if core_memory else skill_context + enriched = f"{core_memory}\n\n[Active Skills]\n{skill_context}" if core_memory else skill_context + return enriched, selected + return core_memory, selected except Exception as exc: logger.debug("SkillSelection: skipped (%s)", exc) - return core_memory + return core_memory, [] def forward( self, @@ -201,21 +344,29 @@ def forward( self._turn_count += 1 - core_memory = self._enrich_with_skills(user_request, core_memory) + core_memory, selected_skills = self._enrich_with_skills(user_request, core_memory) + recent_history = _format_recent_history_context(history) + should_auto_route_url = execution_mode == "auto" and _is_url_document_analysis_request(user_request) + source_url = _extract_first_url(user_request) if should_auto_route_url else None - if _is_rlm_execution_mode(execution_mode) or force_escalate: + if _is_rlm_execution_mode(execution_mode) or force_escalate or should_auto_route_url: logger.debug("EscalatingFleetModule: forced RLM path (mode=%s)", execution_mode) return self._run_rlm( user_request=user_request, core_memory=core_memory, history=history, + recent_history=recent_history, conversation_summary=conversation_summary, + selected_skills=selected_skills, + routing_decision="url_document_rlm" if should_auto_route_url else "forced_rlm", + source_url=source_url, ) prediction = self.respond( user_request=user_request, core_memory=core_memory, history=history, + recent_history=recent_history, ) if self._should_escalate(prediction, execution_mode=execution_mode, force_escalate=False): @@ -224,9 +375,14 @@ def forward( user_request=user_request, core_memory=core_memory, history=history, + recent_history=recent_history, conversation_summary=conversation_summary, + selected_skills=selected_skills, + routing_decision="sentinel_rlm", + source_url=None, ) + _prediction_set(prediction, "selected_skills", selected_skills) return prediction def _run_rlm( @@ -235,26 +391,53 @@ def _run_rlm( user_request: str, core_memory: str, history: dspy.History, + recent_history: str, conversation_summary: str, + selected_skills: list[str] | None = None, + routing_decision: str = "rlm", + source_url: str | None = None, ) -> dspy.Prediction: if self._rlm is None: return self.respond( user_request=user_request, core_memory=core_memory, history=history, + recent_history=recent_history, ) context = conversation_summary or self.compress_history(history) + rlm = self._url_document_rlm if source_url and self._url_document_rlm is not None else self._rlm + url_document_mode = bool(source_url and rlm is self._url_document_rlm) + prompt = _build_rlm_prompt_context( + user_request=user_request, + recent_history=recent_history, + compressed_history=context, + core_memory=core_memory, + url_document_mode=url_document_mode, + ) + call_kwargs: dict[str, Any] = { + "task": user_request, + "prompt": prompt, + } + if url_document_mode: + fetched = self._fetch_url_document(source_url=source_url) + call_kwargs["source_url"] = fetched.source_url + call_kwargs["document_text"] = fetched.document_text + call_kwargs["source_metadata"] = fetched.source_metadata + call_kwargs["history"] = history try: - return self._rlm( - task=user_request, - prompt=context or core_memory or user_request, - ) + result = rlm(**call_kwargs) + _prediction_set(result, "selected_skills", selected_skills or []) + _prediction_set(result, "routing_decision", routing_decision) + if source_url: + _prediction_set(result, "source_url", source_url) + return result except Exception as exc: logger.warning("EscalatingFleetModule: RLM path failed (%s), falling back to ChainOfThought", exc) fallback = self.respond( user_request=user_request, core_memory=core_memory, history=history, + recent_history=recent_history, ) fallback["degraded"] = True fallback["warning"] = _RLM_FALLBACK_WARNING @@ -263,8 +446,52 @@ def _run_rlm( fallback["runtime_failure_phase"] = "escalating_rlm" fallback["runtime_fallback_used"] = True fallback["runtime_warning"] = _RLM_FALLBACK_WARNING + fallback["selected_skills"] = selected_skills or [] + fallback["routing_decision"] = routing_decision + if source_url: + fallback["source_url"] = source_url return fallback + def _fetch_url_document(self, *, source_url: str) -> _FetchedUrlDocument: + if self._interpreter is None: + return _FetchedUrlDocument( + source_url=source_url, + source_metadata={"status": "not_fetched", "reason": "interpreter_unavailable"}, + ) + try: + from fleet_rlm.runtime.tools.document_tools import fetch_document_text + + fetched = fetch_document_text(source_url) + except Exception as exc: + return _FetchedUrlDocument( + source_url=source_url, + source_metadata={"status": "error", "error": str(exc)}, + ) + + if fetched.get("status") != "ok": + return _FetchedUrlDocument( + source_url=source_url, + source_metadata={ + "status": "error", + "error": str(fetched.get("error", "unknown error")), + }, + ) + + text = str(fetched.get("text") or "") + char_count = fetched.get("char_count", len(text)) + raw_metadata = fetched.get("metadata") + metadata: dict[str, str] = { + "status": "ok", + "char_count": str(char_count), + } + if isinstance(raw_metadata, dict): + metadata.update({str(key): str(value) for key, value in raw_metadata.items()}) + return _FetchedUrlDocument( + source_url=source_url, + document_text=text, + source_metadata=metadata, + ) + __all__ = [ "ESCALATION_SENTINEL", diff --git a/src/fleet_rlm/runtime/modules/factory.py b/src/fleet_rlm/runtime/modules/factory.py index 97f803766..91a148fec 100644 --- a/src/fleet_rlm/runtime/modules/factory.py +++ b/src/fleet_rlm/runtime/modules/factory.py @@ -8,6 +8,36 @@ import dspy +class _NoCallbackRLM(dspy.RLM): + """RLM variant for REPL-only tasks where host semantic callbacks are disabled.""" + + def _build_signatures(self) -> tuple[Any, Any]: + action_sig, extract_sig = super()._build_signatures() + instructions = str(action_sig.instructions) + instructions = instructions.replace( + "- `llm_query(prompt)` - query a sub-LLM (~500K char capacity) for semantic analysis\n", + "", + ).replace( + "- `llm_query_batched(prompts)` - query multiple prompts concurrently (much faster for multiple queries)\n", + "", + ) + instructions = instructions.replace( + "4. USE llm_query FOR SEMANTICS - String matching finds WHERE things are; " + "llm_query understands WHAT things mean.", + "4. USE PYTHON INSPECTION - Extract headings, links, counts, samples, and sections with code; " + "semantic callbacks are disabled for this run.", + ) + instructions = instructions.replace( + f"You have max {self.max_llm_calls} sub-LLM calls. When done, call SUBMIT() with your output.", + "Semantic callbacks are disabled. When done, call SUBMIT() with your output.", + ) + return action_sig.with_instructions(instructions), extract_sig + + def _make_llm_tools(self, max_workers: int = 8) -> dict[str, Any]: + _ = max_workers + return {} + + def create_runtime_rlm( *, signature: type[dspy.Signature], @@ -18,6 +48,7 @@ def create_runtime_rlm( verbose: bool, tools: list[Any] | None = None, sub_lm: dspy.LM | None = None, + include_llm_tools: bool = True, ) -> dspy.Module: """Create a canonical RLM instance for a runtime signature.""" @@ -35,7 +66,8 @@ def create_runtime_rlm( if sub_lm is not None: kwargs["sub_lm"] = sub_lm - return dspy.RLM( + rlm_cls = dspy.RLM if include_llm_tools else _NoCallbackRLM + return rlm_cls( **kwargs, ) diff --git a/src/fleet_rlm/runtime/modules/skill_selection.py b/src/fleet_rlm/runtime/modules/skill_selection.py index 80539b7f1..d82e9618e 100644 --- a/src/fleet_rlm/runtime/modules/skill_selection.py +++ b/src/fleet_rlm/runtime/modules/skill_selection.py @@ -27,6 +27,7 @@ "optimization": "GEPA/MIPROv2 optimization, evaluation metrics, MLflow", "diagnostics": "Diagnose runtime failures, contract drift, test triage", "volume-bootstrap": "Volume filesystem structure, CRUD helpers, persistence guarantees", + "browser-interaction": "Rendered page fetching, JavaScript-heavy docs, Playwright browser inspection", } _KEYWORD_MAP: dict[str, list[str]] = { @@ -72,6 +73,29 @@ "semantic chunk", "variable-mode", "codebase analysis", + "analyze the documentation", + "analyze documentation", + "read the docs", + "summarize the page", + "http://", + "https://", + "fetch", + "browse", + "scrape", + "analyze https", + "analyze http", + ], + "browser-interaction": [ + "render the page", + "rendered page", + "javascript page", + "spa content", + "headless browser", + "playwright", + "dynamic page", + "browser fetch", + "screenshot the page", + "interact with the page", ], "optimization": [ "optimi", diff --git a/src/fleet_rlm/runtime/modules/variable_mode.py b/src/fleet_rlm/runtime/modules/variable_mode.py index 1a8433ea2..04a209603 100644 --- a/src/fleet_rlm/runtime/modules/variable_mode.py +++ b/src/fleet_rlm/runtime/modules/variable_mode.py @@ -7,6 +7,7 @@ from __future__ import annotations +from contextlib import nullcontext from typing import Any import dspy @@ -61,14 +62,20 @@ def __init__( max_output_chars: int | None = None, sub_lm: dspy.LM | None = None, extra_tools: list[Any] | None = None, + include_sub_tools: bool = True, + include_llm_tools: bool = True, ) -> None: super().__init__() + self._interpreter = interpreter + self._include_llm_tools = include_llm_tools + self._adapter = None if include_llm_tools else dspy.JSONAdapter() # Gather sub_rlm tools from the interpreter (if it exposes them) tools: list[Any] = list(extra_tools or []) - for attr_name in ("sub_rlm", "sub_rlm_batched"): - fn = getattr(interpreter, attr_name, None) - if callable(fn): - tools.append(fn) + if include_sub_tools: + for attr_name in ("sub_rlm", "sub_rlm_batched"): + fn = getattr(interpreter, attr_name, None) + if callable(fn): + tools.append(fn) self._rlm = create_runtime_rlm( signature=signature, @@ -81,6 +88,7 @@ def __init__( verbose=verbose, tools=tools or None, sub_lm=sub_lm, + include_llm_tools=include_llm_tools, ) def forward(self, **kwargs: Any) -> dspy.Prediction: @@ -90,7 +98,29 @@ def forward(self, **kwargs: Any) -> dspy.Prediction: model writes code to explore those variables before calling ``SUBMIT(...)`` with the signature's declared outputs. """ - return self._rlm(**kwargs) + if self._include_llm_tools: + result = self._rlm(**kwargs) + self._record_trajectory_spans(result) + return result + + previous = getattr(self._interpreter, "semantic_callbacks_enabled", True) + adapter_context = dspy.settings.context(adapter=self._adapter) if self._adapter else nullcontext() + try: + setattr(self._interpreter, "semantic_callbacks_enabled", False) + with adapter_context: + result = self._rlm(**kwargs) + self._record_trajectory_spans(result) + return result + finally: + setattr(self._interpreter, "semantic_callbacks_enabled", previous) + + def _record_trajectory_spans(self, result: dspy.Prediction) -> None: + try: + from fleet_rlm.integrations.observability.mlflow_context import record_rlm_trajectory_spans + + record_rlm_trajectory_spans(getattr(result, "trajectory", None)) + except Exception: + return def build_variable_mode_rlm( @@ -103,6 +133,8 @@ def build_variable_mode_rlm( max_output_chars: int | None = None, sub_lm: dspy.LM | None = None, extra_tools: list[Any] | None = None, + include_sub_tools: bool = True, + include_llm_tools: bool = True, ) -> RLMVariableExecutionModule: """Factory for the true-RLM variable-mode execution module. @@ -119,6 +151,8 @@ def build_variable_mode_rlm( max_output_chars=max_output_chars, sub_lm=sub_lm, extra_tools=extra_tools, + include_sub_tools=include_sub_tools, + include_llm_tools=include_llm_tools, ) diff --git a/src/fleet_rlm/runtime/tools/binding.py b/src/fleet_rlm/runtime/tools/binding.py index a8979c063..534ca9c0f 100644 --- a/src/fleet_rlm/runtime/tools/binding.py +++ b/src/fleet_rlm/runtime/tools/binding.py @@ -7,7 +7,7 @@ from collections.abc import Callable from typing import Any -from fleet_rlm.runtime.tools.document_tools import _load_document_impl +from fleet_rlm.runtime.tools.document_tools import _load_document_impl, _validate_download_url from fleet_rlm.runtime.tools.knowledge_tools import _search_knowledge_impl from fleet_rlm.runtime.tools.rlm_delegate import ( delegate_to_rlm as _delegate_to_rlm, @@ -36,6 +36,7 @@ INTERPRETER_TOOL_NAMES = frozenset( { + "browser_fetch_page", "clear_buffer", "delegate_to_rlm", "delegate_to_rlm_batched", @@ -272,9 +273,61 @@ def recursive_workspace( } factories["recursive_workspace"] = recursive_workspace + + def browser_fetch_page( + url: str, + wait_until: str = "networkidle", + extract_links: bool = False, + ) -> dict[str, Any]: + """Fetch a JS-rendered page using Playwright inside the sandbox.""" + _validate_download_url(url) + return execute_sandbox_tool( + interpreter, + _BROWSER_FETCH_PAGE_CODE, + {"target_url": url, "wait_until": wait_until, "extract_links": extract_links}, + ) + + factories["browser_fetch_page"] = browser_fetch_page return factories +_BROWSER_FETCH_PAGE_CODE = """\ +try: + from playwright.sync_api import sync_playwright +except ImportError: + SUBMIT( + status="error", + error="Playwright is not installed in this sandbox. " + "Use a browser-capable sandbox (fleet-rlm-browser snapshot) for rendered page fetching.", + ) +else: + with sync_playwright() as p: + browser = p.chromium.launch(headless=True, args=["--no-sandbox", "--disable-dev-shm-usage"]) + page = browser.new_page() + try: + page.goto(target_url, wait_until=wait_until, timeout=30000) + text = page.inner_text("body") + title = page.title() + links = [] + if extract_links: + links = page.eval_on_selector_all( + "a[href]", + "els => els.map(e => ({href: e.href, text: (e.textContent || '').trim()}))", + ) + SUBMIT( + status="ok", + url=target_url, + title=title, + text=text[:200000], + char_count=len(text), + links=links[:100] if extract_links else [], + ) + finally: + page.close() + browser.close() +""" + + def bind_runtime_tools( tools: list[Any], *, diff --git a/src/fleet_rlm/runtime/tools/browser_tools.py b/src/fleet_rlm/runtime/tools/browser_tools.py new file mode 100644 index 000000000..4abb9a858 --- /dev/null +++ b/src/fleet_rlm/runtime/tools/browser_tools.py @@ -0,0 +1,41 @@ +"""Browser automation tool for JavaScript-rendered page fetching. + +Exposes ``browser_fetch_page`` marked with ``@tool_fn`` so that +``discover_tools()`` can collect it. The concrete execution is delegated +to the Daytona interpreter (Playwright runs inside the sandbox); calling +this function directly raises ``RuntimeError``. +""" + +from __future__ import annotations + +from typing import Any + +from fleet_rlm.runtime.tools._marker import tool_fn + + +@tool_fn +def browser_fetch_page( + url: str, + wait_until: str = "networkidle", + extract_links: bool = False, +) -> dict[str, Any]: + """Fetch a JavaScript-rendered web page using a headless Chromium browser. + + Use when fetch_page or fetch_document_text returns empty/SPA shell content. + Requires a browser-capable sandbox (fleet-rlm-browser snapshot). + + Args: + url: Public HTTP(S) URL to render and extract text from. + wait_until: Playwright load state — "networkidle", "load", or "domcontentloaded". + extract_links: Whether to also extract anchor links from the page. + + Returns: + Dict with status, url, title, text, char_count, and optionally links. + """ + raise RuntimeError( + "browser_fetch_page requires an active AgentRuntime with a browser-capable " + "Daytona sandbox (fleet-rlm-browser snapshot)." + ) + + +__all__ = ["browser_fetch_page"] diff --git a/src/fleet_rlm/runtime/tools/document_tools.py b/src/fleet_rlm/runtime/tools/document_tools.py index 867570180..4cc192f2f 100644 --- a/src/fleet_rlm/runtime/tools/document_tools.py +++ b/src/fleet_rlm/runtime/tools/document_tools.py @@ -37,7 +37,10 @@ "text/plain": ".txt", "application/json": ".json", "text/markdown": ".md", + "application/xml": ".xml", + "text/xml": ".xml", } +_AUXILIARY_DOC_PATHS = ("/llms.txt", "/sitemap.xml") def _is_private_download_address(address: str) -> bool: @@ -165,6 +168,45 @@ def _download_url(url: str) -> Path: return Path(tmp_path) +def _origin_url(parsed: urllib.parse.ParseResult) -> str: + return urllib.parse.urlunparse((parsed.scheme, parsed.netloc, "", "", "", "")) + + +def _documentation_auxiliary_urls(url: str) -> list[str]: + """Return same-origin documentation index URLs worth bundling with a page. + + RLM document analysis works best when the REPL receives the page content + plus a compact site map. Documentation sites increasingly expose + ``/llms.txt`` for LLM-readable summaries and ``/sitemap.xml`` for URL + inventory, so Fleet opportunistically fetches those small same-origin + companions for root documentation URLs. + """ + parsed = urllib.parse.urlparse(url) + if parsed.scheme not in {"http", "https"} or not parsed.netloc: + return [] + path = parsed.path or "/" + if Path(path).suffix and path not in {"/"}: + return [] + origin = _origin_url(parsed) + current = urllib.parse.urlunparse((parsed.scheme, parsed.netloc, parsed.path or "/", "", "", "")) + candidates = [origin + suffix for suffix in _AUXILIARY_DOC_PATHS] + return [candidate for candidate in candidates if candidate.rstrip("/") != current.rstrip("/")] + + +def _read_remote_document_text(url: str) -> tuple[str, dict[str, Any]]: + from fleet_rlm.runtime.content.ingestion import ( + read_document_content as _read_document_content, + ) + + tmp_path: Path | None = None + try: + tmp_path = _download_url(url) + return _read_document_content(tmp_path) + finally: + if tmp_path is not None: + tmp_path.unlink(missing_ok=True) + + @tool_fn def load_document(source: str, alias: str = "active") -> dict[str, Any]: """Load a local file, directory listing, or public URL into document context.""" @@ -293,22 +335,37 @@ def fetch_document_text(url_or_path: str) -> dict[str, Any]: - ``metadata``: Extraction metadata dict (source_type, extraction_method, etc.). - ``error``: Error message (present when ``status == "error"``). """ - from fleet_rlm.runtime.content.ingestion import ( - read_document_content as _read_document_content, - ) - stripped = url_or_path.strip() - tmp_path: Path | None = None try: if not stripped.startswith(("http://", "https://")): return { "status": "error", "error": "fetch_document_text only accepts HTTP(S) URLs.", } - tmp_path = _download_url(stripped) - file_path = tmp_path + text, metadata = _read_remote_document_text(stripped) + bundled_sources = [{"url": stripped, "char_count": len(text)}] + combined_parts = [f"# Source document: {stripped}\n\n{text}"] + auxiliary_errors: list[str] = [] + + for auxiliary_url in _documentation_auxiliary_urls(stripped): + try: + auxiliary_text, _auxiliary_metadata = _read_remote_document_text(auxiliary_url) + except Exception as exc: + auxiliary_errors.append(f"{auxiliary_url}: {exc}") + continue + if not auxiliary_text.strip(): + continue + bundled_sources.append({"url": auxiliary_url, "char_count": len(auxiliary_text)}) + combined_parts.append(f"# Auxiliary document: {auxiliary_url}\n\n{auxiliary_text}") + + if len(bundled_sources) > 1: + text = "\n\n".join(combined_parts) + metadata = dict(metadata or {}) + metadata["bundled_sources"] = bundled_sources + metadata["bundled_source_count"] = len(bundled_sources) + if auxiliary_errors: + metadata["auxiliary_errors"] = auxiliary_errors - text, metadata = _read_document_content(file_path) return { "status": "ok", "text": text, @@ -317,9 +374,6 @@ def fetch_document_text(url_or_path: str) -> dict[str, Any]: } except Exception as exc: return {"status": "error", "error": str(exc)} - finally: - if tmp_path is not None: - tmp_path.unlink(missing_ok=True) __all__ = [ diff --git a/src/fleet_rlm/runtime/tools/registry.py b/src/fleet_rlm/runtime/tools/registry.py index c619d7b7b..68fdab1ee 100644 --- a/src/fleet_rlm/runtime/tools/registry.py +++ b/src/fleet_rlm/runtime/tools/registry.py @@ -9,6 +9,7 @@ from ._marker import tool_fn TOOL_MODULE_NAMES: tuple[str, ...] = ( + "browser_tools", "chunking_tools", "document_tools", "filesystem", diff --git a/src/frontend/src/components/agent-elements/input/model-picker.tsx b/src/frontend/src/components/agent-elements/input/model-picker.tsx new file mode 100644 index 000000000..785f25cb4 --- /dev/null +++ b/src/frontend/src/components/agent-elements/input/model-picker.tsx @@ -0,0 +1,145 @@ +"use client"; + +import { memo, useMemo, useState } from "react"; +import { Check, ChevronDown, Cpu, Settings2 } from "lucide-react"; + +import { Button } from "@/components/ui/button"; +import { + Popover, + PopoverContent, + PopoverTrigger, +} from "@/components/ui/popover"; + +import { cn } from "../utils/cn"; + +export type ModelOption = { + id: string; + label: string; + description?: string; + disabled?: boolean; +}; + +export type ModelPickerProps = { + models: readonly ModelOption[]; + value?: string; + defaultValue?: string; + onChange?: (modelId: string) => void; + onConfigure?: () => void; + configureLabel?: string; + className?: string; + disabled?: boolean; +}; + +export const ModelPicker = memo(function ModelPicker({ + models, + value, + defaultValue, + onChange, + onConfigure, + configureLabel = "Model settings", + className, + disabled = false, +}: ModelPickerProps) { + const isControlled = value !== undefined; + const [internalValue, setInternalValue] = useState(defaultValue); + const [open, setOpen] = useState(false); + const activeId = isControlled ? value : internalValue; + const enabledModels = useMemo( + () => models.filter((model) => !model.disabled), + [models], + ); + const activeModel = + models.find((model) => model.id === activeId) ?? + enabledModels[0] ?? + models[0] ?? + null; + + if (!activeModel) return null; + + const canSelect = Boolean(onChange) && enabledModels.length > 1 && !disabled; + const trigger = ( + + ); + + return ( + + {trigger} + +
+ {models.map((model) => { + const isActive = model.id === activeModel.id; + return ( + + ); + })} +
+ {onConfigure ? ( +
+ +
+ ) : null} +
+
+ ); +}); diff --git a/src/frontend/src/components/agent-elements/tools/subagent-tool.tsx b/src/frontend/src/components/agent-elements/tools/subagent-tool.tsx new file mode 100644 index 000000000..fbb7f51ce --- /dev/null +++ b/src/frontend/src/components/agent-elements/tools/subagent-tool.tsx @@ -0,0 +1,39 @@ +import { memo } from "react"; + +import { ToolGroup } from "./tool-group"; + +export type SubagentToolProps = { + part: any; + nestedTools?: any[]; + chatStatus?: string; +}; + +function subagentLabel(part: any): string { + const raw = + part?.input?.subagent_type ?? + part?.input?.agent_type ?? + part?.input?.name ?? + (part?.type === "tool-Agent" ? "Agent" : "Task"); + const label = String(raw || "").trim(); + return label || "Agent"; +} + +export const SubagentTool = memo(function SubagentTool({ + part, + nestedTools, + chatStatus, +}: SubagentToolProps) { + const label = subagentLabel(part); + + return ( + + ); +}); diff --git a/src/frontend/src/components/agent-elements/tools/tool-renderer.tsx b/src/frontend/src/components/agent-elements/tools/tool-renderer.tsx index 693303f6a..7cff12a7d 100644 --- a/src/frontend/src/components/agent-elements/tools/tool-renderer.tsx +++ b/src/frontend/src/components/agent-elements/tools/tool-renderer.tsx @@ -6,10 +6,10 @@ import { BashTool } from "./bash-tool"; import { EditTool } from "./edit-tool"; import { TodoTool } from "./todo-tool"; import { PlanTool } from "./plan-tool"; -import { ToolGroup } from "./tool-group"; import { McpTool, unwrapMcpOutput } from "./mcp-tool"; import { ThinkingTool } from "./thinking-tool"; import { SearchTool } from "./search-tool"; +import { SubagentTool } from "./subagent-tool"; import { QuestionTool } from "../question/question-tool"; import type { CustomToolRendererProps } from "../types"; @@ -54,20 +54,8 @@ export const ToolRenderer = memo(function ToolRenderer({ case "tool-Question": return ; case "tool-Task": - case "tool-Agent": { - const labelBase = part.type === "tool-Agent" ? "Agent" : "Task"; - return ( - - ); - } + case "tool-Agent": + return ; case "tool-Thinking": return ; } diff --git a/src/frontend/src/features/workspace/conversation/transcript/__tests__/workspace-message-list.agent-elements.test.tsx b/src/frontend/src/features/workspace/conversation/transcript/__tests__/workspace-message-list.agent-elements.test.tsx index 5ede8b008..f03ff5491 100644 --- a/src/frontend/src/features/workspace/conversation/transcript/__tests__/workspace-message-list.agent-elements.test.tsx +++ b/src/frontend/src/features/workspace/conversation/transcript/__tests__/workspace-message-list.agent-elements.test.tsx @@ -103,8 +103,8 @@ describe("WorkspaceMessageList Agent Elements integration", () => { { onResolveHitl }, ); - const approveButton = Array.from(container.querySelectorAll("button")).find((button) => - button.textContent?.includes("Approve"), + const approveButton = Array.from(container.querySelectorAll("button")).find( + (button) => button.textContent?.includes("Approve"), ); expect(approveButton).toBeTruthy(); @@ -112,8 +112,8 @@ describe("WorkspaceMessageList Agent Elements integration", () => { approveButton?.dispatchEvent(new MouseEvent("click", { bubbles: true })); }); - const sendButton = Array.from(container.querySelectorAll("button")).find((button) => - button.textContent?.includes("Send"), + const sendButton = Array.from(container.querySelectorAll("button")).find( + (button) => button.textContent?.includes("Send"), ); expect(sendButton).toBeTruthy(); @@ -141,7 +141,9 @@ describe("WorkspaceMessageList Agent Elements integration", () => { }, { kind: "reasoning", - parts: [{ type: "text", text: "I should inspect the repository files." }], + parts: [ + { type: "text", text: "I should inspect the repository files." }, + ], isStreaming: false, }, { @@ -170,7 +172,12 @@ describe("WorkspaceMessageList Agent Elements integration", () => { }, ], }, - { id: "a1", type: "assistant", content: "Done inspecting.", streaming: false }, + { + id: "a1", + type: "assistant", + content: "Done inspecting.", + streaming: false, + }, ]); expect(container.textContent).toContain("Execution started"); @@ -183,6 +190,34 @@ describe("WorkspaceMessageList Agent Elements integration", () => { act(() => root.unmount()); }); + it("renders delegated agent work through the Agent Elements SubagentTool path", () => { + const { container, root } = mount([ + { id: "u1", type: "user", content: "delegate this" }, + { + id: "trace-agent", + type: "trace", + content: "", + renderParts: [ + { + kind: "tool", + title: "Delegate", + toolType: "delegate_agent", + state: "output-available", + input: { + subagent_type: "Research agent", + description: "Inspect the RLM trajectory", + }, + output: { status: "completed" }, + }, + ], + }, + ]); + + expect(container.textContent).toContain("Research agent completed"); + + act(() => root.unmount()); + }); + it("opens the attachment menu and stages a document chip", () => { const { container, root } = mount([]); @@ -193,21 +228,24 @@ describe("WorkspaceMessageList Agent Elements integration", () => { attachButton?.dispatchEvent(new MouseEvent("click", { bubbles: true })); }); - const addDocumentButton = Array.from(document.body.querySelectorAll("button")).find((button) => - button.textContent?.includes("Add document"), - ); - const connectorsButton = Array.from(document.body.querySelectorAll("button")).find((button) => - button.textContent?.includes("Connectors"), - ); + const addDocumentButton = Array.from( + document.body.querySelectorAll("button"), + ).find((button) => button.textContent?.includes("Add document")); + const connectorsButton = Array.from( + document.body.querySelectorAll("button"), + ).find((button) => button.textContent?.includes("Connectors")); expect(addDocumentButton).toBeTruthy(); expect(connectorsButton).toBeTruthy(); expect(connectorsButton).toHaveProperty("disabled", true); act(() => { - addDocumentButton?.dispatchEvent(new MouseEvent("click", { bubbles: true })); + addDocumentButton?.dispatchEvent( + new MouseEvent("click", { bubbles: true }), + ); }); - const fileInput = container.querySelector('input[type="file"]'); + const fileInput = + container.querySelector('input[type="file"]'); expect(fileInput).toBeTruthy(); Object.defineProperty(fileInput, "files", { configurable: true, @@ -223,11 +261,52 @@ describe("WorkspaceMessageList Agent Elements integration", () => { act(() => root.unmount()); }); - it("renders the pending planning loader without a lazy component crash", async () => { - const { container, root } = mount([{ id: "u1", type: "user", content: "start working" }], { - isTyping: true, + it("renders the runtime model picker from active model status", () => { + const onOpenModelSettings = vi.fn(); + const { container, root } = mount([], { + activeModels: { + planner: "openai/gemini-3-flash-preview", + delegate: "openai/gemini-3-pro-preview", + delegate_small: null, + }, + onOpenModelSettings, + }); + + expect(container.textContent).toContain("openai/gemini-3-flash-preview"); + + const modelButton = container.querySelector( + 'button[aria-label^="Active model"]', + ); + expect(modelButton).toBeTruthy(); + + act(() => { + modelButton?.dispatchEvent(new MouseEvent("click", { bubbles: true })); + }); + + expect(document.body.textContent).toContain("openai/gemini-3-pro-preview"); + + const settingsButton = Array.from( + document.body.querySelectorAll("button"), + ).find((button) => button.textContent?.includes("Model settings")); + expect(settingsButton).toBeTruthy(); + + act(() => { + settingsButton?.dispatchEvent(new MouseEvent("click", { bubbles: true })); }); + expect(onOpenModelSettings).toHaveBeenCalledOnce(); + + act(() => root.unmount()); + }); + + it("renders the pending planning loader without a lazy component crash", async () => { + const { container, root } = mount( + [{ id: "u1", type: "user", content: "start working" }], + { + isTyping: true, + }, + ); + await act(async () => { await new Promise((resolve) => setTimeout(resolve, 0)); }); diff --git a/src/frontend/src/features/workspace/conversation/transcript/workspace-message-list.tsx b/src/frontend/src/features/workspace/conversation/transcript/workspace-message-list.tsx index 85e3da906..167a416ee 100644 --- a/src/frontend/src/features/workspace/conversation/transcript/workspace-message-list.tsx +++ b/src/frontend/src/features/workspace/conversation/transcript/workspace-message-list.tsx @@ -34,6 +34,12 @@ interface WorkspaceMessageListProps { guidance: string[]; onOpenSettings: () => void; }; + activeModels?: { + planner?: string | null; + delegate?: string | null; + delegate_small?: string | null; + }; + onOpenModelSettings?: () => void; showStatusBar?: boolean; className?: string; } @@ -59,10 +65,14 @@ export function WorkspaceMessageList({ canSubmit = true, placeholder, runtimeWarning, + activeModels, + onOpenModelSettings, showStatusBar = true, className, }: WorkspaceMessageListProps) { - const selectedAssistantTurnId = useWorkspaceUiStore((state) => state.selectedAssistantTurnId); + const selectedAssistantTurnId = useWorkspaceUiStore( + (state) => state.selectedAssistantTurnId, + ); const agentMessages = useMemo( () => toAgentChatMessages(messages, { @@ -71,8 +81,11 @@ export function WorkspaceMessageList({ }), [messages, onResolveClarification, onResolveHitl], ); - const lastUserIndex = messages.findLastIndex((message) => message.type === "user"); - const lastUserMessageId = lastUserIndex >= 0 ? (messages[lastUserIndex]?.id ?? null) : null; + const lastUserIndex = messages.findLastIndex( + (message) => message.type === "user", + ); + const lastUserMessageId = + lastUserIndex >= 0 ? (messages[lastUserIndex]?.id ?? null) : null; const activeTurnAssistantMessageId = lastUserIndex >= 0 ? (messages @@ -84,11 +97,19 @@ export function WorkspaceMessageList({ useEffect(() => { if (!selectedAssistantTurnId || !lastUserMessageId) return; const pendingTurnId = buildPendingAssistantTurnId(lastUserMessageId); - if (selectedAssistantTurnId !== pendingTurnId || !activeTurnAssistantMessageId) return; + if ( + selectedAssistantTurnId !== pendingTurnId || + !activeTurnAssistantMessageId + ) + return; useWorkspaceUiStore.setState({ selectedAssistantTurnId: activeTurnAssistantMessageId, }); - }, [activeTurnAssistantMessageId, lastUserMessageId, selectedAssistantTurnId]); + }, [ + activeTurnAssistantMessageId, + lastUserMessageId, + selectedAssistantTurnId, + ]); const status = chatStatus(isTyping); const inputSlot = useMemo( @@ -101,12 +122,23 @@ export function WorkspaceMessageList({ placeholder={placeholder ?? props.placeholder} executionMode={executionMode} onExecutionModeChange={onExecutionModeChange} + activeModels={activeModels} + onOpenModelSettings={onOpenModelSettings} showStatusBar={showStatusBar} runtimeWarning={runtimeWarning} /> ); }, - [canSubmit, executionMode, onExecutionModeChange, placeholder, runtimeWarning, showStatusBar], + [ + activeModels, + canSubmit, + executionMode, + onExecutionModeChange, + onOpenModelSettings, + placeholder, + runtimeWarning, + showStatusBar, + ], ); const handleQuestionAnswer = ({ @@ -123,7 +155,9 @@ export function WorkspaceMessageList({ const selectedLabels = (answer.selectedIds ?? []).map( (id) => question.options?.find((option) => option.id === id)?.label ?? id, ); - const text = [selectedLabels.join(", "), answer.text].filter(Boolean).join(" - "); + const text = [selectedLabels.join(", "), answer.text] + .filter(Boolean) + .join(" - "); if (!text) return; if (target?.type === "hitl") { onResolveHitl(toolCallId, text); @@ -136,9 +170,14 @@ export function WorkspaceMessageList({ if (messages.length === 0 && showEmptyState) { return ( -
+
- +
void; + activeModels?: { + planner?: string | null; + delegate?: string | null; + delegate_small?: string | null; + }; + onOpenModelSettings?: () => void; showStatusBar?: boolean; runtimeWarning?: { title: string; @@ -38,6 +57,38 @@ interface WorkspaceAgentInputBarProps extends InputBarProps { }; } +function activeModelOptions( + activeModels: WorkspaceAgentInputBarProps["activeModels"], +): ModelOption[] { + const options = [ + { + id: "planner", + label: activeModels?.planner || "Planner model", + description: activeModels?.planner + ? "Planner runtime" + : "Planner model not configured", + disabled: !activeModels?.planner, + }, + { + id: "delegate", + label: activeModels?.delegate || "Delegate model", + description: activeModels?.delegate + ? "Recursive delegate runtime" + : "Delegate model not configured", + disabled: !activeModels?.delegate, + }, + { + id: "delegate_small", + label: activeModels?.delegate_small || "Small delegate model", + description: activeModels?.delegate_small + ? "Lightweight delegate runtime" + : "Small delegate model not configured", + disabled: !activeModels?.delegate_small, + }, + ]; + return options; +} + function ExecutionModeToggle({ value, onChange, @@ -58,6 +109,8 @@ function ExecutionModeToggle({ export function WorkspaceAgentInputBar({ executionMode, onExecutionModeChange, + activeModels, + onOpenModelSettings, showStatusBar = true, runtimeWarning, className, @@ -76,20 +129,23 @@ export function WorkspaceAgentInputBar({ fileInputRef.current?.click(); }, [onAttach]); - const handleDocumentInputChange = useCallback((event: ChangeEvent) => { - const files = Array.from(event.currentTarget.files ?? []); - if (files.length > 0) { - setStagedDocuments((current) => [ - ...current, - ...files.map((file) => ({ - id: `document-${createAttachmentId()}`, - filename: file.name, - size: file.size, - })), - ]); - } - event.currentTarget.value = ""; - }, []); + const handleDocumentInputChange = useCallback( + (event: ChangeEvent) => { + const files = Array.from(event.currentTarget.files ?? []); + if (files.length > 0) { + setStagedDocuments((current) => [ + ...current, + ...files.map((file) => ({ + id: `document-${createAttachmentId()}`, + filename: file.name, + size: file.size, + })), + ]); + } + event.currentTarget.value = ""; + }, + [], + ); const handleRemoveFile = useCallback( (id: string) => { @@ -108,7 +164,9 @@ export function WorkspaceAgentInputBar({ ); return ( -
+
- {runtimeWarning.title} + + {runtimeWarning.title} +
@@ -158,7 +218,16 @@ export function WorkspaceAgentInputBar({ onSend={handleSend} rightActions={ <> - + + {rightActions} } diff --git a/src/frontend/src/features/workspace/screen/workspace-screen.tsx b/src/frontend/src/features/workspace/screen/workspace-screen.tsx index bd88a584e..8a1beb19a 100644 --- a/src/frontend/src/features/workspace/screen/workspace-screen.tsx +++ b/src/frontend/src/features/workspace/screen/workspace-screen.tsx @@ -123,8 +123,10 @@ export function WorkspaceScreen() { executionMode, runtimeMode, repoUrl: inferredRepoContext?.repoUrl, - repoRef: inferredRepoContext?.repoRefCandidate ?? inferredRepoContext?.repoRef, - contextPaths: inferredContextPaths.length > 0 ? inferredContextPaths : undefined, + repoRef: + inferredRepoContext?.repoRefCandidate ?? inferredRepoContext?.repoRef, + contextPaths: + inferredContextPaths.length > 0 ? inferredContextPaths : undefined, }); }, [ @@ -140,11 +142,15 @@ export function WorkspaceScreen() { ], ); - const { sessionRevision, requestedConversationId, clearRequestedConversation } = - useWorkspaceUiStore(); + const { + sessionRevision, + requestedConversationId, + clearRequestedConversation, + } = useWorkspaceUiStore(); // Chat history - const { saveConversation, loadConversation: loadConv } = useChatHistoryStore(); + const { saveConversation, loadConversation: loadConv } = + useChatHistoryStore(); // ── Auto-save on session change ────────────────────────────────── // When sessionRevision increments (newSession() called), save the current @@ -188,8 +194,16 @@ export function WorkspaceScreen() { return; } - if (messagesRef.current.length > 0 && messagesRef.current !== conversation.messages) { - saveConversation(messagesRef.current, phaseRef.current, undefined, turnArtifactsRef.current); + if ( + messagesRef.current.length > 0 && + messagesRef.current !== conversation.messages + ) { + saveConversation( + messagesRef.current, + phaseRef.current, + undefined, + turnArtifactsRef.current, + ); } loadConversation(conversation); @@ -248,6 +262,8 @@ export function WorkspaceScreen() { } : undefined } + activeModels={runtimeStatus.data?.active_models} + onOpenModelSettings={handleOpenRuntimeSettings} />
diff --git a/src/frontend/src/lib/rlm-api/__tests__/ws-frame-parser.test.ts b/src/frontend/src/lib/rlm-api/__tests__/ws-frame-parser.test.ts index 64808b566..4580b382d 100644 --- a/src/frontend/src/lib/rlm-api/__tests__/ws-frame-parser.test.ts +++ b/src/frontend/src/lib/rlm-api/__tests__/ws-frame-parser.test.ts @@ -131,4 +131,29 @@ describe("parseWsServerFrame", () => { expect(frame.data.kind).toBe("execution_step"); expect(frame.data.payload?.step).toMatchObject({ type: "output" }); }); + + it("preserves raw RLM repl execution steps for chat rendering", () => { + const frame = parseWsServerFrame({ + type: "execution_step", + timestamp: 1710849601, + step: { + id: "step-rlm-repl", + type: "repl", + label: "repl_result", + input: { code: "print(document_text[:20])" }, + output: { stdout: "DSPy documentation" }, + timestamp: 1710849602, + }, + }); + + expect(frame).toBeTruthy(); + if (!frame || frame.type !== "event") return; + expect(frame.data.kind).toBe("execution_step"); + expect(frame.data.text).toBe("repl_result"); + expect(frame.data.payload?.step).toMatchObject({ + type: "repl", + input: { code: "print(document_text[:20])" }, + output: { stdout: "DSPy documentation" }, + }); + }); }); diff --git a/src/frontend/src/lib/workspace/__tests__/backend-chat-event-adapter.test.ts b/src/frontend/src/lib/workspace/__tests__/backend-chat-event-adapter.test.ts index 8a239e1a4..8aa87bb71 100644 --- a/src/frontend/src/lib/workspace/__tests__/backend-chat-event-adapter.test.ts +++ b/src/frontend/src/lib/workspace/__tests__/backend-chat-event-adapter.test.ts @@ -1,11 +1,23 @@ import { describe, expect, it, vi } from "vite-plus/test"; import { QueryClient } from "@tanstack/react-query"; import { applyWsFrameToMessages } from "@/lib/workspace/backend-chat-event-adapter"; -import type { ChatMessage, ChatRenderPart } from "@/lib/workspace/workspace-types"; +import type { + ChatMessage, + ChatRenderPart, +} from "@/lib/workspace/workspace-types"; import type { WsServerMessage } from "@/lib/rlm-api"; -function makeEvent(kind: string, text: string, payload?: Record): WsServerMessage { - if (kind === "done" || kind === "turn_completed" || kind === "error" || kind === "turn_failed") { +function makeEvent( + kind: string, + text: string, + payload?: Record, +): WsServerMessage { + if ( + kind === "done" || + kind === "turn_completed" || + kind === "error" || + kind === "turn_failed" + ) { return { type: "event", data: { @@ -15,14 +27,23 @@ function makeEvent(kind: string, text: string, payload?: Record ...payload, source_type: "execution_completed", run_summary: { - status: kind === "error" || kind === "turn_failed" ? "failed" : "completed", + status: + kind === "error" || kind === "turn_failed" + ? "failed" + : "completed", }, }, }, }; } - if (kind === "text" || kind === "reasoning" || kind === "tool_call" || kind === "tool_result") { - const stepType = kind === "tool_call" || kind === "tool_result" ? "tool" : "llm"; + if ( + kind === "text" || + kind === "reasoning" || + kind === "tool_call" || + kind === "tool_result" + ) { + const stepType = + kind === "tool_call" || kind === "tool_result" ? "tool" : "llm"; return { type: "event", data: { @@ -35,7 +56,12 @@ function makeEvent(kind: string, text: string, payload?: Record type: stepType, label: text, input: kind === "tool_call" ? text : undefined, - output: kind === "text" ? { text } : kind === "tool_result" ? text : undefined, + output: + kind === "text" + ? { text } + : kind === "tool_result" + ? text + : undefined, ...payload, }, }, @@ -154,12 +180,15 @@ describe("applyWsFrameToMessages", () => { const reasoningRows = traceRows( messages, - (part, message) => part.kind === "reasoning" && message.traceSource === "live", + (part, message) => + part.kind === "reasoning" && message.traceSource === "live", ); expect(reasoningRows).toHaveLength(2); expect( - reasoningRows.map((row) => (row.part.kind === "reasoning" ? row.part.parts[0]?.text : "")), + reasoningRows.map((row) => + row.part.kind === "reasoning" ? row.part.parts[0]?.text : "", + ), ).toEqual(["Analyzing input ", "and checking constraints"]); }); @@ -182,7 +211,10 @@ describe("applyWsFrameToMessages", () => { }), ); - const reasoning = findFirstPart(messages, (part) => part.kind === "reasoning"); + const reasoning = findFirstPart( + messages, + (part) => part.kind === "reasoning", + ); expect(reasoning).toBeDefined(); if (reasoning?.kind === "reasoning") { expect(reasoning.runtimeContext).toEqual({ @@ -208,7 +240,10 @@ describe("applyWsFrameToMessages", () => { }), ); - const reasoning = findFirstPart(messages, (part) => part.kind === "reasoning"); + const reasoning = findFirstPart( + messages, + (part) => part.kind === "reasoning", + ); expect(reasoning).toBeDefined(); if (reasoning?.kind === "reasoning") { expect(reasoning.label).toBe("prompt_iter_1"); @@ -268,7 +303,8 @@ describe("applyWsFrameToMessages", () => { const liveReasoning = traceRows( messages, - (part, message) => part.kind === "reasoning" && message.traceSource === "trajectory", + (part, message) => + part.kind === "reasoning" && message.traceSource === "trajectory", ); expect(liveReasoning).toHaveLength(1); const reasoningPart = liveReasoning[0]?.part; @@ -277,7 +313,10 @@ describe("applyWsFrameToMessages", () => { expect(reasoningPart.parts[0]?.text).toBe("Inspect the repo first."); } - const cot = findFirstPart(messages, (part) => part.kind === "chain_of_thought"); + const cot = findFirstPart( + messages, + (part) => part.kind === "chain_of_thought", + ); if (cot?.kind === "chain_of_thought") { expect(cot.steps[0]?.body).toBe("Inspect the repo first."); } @@ -285,7 +324,10 @@ describe("applyWsFrameToMessages", () => { it.skip("suppresses trajectory fallback primary rows when live trace already exists", () => { let messages: ChatMessage[] = []; - messages = applyWsFrameToMessages(messages, makeEvent("reasoning", "Live reasoning")).messages; + messages = applyWsFrameToMessages( + messages, + makeEvent("reasoning", "Live reasoning"), + ).messages; messages = applyWsFrameToMessages( messages, @@ -305,10 +347,16 @@ describe("applyWsFrameToMessages", () => { ); expect(trajectoryPrimary).toHaveLength(0); - const reasoningRows = traceRows(messages, (part) => part.kind === "reasoning"); + const reasoningRows = traceRows( + messages, + (part) => part.kind === "reasoning", + ); expect(reasoningRows).toHaveLength(1); - const cot = findFirstPart(messages, (part) => part.kind === "chain_of_thought"); + const cot = findFirstPart( + messages, + (part) => part.kind === "chain_of_thought", + ); expect(cot).toBeDefined(); if (cot?.kind === "chain_of_thought") { expect(cot.steps).toHaveLength(1); @@ -333,7 +381,8 @@ describe("applyWsFrameToMessages", () => { const trajectoryReasoning = traceRows( messages, - (part, message) => part.kind === "reasoning" && message.traceSource === "trajectory", + (part, message) => + part.kind === "reasoning" && message.traceSource === "trajectory", ); expect( trajectoryReasoning.map((row) => @@ -344,7 +393,10 @@ describe("applyWsFrameToMessages", () => { const tools = findAllParts(messages, (part) => part.kind === "tool"); expect(tools).toHaveLength(0); - const cot = findFirstPart(messages, (part) => part.kind === "chain_of_thought"); + const cot = findFirstPart( + messages, + (part) => part.kind === "chain_of_thought", + ); expect(cot).toBeDefined(); if (cot?.kind === "chain_of_thought") { expect(cot.steps).toHaveLength(2); @@ -375,7 +427,10 @@ describe("applyWsFrameToMessages", () => { }), ).messages; - const cot = findFirstPart(messages, (part) => part.kind === "chain_of_thought"); + const cot = findFirstPart( + messages, + (part) => part.kind === "chain_of_thought", + ); expect(cot).toBeDefined(); if (cot?.kind === "chain_of_thought") { expect(cot.steps.map((step) => step.index)).toEqual([0, 1]); @@ -386,7 +441,10 @@ describe("applyWsFrameToMessages", () => { it("keeps exact interleaved order for reasoning and tool events", () => { let messages: ChatMessage[] = []; - messages = applyWsFrameToMessages(messages, makeEvent("reasoning", "r1")).messages; + messages = applyWsFrameToMessages( + messages, + makeEvent("reasoning", "r1"), + ).messages; messages = applyWsFrameToMessages( messages, makeEvent("tool_call", "call", { @@ -394,7 +452,10 @@ describe("applyWsFrameToMessages", () => { tool_args: { pattern: "foo" }, }), ).messages; - messages = applyWsFrameToMessages(messages, makeEvent("reasoning", "r2")).messages; + messages = applyWsFrameToMessages( + messages, + makeEvent("reasoning", "r2"), + ).messages; messages = applyWsFrameToMessages( messages, makeEvent("tool_result", "result", { @@ -402,13 +463,18 @@ describe("applyWsFrameToMessages", () => { tool_output: "match", }), ).messages; - messages = applyWsFrameToMessages(messages, makeEvent("reasoning", "r3")).messages; + messages = applyWsFrameToMessages( + messages, + makeEvent("reasoning", "r3"), + ).messages; const primaryRows = traceRows( messages, (part, message) => message.traceSource === "live" && - (part.kind === "reasoning" || part.kind === "tool" || part.kind === "sandbox"), + (part.kind === "reasoning" || + part.kind === "tool" || + part.kind === "sandbox"), ); expect(primaryRows.map((row) => row.part.kind)).toEqual([ @@ -421,7 +487,10 @@ describe("applyWsFrameToMessages", () => { const toolRows = primaryRows.filter((row) => row.part.kind === "tool"); expect(toolRows).toHaveLength(2); - if (toolRows[0]?.part.kind === "tool" && toolRows[1]?.part.kind === "tool") { + if ( + toolRows[0]?.part.kind === "tool" && + toolRows[1]?.part.kind === "tool" + ) { expect(toolRows[0].part.state).toBe("running"); expect(toolRows[1].part.state).toBe("output-available"); } @@ -430,7 +499,10 @@ describe("applyWsFrameToMessages", () => { it.skip("maps status, rlm_delegate, status to task rows in order", () => { let messages: ChatMessage[] = []; - messages = applyWsFrameToMessages(messages, makeEvent("status", "Moving to step 2")).messages; + messages = applyWsFrameToMessages( + messages, + makeEvent("status", "Moving to step 2"), + ).messages; messages = applyWsFrameToMessages( messages, makeEvent("rlm_delegate", "Delegating", { @@ -448,7 +520,9 @@ describe("applyWsFrameToMessages", () => { ); expect(taskRows).toHaveLength(3); - const taskTitles = taskRows.map((row) => (row.part.kind === "task" ? row.part.title : "")); + const taskTitles = taskRows.map((row) => + row.part.kind === "task" ? row.part.title : "", + ); expect(taskTitles).toEqual([ "Plan update", "Executing PythonInterpreter", @@ -463,7 +537,9 @@ describe("applyWsFrameToMessages", () => { const queue = findFirstPart(messages, (p) => p.kind === "queue"); expect(queue).toBeDefined(); if (queue?.kind === "queue") { - expect(queue.items[queue.items.length - 1]?.label).toBe("Moving to step 2"); + expect(queue.items[queue.items.length - 1]?.label).toBe( + "Moving to step 2", + ); } }); @@ -569,11 +645,16 @@ describe("applyWsFrameToMessages", () => { expect(sandbox.stepIndex).toBe(2); expect(sandbox.output).toBe("loading repository metadata"); expect(sandbox.runtimeContext?.runtimeMode).toBe("daytona_pilot"); - expect(sandbox.runtimeContext?.workspacePath).toBe("/workspace/workspace/repo"); + expect(sandbox.runtimeContext?.workspacePath).toBe( + "/workspace/workspace/repo", + ); expect(sandbox.runtimeContext?.sandboxTransition).toBe("created"); } - const statusNote = findFirstPart(messages, (part) => part.kind === "status_note"); + const statusNote = findFirstPart( + messages, + (part) => part.kind === "status_note", + ); expect(statusNote).toBeUndefined(); }); @@ -661,7 +742,10 @@ describe("applyWsFrameToMessages", () => { }), ).messages; - const reasoningRows = traceRows(messages, (part) => part.kind === "reasoning"); + const reasoningRows = traceRows( + messages, + (part) => part.kind === "reasoning", + ); const toolRows = traceRows(messages, (part) => part.kind === "tool"); expect(reasoningRows).toHaveLength(1); @@ -694,6 +778,169 @@ describe("applyWsFrameToMessages", () => { } }); + it("renders canonical repl execution steps with compact code and output", () => { + const { messages } = applyWsFrameToMessages( + [], + makeEvent("execution_step", "repl_result", { + step: { + type: "repl", + label: "repl_result", + input: { + code: "import urllib.request\nprint('docs')", + }, + output: { + stdout: "docs", + }, + }, + }), + ); + + const sandbox = findFirstPart(messages, (p) => p.kind === "sandbox"); + expect(sandbox).toBeDefined(); + if (sandbox?.kind === "sandbox") { + expect(sandbox.state).toBe("output-available"); + expect(sandbox.code).toContain("urllib.request"); + expect(sandbox.output).toContain("docs"); + } + }); + + it("renders final RLM code/output trajectory as compact sandbox summary rows", () => { + const { messages } = applyWsFrameToMessages( + [], + makeEvent("done", "Done", { + trajectory: { + steps: [ + { + reasoning: "Inspect the fetched documentation.", + code: "print(document_text[:80])", + output: "DSPy docs", + }, + ], + }, + }), + ); + + const summaryReasoning = traceRows( + messages, + (part, message) => + part.kind === "reasoning" && message.traceSource === "summary", + ); + expect(summaryReasoning).toHaveLength(1); + + const sandbox = traceRows( + messages, + (part, message) => + part.kind === "sandbox" && message.traceSource === "summary", + )[0]?.part; + expect(sandbox).toBeDefined(); + if (sandbox?.kind === "sandbox") { + expect(sandbox.state).toBe("output-available"); + expect(sandbox.code).toContain("document_text"); + expect(sandbox.output).toContain("DSPy docs"); + } + }); + + it("does not duplicate final trajectory tool rows after live tool rows streamed", () => { + let messages = applyWsFrameToMessages( + [], + makeEvent("execution_step", "repl_result", { + step: { + type: "repl", + label: "repl_result", + input: { code: "print('docs')" }, + output: { stdout: "docs" }, + }, + }), + ).messages; + + messages = applyWsFrameToMessages( + messages, + makeEvent("done", "Done", { + trajectory: { + steps: [ + { + reasoning: "Inspect the fetched documentation.", + code: "print('docs')", + output: "docs", + }, + ], + }, + }), + ).messages; + + const sandboxRows = traceRows(messages, (part) => part.kind === "sandbox"); + expect(sandboxRows).toHaveLength(1); + expect(sandboxRows[0]?.message.traceSource).toBe("live"); + }); + + it("renders final trajectory rows when only a previous turn streamed live tool rows", () => { + let messages = applyWsFrameToMessages( + [], + makeEvent("execution_step", "repl_result", { + step: { + type: "repl", + label: "previous repl", + input: { code: "print('previous')" }, + output: { stdout: "previous" }, + }, + }), + ).messages; + messages = [ + ...messages, + { + id: "user-next", + type: "user", + content: "Analyze the next page", + phase: 1, + }, + ]; + + messages = applyWsFrameToMessages( + messages, + makeEvent("done", "Done", { + trajectory: { + steps: [ + { + reasoning: "Inspect the current documentation.", + code: "print('current')", + output: "current", + }, + ], + }, + }), + ).messages; + + const summarySandboxRows = traceRows( + messages, + (part, message) => + part.kind === "sandbox" && message.traceSource === "summary", + ); + expect(summarySandboxRows).toHaveLength(1); + const sandbox = summarySandboxRows[0]?.part; + expect(sandbox?.kind === "sandbox" ? sandbox.output : "").toContain( + "current", + ); + }); + + it("renders selected skills and routing decisions as compact status rows", () => { + const { messages } = applyWsFrameToMessages( + [], + makeEvent("execution_step", "RLM document analysis selected", { + selected_skills: ["long-context", "dspy-programs"], + routing_decision: "url_document_rlm", + source_url: "https://dspy.ai", + }), + ); + + const status = findFirstPart(messages, (p) => p.kind === "status_note"); + expect(status).toBeDefined(); + if (status?.kind === "status_note") { + expect(status.text).toContain("long-context"); + expect(status.text).toContain("url_document_rlm"); + expect(status.text).toContain("https://dspy.ai"); + } + }); + it("classifies environment variable payloads as environment_variables on tool_result", () => { const { messages } = applyWsFrameToMessages( [], @@ -703,7 +950,10 @@ describe("applyWsFrameToMessages", () => { }), ); - const env = findFirstPart(messages, (p) => p.kind === "environment_variables"); + const env = findFirstPart( + messages, + (p) => p.kind === "environment_variables", + ); expect(env).toBeDefined(); if (env?.kind === "environment_variables") { expect(env.variables.map((v) => v.name)).toContain("OPENAI_API_KEY"); @@ -734,8 +984,14 @@ describe("applyWsFrameToMessages", () => { it("final finalizes trace summaries and attaches citations/sources/attachments", () => { let messages: ChatMessage[] = []; - messages = applyWsFrameToMessages(messages, makeEvent("text", "Hello")).messages; - messages = applyWsFrameToMessages(messages, makeEvent("reasoning", "Thinking")).messages; + messages = applyWsFrameToMessages( + messages, + makeEvent("text", "Hello"), + ).messages; + messages = applyWsFrameToMessages( + messages, + makeEvent("reasoning", "Thinking"), + ).messages; messages = applyWsFrameToMessages( messages, makeEvent("execution_step", "trace", { @@ -743,7 +999,10 @@ describe("applyWsFrameToMessages", () => { step_data: { thought: "step one", tool_name: "read_file" }, }), ).messages; - messages = applyWsFrameToMessages(messages, makeEvent("status", "Do X")).messages; + messages = applyWsFrameToMessages( + messages, + makeEvent("status", "Do X"), + ).messages; const result = applyWsFrameToMessages( messages, @@ -798,11 +1057,19 @@ describe("applyWsFrameToMessages", () => { const assistant = result.messages.find((m) => m.type === "assistant"); expect(assistant?.streaming).toBe(false); - expect(assistant?.renderParts?.some((p) => p.kind === "inline_citation_group")).toBe(true); - expect(assistant?.renderParts?.some((p) => p.kind === "sources")).toBe(true); - expect(assistant?.renderParts?.some((p) => p.kind === "attachments")).toBe(true); + expect( + assistant?.renderParts?.some((p) => p.kind === "inline_citation_group"), + ).toBe(true); + expect(assistant?.renderParts?.some((p) => p.kind === "sources")).toBe( + true, + ); + expect(assistant?.renderParts?.some((p) => p.kind === "attachments")).toBe( + true, + ); - const citationGroup = assistant?.renderParts?.find((p) => p.kind === "inline_citation_group"); + const citationGroup = assistant?.renderParts?.find( + (p) => p.kind === "inline_citation_group", + ); if (citationGroup?.kind === "inline_citation_group") { expect(citationGroup.citations[0]?.title).toBe("Doc A"); expect(citationGroup.citations[0]?.number).toBe("1"); @@ -817,7 +1084,10 @@ describe("applyWsFrameToMessages", () => { expect(sources.sources[1]?.sourceId).toBe("src-b"); } - const cot = findFirstPart(result.messages, (p) => p.kind === "chain_of_thought"); + const cot = findFirstPart( + result.messages, + (p) => p.kind === "chain_of_thought", + ); if (cot?.kind === "chain_of_thought") { expect(cot.steps.every((step) => step.status === "complete")).toBe(true); } @@ -836,35 +1106,48 @@ describe("applyWsFrameToMessages", () => { const finalReasoningRows = traceRows( result.messages, - (part, message) => part.kind === "reasoning" && message.traceSource === "summary", + (part, message) => + part.kind === "reasoning" && message.traceSource === "summary", ); expect(finalReasoningRows).toHaveLength(3); const summaryLabels = finalReasoningRows.map((row) => row.part.kind === "reasoning" ? row.part.label : undefined, ); - expect(summaryLabels).toEqual(["thought_0", "thought_1", "final_reasoning"]); + expect(summaryLabels).toEqual([ + "thought_0", + "thought_1", + "final_reasoning", + ]); const finalReasoning = finalReasoningRows[2]?.part; if (finalReasoning?.kind === "reasoning") { - expect(finalReasoning.parts[0]?.text).toBe("The evidence lines up with the cited sources."); + expect(finalReasoning.parts[0]?.text).toBe( + "The evidence lines up with the cited sources.", + ); } }); it("prefers final_artifact markdown over raw final event JSON text", () => { const result = applyWsFrameToMessages( [], - makeEvent("done", '{ "final_markdown": "Hello there, it is great to meet you!" }', { - final_artifact: { - kind: "markdown", - value: { - final_markdown: "Hello there, it is great to meet you!", + makeEvent( + "done", + '{ "final_markdown": "Hello there, it is great to meet you!" }', + { + final_artifact: { + kind: "markdown", + value: { + final_markdown: "Hello there, it is great to meet you!", + }, }, }, - }), + ), ); - const assistant = result.messages.find((message) => message.type === "assistant"); + const assistant = result.messages.find( + (message) => message.type === "assistant", + ); expect(assistant?.content).toBe("Hello there, it is great to meet you!"); }); @@ -883,7 +1166,9 @@ describe("applyWsFrameToMessages", () => { }), ); - const assistant = result.messages.find((message) => message.type === "assistant"); + const assistant = result.messages.find( + (message) => message.type === "assistant", + ); expect(assistant?.content).toBe("Canonical completion text"); }); diff --git a/src/frontend/src/lib/workspace/backend-chat-event-adapter.ts b/src/frontend/src/lib/workspace/backend-chat-event-adapter.ts index d79a2d269..4e5f18162 100644 --- a/src/frontend/src/lib/workspace/backend-chat-event-adapter.ts +++ b/src/frontend/src/lib/workspace/backend-chat-event-adapter.ts @@ -1,4 +1,7 @@ -import type { ChatMessage, ChatRenderPart } from "@/lib/workspace/workspace-types"; +import type { + ChatMessage, + ChatRenderPart, +} from "@/lib/workspace/workspace-types"; import type { WsServerEvent, WsServerMessage } from "@/lib/rlm-api"; import { createLocalId } from "@/lib/id"; import { QueryClient } from "@tanstack/react-query"; @@ -63,7 +66,10 @@ function ensureStreamingAssistant(messages: ChatMessage[]): ChatMessage[] { ]; } -function appendAssistantToken(messages: ChatMessage[], token: string): ChatMessage[] { +function appendAssistantToken( + messages: ChatMessage[], + token: string, +): ChatMessage[] { if (!token) return messages; const withAssistant = ensureStreamingAssistant(messages); const idx = latestStreamingAssistantIndex(withAssistant); @@ -116,7 +122,10 @@ function finishReasoning(messages: ChatMessage[]): ChatMessage[] { return updated ? next : messages; } -function completeAssistant(messages: ChatMessage[], text: string): ChatMessage[] { +function completeAssistant( + messages: ChatMessage[], + text: string, +): ChatMessage[] { const idx = latestStreamingAssistantIndex(messages); if (idx >= 0) { @@ -152,7 +161,13 @@ function preferredFinalArtifactText(value: unknown): string | undefined { const record = asRecord(value); if (!record) return undefined; - for (const key of ["final_markdown", "summary", "text", "content", "message"]) { + for (const key of [ + "final_markdown", + "summary", + "text", + "content", + "message", + ]) { const candidate = asOptionalText(record[key]); if (candidate) return candidate; } @@ -165,30 +180,48 @@ function preferredFinalArtifactText(value: unknown): string | undefined { return undefined; } -function resolveFinalAssistantText(text: string, payload?: Record): string { - const preferred = preferredFinalArtifactText(payload?.final_artifact ?? payload?.finalArtifact); +function resolveFinalAssistantText( + text: string, + payload?: Record, +): string { + const preferred = preferredFinalArtifactText( + payload?.final_artifact ?? payload?.finalArtifact, + ); return preferred ?? text; } -function readGuardrailWarnings(payload: Record | undefined): string[] { +function readGuardrailWarnings( + payload: Record | undefined, +): string[] { const raw = payload?.guardrail_warnings; if (!Array.isArray(raw)) return []; - return raw.map((item) => (typeof item === "string" ? item.trim() : "")).filter(Boolean); + return raw + .map((item) => (typeof item === "string" ? item.trim() : "")) + .filter(Boolean); } function canonicalSummaryPayload( payload: Record | undefined, ): Record | undefined { - return asRecord(payload?.run_summary ?? payload?.runSummary ?? payload?.summary); + return asRecord( + payload?.run_summary ?? payload?.runSummary ?? payload?.summary, + ); } -function canonicalCompletionStatus(payload: Record | undefined): string { +function canonicalCompletionStatus( + payload: Record | undefined, +): string { const summary = canonicalSummaryPayload(payload); - return asOptionalText(summary?.status ?? payload?.status)?.toLowerCase() ?? ""; + return ( + asOptionalText(summary?.status ?? payload?.status)?.toLowerCase() ?? "" + ); } -function canonicalStepText(step: Record, fallback: string): string { +function canonicalStepText( + step: Record, + fallback: string, +): string { return ( asOptionalText(step.label) ?? asOptionalText(step.output) ?? @@ -197,6 +230,26 @@ function canonicalStepText(step: Record, fallback: string): str ); } +function routingStatusText( + text: string, + payload?: Record, +): string { + const selectedSkills = Array.isArray(payload?.selected_skills) + ? payload.selected_skills.map((item) => String(item)).filter(Boolean) + : []; + const routingDecision = asOptionalText(payload?.routing_decision); + const sourceUrl = asOptionalText(payload?.source_url); + if (selectedSkills.length === 0 && !routingDecision && !sourceUrl) + return text; + + const parts = [text.trim()].filter(Boolean); + if (routingDecision) parts.push(`route ${routingDecision}`); + if (selectedSkills.length > 0) + parts.push(`skills ${selectedSkills.join(", ")}`); + if (sourceUrl) parts.push(`source ${sourceUrl}`); + return parts.join(" | "); +} + function applyCanonicalExecutionStep( messages: ChatMessage[], text: string, @@ -205,7 +258,12 @@ function applyCanonicalExecutionStep( const step = asRecord(payload?.step); if (!step) { return { - messages: appendStatusTrace(messages, text || "Execution step received", "neutral", payload), + messages: appendStatusTrace( + messages, + routingStatusText(text || "Execution step received", payload), + "neutral", + payload, + ), terminal: false, errored: false, }; @@ -229,13 +287,21 @@ function applyCanonicalExecutionStep( } if (stepType === "llm") { const output = asRecord(step.output); - const token = typeof output?.text === "string" ? output.text : asOptionalText(step.output); + const token = + typeof output?.text === "string" + ? output.text + : asOptionalText(step.output); const reasoning = - typeof step.label === "string" ? step.label : asOptionalText(output?.reasoning ?? step.input); + typeof step.label === "string" + ? step.label + : asOptionalText(output?.reasoning ?? step.input); return { messages: token ? appendAssistantToken(messages, token) - : appendReasoningEvent(messages, reasoning ?? stepText, "live", { ...payload, ...step }), + : appendReasoningEvent(messages, reasoning ?? stepText, "live", { + ...payload, + ...step, + }), terminal: false, errored: false, }; @@ -255,7 +321,7 @@ function applyCanonicalExecutionStep( return { messages: appendStatusTrace( messages, - stepText || "Execution step received", + routingStatusText(stepText || "Execution step received", payload), "neutral", payload, ), @@ -279,26 +345,44 @@ function applyCanonicalExecutionCompleted( if (status === "failed" || status === "error") { let next = finishReasoning(messages); next = finalizeTraceParts(next); - next = appendSystem(next, `Backend error: ${text || "Unknown server error."}`); + next = appendSystem( + next, + `Backend error: ${text || "Unknown server error."}`, + ); return { messages: next, terminal: true, errored: true }; } - let next = completeAssistant(messages, resolveFinalAssistantText(text, payload)); + let next = completeAssistant( + messages, + resolveFinalAssistantText(text, payload), + ); next = finishReasoning(next); next = finalizeTraceParts(next); next = appendFinalTrajectoryThoughts(next, payload); + next = appendFinalTrajectoryToolRows(next, payload); const finalReasoning = - typeof payload?.final_reasoning === "string" ? payload.final_reasoning.trim() : ""; + typeof payload?.final_reasoning === "string" + ? payload.final_reasoning.trim() + : ""; if (finalReasoning) { - next = appendReasoningEvent(next, finalReasoning, "summary", payload, "final_reasoning"); + next = appendReasoningEvent( + next, + finalReasoning, + "summary", + payload, + "final_reasoning", + ); } next = attachFinalReferences(next, payload); const warnings = readGuardrailWarnings(payload); if (warnings.length > 0) { - next = appendSystem(next, `Guardrail warnings:\n- ${warnings.join("\n- ")}`); + next = appendSystem( + next, + `Guardrail warnings:\n- ${warnings.join("\n- ")}`, + ); } return { messages: next, terminal: true, errored: false }; @@ -389,7 +473,68 @@ function appendFinalTrajectoryThoughts( return steps.reduce((acc, step) => { if (!step.thought) return acc; - return appendReasoningEvent(acc, step.thought, "summary", payload, `thought_${step.index}`); + return appendReasoningEvent( + acc, + step.thought, + "summary", + payload, + `thought_${step.index}`, + ); + }, messages); +} + +function currentTurnMessages(messages: ChatMessage[]): ChatMessage[] { + const lastUserIndex = messages.findLastIndex( + (message) => message.type === "user", + ); + return lastUserIndex >= 0 ? messages.slice(lastUserIndex + 1) : messages; +} + +function hasLiveToolOrSandboxTraceForCurrentTurn( + messages: ChatMessage[], +): boolean { + return currentTurnMessages(messages).some( + (message) => + message.type === "trace" && + message.traceSource === "live" && + message.renderParts?.some( + (part) => part.kind === "tool" || part.kind === "sandbox", + ), + ); +} + +function appendFinalTrajectoryToolRows( + messages: ChatMessage[], + payload?: Record, +): ChatMessage[] { + if (hasLiveToolOrSandboxTraceForCurrentTurn(messages)) return messages; + + const steps = normalizeTrajectoryStepsFromFinalPayload(payload); + return steps.reduce((acc, step) => { + if (!step.toolName) return acc; + const stepPayload = { + ...payload, + tool_name: step.toolName, + tool_input: step.toolInput, + tool_args: step.toolInput, + tool_output: step.toolOutput, + output: step.toolOutput, + step_index: step.index, + step: { + type: step.toolName.toLowerCase().includes("repl") ? "repl" : "tool", + label: step.label, + input: step.toolInput, + output: step.toolOutput, + }, + }; + return appendToolLikePart( + acc, + step.toolOutput === undefined ? "tool_call" : "tool_result", + step.label, + stepPayload, + appendTracePart, + { traceSource: "summary" }, + ); }, messages); } @@ -411,7 +556,9 @@ function finalizeTraceParts(messages: ChatMessage[]): ChatMessage[] { items: part.items.map((it) => ({ ...it, completed: true })), }; case "task": - return part.status === "in_progress" ? { ...part, status: "completed" as const } : part; + return part.status === "in_progress" + ? { ...part, status: "completed" as const } + : part; case "tool": case "sandbox": return part.state === "running" || part.state === "input-streaming" @@ -432,7 +579,12 @@ function resolveHitlByMessageId( ): ChatMessage[] { let changed = false; const next = messages.map((msg) => { - if (changed || msg.id !== messageId || msg.type !== "hitl" || !msg.hitlData) { + if ( + changed || + msg.id !== messageId || + msg.type !== "hitl" || + !msg.hitlData + ) { return msg; } changed = true; @@ -448,10 +600,18 @@ function resolveHitlByMessageId( return changed ? next : messages; } -function rollbackHitlByMessageId(messages: ChatMessage[], messageId: string): ChatMessage[] { +function rollbackHitlByMessageId( + messages: ChatMessage[], + messageId: string, +): ChatMessage[] { let changed = false; const next = messages.map((msg) => { - if (changed || msg.id !== messageId || msg.type !== "hitl" || !msg.hitlData) { + if ( + changed || + msg.id !== messageId || + msg.type !== "hitl" || + !msg.hitlData + ) { return msg; } changed = true; @@ -467,12 +627,18 @@ function rollbackHitlByMessageId(messages: ChatMessage[], messageId: string): Ch return changed ? next : messages; } -function applyEvent(messages: ChatMessage[], frame: WsServerEvent): ApplyFrameResult { +function applyEvent( + messages: ChatMessage[], + frame: WsServerEvent, +): ApplyFrameResult { const { kind, text, payload } = frame.data; switch (kind) { case "execution_started": { - const normalizedPayload = { ...payload, phase: payload?.phase ?? "startup" }; + const normalizedPayload = { + ...payload, + phase: payload?.phase ?? "startup", + }; const sandboxPart = sandboxProgressPartFromStatus(normalizedPayload); if (sandboxPart) { return { @@ -499,8 +665,11 @@ function applyEvent(messages: ChatMessage[], frame: WsServerEvent): ApplyFrameRe const command = asOptionalText(payload?.command); const result = asRecord(payload?.result); const messageId = asOptionalText(result?.message_id ?? result?.messageId); - const resolution = asOptionalText(result?.resolution) ?? asOptionalText(result?.action_label); - const succeeded = asOptionalText(result?.status)?.toLowerCase() !== "error"; + const resolution = + asOptionalText(result?.resolution) ?? + asOptionalText(result?.action_label); + const succeeded = + asOptionalText(result?.status)?.toLowerCase() !== "error"; let next = messages; if (succeeded && command === "resolve_hitl" && messageId && resolution) { next = resolveHitlByMessageId(next, messageId, resolution); @@ -514,7 +683,8 @@ function applyEvent(messages: ChatMessage[], frame: WsServerEvent): ApplyFrameRe { kind: "status_note", tone: succeeded ? "success" : "error", - text: text || (succeeded ? "Action acknowledged" : "Action rejected"), + text: + text || (succeeded ? "Action acknowledged" : "Action rejected"), }, text || (succeeded ? "Action acknowledged" : "Action rejected"), ), @@ -537,7 +707,9 @@ export function applyWsFrameToMessages( _queryClient?: QueryClient, ): ApplyFrameResult { if (frame.type === "error") { - const next = finalizeTraceParts(appendSystem(messages, `Backend error: ${frame.message}`)); + const next = finalizeTraceParts( + appendSystem(messages, `Backend error: ${frame.message}`), + ); return { messages: finishReasoning(next), terminal: true, errored: true }; } diff --git a/src/frontend/src/lib/workspace/backend-chat-event-tool-parts.ts b/src/frontend/src/lib/workspace/backend-chat-event-tool-parts.ts index e73442f01..1eb4975a9 100644 --- a/src/frontend/src/lib/workspace/backend-chat-event-tool-parts.ts +++ b/src/frontend/src/lib/workspace/backend-chat-event-tool-parts.ts @@ -220,13 +220,22 @@ function sandboxFromPayload( payload?.step && typeof payload.step === "object" && !Array.isArray(payload.step) ? (payload.step as Record) : undefined; + const stepInput = asRecord(step?.input); + const stepOutput = asRecord(step?.output); const code = (typeof step?.input === "string" && step.input) || + asOptionalText(stepInput?.code) || + asOptionalText(stepInput?.code_preview) || + asOptionalText(stepInput?.command) || (typeof payload?.tool_input === "string" && payload.tool_input) || (typeof payload?.tool_args === "string" && payload.tool_args) || ""; const output = (typeof step?.output === "string" && step.output) || + asOptionalText(stepOutput?.stdout) || + asOptionalText(stepOutput?.stderr) || + asOptionalText(stepOutput?.output) || + asOptionalText(stepOutput?.result) || (typeof payload?.tool_output === "string" && payload.tool_output) || text; const state = inferToolState(kind, text, payload); diff --git a/src/frontend/src/lib/workspace/backend-chat-event-trajectory.ts b/src/frontend/src/lib/workspace/backend-chat-event-trajectory.ts index c85e1c0a5..9968ac6aa 100644 --- a/src/frontend/src/lib/workspace/backend-chat-event-trajectory.ts +++ b/src/frontend/src/lib/workspace/backend-chat-event-trajectory.ts @@ -35,7 +35,11 @@ function parseTrajectoryStepIndex( payload?: Record, stepData?: Record, ): number { - return asOptionalNumber(payload?.step_index) ?? asOptionalNumber(stepData?.index) ?? 0; + return ( + asOptionalNumber(payload?.step_index) ?? + asOptionalNumber(stepData?.index) ?? + 0 + ); } function normalizeTrajectoryStep( @@ -44,10 +48,16 @@ function normalizeTrajectoryStep( fallbackText?: string, ): NormalizedTrajectoryStep { const action = asOptionalText(raw.action); - const toolName = asOptionalText(raw.tool_name ?? raw.toolName); - const thought = asOptionalText(raw.thought) ?? asOptionalText(fallbackText); + const code = asOptionalText(raw.code); + const toolName = + asOptionalText(raw.tool_name ?? raw.toolName) ?? + (code ? "repl_execute" : undefined); + const thought = + asOptionalText(raw.thought) ?? + asOptionalText(raw.reasoning) ?? + asOptionalText(fallbackText); const toolInput = normalizeOptionalUnknown( - raw.tool_args ?? raw.input ?? raw.tool_input ?? raw.toolInput, + raw.tool_args ?? raw.input ?? raw.tool_input ?? raw.toolInput ?? code, ); const toolOutput = normalizeOptionalUnknown( raw.output ?? raw.observation ?? raw.tool_output ?? raw.toolOutput, @@ -171,6 +181,17 @@ export function normalizeTrajectoryStepsFromFinalPayload( const trajectoryRecord = asRecord(rawTrajectory); if (trajectoryRecord) { + const nestedSteps = trajectoryRecord.steps ?? trajectoryRecord.trajectory; + if (Array.isArray(nestedSteps)) { + return nestedSteps + .map((entry, idx) => { + const record = asRecord(entry); + if (!record) return null; + const index = asOptionalNumber(record.index) ?? idx; + return normalizeTrajectoryStep(record, index); + }) + .filter((step): step is NormalizedTrajectoryStep => step != null); + } return normalizeTrajectorySteps("", trajectoryRecord); } @@ -225,16 +246,22 @@ function summarizeTrajectoryValue(value: unknown): string | undefined { } } -export function trajectoryStepDetails(step: NormalizedTrajectoryStep): string[] { +export function trajectoryStepDetails( + step: NormalizedTrajectoryStep, +): string[] { const details: string[] = []; if (step.toolName) { details.push(`Tool · ${step.toolName}`); } if (step.toolInput !== undefined) { - details.push(`Input · ${summarizeTrajectoryValue(step.toolInput) ?? "Available"}`); + details.push( + `Input · ${summarizeTrajectoryValue(step.toolInput) ?? "Available"}`, + ); } if (step.toolOutput !== undefined) { - details.push(`Observation · ${summarizeTrajectoryValue(step.toolOutput) ?? "Available"}`); + details.push( + `Observation · ${summarizeTrajectoryValue(step.toolOutput) ?? "Available"}`, + ); } return details; } diff --git a/tests/unit/api/test_chat_persistence.py b/tests/unit/api/test_chat_persistence.py index afd2bc616..ab6e71f49 100644 --- a/tests/unit/api/test_chat_persistence.py +++ b/tests/unit/api/test_chat_persistence.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json from types import SimpleNamespace from typing import Any @@ -229,6 +230,37 @@ async def local_persist(**kwargs: Any) -> None: ] +@pytest.mark.asyncio +async def test_disconnect_can_leave_background_execution_running_without_stale_persist() -> None: + from fleet_rlm.api.runtime_services.chat_persistence import handle_chat_disconnect + + calls: list[dict[str, Any]] = [] + + async def local_persist(**kwargs: Any) -> None: + calls.append(dict(kwargs)) + + async def background() -> str: + await asyncio.sleep(0) + return "ok" + + stream_task = asyncio.create_task(background()) + cancel_flag: dict[str, bool] = {"cancelled": False} + + await handle_chat_disconnect( + pending_receive_task=None, + stream_task=stream_task, + cancel_flag=cancel_flag, + local_persist=local_persist, + lifecycle=None, + cancel_active_run=False, + persist_on_disconnect=False, + ) + await stream_task + + assert cancel_flag["cancelled"] is False + assert calls == [] + + @pytest.mark.asyncio async def test_stream_error_cleanup_disallows_volume_session_creation() -> None: from fleet_rlm.api.routers.ws.transport import handle_chat_loop_exception diff --git a/tests/unit/api/test_chat_runtime.py b/tests/unit/api/test_chat_runtime.py new file mode 100644 index 000000000..f589cbf19 --- /dev/null +++ b/tests/unit/api/test_chat_runtime.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from types import SimpleNamespace + +from fleet_rlm.api.runtime_services.chat_runtime import ( + PreparedChatRuntime, + _chat_agent_builder_kwargs, +) + + +def test_chat_agent_builder_kwargs_forwards_rlm_limits_from_server_config() -> None: + cfg = SimpleNamespace( + react_max_iters=15, + rlm_max_iterations=31, + rlm_max_llm_calls=17, + agent_max_output_chars=12_345, + ) + runtime = PreparedChatRuntime( + cfg=cfg, # type: ignore[arg-type] + planner_lm=object(), + delegate_lm=object(), + repository=object(), + persistence=None, + persistence_required=False, + identity_rows=None, + ) + + kwargs = _chat_agent_builder_kwargs(runtime) + + assert kwargs["react_max_iters"] == 15 + assert kwargs["rlm_max_iterations"] == 31 + assert kwargs["rlm_max_llm_calls"] == 17 + assert kwargs["rlm_max_output_chars"] == 12_345 diff --git a/tests/unit/api/test_events.py b/tests/unit/api/test_events.py index df010e551..48cf4216c 100644 --- a/tests/unit/api/test_events.py +++ b/tests/unit/api/test_events.py @@ -158,3 +158,33 @@ def test_backend_status_projects_to_canonical_execution_step_frame(): assert frame["kind"] == "execution_step" assert frame["payload"]["source_type"] == "status" + + +def test_runtime_trace_metadata_counts_structured_rlm_trajectory(): + stream_module = importlib.import_module("fleet_rlm.api.routers.ws.stream") + + metadata = stream_module._runtime_trace_metadata( + { + "routing_decision": "url_document_rlm", + "selected_skills": ["long-context"], + "source_url": "https://dspy.ai", + "trajectory": { + "steps": [ + { + "reasoning": "Inspect docs", + "code": "print(document_text[:80])", + "output": "DSPy docs", + } + ] + }, + } + ) + + assert metadata["fleet_rlm.routing_decision"] == "url_document_rlm" + assert metadata["fleet_rlm.selected_skills"] == "long-context" + assert metadata["fleet_rlm.source_url"] == "https://dspy.ai" + assert metadata["fleet_rlm.trajectory_steps"] == "1" + assert metadata["fleet_rlm.trajectory_has_reasoning"] == "true" + assert metadata["fleet_rlm.trajectory_has_tools"] == "true" + assert metadata["fleet_rlm.trajectory_has_repl"] == "true" + assert metadata["fleet_rlm.trajectory_has_outputs"] == "true" diff --git a/tests/unit/api/test_runtime_diagnostics.py b/tests/unit/api/test_runtime_diagnostics.py index 672cf20fb..d8d14fcae 100644 --- a/tests/unit/api/test_runtime_diagnostics.py +++ b/tests/unit/api/test_runtime_diagnostics.py @@ -20,3 +20,134 @@ def test_runtime_status_includes_daytona_slot_diagnostics() -> None: "available_slots": 5, "active_count": 0, } + + +def test_runtime_status_surfaces_persisted_mlflow_scorers(monkeypatch) -> None: + from fleet_rlm.api.dependencies import ConfigDeps, DiagnosticsDeps, LmDeps + from fleet_rlm.api.runtime_services import diagnostics + from fleet_rlm.api.runtime_services.diagnostics import build_runtime_status_response + + monkeypatch.setattr( + diagnostics, + "MlflowConfig", + type( + "FakeMlflowConfigFactory", + (), + { + "from_env": staticmethod( + lambda: type( + "FakeMlflowConfig", + (), + { + "enabled": True, + "enable_auto_assessment": False, + "tracking_uri": "http://127.0.0.1:5001", + }, + )() + ) + }, + ), + ) + monkeypatch.setattr( + "fleet_rlm.integrations.observability.auto_assessment.persisted_scorer_names", + lambda config: ["Trace Judge"], + ) + + response = build_runtime_status_response( + config_deps=ConfigDeps(), + lm_deps=LmDeps(), + diagnostics_deps=DiagnosticsDeps(), + ) + + assert response.mlflow["auto_assessment_enabled"] is False + assert response.mlflow["persisted_scorer_count"] == 1 + assert response.mlflow["persisted_scorers"] == ["Trace Judge"] + assert any("Trace Judge" in item for item in response.guidance) + + +def test_runtime_status_reconciles_stale_saturated_daytona_slots(monkeypatch) -> None: + from fleet_rlm.api.dependencies import ConfigDeps, DiagnosticsDeps, LmDeps + from fleet_rlm.api.runtime_services import diagnostics + from fleet_rlm.integrations.daytona.concurrency import SandboxUsageStats + + class FakeRuntime: + def _count_provider_fleet_sandboxes_sync(self) -> int: + return 0 + + def close(self) -> None: + return None + + monkeypatch.setattr( + diagnostics, + "get_current_sandbox_usage", + lambda: SandboxUsageStats(limit=5, available_slots=0, active_count=5), + ) + monkeypatch.setattr( + "fleet_rlm.integrations.daytona.runtime.DaytonaSandboxRuntime", + lambda: FakeRuntime(), + ) + monkeypatch.setattr( + diagnostics, + "reconcile_sandbox_slots", + lambda provider_active_count: SandboxUsageStats( + limit=5, + available_slots=5, + active_count=provider_active_count, + ), + ) + + response = diagnostics.build_runtime_status_response( + config_deps=ConfigDeps(), + lm_deps=LmDeps(), + diagnostics_deps=DiagnosticsDeps(), + ) + + assert response.daytona["sandbox_slots"] == { + "limit": 5, + "available_slots": 5, + "active_count": 0, + } + + +def test_runtime_status_reconciles_stale_partial_daytona_slots(monkeypatch) -> None: + from fleet_rlm.api.dependencies import ConfigDeps, DiagnosticsDeps, LmDeps + from fleet_rlm.api.runtime_services import diagnostics + from fleet_rlm.integrations.daytona.concurrency import SandboxUsageStats + + class FakeRuntime: + def _count_provider_fleet_sandboxes_sync(self) -> int: + return 0 + + def close(self) -> None: + return None + + monkeypatch.setattr( + diagnostics, + "get_current_sandbox_usage", + lambda: SandboxUsageStats(limit=5, available_slots=3, active_count=2), + ) + monkeypatch.setattr( + "fleet_rlm.integrations.daytona.runtime.DaytonaSandboxRuntime", + lambda: FakeRuntime(), + ) + monkeypatch.setattr( + diagnostics, + "reconcile_sandbox_slots", + lambda provider_active_count: SandboxUsageStats( + limit=5, + available_slots=5, + active_count=provider_active_count, + ), + ) + + response = diagnostics.build_runtime_status_response( + config_deps=ConfigDeps(), + lm_deps=LmDeps(), + diagnostics_deps=DiagnosticsDeps(), + ) + + assert response.daytona["sandbox_slots"] == { + "limit": 5, + "available_slots": 5, + "active_count": 0, + } diff --git a/tests/unit/api/test_sandboxes.py b/tests/unit/api/test_sandboxes.py index e6f53053e..8cec8114c 100644 --- a/tests/unit/api/test_sandboxes.py +++ b/tests/unit/api/test_sandboxes.py @@ -92,6 +92,47 @@ async def fake_close(client: Any) -> None: assert [item.id for item in response.items] == ["owned"] +@pytest.mark.asyncio +async def test_delete_sandbox_releases_fleet_slot(monkeypatch: pytest.MonkeyPatch) -> None: + from fleet_rlm.api.runtime_services import sandboxes + from fleet_rlm.utils.sandbox_ownership import SANDBOX_OWNER_LABEL + + released = 0 + + class FakeSandbox: + id = "sbx-1" + labels = {"managed-by": "fleet-rlm", SANDBOX_OWNER_LABEL: "tenant:user"} + + def stop(self, **kwargs: Any) -> None: + _ = kwargs + + def delete(self) -> None: + return None + + class FakeClient: + def get(self, sandbox_id: str) -> FakeSandbox: + assert sandbox_id == "sbx-1" + return FakeSandbox() + + async def fake_close(client: Any) -> None: + _ = client + + def fake_release() -> None: + nonlocal released + released += 1 + + monkeypatch.setattr(sandboxes, "_build_daytona_client", lambda: FakeClient()) + monkeypatch.setattr(sandboxes, "_close_daytona_client", fake_close) + monkeypatch.setattr(sandboxes, "release_sandbox_slot", fake_release) + + await sandboxes.delete_sandbox( + "sbx-1", + owner_labels={SANDBOX_OWNER_LABEL: "tenant:user"}, + ) + + assert released == 1 + + @pytest.mark.asyncio async def test_sandbox_service_maps_generic_daytona_error_to_503(monkeypatch: pytest.MonkeyPatch) -> None: from daytona import DaytonaError diff --git a/tests/unit/api/test_session_service.py b/tests/unit/api/test_session_service.py new file mode 100644 index 000000000..0d63bcede --- /dev/null +++ b/tests/unit/api/test_session_service.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from types import SimpleNamespace + +from fleet_rlm.api.runtime_services.session_service import SessionService + + +def test_list_session_state_counts_exported_turns_schema() -> None: + service = SessionService(persistence=None) + identity = SimpleNamespace(tenant_claim="tenant", user_claim="user") + session_cache = { + "owner:abc:session-1": { + "owner_tenant_claim": "tenant", + "owner_user_claim": "user", + "workspace_id": "default", + "user_id": "anonymous", + "session_id": "session-1", + "session": { + "state": { + "turns": [ + {"user_message": "remember HISTORY_CHECK", "response": "ok"}, + {"user_message": "what marker?", "response": "HISTORY_CHECK"}, + ] + } + }, + } + } + + response = service.list_session_state(session_cache=session_cache, identity=identity) + + assert response.ok is True + assert len(response.sessions) == 1 + assert response.sessions[0].history_turns == 2 diff --git a/tests/unit/api/test_volume_services.py b/tests/unit/api/test_volume_services.py new file mode 100644 index 000000000..e533f8d9c --- /dev/null +++ b/tests/unit/api/test_volume_services.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import pytest +from fastapi import HTTPException + + +def test_volume_service_accepts_all_daytona_canonical_roots() -> None: + from fleet_rlm.api.runtime_services.volumes import CANONICAL_VOLUME_ROOTS, normalize_volume_tree_path + from fleet_rlm.integrations.daytona.volumes import VFS_CANONICAL_ROOTS + + assert set(CANONICAL_VOLUME_ROOTS) == VFS_CANONICAL_ROOTS + for root in VFS_CANONICAL_ROOTS: + assert normalize_volume_tree_path(root) == root + + +def test_volume_service_rejects_unknown_root() -> None: + from fleet_rlm.api.runtime_services.volumes import normalize_volume_tree_path + + with pytest.raises(HTTPException) as exc_info: + normalize_volume_tree_path("/tmp") + + assert exc_info.value.status_code == 403 diff --git a/tests/unit/cli/test_mlflow_cli.py b/tests/unit/cli/test_mlflow_cli.py index 43b96d921..fc94eecf1 100644 --- a/tests/unit/cli/test_mlflow_cli.py +++ b/tests/unit/cli/test_mlflow_cli.py @@ -13,7 +13,12 @@ def _load_cli_module(): return importlib.import_module("scripts.mlflow_cli") -def _fake_mlflow_module(*, scorers: list[object] | None = None, deleted: list[dict[str, object]] | None = None): +def _fake_mlflow_module( + *, + scorers: list[object] | None = None, + deleted: list[dict[str, object]] | None = None, + scorer: object | None = None, +): module = ModuleType("mlflow") module.tracking_uris = [] # type: ignore[attr-defined] @@ -32,9 +37,25 @@ def delete_scorer(*, name: str, experiment_id: str | None = None, version: int | if deleted is not None: deleted.append({"name": name, "experiment_id": experiment_id, "version": version}) + def get_scorer(*, name: str, experiment_id: str | None = None, version: int | str | None = None): + module.get_scorer_args = {"name": name, "experiment_id": experiment_id, "version": version} # type: ignore[attr-defined] + return scorer + + class ScorerSamplingConfig: + def __init__(self, *, sample_rate: float | None = None, filter_string: str | None = None) -> None: + self.sample_rate = sample_rate + self.filter_string = filter_string + module.set_tracking_uri = set_tracking_uri # type: ignore[attr-defined] module.set_experiment = set_experiment # type: ignore[attr-defined] - module.genai = SimpleNamespace(list_scorers=list_scorers, delete_scorer=delete_scorer) # type: ignore[attr-defined] + scorers_module = ModuleType("mlflow.genai.scorers") + scorers_module.ScorerSamplingConfig = ScorerSamplingConfig # type: ignore[attr-defined] + genai_module = ModuleType("mlflow.genai") + genai_module.list_scorers = list_scorers # type: ignore[attr-defined] + genai_module.get_scorer = get_scorer # type: ignore[attr-defined] + genai_module.delete_scorer = delete_scorer # type: ignore[attr-defined] + genai_module.scorers = scorers_module # type: ignore[attr-defined] + module.genai = genai_module # type: ignore[attr-defined] return module @@ -97,3 +118,74 @@ def test_scorers_delete_calls_mlflow_delete_scorer( assert deleted == [{"name": "Trace Judge", "experiment_id": "exp-override", "version": "7"}] assert "deleted_scorer=Trace Judge" in output assert "experiment_id=exp-override" in output + + +def test_scorers_stop_calls_registered_scorer_stop( + monkeypatch: pytest.MonkeyPatch, + clean_runtime_env: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture[str], +) -> None: + apply_mlflow_env(clean_runtime_env) + stopped: list[dict[str, object]] = [] + + scorer = SimpleNamespace( + stop=lambda *, name, experiment_id: stopped.append({"name": name, "experiment_id": experiment_id}) + ) + fake_mlflow = _fake_mlflow_module(scorer=scorer) + monkeypatch.setitem(sys.modules, "mlflow", fake_mlflow) + + cli = _load_cli_module() + result = cli.do_scorers_stop(SimpleNamespace(name="Trace Judge", experiment_id=None)) + + output = capsys.readouterr().out + assert result == 0 + assert stopped == [{"name": "Trace Judge", "experiment_id": "exp-active"}] + assert "stopped_scorer=Trace Judge" in output + + +def test_scorers_start_calls_registered_scorer_start( + monkeypatch: pytest.MonkeyPatch, + clean_runtime_env: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture[str], +) -> None: + apply_mlflow_env(clean_runtime_env) + started: list[dict[str, object]] = [] + + def start(*, name, experiment_id, sampling_config) -> None: + started.append( + { + "name": name, + "experiment_id": experiment_id, + "sample_rate": sampling_config.sample_rate, + "filter_string": sampling_config.filter_string, + } + ) + + scorer = SimpleNamespace(start=start) + fake_mlflow = _fake_mlflow_module(scorer=scorer) + monkeypatch.setitem(sys.modules, "mlflow", fake_mlflow) + monkeypatch.setitem(sys.modules, "mlflow.genai", fake_mlflow.genai) # type: ignore[attr-defined] + monkeypatch.setitem(sys.modules, "mlflow.genai.scorers", fake_mlflow.genai.scorers) # type: ignore[attr-defined] + + cli = _load_cli_module() + result = cli.do_scorers_start( + SimpleNamespace( + name="Trace Judge", + experiment_id="exp-override", + sample_rate=0.5, + filter_string="status = 'OK'", + ) + ) + + output = capsys.readouterr().out + assert result == 0 + assert started == [ + { + "name": "Trace Judge", + "experiment_id": "exp-override", + "sample_rate": 0.5, + "filter_string": "status = 'OK'", + } + ] + assert "started_scorer=Trace Judge" in output + assert "sample_rate=0.5" in output diff --git a/tests/unit/integrations/test_daytona_concurrency.py b/tests/unit/integrations/test_daytona_concurrency.py index 391775f97..20c1f812e 100644 --- a/tests/unit/integrations/test_daytona_concurrency.py +++ b/tests/unit/integrations/test_daytona_concurrency.py @@ -14,6 +14,7 @@ acquire_sandbox_slot = concurrency.acquire_sandbox_slot attach_slot_release_handler = concurrency.attach_slot_release_handler get_current_sandbox_usage = concurrency.get_current_sandbox_usage +reconcile_sandbox_slots = concurrency.reconcile_sandbox_slots release_sandbox_slot = concurrency.release_sandbox_slot release_sandbox_slot_for = concurrency.release_sandbox_slot_for @@ -149,6 +150,50 @@ async def test_get_usage_after_acquire() -> None: release_sandbox_slot() +@pytest.mark.asyncio +async def test_reconcile_sandbox_slots_resets_stale_active_count() -> None: + with patch.dict("os.environ", {"FLEET_MAX_CONCURRENT_SANDBOXES": "5"}): + for _ in range(5): + await acquire_sandbox_slot(timeout=1.0) + stale_semaphore = concurrency._GLOBAL_SEMAPHORE + + reconciled = reconcile_sandbox_slots(provider_active_count=0) + + assert reconciled.limit == 5 + assert reconciled.active_count == 0 + assert reconciled.available_slots == 5 + assert concurrency._GLOBAL_SEMAPHORE is not stale_semaphore + result = await acquire_sandbox_slot(timeout=1.0) + assert result is True + release_sandbox_slot() + + +@pytest.mark.asyncio +async def test_reconcile_sandbox_slots_clamps_provider_count_to_limit() -> None: + with patch.dict("os.environ", {"FLEET_MAX_CONCURRENT_SANDBOXES": "2"}): + await acquire_sandbox_slot(timeout=1.0) + + reconciled = reconcile_sandbox_slots(provider_active_count=20) + + assert reconciled.limit == 2 + assert reconciled.active_count == 2 + assert reconciled.available_slots == 0 + + +@pytest.mark.asyncio +async def test_reconciled_slots_can_release_back_to_original_limit() -> None: + with patch.dict("os.environ", {"FLEET_MAX_CONCURRENT_SANDBOXES": "3"}): + reconciled = reconcile_sandbox_slots(provider_active_count=2) + + assert reconciled.available_slots == 1 + + release_sandbox_slot() + release_sandbox_slot() + usage = get_current_sandbox_usage() + assert usage.available_slots == 3 + assert usage.active_count == 0 + + # --------------------------------------------------------------------------- # Slot release handler # --------------------------------------------------------------------------- diff --git a/tests/unit/integrations/test_daytona_runtime.py b/tests/unit/integrations/test_daytona_runtime.py index 6e3c1a313..6ce9c629e 100644 --- a/tests/unit/integrations/test_daytona_runtime.py +++ b/tests/unit/integrations/test_daytona_runtime.py @@ -155,6 +155,100 @@ def fake_create(self: DaytonaSandboxRuntime, spec: SandboxSpec): assert concurrency.get_current_sandbox_usage().active_count == 0 +@pytest.mark.asyncio +async def test_sandbox_create_reconciles_stale_slots_and_retries(monkeypatch: pytest.MonkeyPatch) -> None: + from fleet_rlm.integrations.daytona import runtime as runtime_module + from fleet_rlm.integrations.daytona.models import SandboxSpec + from fleet_rlm.integrations.daytona.runtime import DaytonaSandboxRuntime + + sandbox = SimpleNamespace(delete=MagicMock(), stop=MagicMock()) + acquire_calls = 0 + reconciled_counts: list[int] = [] + + async def fake_acquire(*, timeout: float | None = None) -> bool: + nonlocal acquire_calls + _ = timeout + acquire_calls += 1 + if acquire_calls == 1: + raise asyncio.TimeoutError + return True + + async def fake_provider_count(self: DaytonaSandboxRuntime) -> int: + _ = self + return 0 + + def fake_reconcile(provider_active_count: int): + reconciled_counts.append(provider_active_count) + return concurrency.SandboxUsageStats(limit=5, available_slots=5, active_count=0) + + def fake_create(self: DaytonaSandboxRuntime, spec: SandboxSpec): + _ = (self, spec) + return sandbox + + monkeypatch.setattr(runtime_module, "acquire_sandbox_slot", fake_acquire) + monkeypatch.setattr( + runtime_module, + "get_current_sandbox_usage", + lambda: concurrency.SandboxUsageStats(limit=5, available_slots=0, active_count=5), + ) + monkeypatch.setattr(runtime_module, "reconcile_sandbox_slots", fake_reconcile) + monkeypatch.setattr(DaytonaSandboxRuntime, "_count_provider_fleet_sandboxes", fake_provider_count) + monkeypatch.setattr(runtime_module, "aresolve_sandbox_spec_snapshot", lambda spec, config: spec) + monkeypatch.setattr(DaytonaSandboxRuntime, "_create_sandbox_from_spec_impl", fake_create) + + result = await _sandbox_runtime().acreate_sandbox(spec=SandboxSpec(name="custom-sandbox")) + + assert result is sandbox + assert acquire_calls == 2 + assert reconciled_counts == [0] + + +@pytest.mark.asyncio +async def test_sandbox_create_keeps_busy_error_when_provider_reconcile_fails( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from fleet_rlm.integrations.daytona import runtime as runtime_module + from fleet_rlm.integrations.daytona.errors import DaytonaDiagnosticError + from fleet_rlm.integrations.daytona.models import SandboxSpec + from fleet_rlm.integrations.daytona.runtime import DaytonaSandboxRuntime + + async def fake_acquire(*, timeout: float | None = None) -> bool: + _ = timeout + raise asyncio.TimeoutError + + async def fake_provider_count(self: DaytonaSandboxRuntime) -> int: + _ = self + raise RuntimeError("provider unavailable") + + reconcile = MagicMock() + + monkeypatch.setattr(runtime_module, "acquire_sandbox_slot", fake_acquire) + monkeypatch.setattr(runtime_module, "reconcile_sandbox_slots", reconcile) + monkeypatch.setattr(DaytonaSandboxRuntime, "_count_provider_fleet_sandboxes", fake_provider_count) + + with pytest.raises(DaytonaDiagnosticError, match="Sandbox concurrency limit reached") as exc_info: + await _sandbox_runtime().acreate_sandbox(spec=SandboxSpec(name="custom-sandbox")) + + assert exc_info.value.category == "sandbox_concurrency_busy" + reconcile.assert_not_called() + + +def test_provider_fleet_sandbox_count_filters_labels_and_inactive_states() -> None: + class FakeClient: + def list(self, *, labels: dict[str, str]): + assert labels == {"managed-by": "fleet-rlm"} + return [ + SimpleNamespace(labels={"managed-by": "fleet-rlm"}, state="started"), + SimpleNamespace(labels={"managed-by": "fleet-rlm"}, state="archived"), + SimpleNamespace(labels={"managedBy": "fleet-pi"}, state="started"), + ] + + runtime = _sandbox_runtime() + runtime._client = FakeClient() + + assert runtime._count_provider_fleet_sandboxes_sync() == 1 + + @pytest.mark.asyncio async def test_async_sandbox_create_does_not_block_event_loop(monkeypatch: pytest.MonkeyPatch) -> None: from fleet_rlm.integrations.daytona import runtime as runtime_module @@ -346,6 +440,121 @@ def test_daytona_volume_browser_allows_durable_phase_roots() -> None: } <= VFS_CANONICAL_ROOTS +def test_daytona_volume_browser_reports_all_canonical_roots(monkeypatch: pytest.MonkeyPatch) -> None: + from types import SimpleNamespace + + from fleet_rlm.integrations.daytona import file_browser + from fleet_rlm.integrations.daytona.volumes import VFS_CANONICAL_ROOTS + + class _Sandbox: + fs = SimpleNamespace(list_files=lambda path: []) + + class _MountedVolume: + def __enter__(self) -> _Sandbox: + return _Sandbox() + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + monkeypatch.setattr(file_browser, "_mounted_daytona_volume", lambda volume_name: _MountedVolume()) + + tree = file_browser.list_daytona_volume_tree("volume", root_path="/", max_depth=1) + + assert tree["allowed_roots"] == sorted(VFS_CANONICAL_ROOTS) + + +def test_daytona_volume_browser_filters_root_to_canonical_roots(monkeypatch: pytest.MonkeyPatch) -> None: + from types import SimpleNamespace + + from fleet_rlm.integrations.daytona import file_browser + + class _Sandbox: + def __init__(self) -> None: + self.fs = SimpleNamespace(list_files=self._list_files) + + def _list_files(self, path: str) -> list[SimpleNamespace]: + if path == str(file_browser.DAYTONA_PERSISTENT_VOLUME_MOUNT_PATH): + return [ + SimpleNamespace(name="artifacts", is_dir=True, size=None, mod_time=None), + SimpleNamespace(name="workspace", is_dir=True, size=None, mod_time=None), + SimpleNamespace(name="scratch.txt", is_dir=False, size=12, mod_time=None), + ] + return [] + + class _MountedVolume: + def __enter__(self) -> _Sandbox: + return _Sandbox() + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + monkeypatch.setattr(file_browser, "_mounted_daytona_volume", lambda volume_name: _MountedVolume()) + + tree = file_browser.list_daytona_volume_tree("volume", root_path="/", max_depth=1) + root_children = tree["nodes"][0]["children"] + + assert [child["path"] for child in root_children] == ["/artifacts"] + assert tree["total_dirs"] == 1 + assert tree["total_files"] == 0 + + +def test_build_browser_snapshot_image_includes_playwright_install() -> None: + from unittest.mock import MagicMock + + from fleet_rlm.integrations.daytona import snapshots + + mock_image = MagicMock() + mock_image.run_commands.return_value = mock_image + + fake_daytona_image = MagicMock() + fake_daytona_image.base.return_value = mock_image + + with patch.dict("sys.modules", {"daytona": SimpleNamespace(Image=fake_daytona_image)}): + result = snapshots.build_browser_snapshot_image(include_vnc=False) + + assert result is mock_image + run_commands_calls = mock_image.run_commands.call_args_list + # Should have: system deps, pip install uv, uv pip install packages, playwright install + assert len(run_commands_calls) == 4 + system_call = run_commands_calls[0][0][0] + assert "libx11-6" in system_call + assert "libnss3" in system_call + assert "xvfb" not in system_call # VNC excluded + playwright_call = run_commands_calls[3][0][0] + assert "playwright install chromium" in playwright_call + + +def test_build_browser_snapshot_image_includes_vnc_when_enabled() -> None: + from unittest.mock import MagicMock + + from fleet_rlm.integrations.daytona import snapshots + + mock_image = MagicMock() + mock_image.run_commands.return_value = mock_image + + fake_daytona_image = MagicMock() + fake_daytona_image.base.return_value = mock_image + + with patch.dict("sys.modules", {"daytona": SimpleNamespace(Image=fake_daytona_image)}): + snapshots.build_browser_snapshot_image(include_vnc=True) + + system_call = mock_image.run_commands.call_args_list[0][0][0] + assert "xvfb" in system_call + assert "novnc" in system_call + + +def test_resolve_snapshot_for_skills_returns_browser_snapshot() -> None: + from fleet_rlm.integrations.daytona.runtime import resolve_snapshot_for_skills + from fleet_rlm.integrations.daytona.snapshots import BROWSER_SNAPSHOT_NAME, DEFAULT_SNAPSHOT_NAME + + assert resolve_snapshot_for_skills(None) == DEFAULT_SNAPSHOT_NAME + assert resolve_snapshot_for_skills([]) == DEFAULT_SNAPSHOT_NAME + assert resolve_snapshot_for_skills(["long-context"]) == DEFAULT_SNAPSHOT_NAME + assert resolve_snapshot_for_skills(["browser_interaction"]) == BROWSER_SNAPSHOT_NAME + assert resolve_snapshot_for_skills(["long-context", "browser_interaction"]) == BROWSER_SNAPSHOT_NAME + assert resolve_snapshot_for_skills(["Playwright automation"]) == BROWSER_SNAPSHOT_NAME + + def test_store_evidence_redacts_credentials_from_bridge_errors() -> None: from fleet_rlm.integrations.daytona.isolation import store_evidence diff --git a/tests/unit/integrations/test_daytona_sandbox_executor.py b/tests/unit/integrations/test_daytona_sandbox_executor.py new file mode 100644 index 000000000..48a19e1f9 --- /dev/null +++ b/tests/unit/integrations/test_daytona_sandbox_executor.py @@ -0,0 +1,198 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest +from dspy.primitives import CodeInterpreterError + + +def test_bridge_tools_can_disable_semantic_callbacks() -> None: + from fleet_rlm.integrations.daytona.bridge_callbacks import bridge_tools + + interpreter = SimpleNamespace( + _tools={}, + semantic_callbacks_enabled=False, + sub_rlm=lambda prompt: prompt, + llm_query=lambda prompt: prompt, + llm_query_batched=lambda prompts: prompts, + ) + + tools = bridge_tools(interpreter) + + assert "llm_query" not in tools + assert "llm_query_batched" not in tools + assert "sub_rlm" in tools + assert "fetch_document_text" in tools + + +def test_broker_start_failure_latches_and_blocks_immediate_retry() -> None: + from fleet_rlm.integrations.daytona.sandbox_executor import ( + _BROKER_START_FAILURES, + ExecutionCallbacks, + run_prepared_execution, + ) + + _BROKER_START_FAILURES.clear() + ensure_calls = 0 + session = SimpleNamespace(sandbox_id="sandbox-1") + owner = SimpleNamespace( + execute_timeout=30, + timeout=30, + _bridge_start_error=None, + _invoke_tool=lambda name, args, kwargs: None, + ) + + def fail_ensure_bridge(**kwargs: Any) -> None: + nonlocal ensure_calls + _ = kwargs + ensure_calls += 1 + raise CodeInterpreterError("Broker server failed to start within timeout") + + callbacks = ExecutionCallbacks( + bridge_tools=lambda: {"llm_query": lambda prompt: prompt}, + reject_recursive_callbacks=lambda code: None, + requires_bridge=lambda code, tools: True, + ensure_bridge=fail_ensure_bridge, + execute_direct=lambda **kwargs: pytest.fail("bridge-required code should not execute directly"), + response_from_execution=lambda execution: pytest.fail("execution should fail before response conversion"), + ) + + with pytest.raises(CodeInterpreterError, match="Broker server failed to start"): + run_prepared_execution( + owner, + session=session, + context=object(), + code="print(llm_query('x'))", + callbacks=callbacks, + ) + + assert ensure_calls == 1 + assert owner._bridge_start_error == "Broker server failed to start within timeout" + + with pytest.raises(CodeInterpreterError, match="llm_query should not be retried"): + run_prepared_execution( + owner, + session=session, + context=object(), + code="print(llm_query('again'))", + callbacks=callbacks, + ) + + assert ensure_calls == 1 + _BROKER_START_FAILURES.clear() + + +def test_broker_start_failure_cache_blocks_retry_across_executor_instances() -> None: + from fleet_rlm.integrations.daytona.sandbox_executor import ( + _BROKER_START_FAILURES, + ExecutionCallbacks, + run_prepared_execution, + ) + + _BROKER_START_FAILURES.clear() + ensure_calls = 0 + session = SimpleNamespace(sandbox_id="sandbox-1") + first_owner = SimpleNamespace( + execute_timeout=30, + timeout=30, + _bridge_start_error=None, + _invoke_tool=lambda name, args, kwargs: None, + ) + second_owner = SimpleNamespace( + execute_timeout=30, + timeout=30, + _bridge_start_error=None, + _invoke_tool=lambda name, args, kwargs: None, + ) + + def fail_ensure_bridge(**kwargs: Any) -> None: + nonlocal ensure_calls + _ = kwargs + ensure_calls += 1 + raise CodeInterpreterError("Broker server failed to start within timeout") + + callbacks = ExecutionCallbacks( + bridge_tools=lambda: {"llm_query": lambda prompt: prompt}, + reject_recursive_callbacks=lambda code: None, + requires_bridge=lambda code, tools: True, + ensure_bridge=fail_ensure_bridge, + execute_direct=lambda **kwargs: pytest.fail("bridge-required code should not execute directly"), + response_from_execution=lambda execution: pytest.fail("execution should fail before response conversion"), + ) + + with pytest.raises(CodeInterpreterError, match="Broker server failed to start"): + run_prepared_execution( + first_owner, + session=session, + context=object(), + code="print(llm_query('x'))", + callbacks=callbacks, + ) + + with pytest.raises(CodeInterpreterError, match="llm_query should not be retried"): + run_prepared_execution( + second_owner, + session=session, + context=object(), + code="print(llm_query('again'))", + callbacks=callbacks, + ) + + assert ensure_calls == 1 + _BROKER_START_FAILURES.clear() + + +def test_broker_start_failure_cooldown_allows_same_owner_retry() -> None: + from fleet_rlm.integrations.daytona.sandbox_executor import ( + _BROKER_START_FAILURE_COOLDOWN_SECONDS, + _BROKER_START_FAILURES, + ExecutionCallbacks, + run_prepared_execution, + ) + + class FakeBridge: + def execute_tool_call(self, **kwargs: Any) -> dict[str, Any]: + return {"status": "ok", "timeout": kwargs["timeout"]} + + _BROKER_START_FAILURES.clear() + session = SimpleNamespace(sandbox_id="sandbox-1") + owner = SimpleNamespace( + execute_timeout=30, + timeout=30, + _bridge_start_error="Broker server failed to start within timeout", + _invoke_tool=lambda name, args, kwargs: None, + ) + _BROKER_START_FAILURES["sandbox-1"] = ( + 0.0 - _BROKER_START_FAILURE_COOLDOWN_SECONDS - 1.0, + "Broker server failed to start within timeout", + ) + ensure_calls = 0 + + def ensure_bridge(**kwargs: Any) -> FakeBridge: + nonlocal ensure_calls + _ = kwargs + ensure_calls += 1 + return FakeBridge() + + callbacks = ExecutionCallbacks( + bridge_tools=lambda: {"llm_query": lambda prompt: prompt}, + reject_recursive_callbacks=lambda code: None, + requires_bridge=lambda code, tools: True, + ensure_bridge=ensure_bridge, + execute_direct=lambda **kwargs: pytest.fail("bridge-required code should not execute directly"), + response_from_execution=lambda execution: pytest.fail("execution should not be converted here"), + ) + + result = run_prepared_execution( + owner, + session=session, + context=object(), + code="print(llm_query('again'))", + callbacks=callbacks, + ) + + assert result == {"status": "ok", "timeout": 30} + assert ensure_calls == 1 + assert owner._bridge_start_error is None + assert _BROKER_START_FAILURES == {} diff --git a/tests/unit/integrations/test_mlflow_context.py b/tests/unit/integrations/test_mlflow_context.py new file mode 100644 index 000000000..eb9bb7631 --- /dev/null +++ b/tests/unit/integrations/test_mlflow_context.py @@ -0,0 +1,360 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + + +def test_update_current_mlflow_trace_mirrors_fleet_metadata_to_tags(monkeypatch) -> None: + from fleet_rlm.integrations.observability import mlflow_context + from fleet_rlm.integrations.observability.mlflow_context import ( + MlflowTraceRequestContext, + mlflow_request_context, + ) + + captured: list[dict[str, Any]] = [] + + fake_mlflow = SimpleNamespace( + get_current_active_span=lambda: object(), + get_active_trace_id=lambda: "tr-test", + update_current_trace=lambda **kwargs: captured.append(kwargs), + ) + monkeypatch.setattr( + mlflow_context, + "_runtime_module", + lambda: SimpleNamespace( + _import_mlflow=lambda: fake_mlflow, + get_mlflow_config=lambda: SimpleNamespace(active_model_id=None), + flush_mlflow_traces=lambda: None, + logger=SimpleNamespace(debug=lambda *args, **kwargs: None), + ), + ) + + with mlflow_request_context( + MlflowTraceRequestContext( + client_request_id="chat-123", + session_id="workspace:user:session", + user_id="user", + app_env="local", + request_preview="hello", + metadata={"fleet_rlm.routing_decision": "url_document_rlm"}, + ) + ): + mlflow_context.update_current_mlflow_trace(trace_metadata={"fleet_rlm.source_url": "https://dspy.ai"}) + + update = captured[0] + assert update["tags"]["fleet_rlm.trace_kind"] == "application" + assert update["tags"]["fleet_rlm.client_request_id"] == "chat-123" + assert update["tags"]["fleet_rlm.session_id"] == "workspace:user:session" + assert update["tags"]["fleet_rlm.routing_decision"] == "url_document_rlm" + assert update["tags"]["fleet_rlm.source_url"] == "https://dspy.ai" + + +def test_update_current_mlflow_trace_does_not_resend_active_trace_tags(monkeypatch) -> None: + from fleet_rlm.integrations.observability import mlflow_context + from fleet_rlm.integrations.observability.mlflow_context import ( + MlflowTraceRequestContext, + mlflow_request_context, + ) + + captured: list[dict[str, Any]] = [] + + fake_mlflow = SimpleNamespace( + get_current_active_span=lambda: object(), + get_active_trace_id=lambda: "tr-test", + update_current_trace=lambda **kwargs: captured.append(kwargs), + ) + monkeypatch.setattr( + mlflow_context, + "_runtime_module", + lambda: SimpleNamespace( + _import_mlflow=lambda: fake_mlflow, + get_mlflow_config=lambda: SimpleNamespace(active_model_id=None), + flush_mlflow_traces=lambda: None, + logger=SimpleNamespace(debug=lambda *args, **kwargs: None), + ), + ) + + with mlflow_request_context( + MlflowTraceRequestContext( + client_request_id="chat-123", + session_id="workspace:user:session", + user_id="user", + metadata={"fleet_rlm.routing_decision": "url_document_rlm"}, + ) + ): + mlflow_context.update_current_mlflow_trace(trace_metadata={"fleet_rlm.source_url": "https://dspy.ai"}) + mlflow_context.update_current_mlflow_trace(trace_metadata={"fleet_rlm.source_url": "https://dspy.ai"}) + + assert captured[0]["tags"]["fleet_rlm.routing_decision"] == "url_document_rlm" + assert captured[0]["tags"]["fleet_rlm.source_url"] == "https://dspy.ai" + assert captured[1]["tags"] is None + + +def test_update_current_mlflow_trace_sets_tags_on_resolved_inactive_trace(monkeypatch) -> None: + from fleet_rlm.integrations.observability import mlflow_context + from fleet_rlm.integrations.observability.mlflow_context import ( + MlflowTraceRequestContext, + mlflow_request_context, + ) + + captured_tags: list[tuple[str, str, str]] = [] + + fake_mlflow = SimpleNamespace( + get_current_active_span=lambda: None, + get_active_trace_id=lambda: None, + get_last_active_trace_id=lambda thread_local=True: "tr-inactive", + set_trace_tag=lambda trace_id, key, value: captured_tags.append((trace_id, key, value)), + ) + monkeypatch.setattr( + mlflow_context, + "_runtime_module", + lambda: SimpleNamespace( + _import_mlflow=lambda: fake_mlflow, + get_mlflow_config=lambda: SimpleNamespace(active_model_id=None), + flush_mlflow_traces=lambda: None, + logger=SimpleNamespace(debug=lambda *args, **kwargs: None), + ), + ) + + with mlflow_request_context( + MlflowTraceRequestContext( + client_request_id="chat-123", + session_id="workspace:user:session", + user_id="user", + metadata={"fleet_rlm.routing_decision": "url_document_rlm"}, + ) + ): + trace_id = mlflow_context.update_current_mlflow_trace( + trace_metadata={"fleet_rlm.source_url": "https://dspy.ai"} + ) + + assert trace_id == "tr-inactive" + assert ("tr-inactive", "fleet_rlm.client_request_id", "chat-123") in captured_tags + assert ("tr-inactive", "fleet_rlm.routing_decision", "url_document_rlm") in captured_tags + assert ("tr-inactive", "fleet_rlm.source_url", "https://dspy.ai") in captured_tags + + +def test_update_current_mlflow_trace_resolves_completed_trace_by_client_request_id(monkeypatch) -> None: + from fleet_rlm.integrations.observability import mlflow_context, mlflow_traces + from fleet_rlm.integrations.observability.mlflow_context import ( + MlflowTraceRequestContext, + mlflow_request_context, + ) + + captured_tags: list[tuple[str, str, str]] = [] + + fake_mlflow = SimpleNamespace( + get_current_active_span=lambda: None, + get_active_trace_id=lambda: None, + get_last_active_trace_id=lambda thread_local=True: None, + set_trace_tag=lambda trace_id, key, value: captured_tags.append((trace_id, key, value)), + ) + fake_config = SimpleNamespace(active_model_id=None) + monkeypatch.setattr( + mlflow_context, + "_runtime_module", + lambda: SimpleNamespace( + _import_mlflow=lambda: fake_mlflow, + get_mlflow_config=lambda: fake_config, + flush_mlflow_traces=lambda: None, + logger=SimpleNamespace(debug=lambda *args, **kwargs: None), + ), + ) + monkeypatch.setattr( + mlflow_traces, + "resolve_trace_by_client_request_id", + lambda client_request_id, config, max_results: SimpleNamespace( + info=SimpleNamespace(trace_id="tr-resolved", client_request_id=client_request_id) + ), + ) + + with mlflow_request_context( + MlflowTraceRequestContext( + client_request_id="chat-456", + session_id="workspace:user:session", + user_id="user", + ) + ): + trace_id = mlflow_context.update_current_mlflow_trace(trace_metadata={"fleet_rlm.trajectory_steps": "1"}) + + assert trace_id == "tr-resolved" + assert ("tr-resolved", "fleet_rlm.client_request_id", "chat-456") in captured_tags + assert ("tr-resolved", "fleet_rlm.trajectory_steps", "1") in captured_tags + + +def test_mlflow_request_context_reapplies_final_metadata_after_flush(monkeypatch) -> None: + from fleet_rlm.integrations.observability import mlflow_context, mlflow_traces + from fleet_rlm.integrations.observability.mlflow_context import ( + MlflowTraceRequestContext, + mlflow_request_context, + ) + + captured_tags: list[tuple[str, str, str]] = [] + flush_count = 0 + + fake_mlflow = SimpleNamespace( + get_current_active_span=lambda: None, + get_active_trace_id=lambda: None, + get_last_active_trace_id=lambda thread_local=True: None, + set_trace_tag=lambda trace_id, key, value: captured_tags.append((trace_id, key, value)), + ) + + def flush_mlflow_traces() -> None: + nonlocal flush_count + flush_count += 1 + + monkeypatch.setattr( + mlflow_context, + "_runtime_module", + lambda: SimpleNamespace( + _import_mlflow=lambda: fake_mlflow, + get_mlflow_config=lambda: SimpleNamespace(active_model_id=None), + flush_mlflow_traces=flush_mlflow_traces, + logger=SimpleNamespace(debug=lambda *args, **kwargs: None), + ), + ) + + def resolve_after_first_flush(client_request_id: str, config: Any, max_results: int) -> Any | None: + _ = config, max_results + if flush_count == 0: + return None + return SimpleNamespace(info=SimpleNamespace(trace_id="tr-after-flush", client_request_id=client_request_id)) + + monkeypatch.setattr(mlflow_traces, "resolve_trace_by_client_request_id", resolve_after_first_flush) + + with mlflow_request_context(MlflowTraceRequestContext(client_request_id="chat-789")): + mlflow_context.update_current_mlflow_trace( + response_preview="final", + trace_metadata={"fleet_rlm.routing_decision": "url_document_rlm"}, + ) + + assert flush_count == 2 + assert ("tr-after-flush", "fleet_rlm.client_request_id", "chat-789") in captured_tags + assert ("tr-after-flush", "fleet_rlm.routing_decision", "url_document_rlm") in captured_tags + + +def test_record_rlm_trajectory_spans_materializes_repl_steps(monkeypatch) -> None: + from fleet_rlm.integrations.observability import mlflow_context + + captured: list[dict[str, Any]] = [] + + class FakeSpan: + def __init__(self, name: str, span_type: str | None, attributes: dict[str, Any] | None) -> None: + self.record = {"name": name, "span_type": span_type, "attributes": attributes or {}, "status": "OK"} + + def __enter__(self) -> "FakeSpan": + captured.append(self.record) + return self + + def __exit__(self, *args: object) -> None: + return None + + def set_inputs(self, inputs: Any) -> None: + self.record["inputs"] = inputs + + def set_outputs(self, outputs: Any) -> None: + self.record["outputs"] = outputs + + def set_status(self, status: str) -> None: + self.record["status"] = status + + fake_mlflow = SimpleNamespace( + get_current_active_span=lambda: object(), + start_span=lambda name, span_type=None, attributes=None: FakeSpan(name, span_type, attributes), + ) + monkeypatch.setattr( + mlflow_context, + "_runtime_module", + lambda: SimpleNamespace( + _import_mlflow=lambda: fake_mlflow, + logger=SimpleNamespace(debug=lambda *args, **kwargs: None), + ), + ) + + recorded = mlflow_context.record_rlm_trajectory_spans( + [ + { + "reasoning": "Inspect the document text.", + "code": "print(document_text[:100])", + "output": "DSPy docs", + } + ] + ) + + assert recorded == 1 + assert captured[0]["name"] == "rlm_available_tools" + assert captured[0]["span_type"] == "LLM" + assert captured[0]["inputs"]["tools"][0]["function"]["name"] == "repl_execute" + assert "repl_execute" in captured[0]["attributes"]["mlflow.chat.tools"] + assert captured[1]["name"] == "repl_execute" + assert captured[1]["span_type"] == "TOOL" + assert captured[1]["inputs"]["code"] == "print(document_text[:100])" + assert captured[1]["outputs"]["output"] == "DSPy docs" + assert captured[1]["attributes"]["fleet_rlm.trajectory_has_code"] == "true" + assert captured[1]["status"] == "OK" + + +def test_record_rlm_trajectory_spans_marks_error_outputs(monkeypatch) -> None: + from fleet_rlm.integrations.observability import mlflow_context + + captured: list[dict[str, Any]] = [] + + class FakeSpan: + def __init__(self, name: str, span_type: str | None, attributes: dict[str, Any] | None) -> None: + self.record = {"name": name, "span_type": span_type, "attributes": attributes or {}, "status": "OK"} + + def __enter__(self) -> "FakeSpan": + captured.append(self.record) + return self + + def __exit__(self, *args: object) -> None: + return None + + def set_inputs(self, inputs: Any) -> None: + self.record["inputs"] = inputs + + def set_outputs(self, outputs: Any) -> None: + self.record["outputs"] = outputs + + def set_status(self, status: str) -> None: + self.record["status"] = status + + fake_mlflow = SimpleNamespace( + get_current_active_span=lambda: object(), + start_span=lambda name, span_type=None, attributes=None: FakeSpan(name, span_type, attributes), + ) + monkeypatch.setattr( + mlflow_context, + "_runtime_module", + lambda: SimpleNamespace( + _import_mlflow=lambda: fake_mlflow, + logger=SimpleNamespace(debug=lambda *args, **kwargs: None), + ), + ) + + recorded = mlflow_context.record_rlm_trajectory_spans( + [{"code": "SUBMIT('unterminated)", "output": "[Error] unterminated string literal"}] + ) + + assert recorded == 1 + assert captured[1]["status"] == "ERROR" + assert captured[1]["attributes"]["fleet_rlm.trajectory_error"] == "true" + + +def test_record_rlm_trajectory_spans_skips_without_active_trace(monkeypatch) -> None: + from fleet_rlm.integrations.observability import mlflow_context + + fake_mlflow = SimpleNamespace( + get_current_active_span=lambda: None, + get_active_trace_id=lambda: None, + ) + monkeypatch.setattr( + mlflow_context, + "_runtime_module", + lambda: SimpleNamespace( + _import_mlflow=lambda: fake_mlflow, + logger=SimpleNamespace(debug=lambda *args, **kwargs: None), + ), + ) + + assert mlflow_context.record_rlm_trajectory_spans([{"code": "print('x')"}]) == 0 diff --git a/tests/unit/integrations/test_mlflow_traces.py b/tests/unit/integrations/test_mlflow_traces.py new file mode 100644 index 000000000..4ae3c223e --- /dev/null +++ b/tests/unit/integrations/test_mlflow_traces.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + + +class _FakeTrace: + def __init__(self, assessments: list[dict[str, Any]]) -> None: + self._assessments = assessments + + def to_dict(self) -> dict[str, Any]: + return { + "info": { + "trace_id": "tr-1", + "request_preview": "request", + "response_preview": "response", + "trace_metadata": {}, + } + } + + def search_assessments(self) -> list[Any]: + return [SimpleNamespace(to_dictionary=lambda item=item: item) for item in self._assessments] + + def search_spans(self) -> list[Any]: + return [] + + +def test_trace_to_dataset_row_skips_disabled_persisted_scorer_feedback(monkeypatch) -> None: + from fleet_rlm.integrations.observability import mlflow_traces + from fleet_rlm.integrations.observability.config import MlflowConfig + + monkeypatch.setattr(mlflow_traces, "_disabled_persisted_scorer_names", lambda config: {"Trace Judge"}) + + row = mlflow_traces.trace_to_dataset_row( + _FakeTrace( + [ + { + "assessment_name": "Trace Judge", + "source": {"source_type": "LLM_JUDGE", "source_id": "gateway:/gemini"}, + "feedback": {"value": "yes"}, + "rationale": "No tools were called.", + }, + { + "assessment_name": "response_is_correct", + "source": {"source_type": "HUMAN", "source_id": "user"}, + "feedback": {"value": True}, + "rationale": "Looks right.", + }, + ] + ), + config=MlflowConfig(enable_auto_assessment=False), + ) + + assert "Trace Judge" not in row.get("feedback", {}) + assert row["feedback"]["response_is_correct"]["value"] is True + assert row["skipped_feedback"] == [ + { + "assessment_name": "Trace Judge", + "source_type": "LLM_JUDGE", + "source_id": "gateway:/gemini", + "reason": "persisted_scorer_while_fleet_auto_assessment_disabled", + } + ] + + +def test_search_traces_uses_locations_when_supported() -> None: + from fleet_rlm.integrations.observability.mlflow_traces import _search_traces + + captured: dict[str, Any] = {} + + def search_traces( + *, + locations: list[str] | None = None, + experiment_ids: list[str] | None = None, + **kwargs: Any, + ) -> list[Any]: + kwargs["locations"] = locations + kwargs["experiment_ids"] = experiment_ids + captured.update(kwargs) + return [] + + _search_traces( + SimpleNamespace(search_traces=search_traces), + experiment_ids=["1"], + max_results=5, + return_type="list", + include_spans=False, + ) + + assert captured["locations"] == ["1"] + assert captured["experiment_ids"] is None diff --git a/tests/unit/integrations/test_observability.py b/tests/unit/integrations/test_observability.py index 4a201e001..2abb1db7a 100644 --- a/tests/unit/integrations/test_observability.py +++ b/tests/unit/integrations/test_observability.py @@ -1,5 +1,6 @@ from __future__ import annotations +import builtins import logging import sys import threading @@ -115,7 +116,11 @@ def test_warn_if_persisted_scorers_active_logs_actionable_warning( def list_scorers(*, experiment_id: str | None = None) -> list[object]: list_calls.append(experiment_id) - return [SimpleNamespace(name="Trace Judge"), {"scorer_name": "retired-scorer"}] + return [ + SimpleNamespace(name="Trace Judge", status="STARTED", sample_rate=1.0), + {"scorer_name": "retired-scorer", "status": "STARTED", "sample_rate": 0.25}, + SimpleNamespace(name="stopped-scorer", status="STOPPED", sample_rate=0.0), + ] fake_mlflow = SimpleNamespace( genai=SimpleNamespace(list_scorers=list_scorers), @@ -132,10 +137,92 @@ def list_scorers(*, experiment_id: str | None = None) -> list[object]: assert list_calls == ["exp-123"] assert "Trace Judge" in caplog.text assert "retired-scorer" in caplog.text + assert "stopped-scorer" not in caplog.text assert "scripts/mlflow_cli.py scorers list" in caplog.text assert "FLEET_RLM_ENABLE_AUTO_ASSESSMENT" in caplog.text +def test_persisted_scorer_names_uses_short_cache() -> None: + from fleet_rlm.integrations.observability import auto_assessment + from fleet_rlm.integrations.observability.auto_assessment import persisted_scorer_names + from fleet_rlm.integrations.observability.config import MlflowConfig + + auto_assessment._PERSISTED_SCORER_CACHE = None + calls: list[str | None] = [] + + def list_scorers(*, experiment_id: str | None = None) -> list[object]: + calls.append(experiment_id) + return [SimpleNamespace(name="Trace Judge", status="STARTED", sample_rate=1.0)] + + fake_mlflow = SimpleNamespace( + genai=SimpleNamespace(list_scorers=list_scorers), + get_experiment_by_name=lambda name: SimpleNamespace(experiment_id="exp-123"), + ) + config = MlflowConfig(enable_auto_assessment=False, experiment="fleet-rlm-test") + + assert persisted_scorer_names(config, mlflow=fake_mlflow) == ["Trace Judge"] + assert persisted_scorer_names(config, mlflow=fake_mlflow) == ["Trace Judge"] + assert calls == ["exp-123"] + auto_assessment._PERSISTED_SCORER_CACHE = None + + +def test_persisted_scorer_names_ignores_stopped_scorers() -> None: + from fleet_rlm.integrations.observability import auto_assessment + from fleet_rlm.integrations.observability.auto_assessment import persisted_scorer_names + from fleet_rlm.integrations.observability.config import MlflowConfig + + auto_assessment._PERSISTED_SCORER_CACHE = None + + def list_scorers(*, experiment_id: str | None = None) -> list[object]: + return [ + SimpleNamespace(name="active", status="ScorerStatus.STARTED", sample_rate=1.0), + SimpleNamespace(name="stopped", status="ScorerStatus.STOPPED", sample_rate=0.0), + ] + + fake_mlflow = SimpleNamespace( + genai=SimpleNamespace(list_scorers=list_scorers), + get_experiment_by_name=lambda name: SimpleNamespace(experiment_id="exp-123"), + ) + + names = persisted_scorer_names( + MlflowConfig(enable_auto_assessment=False, experiment="fleet-rlm-test"), + mlflow=fake_mlflow, + cache_seconds=0, + ) + + assert names == ["active"] + + +def test_import_mlflow_clears_partial_import_and_retries(monkeypatch: pytest.MonkeyPatch) -> None: + from fleet_rlm.integrations.observability import mlflow_runtime + + calls = 0 + real_import = builtins.__import__ + partial_mlflow = types.ModuleType("mlflow") + partial_child = types.ModuleType("mlflow.genai") + good_mlflow = types.ModuleType("mlflow") + good_mlflow.version = SimpleNamespace(VERSION="3.12.0") + + def fake_import(name: str, *args: object, **kwargs: object) -> object: + nonlocal calls + if name != "mlflow": + return real_import(name, *args, **kwargs) + calls += 1 + if calls == 1: + sys.modules["mlflow"] = partial_mlflow + sys.modules["mlflow.genai"] = partial_child + return partial_mlflow + sys.modules["mlflow"] = good_mlflow + return good_mlflow + + monkeypatch.setattr(builtins, "__import__", fake_import) + + assert mlflow_runtime._import_mlflow() is good_mlflow + assert calls == 2 + assert sys.modules["mlflow"] is good_mlflow + assert "mlflow.genai" not in sys.modules + + def test_warn_if_persisted_scorers_active_skips_when_auto_assessment_enabled() -> None: from fleet_rlm.integrations.observability.auto_assessment import warn_if_persisted_scorers_active from fleet_rlm.integrations.observability.config import MlflowConfig diff --git a/tests/unit/runtime/test_escalating_module.py b/tests/unit/runtime/test_escalating_module.py index 0a28d1d86..90734881e 100644 --- a/tests/unit/runtime/test_escalating_module.py +++ b/tests/unit/runtime/test_escalating_module.py @@ -9,7 +9,11 @@ import pytest from fleet_rlm.runtime.factory import ESCALATING_RUNTIME_ENV_VAR, build_chat_agent -from fleet_rlm.runtime.modules.escalating import ESCALATION_SENTINEL, EscalatingFleetModule +from fleet_rlm.runtime.modules.escalating import ( + ESCALATION_SENTINEL, + EscalatingFleetModule, + _build_rlm_prompt_context, +) class _FakePrediction(dspy.Prediction): @@ -49,7 +53,80 @@ def __call__(self, **kwargs: Any) -> dspy.Prediction: return self.prediction +class _PreviewPosthocAgent(_PosthocAgent): + def preview_routing(self, *, user_request: str, execution_mode: str = "auto") -> dict[str, Any]: + _ = user_request, execution_mode + return { + "routing_decision": "url_document_rlm", + "source_url": "https://dspy.ai", + } + + class TestEscalatingFleetModule: + def test_url_document_rlm_is_bounded_and_disables_child_tools( + self, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + from fleet_rlm.runtime.modules import variable_mode + + calls: list[dict[str, Any]] = [] + + def fake_build_variable_mode_rlm(**kwargs: Any) -> MagicMock: + calls.append(kwargs) + return MagicMock() + + monkeypatch.setattr(variable_mode, "build_variable_mode_rlm", fake_build_variable_mode_rlm) + + EscalatingFleetModule( + interpreter=object(), + tools=[lambda: None], + max_iterations=20, + max_llm_calls=50, + ) + + url_call = calls[1] + assert url_call["max_iterations"] == 4 + assert url_call["max_llm_calls"] == 8 + assert url_call["extra_tools"] == [] + assert url_call["include_sub_tools"] is False + assert url_call["include_llm_tools"] is False + + def test_url_document_prompt_tells_rlm_semantic_callbacks_are_disabled(self) -> None: + prompt = _build_rlm_prompt_context( + user_request="analyze https://dspy.ai docs", + recent_history="", + compressed_history="", + core_memory="", + url_document_mode=True, + ) + + assert "llm_query and llm_query_batched are disabled" in prompt + assert "synthesize from Python inspection" in prompt + assert "llm_query" in prompt + + def test_escalating_module_passes_max_output_chars_to_rlm_wrappers( + self, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + from fleet_rlm.runtime.modules import variable_mode + + calls: list[dict[str, Any]] = [] + + def fake_build_variable_mode_rlm(**kwargs: Any) -> MagicMock: + calls.append(kwargs) + return MagicMock() + + monkeypatch.setattr(variable_mode, "build_variable_mode_rlm", fake_build_variable_mode_rlm) + + EscalatingFleetModule( + interpreter=object(), + tools=[], + max_output_chars=12_345, + ) + + assert calls[0]["max_output_chars"] == 12_345 + assert calls[1]["max_output_chars"] == 12_345 + def test_cot_path_taken_when_no_signal(self) -> None: module = _make_module() _stub_respond(module, reasoning="Just thinking carefully.", response="Here is the answer.") @@ -57,6 +134,25 @@ def test_cot_path_taken_when_no_signal(self) -> None: module.respond.assert_called_once() assert getattr(result, "assistant_response", None) == "Here is the answer." + def test_cot_path_passes_recency_ordered_history_context(self) -> None: + module = _make_module() + _stub_respond(module, reasoning="Just thinking carefully.", response="Here is the answer.") + history = dspy.History( + messages=[ + {"user_message": "remember OLD_MARKER", "response": "OLD_MARKER"}, + {"user_message": "remember NEW_MARKER", "response": "NEW_MARKER"}, + ] + ) + + module(user_request="What marker did I just ask you to remember?", execution_mode="auto", history=history) + + call_kwargs = module.respond.call_args.kwargs + assert call_kwargs["history"] is history + assert "OLD_MARKER" in call_kwargs["recent_history"] + assert "NEW_MARKER" in call_kwargs["recent_history"] + assert call_kwargs["recent_history"].rfind("NEW_MARKER") > call_kwargs["recent_history"].rfind("OLD_MARKER") + assert "most recent prior turn" in call_kwargs["recent_history"] + def test_rlm_path_triggered_by_sentinel_in_reasoning(self) -> None: module = _make_module() _stub_respond(module, reasoning=f"I need external data {ESCALATION_SENTINEL}", response="step1") @@ -103,6 +199,94 @@ def test_execution_mode_rlm_only_skips_cot(self) -> None: module.respond.assert_not_called() assert getattr(result, "answer", None) == "rlm_only_mode" + def test_url_document_analysis_auto_routes_to_rlm(self) -> None: + module = _make_module() + _stub_respond(module) + rlm_pred = _FakePrediction(answer="doc analysis") + module._rlm = MagicMock(return_value=rlm_pred) + _stub_summarize(module) + + result = module( + user_request="analyze https://dspy.ai and provide an in depth analysis of the documentation", + execution_mode="auto", + ) + + module.respond.assert_not_called() + module._rlm.assert_called_once() + assert getattr(result, "answer", None) == "doc analysis" + assert result["routing_decision"] == "url_document_rlm" + assert result["source_url"] == "https://dspy.ai" + + def test_url_document_analysis_passes_fetched_doc_as_rlm_variables( + self, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + module = _make_module(interpreter=object()) + _stub_respond(module) + rlm_pred = _FakePrediction(answer="doc analysis") + module._url_document_rlm = MagicMock(return_value=rlm_pred) + _stub_summarize(module) + + monkeypatch.setattr( + "fleet_rlm.runtime.tools.document_tools.fetch_document_text", + lambda url: { + "status": "ok", + "text": "# DSPy docs\nRLM details", + "char_count": 23, + "metadata": {"source_type": "html"}, + }, + ) + + module( + user_request="analyze https://dspy.ai and provide documentation notes", + execution_mode="auto", + ) + + module._url_document_rlm.assert_called_once() + call_kwargs = module._url_document_rlm.call_args.kwargs + assert call_kwargs["source_url"] == "https://dspy.ai" + assert call_kwargs["document_text"] == "# DSPy docs\nRLM details" + assert call_kwargs["source_metadata"] == { + "status": "ok", + "char_count": "23", + "source_type": "html", + } + assert "# DSPy docs\nRLM details" not in call_kwargs["prompt"] + assert call_kwargs["prompt"].startswith( + "Task:\nanalyze https://dspy.ai and provide documentation notes" + ) + assert "URL document variables" in call_kwargs["prompt"] + assert call_kwargs["prompt"].endswith( + "Repeat task:\nanalyze https://dspy.ai and provide documentation notes" + ) + + def test_tools_only_does_not_auto_route_url_to_rlm(self) -> None: + module = _make_module() + _stub_respond(module, response="tool path") + module._rlm = MagicMock(return_value=_FakePrediction(answer="should not run")) + + result = module( + user_request="analyze https://dspy.ai and provide documentation notes", + execution_mode="tools_only", + ) + + module.respond.assert_called_once() + module._rlm.assert_not_called() + assert getattr(result, "assistant_response", None) == "tool path" + + def test_preview_routing_surfaces_url_document_route_before_execution(self) -> None: + module = _make_module() + + preview = module.preview_routing( + user_request="analyze https://dspy.ai and summarize the docs", + execution_mode="auto", + ) + + assert preview == { + "routing_decision": "url_document_rlm", + "source_url": "https://dspy.ai", + } + def test_rlm_fallback_to_cot_on_error(self) -> None: module = _make_module() cot_pred = _FakePrediction(reasoning=ESCALATION_SENTINEL, assistant_response="cot_resp") @@ -204,6 +388,64 @@ async def test_posthoc_stream_surfaces_rlm_fallback_warning(self, monkeypatch: p assert done.payload["execution_mode"] == "auto" assert done.text == "fallback answer" + @pytest.mark.asyncio + async def test_posthoc_stream_surfaces_rlm_code_trajectory(self, monkeypatch: pytest.MonkeyPatch) -> None: + _disable_runtime_tool_discovery(monkeypatch) + from fleet_rlm.runtime.agent.runtime import AgentRuntime + + rt = AgentRuntime(use_escalation=True) + rt.agent = _PosthocAgent( + _FakePrediction( + answer="analysis complete", + selected_skills=["long-context"], + routing_decision="url_document_rlm", + source_url="https://dspy.ai", + trajectory=[ + { + "reasoning": "Fetch and inspect the docs page.", + "code": "import urllib.request\nprint('docs')", + "output": "docs", + } + ], + ) + ) + + events = [event async for event in rt.aiter_chat_turn_stream("analyze https://dspy.ai")] + + status = next(event for event in events if event.payload.get("selected_skills") == ["long-context"]) + reasoning = next(event for event in events if event.kind == "reasoning") + repl_call = next(event for event in events if event.kind == "tool_call") + repl_result = next(event for event in events if event.kind == "tool_result") + done = events[-1] + + assert "long-context" in status.text + assert reasoning.text == "Fetch and inspect the docs page." + assert repl_call.payload["tool_name"] == "repl_execute" + assert "urllib.request" in repl_call.payload["tool_input"] + assert repl_result.payload["tool_output"] == "docs" + assert done.payload["routing_decision"] == "url_document_rlm" + assert done.payload["source_url"] == "https://dspy.ai" + + @pytest.mark.asyncio + async def test_posthoc_stream_emits_routing_preview_before_result( + self, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + _disable_runtime_tool_discovery(monkeypatch) + from fleet_rlm.runtime.agent.runtime import AgentRuntime + + rt = AgentRuntime(use_escalation=True) + rt.agent = _PreviewPosthocAgent(_FakePrediction(answer="analysis complete")) + + events = [event async for event in rt.aiter_chat_turn_stream("analyze https://dspy.ai")] + + assert events[0].text == "Starting turn..." + assert events[1].payload == { + "routing_decision": "url_document_rlm", + "source_url": "https://dspy.ai", + } + assert "url_document_rlm" in events[1].text + class TestBuildChatAgentRuntimeDefault: def test_build_chat_agent_defaults_to_escalating_module(self, monkeypatch: pytest.MonkeyPatch) -> None: @@ -230,3 +472,17 @@ def test_build_chat_agent_explicit_toggle_overrides_env(self, monkeypatch: pytes rt = build_chat_agent(planner_lm=object(), use_escalation=True) assert isinstance(rt.agent, EscalatingFleetModule) + + def test_build_chat_agent_forwards_rlm_limits_to_runtime(self, monkeypatch: pytest.MonkeyPatch) -> None: + _disable_runtime_tool_discovery(monkeypatch) + + rt = build_chat_agent( + planner_lm=object(), + rlm_max_iterations=9, + rlm_max_llm_calls=11, + rlm_max_output_chars=12_345, + ) + + assert rt.rlm_max_iterations == 9 + assert rt.rlm_max_llm_calls == 11 + assert rt.rlm_max_output_chars == 12_345 diff --git a/tests/unit/runtime/test_modules.py b/tests/unit/runtime/test_modules.py index d2d290d14..7bba43b7f 100644 --- a/tests/unit/runtime/test_modules.py +++ b/tests/unit/runtime/test_modules.py @@ -27,9 +27,33 @@ def fake_create_runtime_rlm(**kwargs: Any) -> MagicMock: assert captured["max_llm_calls"] == 13 assert captured["max_output_chars"] == variable_mode.VARIABLE_MODE_MAX_OUTPUT_CHARS assert captured["tools"] == [interpreter.sub_rlm, interpreter.sub_rlm_batched] + assert captured["include_llm_tools"] is True assert module is not None +def test_variable_mode_module_can_disable_recursive_sub_tools(monkeypatch) -> None: + import dspy + + from fleet_rlm.runtime.modules import variable_mode + + interpreter = SimpleNamespace(sub_rlm=lambda prompt: prompt, sub_rlm_batched=lambda prompts: prompts) + captured: dict[str, Any] = {} + + def fake_create_runtime_rlm(**kwargs: Any) -> MagicMock: + captured.update(kwargs) + return MagicMock(spec=dspy.Module) + + monkeypatch.setattr(variable_mode, "create_runtime_rlm", fake_create_runtime_rlm) + + variable_mode.RLMVariableExecutionModule( + interpreter=interpreter, + include_sub_tools=False, + extra_tools=[], + ) + + assert captured["tools"] is None + + def test_variable_mode_forward_preserves_signature_kwargs(monkeypatch) -> None: import dspy @@ -49,6 +73,62 @@ def test_variable_mode_forward_preserves_signature_kwargs(monkeypatch) -> None: assert result.coverage_pct == 90 +def test_variable_mode_forward_scopes_disabled_semantic_callbacks(monkeypatch) -> None: + import dspy + + from fleet_rlm.runtime.modules import variable_mode + + interpreter = SimpleNamespace(semantic_callbacks_enabled=True) + observed: dict[str, Any] = {} + + def fake_create_runtime_rlm(**kwargs: Any) -> MagicMock: + observed["include_llm_tools"] = kwargs["include_llm_tools"] + + def _run(**call_kwargs: Any) -> dspy.Prediction: + observed["semantic_callbacks_enabled_during_call"] = interpreter.semantic_callbacks_enabled + observed["adapter_during_call"] = dspy.settings.adapter + observed["call_kwargs"] = call_kwargs + return dspy.Prediction(answer="done") + + return MagicMock(side_effect=_run) + + monkeypatch.setattr(variable_mode, "create_runtime_rlm", fake_create_runtime_rlm) + + module = variable_mode.RLMVariableExecutionModule( + interpreter=interpreter, + include_llm_tools=False, + ) + result = module(task="inspect document") + + assert result.answer == "done" + assert observed["include_llm_tools"] is False + assert observed["semantic_callbacks_enabled_during_call"] is False + assert isinstance(observed["adapter_during_call"], dspy.JSONAdapter) + assert interpreter.semantic_callbacks_enabled is True + + +def test_create_runtime_rlm_without_llm_tools_removes_callback_instructions() -> None: + import dspy + + from fleet_rlm.runtime.agent.signatures import RLMVariableSignature + from fleet_rlm.runtime.modules.factory import create_runtime_rlm + + rlm = create_runtime_rlm( + signature=RLMVariableSignature, + interpreter=SimpleNamespace(), + max_iterations=2, + max_llm_calls=3, + verbose=False, + include_llm_tools=False, + ) + + instructions = rlm.generate_action.signature.instructions + assert "`llm_query(prompt)`" not in instructions + assert "`llm_query_batched(prompts)`" not in instructions + assert "semantic callbacks are disabled" in instructions + assert isinstance(rlm, dspy.Module) + + def test_build_variable_mode_rlm_returns_wrapper(monkeypatch) -> None: import dspy diff --git a/tests/unit/runtime/test_skill_selection.py b/tests/unit/runtime/test_skill_selection.py index 43aeb14f3..d8515f8e9 100644 --- a/tests/unit/runtime/test_skill_selection.py +++ b/tests/unit/runtime/test_skill_selection.py @@ -7,6 +7,20 @@ from fleet_rlm.runtime.modules.skill_selection import SkillSelectionModule +def test_browser_interaction_skill_is_cataloged_and_keyword_selected() -> None: + from fleet_rlm.runtime.modules.skill_selection import AVAILABLE_SKILLS + + module = SkillSelectionModule() + module._load_skills = MagicMock(return_value="[Skill: browser-interaction]\nInstructions") + + result = module(user_request="Use playwright to inspect this javascript page") + + assert "browser-interaction" in AVAILABLE_SKILLS + assert result.selected_skills == ["browser-interaction"] + assert result.skill_context == "[Skill: browser-interaction]\nInstructions" + module._load_skills.assert_called_once_with(["browser-interaction"]) + + def test_skill_selection_no_keyword_match_skips_llm_selector() -> None: module = SkillSelectionModule() module.select = MagicMock(side_effect=AssertionError("selector should not be called")) diff --git a/tests/unit/runtime/test_tools.py b/tests/unit/runtime/test_tools.py index 8485c0f66..ca0d0e3c9 100644 --- a/tests/unit/runtime/test_tools.py +++ b/tests/unit/runtime/test_tools.py @@ -341,6 +341,58 @@ def test_chunk_document_and_load_document_helpers_use_text_and_directories(tmp_p } +def test_fetch_document_text_bundles_root_document_auxiliary_indexes( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from fleet_rlm.runtime.tools import document_tools + + reads: list[str] = [] + + def fake_read_remote_document_text(url: str) -> tuple[str, dict[str, Any]]: + reads.append(url) + if url.endswith("/llms.txt"): + return "# LLM docs index\nAPI Reference", {"source_type": "text"} + if url.endswith("/sitemap.xml"): + return "https://docs.example/api", {"source_type": "xml"} + return "# Docs home\nWelcome", {"source_type": "html"} + + monkeypatch.setattr(document_tools, "_read_remote_document_text", fake_read_remote_document_text) + + fetched = document_tools.fetch_document_text("https://docs.example/") + + assert fetched["status"] == "ok" + assert reads == [ + "https://docs.example/", + "https://docs.example/llms.txt", + "https://docs.example/sitemap.xml", + ] + assert "# Source document: https://docs.example/" in fetched["text"] + assert "# Auxiliary document: https://docs.example/llms.txt" in fetched["text"] + assert "# Auxiliary document: https://docs.example/sitemap.xml" in fetched["text"] + assert fetched["metadata"]["bundled_source_count"] == 3 + + +def test_fetch_document_text_skips_auxiliary_indexes_for_file_urls( + monkeypatch: pytest.MonkeyPatch, +) -> None: + from fleet_rlm.runtime.tools import document_tools + + reads: list[str] = [] + + def fake_read_remote_document_text(url: str) -> tuple[str, dict[str, Any]]: + reads.append(url) + return "plain text", {"source_type": "text"} + + monkeypatch.setattr(document_tools, "_read_remote_document_text", fake_read_remote_document_text) + + fetched = document_tools.fetch_document_text("https://docs.example/readme.txt") + + assert fetched["status"] == "ok" + assert reads == ["https://docs.example/readme.txt"] + assert fetched["text"] == "plain text" + assert "bundled_source_count" not in fetched["metadata"] + + def test_download_url_removes_partial_temp_file_on_size_limit( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, @@ -390,3 +442,95 @@ def fake_mkstemp(suffix: str) -> tuple[int, str]: assert created assert not created[0].exists() + + +def test_browser_fetch_page_stub_raises_runtime_error() -> None: + from fleet_rlm.runtime.tools.browser_tools import browser_fetch_page + + with pytest.raises(RuntimeError, match="browser-capable"): + browser_fetch_page("https://example.com") + + +def test_browser_fetch_page_is_discoverable() -> None: + from fleet_rlm.runtime.tools.browser_tools import browser_fetch_page + + assert getattr(browser_fetch_page, "__is_fleet_tool__", False) is True + + +def test_browser_fetch_page_in_discover_tools() -> None: + from fleet_rlm.runtime.tools.registry import discover_tools + + discover_tools.cache_clear() + tools = discover_tools() + tool_names = [getattr(t, "name", None) or getattr(t.func, "__name__", "") for t in tools] + assert "browser_fetch_page" in tool_names + + +def test_bound_browser_fetch_page_validates_public_url_before_sandbox(monkeypatch: pytest.MonkeyPatch) -> None: + import fleet_rlm.runtime.tools.binding as binding_mod + import fleet_rlm.runtime.tools.document_tools as document_tools + from fleet_rlm.runtime.tools.binding import bind_runtime_tools + from fleet_rlm.runtime.tools.browser_tools import browser_fetch_page + + calls: list[dict[str, Any]] = [] + + def fake_getaddrinfo(*args: Any, **kwargs: Any) -> list[Any]: + _ = (args, kwargs) + return [(document_tools.socket.AF_INET, 0, 0, "", ("93.184.216.34", 443))] + + def fake_execute_sandbox_tool(interpreter: Any, code: str, variables: dict[str, Any]) -> dict[str, Any]: + _ = (interpreter, code) + calls.append(variables) + return {"status": "ok", "url": variables["target_url"]} + + monkeypatch.setattr(document_tools.socket, "getaddrinfo", fake_getaddrinfo) + monkeypatch.setattr(binding_mod, "execute_sandbox_tool", fake_execute_sandbox_tool) + + bound = bind_runtime_tools( + [browser_fetch_page], + runtime=types.SimpleNamespace(core_memory={}), + interpreter=object(), + ) + + result = getattr(bound[0], "func", bound[0])("https://example.test/docs", extract_links=True) + + assert result == {"status": "ok", "url": "https://example.test/docs"} + assert calls == [ + { + "target_url": "https://example.test/docs", + "wait_until": "networkidle", + "extract_links": True, + } + ] + + +@pytest.mark.parametrize( + "url", + [ + "http://localhost:3000", + "http://127.0.0.1:8000", + "http://10.0.0.2/docs", + "http://169.254.169.254/latest/meta-data", + ], +) +def test_bound_browser_fetch_page_rejects_private_targets( + monkeypatch: pytest.MonkeyPatch, + url: str, +) -> None: + import fleet_rlm.runtime.tools.binding as binding_mod + from fleet_rlm.runtime.tools.binding import bind_runtime_tools + from fleet_rlm.runtime.tools.browser_tools import browser_fetch_page + + def fake_execute_sandbox_tool(*args: Any, **kwargs: Any) -> dict[str, Any]: + _ = (args, kwargs) + raise AssertionError("unsafe URL should be rejected before sandbox execution") + + monkeypatch.setattr(binding_mod, "execute_sandbox_tool", fake_execute_sandbox_tool) + bound = bind_runtime_tools( + [browser_fetch_page], + runtime=types.SimpleNamespace(core_memory={}), + interpreter=object(), + ) + + with pytest.raises(ValueError, match="private network"): + getattr(bound[0], "func", bound[0])(url) From fef3e0a9649476d24d9e96407e91a1699afa7803 Mon Sep 17 00:00:00 2001 From: Zachary BENSALEM Date: Fri, 5 Jun 2026 00:45:06 +0200 Subject: [PATCH 2/7] ``` feat(phase7): align RLM recursion with reference implementation and tighten persistence contracts - Update daytona to 0.184.0 and fastapi to 0.136.3 - Add Phase 7 documentation covering history management, bounded child snapshots, token-aware compaction, explicit depth tracking, and fallback behavior - Tighten PersistenceDep type from Any to PersistenceProtocol and raise RuntimeError when backend not initialized - Remove ServerState compatibility proxy attributes (_SERVER_STATE_PROXY_ATTRS) - Move RuntimeEventContext from api/events/ --- pyproject.toml | 4 +- scripts/capture_phase0_baseline.sh | 30 + scripts/verify_phase0_regression.sh | 28 + src/fleet_rlm/AGENTS.md | 37 + src/fleet_rlm/api/dependencies.py | 52 +- src/fleet_rlm/api/events/__init__.py | 4 + src/fleet_rlm/api/events/events.py | 24 +- src/fleet_rlm/api/events/project_chat.py | 84 ++ src/fleet_rlm/api/events/project_graph.py | 158 ++ src/fleet_rlm/api/events/step_builder.py | 25 + src/fleet_rlm/api/routers/ws/artifacts.py | 2 +- .../api/routers/ws/connection_loop.py | 517 +++++++ src/fleet_rlm/api/routers/ws/endpoint.py | 13 +- src/fleet_rlm/api/routers/ws/session.py | 8 +- src/fleet_rlm/api/routers/ws/stream.py | 1290 ----------------- src/fleet_rlm/api/routers/ws/stream_events.py | 143 ++ src/fleet_rlm/api/routers/ws/stream_loop.py | 57 + .../api/routers/ws/stream_summary.py | 303 ++++ src/fleet_rlm/api/routers/ws/transport.py | 12 +- src/fleet_rlm/api/routers/ws/turn_runner.py | 395 +++++ src/fleet_rlm/api/routers/ws/turn_setup.py | 5 +- .../api/runtime_services/chat_persistence.py | 1114 +------------- .../api/runtime_services/run_lifecycle.py | 385 +++++ .../api/runtime_services/session_manifest.py | 307 ++++ .../runtime_services/session_persistence.py | 405 ++++++ .../api/runtime_services/stream_failures.py | 32 + .../integrations/daytona/concurrency.py | 4 +- src/fleet_rlm/runtime/agent/runtime.py | 210 +-- src/fleet_rlm/runtime/agent/signatures.py | 11 + src/fleet_rlm/runtime/events.py | 235 +++ src/fleet_rlm/runtime/execution/llm_query.py | 115 +- .../runtime/execution/streaming_events.py | 121 +- src/fleet_rlm/runtime/modules/escalating.py | 7 +- src/fleet_rlm/runtime/tools/rlm_delegate.py | 240 +-- tests/contracts/test_golden_payloads.py | 193 +++ tests/unit/api/test_bootstrap.py | 10 +- tests/unit/api/test_chat_persistence.py | 25 +- tests/unit/api/test_dependencies.py | 10 +- tests/unit/api/test_events.py | 4 +- tests/unit/runtime/test_escalating_module.py | 8 +- tests/unit/runtime/test_tools.py | 16 - uv.lock | 40 +- 42 files changed, 3701 insertions(+), 2982 deletions(-) create mode 100755 scripts/capture_phase0_baseline.sh create mode 100755 scripts/verify_phase0_regression.sh create mode 100644 src/fleet_rlm/api/events/project_chat.py create mode 100644 src/fleet_rlm/api/events/project_graph.py create mode 100644 src/fleet_rlm/api/routers/ws/connection_loop.py delete mode 100644 src/fleet_rlm/api/routers/ws/stream.py create mode 100644 src/fleet_rlm/api/routers/ws/stream_events.py create mode 100644 src/fleet_rlm/api/routers/ws/stream_loop.py create mode 100644 src/fleet_rlm/api/routers/ws/stream_summary.py create mode 100644 src/fleet_rlm/api/routers/ws/turn_runner.py create mode 100644 src/fleet_rlm/api/runtime_services/run_lifecycle.py create mode 100644 src/fleet_rlm/api/runtime_services/session_manifest.py create mode 100644 src/fleet_rlm/api/runtime_services/session_persistence.py create mode 100644 src/fleet_rlm/api/runtime_services/stream_failures.py create mode 100644 src/fleet_rlm/runtime/events.py create mode 100644 tests/contracts/test_golden_payloads.py diff --git a/pyproject.toml b/pyproject.toml index 463065c99..1600d6ea8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "dspy[optuna]==3.2.1", # Core orchestration and CLI/runtime ergonomics. - "daytona>=0.181.0,<1", + "daytona>=0.184.0,<1", "hydra-core>=1.3,<2", "prompt-toolkit>=3.0.50,<4", "rich>=14.3.3,<15", @@ -56,7 +56,7 @@ dependencies = [ "tomli>=2.0.0; python_version < '3.11'", # Web API surface. - "fastapi[standard]==0.136.1", + "fastapi[standard]==0.136.3", "joserfc>=1.0.1", "uvicorn>=0.47.0,<1", diff --git a/scripts/capture_phase0_baseline.sh b/scripts/capture_phase0_baseline.sh new file mode 100755 index 000000000..3c37bd3eb --- /dev/null +++ b/scripts/capture_phase0_baseline.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Phase 0 baseline capture script +# Run this to capture golden payloads and openapi.yaml baseline before refactoring + +set -e + +echo "=== Phase 0: Capturing golden payloads and openapi.yaml baseline ===" + +# Create golden payloads directory +mkdir -p tests/contracts/golden_payloads + +# Capture openapi.yaml baseline +echo "Capturing openapi.yaml baseline..." +cp openapi.yaml tests/contracts/golden_payloads/openapi_baseline.yaml + +# Capture frontend generated client baseline +echo "Capturing frontend generated client baseline..." +cp src/frontend/src/lib/rlm-api/generated/openapi.ts tests/contracts/golden_payloads/openapi_client_baseline.ts + +# Run golden payload capture tests +echo "Running golden payload capture tests..." +# Temporarily remove golden payloads directory to trigger capture +rm -rf tests/contracts/golden_payloads +uv run pytest tests/contracts/test_golden_payloads.py::test_capture_chat_websocket_golden_payloads -v +uv run pytest tests/contracts/test_golden_payloads.py::test_capture_passive_events_websocket_golden_payloads -v + +echo "=== Phase 0 baseline capture complete ===" +echo "Golden payloads saved to: tests/contracts/golden_payloads/" +echo "OpenAPI baseline saved to: tests/contracts/golden_payloads/openapi_baseline.yaml" +echo "Client baseline saved to: tests/contracts/golden_payloads/openapi_client_baseline.ts" diff --git a/scripts/verify_phase0_regression.sh b/scripts/verify_phase0_regression.sh new file mode 100755 index 000000000..811ebc6b9 --- /dev/null +++ b/scripts/verify_phase0_regression.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# Phase 0 regression verification script +# Run this after refactoring to verify against golden payload baseline + +set -e + +echo "=== Phase 0: Verifying regression against baseline ===" + +# Run regression tests +echo "Running regression tests against golden payloads..." +uv run pytest tests/contracts/test_golden_payloads.py::test_regression_chat_websocket_events -v +uv run pytest tests/contracts/test_golden_payloads.py::test_regression_passive_events_websocket_events -v + +# Compare openapi.yaml +echo "Comparing openapi.yaml against baseline..." +if ! diff -u tests/contracts/golden_payloads/openapi_baseline.yaml openapi.yaml; then + echo "WARNING: openapi.yaml has changed from baseline" + echo "Review the diff above. If changes are intentional, update the baseline." +fi + +# Compare generated client +echo "Comparing generated client against baseline..." +if ! diff -u tests/contracts/golden_payloads/openapi_client_baseline.ts src/frontend/src/lib/rlm-api/generated/openapi.ts; then + echo "WARNING: Generated client has changed from baseline" + echo "Review the diff above. If changes are intentional, update the baseline." +fi + +echo "=== Phase 0 regression verification complete ===" diff --git a/src/fleet_rlm/AGENTS.md b/src/fleet_rlm/AGENTS.md index 919cca0e1..4023dc6ff 100644 --- a/src/fleet_rlm/AGENTS.md +++ b/src/fleet_rlm/AGENTS.md @@ -242,6 +242,43 @@ Common mistakes to avoid: - Hand-editing packaged UI build output or generated OpenAPI artifacts - Treating Volumes or `/ready` semantics differently from the implemented contract +## Phase 7: RLM Recursion and History Management + +Phase 7 aligns the `dspy.RLM` path and recursion with the reference implementation, focusing on structured history management, explicit depth tracking, and token-budget-aware compaction. + +### P7.1: History as Native REPL Variable + +- `RLMVariableSignature` now includes `history: dspy.History` as an `InputField` +- `EscalatingFleetModule._run_rlm` always passes the `history` object to the RLM module +- The model can inspect full prior conversation turns with code (e.g., `history.messages[-1]`) rather than relying solely on flattened recency snippets + +### P7.2: Bounded Redacted Conversation Snapshot to Recursive Children + +- `LLMQueryMixin._execute_sub_rlm` builds a bounded, redacted conversation snapshot for child contexts +- `_build_child_history_snapshot()` extracts the last N turns (default: 2) from the parent runtime's history +- Sensitive values (API keys, tokens, passwords) are redacted using pattern-based replacement +- Snapshot size is bounded (default: 2000 chars) and truncated with a marker if exceeded +- Children receive a fresh REPL (per reference) but get explicit conversation continuity + +### P7.3: Token-Budget-Aware Compaction + +- `AgentRuntime._maybe_refresh_summary()` now compacts history based on estimated token usage +- `_estimate_history_chars()` provides a character-count proxy for token estimation (4 chars/token approximation) +- Compaction triggers when history exceeds the configured threshold (default: 70% of 64K token context window) or the turn interval is reached +- `history_max_turns` remains a hard ceiling for turn-based truncation +- New `compaction_threshold_pct` parameter controls the token-budget threshold + +### P7.4: Explicit Depth Tracking and Fallback + +- `AgentRuntime._recursion_depth_state()` returns `(depth, max_depth)` from interpreter state +- `RuntimeEventContext` includes `depth` and `max_depth` fields surfaced on runtime events +- `sub_rlm` and `sub_rlm_batched` fall back to `llm_query` and `llm_query_batched` when max recursion depth is reached +- Fallback prevents infinite recursion while preserving answer quality + +### P7.5: Benchmark Fast-Paths Removed + +- Confirmed no benchmark fast-paths exist in `rlm_delegate.py` (already removed in earlier phases) + ## Canonical Commands Backend setup and runtime: diff --git a/src/fleet_rlm/api/dependencies.py b/src/fleet_rlm/api/dependencies.py index 1275eff6d..e24316e44 100644 --- a/src/fleet_rlm/api/dependencies.py +++ b/src/fleet_rlm/api/dependencies.py @@ -12,6 +12,7 @@ from fleet_rlm.integrations.database import DatabaseManager, FleetRepository from fleet_rlm.integrations.database.repository_identity import IdentityUpsertResult +from fleet_rlm.integrations.persistence_protocol import PersistenceProtocol from fleet_rlm.utils.identity import owner_fingerprint from .auth import AuthError, AuthProvider, NormalizedIdentity, resolve_admitted_identity @@ -77,7 +78,7 @@ class PersistenceDeps: db_manager: DatabaseManager | None = None repository: FleetRepository | None = None - local_store: Any | None = None + local_store: PersistenceProtocol | None = None @dataclass @@ -106,31 +107,11 @@ class InterpreterPoolDeps: pool: Any | None = None -_SERVER_STATE_PROXY_ATTRS: dict[str, tuple[str, str]] = { - "config": ("config_deps", "config"), - "planner_lm": ("lm_deps", "planner_lm"), - "delegate_lm": ("lm_deps", "delegate_lm"), - "runtime_model_lock": ("lm_deps", "runtime_model_lock"), - "auth_provider": ("auth_deps", "auth_provider"), - "sessions": ("session_cache_deps", "sessions"), - "db_manager": ("persistence_deps", "db_manager"), - "repository": ("persistence_deps", "repository"), - "local_store": ("persistence_deps", "local_store"), - "events_event_emitter": ("diagnostics_deps", "events_event_emitter"), - "runtime_test_results": ("diagnostics_deps", "runtime_test_results"), - "optional_service_status": ("diagnostics_deps", "optional_service_status"), - "optional_service_errors": ("diagnostics_deps", "optional_service_errors"), - "mlflow_server_process": ("diagnostics_deps", "mlflow_server_process"), - "optional_startup_task": ("diagnostics_deps", "optional_startup_task"), - "interpreter_pool": ("interpreter_pool_deps", "pool"), -} - - class ServerState: """Shared server state, set during lifespan. - New code should depend on focused dependency slices. The flat attributes - remain as mapped compatibility accessors for tests and older internals. + New code should depend on focused dependency slices directly. + Use ``state.config_deps``, ``state.lm_deps``, ``state.persistence_deps``, etc. """ def __init__( @@ -156,19 +137,6 @@ def __init__( ) self.interpreter_pool_deps = interpreter_pool_deps or InterpreterPoolDeps() - def __getattr__(self, name: str) -> Any: - if target := _SERVER_STATE_PROXY_ATTRS.get(name): - deps_name, attr_name = target - return getattr(getattr(self, deps_name), attr_name) - raise AttributeError(name) - - def __setattr__(self, name: str, value: Any) -> None: - if target := _SERVER_STATE_PROXY_ATTRS.get(name): - deps_name, attr_name = target - setattr(getattr(self, deps_name), attr_name, value) - return - super().__setattr__(name, value) - @property def is_ready(self) -> bool: """Return whether critical server dependencies are ready to serve requests.""" @@ -322,19 +290,19 @@ def get_repository(request: Request) -> FleetRepository | None: RepositoryDep = Annotated[FleetRepository | None, Depends(get_repository)] -def get_persistence(request: Request) -> Any: - """Return the unified persistence backend (repository or local_store).""" +def get_persistence(request: Request) -> PersistenceProtocol: + """Return the startup-resolved persistence backend (repository or local_store).""" persistence_deps = get_persistence_deps(request) if persistence_deps.repository is not None: return persistence_deps.repository if persistence_deps.local_store is not None: return persistence_deps.local_store - from fleet_rlm.integrations.local_store import LocalStore - - return LocalStore() + raise RuntimeError( + "Persistence backend not initialized. Ensure FastAPI lifespan startup has completed before handling requests." + ) -PersistenceDep = Annotated[Any, Depends(get_persistence)] +PersistenceDep = Annotated[PersistenceProtocol, Depends(get_persistence)] def build_unauthenticated_identity( diff --git a/src/fleet_rlm/api/events/__init__.py b/src/fleet_rlm/api/events/__init__.py index 05613d832..133a1c1f1 100644 --- a/src/fleet_rlm/api/events/__init__.py +++ b/src/fleet_rlm/api/events/__init__.py @@ -14,6 +14,8 @@ sanitize_event_payload, summarize_code_for_event, ) +from .project_chat import project_chat +from .project_graph import project_graph from .step_builder import ExecutionStepBuilder __all__ = [ @@ -28,6 +30,8 @@ "ExecutionStepType", "ExecutionSubscription", "RuntimeEventContext", + "project_chat", + "project_graph", "sanitize_event_payload", "summarize_code_for_event", ] diff --git a/src/fleet_rlm/api/events/events.py b/src/fleet_rlm/api/events/events.py index 8f02017c9..08c8f3c00 100644 --- a/src/fleet_rlm/api/events/events.py +++ b/src/fleet_rlm/api/events/events.py @@ -16,6 +16,8 @@ from fastapi import WebSocket from pydantic import BaseModel, Field +from fleet_rlm.runtime.events import RuntimeEventContext + from .sanitizer import sanitize_event_payload, summarize_code_for_event logger = logging.getLogger(__name__) @@ -46,28 +48,6 @@ ] -class RuntimeEventContext(BaseModel): - """Stable runtime context attached to backend-emitted events.""" - - runtime_mode: str | None = None - execution_mode: str | None = None - execution_profile: str | None = None - sandbox_id: str | None = None - child_sandbox_id: str | None = None - volume_name: str | None = None - workspace_path: str | None = None - repo_url: str | None = None - repo_ref: str | None = None - document_path: str | None = None - depth: int | None = None - max_depth: int | None = None - actor_kind: ExecutionActorKind | None = None - actor_id: str | None = None - parent_id: str | None = None - lane_key: str | None = None - llm_call_budget: int | None = None - - class BackendEvent(BaseModel): """Canonical backend event before projection to chat or workbench streams.""" diff --git a/src/fleet_rlm/api/events/project_chat.py b/src/fleet_rlm/api/events/project_chat.py new file mode 100644 index 000000000..d04bb2967 --- /dev/null +++ b/src/fleet_rlm/api/events/project_chat.py @@ -0,0 +1,84 @@ +"""Chat-frame projector: RuntimeEvent → websocket chat payload. + +One structured event in, one deterministic dict out — no text re-parsing, +no intermediate ``BackendEvent`` hop, no timestamp-derived ``event_id``. + +Usage:: + + from fleet_rlm.api.events.project_chat import project_chat + + frame = project_chat(runtime_event, sequence=42, run_id="run-abc") +""" + +from __future__ import annotations + +from typing import Any + +from fleet_rlm.runtime.events import EVENT_SCHEMA_VERSION, RuntimeEvent, RuntimeEventKind + +_TURN_STARTED_KINDS: frozenset[RuntimeEventKind] = frozenset({RuntimeEventKind.TURN_STARTED}) +_TERMINAL_KINDS: frozenset[RuntimeEventKind] = frozenset({RuntimeEventKind.DONE, RuntimeEventKind.ERROR}) + + +def _frame_kind(event_kind: RuntimeEventKind) -> str: + if event_kind in _TURN_STARTED_KINDS: + return "execution_started" + if event_kind in _TERMINAL_KINDS: + return "execution_completed" + return "execution_step" + + +def project_chat( + event: RuntimeEvent, + *, + sequence: int = 0, + run_id: str | None = None, +) -> dict[str, Any]: + """Project one ``RuntimeEvent`` to a websocket chat frame dict. + + Args: + event: The canonical runtime event to project. + sequence: Monotonic per-turn counter, used as the ``event_id``. + run_id: Optional run identifier prefixed to ``event_id``. + + Returns: + A dict ready for ``websocket.send_json()``. + """ + payload: dict[str, Any] = dict(event.payload) + payload.setdefault("source_type", event.kind.value) + + if event.context is not None: + payload["runtime"] = event.context.model_dump(mode="json", exclude_none=True) + + if event.tool is not None and event.kind in {RuntimeEventKind.TOOL_CALL, RuntimeEventKind.TOOL_RESULT}: + payload.setdefault("tool_name", event.tool.tool_name) + if event.tool.tool_args: + payload.setdefault("tool_args", event.tool.tool_args) + if event.tool.tool_input is not None: + payload.setdefault("tool_input", event.tool.tool_input) + if event.tool.tool_output is not None: + payload.setdefault("tool_output", event.tool.tool_output) + if event.tool.step_index is not None: + payload.setdefault("step_index", event.tool.step_index) + + frame_kind = _frame_kind(event.kind) + if frame_kind == "execution_completed": + payload.setdefault( + "status", + "failed" if event.kind == RuntimeEventKind.ERROR else "completed", + ) + + event_id = f"{run_id}:{sequence}" if run_id else str(sequence) + + return { + "kind": frame_kind, + "text": event.text, + "payload": payload, + "timestamp": event.timestamp.isoformat(), + "version": EVENT_SCHEMA_VERSION, + "event_id": event_id, + "sequence": sequence, + } + + +__all__ = ["project_chat"] diff --git a/src/fleet_rlm/api/events/project_graph.py b/src/fleet_rlm/api/events/project_graph.py new file mode 100644 index 000000000..9b86eda3b --- /dev/null +++ b/src/fleet_rlm/api/events/project_graph.py @@ -0,0 +1,158 @@ +"""Execution-graph projector: RuntimeEvent → ExecutionStep. + +Reads typed fields (``event.tool``, ``event.kind``, ``event.context``) directly — +no ``_extract_tool_name`` text parsing, no ``_extract_actor_kind`` dict scraping. + +Usage:: + + from fleet_rlm.api.events.project_graph import project_graph + + step = project_graph(runtime_event, builder) +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from fleet_rlm.runtime.events import RuntimeEvent, RuntimeEventKind + +from .sanitizer import sanitize_event_payload +from .step_builder_extractors import ExecutionStepType, _derive_lane_key, _tool_step_type + +if TYPE_CHECKING: + from .step_builder import ExecutionStepBuilder + + +def _step_type_from_kind(kind: RuntimeEventKind, tool_name: str | None) -> ExecutionStepType: + if kind in {RuntimeEventKind.TOOL_CALL, RuntimeEventKind.TOOL_RESULT}: + return _tool_step_type(tool_name) + if kind in {RuntimeEventKind.DONE, RuntimeEventKind.ERROR}: + return "output" + return "llm" + + +def _label_from_event(event: RuntimeEvent) -> str | None: + kind = event.kind + if kind in {RuntimeEventKind.TOOL_CALL, RuntimeEventKind.TOOL_RESULT}: + return event.tool.tool_name if event.tool else (event.payload.get("tool_name") or event.text or kind.value) + if kind == RuntimeEventKind.DONE: + return "assistant_output" + if kind == RuntimeEventKind.ERROR: + return "error" + if kind == RuntimeEventKind.REASONING: + return event.text or "reasoning" + if kind == RuntimeEventKind.TEXT: + return "assistant_token" + if kind in {RuntimeEventKind.STATUS, RuntimeEventKind.WARNING}: + stripped = event.text.strip() + if not stripped: + return None + if stripped in {"Calling tool:", "Tool finished."}: + return None + return stripped + return kind.value + + +def _input_for_kind(event: RuntimeEvent) -> Any: + kind = event.kind + if kind == RuntimeEventKind.TOOL_CALL: + if event.tool: + return {"tool_name": event.tool.tool_name, "tool_args": event.tool.tool_args} + return dict(event.payload) + if kind == RuntimeEventKind.TOOL_RESULT: + return {"event_kind": "tool_result", "tool_name": event.tool.tool_name if event.tool else None} + if kind == RuntimeEventKind.TEXT: + return {"event_kind": "text"} + if kind in {RuntimeEventKind.STATUS, RuntimeEventKind.WARNING}: + return {"event_kind": kind.value} + if kind in {RuntimeEventKind.DONE, RuntimeEventKind.ERROR}: + return {"event_kind": kind.value} + return dict(event.payload) + + +def _output_for_kind(event: RuntimeEvent) -> Any: + kind = event.kind + if kind == RuntimeEventKind.TOOL_CALL: + return None + if kind == RuntimeEventKind.TOOL_RESULT: + if event.tool and event.tool.tool_output is not None: + return event.tool.tool_output + return dict(event.payload) + if kind == RuntimeEventKind.TEXT: + return {"text": event.text} + if kind == RuntimeEventKind.REASONING: + return {"text": event.text} + if kind in {RuntimeEventKind.DONE, RuntimeEventKind.ERROR}: + return {"text": event.text, "payload": dict(event.payload)} + return {"text": event.text} if event.text else None + + +def project_graph(event: RuntimeEvent, builder: ExecutionStepBuilder) -> Any: + """Project one ``RuntimeEvent`` to an ``ExecutionStep``. + + Step type and label are derived from typed ``RuntimeEvent`` fields — no + text parsing. The builder is responsible only for ID generation and + parent-link tracking. + + Args: + event: The canonical runtime event to project. + builder: Per-turn step builder providing IDs and parent links. + + Returns: + A constructed :class:`~fleet_rlm.api.events.events.ExecutionStep`, or + ``None`` if this event kind should produce no graph node. + """ + from .events import ExecutionStep + + label = _label_from_event(event) + if label is None: + return None + + kind = event.kind + tool_name = event.tool.tool_name if event.tool else None + step_type = _step_type_from_kind(kind, tool_name) + + ctx = event.context + depth = ctx.depth if ctx else None + actor_kind_raw = ctx.actor_kind if ctx else None + actor_id = ctx.actor_id if ctx else None + parent_id_hint = ctx.parent_id if ctx else None + + actor_kind: str = actor_kind_raw or "unknown" + if actor_kind == "unknown" and depth is None: + actor_kind = "root_rlm" + depth = 0 + + lane_key = _derive_lane_key(actor_kind, actor_id, depth) + + payload_dict: dict[str, Any] = dict(event.payload) + resolved_parent_id = parent_id_hint or builder._resolve_parent(payload_dict) + + input_payload = sanitize_event_payload(_input_for_kind(event)) + output_payload = sanitize_event_payload(_output_for_kind(event)) + + from .sanitizer import _truncate_text + + step = ExecutionStep( + id=builder._next_id(), + parent_id=resolved_parent_id, + type=step_type, + label=_truncate_text(label), + depth=depth, + actor_kind=actor_kind, + actor_id=actor_id, + lane_key=lane_key, + input=input_payload, + output=output_payload, + timestamp=event.timestamp.timestamp(), + ) + + if depth is not None: + builder._depth_parents[depth] = step.id + if kind == RuntimeEventKind.TOOL_CALL: + builder._last_tool_step_id = step.id + + return step + + +__all__ = ["project_graph"] diff --git a/src/fleet_rlm/api/events/step_builder.py b/src/fleet_rlm/api/events/step_builder.py index d7267ecb6..211c3503b 100644 --- a/src/fleet_rlm/api/events/step_builder.py +++ b/src/fleet_rlm/api/events/step_builder.py @@ -32,6 +32,7 @@ class ExecutionStepBuilder: run_id: str root_id: str = field(init=False) _counter: int = 0 + _sequence: int = 0 _last_tool_step_id: str | None = None _depth_parents: dict[int, str] = field(default_factory=dict) _repl_parent_by_hash: dict[str, str] = field(default_factory=dict) @@ -269,6 +270,30 @@ def from_stream_event( return None + def from_runtime_event(self, event: Any) -> ExecutionStep | None: + """Project a :class:`~fleet_rlm.runtime.events.RuntimeEvent` to a step. + + Uses typed ``event.tool``, ``event.kind``, and ``event.context`` fields — + no text parsing. Falls back to :meth:`from_stream_event` for events that + are not ``RuntimeEvent`` instances. + """ + from fleet_rlm.runtime.events import RuntimeEvent + + self._sequence += 1 + if not isinstance(event, RuntimeEvent): + raw_ts = getattr(event, "timestamp", None) + ts = raw_ts.timestamp() if raw_ts is not None and hasattr(raw_ts, "timestamp") else time.time() + return self.from_stream_event( + kind=str(getattr(event, "kind", "status")), + text=str(getattr(event, "text", "") or ""), + payload=dict(getattr(event, "payload", {}) or {}), + timestamp=ts, + ) + + from .project_graph import project_graph + + return project_graph(event, self) + def from_interpreter_hook(self, payload: dict[str, Any]) -> ExecutionStep | None: if not isinstance(payload, dict): return None diff --git a/src/fleet_rlm/api/routers/ws/artifacts.py b/src/fleet_rlm/api/routers/ws/artifacts.py index 5cc15521d..27fc4ab88 100644 --- a/src/fleet_rlm/api/routers/ws/artifacts.py +++ b/src/fleet_rlm/api/routers/ws/artifacts.py @@ -12,7 +12,7 @@ from fleet_rlm.utils.logging import sanitize_for_log as _sanitize_for_log from fleet_rlm.utils.time import now_iso -from ...runtime_services.chat_persistence import PersistenceRequiredError +from ...runtime_services.stream_failures import PersistenceRequiredError logger = logging.getLogger(__name__) diff --git a/src/fleet_rlm/api/routers/ws/connection_loop.py b/src/fleet_rlm/api/routers/ws/connection_loop.py new file mode 100644 index 000000000..89b19376c --- /dev/null +++ b/src/fleet_rlm/api/routers/ws/connection_loop.py @@ -0,0 +1,517 @@ +"""WebSocket connection loop: message receive/interleave and background execution. + +Owns: ``_ExecutionConnectionLoop``, message-type dispatch, +receive/stream interleaving, and the background execution task. +""" + +from __future__ import annotations + +import asyncio +import logging +from datetime import datetime, timezone +from types import SimpleNamespace +from typing import Any + +from fastapi import WebSocket, WebSocketDisconnect + +from ...dependencies import DiagnosticsDeps, SessionCacheDeps +from ...events import ExecutionEventEmitter, ExecutionSubscription +from ...runtime_services.chat_persistence import ( + build_startup_status_event, + handle_chat_disconnect, +) +from ...runtime_services.chat_runtime import ( + ChatAgentProtocol, + LocalPersistFn, + SessionContext, + set_interpreter_default_profile, +) +from ...runtime_services.chat_runtime import ( + ChatSessionState as _ChatSessionState, +) +from ...runtime_services.chat_runtime import ( + PreparedChatRuntime as _PreparedChatRuntime, +) +from ...runtime_services.session_persistence import build_local_persist_fn +from ...schemas import WSMessage +from .commands import handle_command_with_persist +from .session import switch_session_if_needed +from .stream_events import build_stream_event_dict +from .transport import ( + _error_envelope, + _try_send_json, + handle_chat_loop_exception, + parse_ws_message_or_send_error, + resolve_session_identity, +) +from .turn_runner import run_streaming_turn +from .turn_setup import prepare_chat_message_turn + +logger = logging.getLogger(__name__) + + +def _routing_status_text(payload: dict[str, Any]) -> str: + selected = ", ".join(str(item) for item in payload.get("selected_skills", []) or []) + route = payload.get("routing_decision", "auto") + source = payload.get("source_url") + text = f"Route: {route}" + if selected: + text += f" | skills: {selected}" + if source: + text += f" | source: {source}" + return text + + +def _build_routing_preview_event(agent: ChatAgentProtocol, msg: WSMessage) -> Any | None: + preview_routing = getattr(agent, "preview_routing", None) + if not callable(preview_routing): + return None + payload = preview_routing( + user_request=msg.content, + execution_mode=msg.execution_mode or "auto", + ) + if not isinstance(payload, dict) or not payload.get("routing_decision"): + return None + return SimpleNamespace( + kind="status", + text=_routing_status_text(payload), + payload={**payload, "phase": "routing"}, + timestamp=datetime.now(timezone.utc), + ) + + +def _ensure_pending_receive_task( + *, + websocket: WebSocket, + pending_receive_task: asyncio.Task[object] | None, +) -> asyncio.Task[object]: + if pending_receive_task is not None: + return pending_receive_task + return asyncio.create_task(websocket.receive_json()) + + +async def _await_message_while_streaming( + *, + websocket: WebSocket, + stream_task: asyncio.Task[str | None], + pending_receive_task: asyncio.Task[object] | None, + session: _ChatSessionState, +) -> tuple[WSMessage | None, asyncio.Task[str | None] | None, asyncio.Task[object] | None]: + pending_receive_task = _ensure_pending_receive_task( + websocket=websocket, + pending_receive_task=pending_receive_task, + ) + done, _pending = await asyncio.wait( + {stream_task, pending_receive_task}, + return_when=asyncio.FIRST_COMPLETED, + ) + + if stream_task in done: + session.last_loaded_docs_path = await stream_task + return None, None, pending_receive_task + + raw_payload = await pending_receive_task + msg = await parse_ws_message_or_send_error( + websocket=websocket, + raw_payload=raw_payload, + ) + return msg, stream_task, None + + +async def _handle_message_while_streaming( + *, + websocket: WebSocket, + msg: WSMessage, + agent: ChatAgentProtocol, + runtime: _PreparedChatRuntime, + session: _ChatSessionState, + local_persist: LocalPersistFn, +) -> bool: + if msg.type == "cancel": + session.cancel_flag["cancelled"] = True + return True + + if msg.type == "command": + await handle_command_with_persist( + websocket=websocket, + agent=agent, + payload=msg.model_dump(), + session_record=session.session_record, + persistence=runtime.persistence, + identity_rows=runtime.identity_rows, + persistence_required=runtime.persistence_required, + local_persist=local_persist, + ) + return True + + if session.lifecycle is not None and session.lifecycle.run_completed: + return False + + await _try_send_json( + websocket, + { + "type": "error", + "message": ( + "A run is already in progress. Cancel it or wait for completion before sending another message." + ), + }, + ) + return True + + +async def _receive_next_chat_message( + *, + websocket: WebSocket, + pending_message: WSMessage | None, + pending_receive_task: asyncio.Task[object] | None, +) -> tuple[WSMessage | None, asyncio.Task[object] | None]: + if pending_message is not None: + return pending_message, pending_receive_task + + if pending_receive_task is not None: + raw_payload = await pending_receive_task + pending_receive_task = None + else: + raw_payload = await websocket.receive_json() + + msg = await parse_ws_message_or_send_error( + websocket=websocket, + raw_payload=raw_payload, + ) + return msg, pending_receive_task + + +async def _handle_idle_non_turn_message( + *, + websocket: WebSocket, + msg: WSMessage, + agent: ChatAgentProtocol, + runtime: _PreparedChatRuntime, + session: _ChatSessionState, + local_persist: LocalPersistFn, +) -> bool: + if msg.type == "cancel": + session.cancel_flag["cancelled"] = True + await _try_send_json( + websocket, + _error_envelope( + code="no_active_run", + message="No active websocket run is available to cancel.", + ), + ) + return True + + if msg.type == "command": + await handle_command_with_persist( + websocket=websocket, + agent=agent, + payload=msg.model_dump(), + session_record=session.session_record, + persistence=runtime.persistence, + identity_rows=runtime.identity_rows, + persistence_required=runtime.persistence_required, + local_persist=local_persist, + ) + return True + + if msg.type != "message": + await _try_send_json( + websocket, + {"type": "error", "message": f"Unknown message type: {msg.type}"}, + ) + return True + + return False + + +async def _process_chat_message( + *, + websocket: WebSocket | None, + msg: WSMessage, + agent: ChatAgentProtocol, + interpreter: object | None, + session: _ChatSessionState, + local_persist: LocalPersistFn, + runtime: _PreparedChatRuntime, + workspace_id: str, + user_id: str, + sess_id: str, + execution_emitter: ExecutionEventEmitter, +) -> str | None: + """Process one ``message`` payload and return the loaded docs path.""" + prepared_turn = await prepare_chat_message_turn( + websocket=websocket, + msg=msg, + agent=agent, + session=session, + local_persist=local_persist, + runtime=runtime, + workspace_id=workspace_id, + user_id=user_id, + sess_id=sess_id, + execution_emitter=execution_emitter, + ) + if prepared_turn is None: + return session.last_loaded_docs_path + + def cancel_check() -> bool: + return session.cancel_flag["cancelled"] + + orchestration_session = session.orchestration_session or SessionContext( + workspace_id=workspace_id, + user_id=user_id, + session_id=sess_id, + session_record=session.session_record, + ) + session.orchestration_session = orchestration_session + + return await run_streaming_turn( + websocket=websocket, + agent=agent, + prepared_turn=prepared_turn, + orchestration_session=orchestration_session, + cancel_check=cancel_check, + interpreter=interpreter, + persist_session_state=local_persist, + execution_emitter=execution_emitter, + ) + + +async def _background_execution_task( + *, + msg: WSMessage, + session_cache: SessionCacheDeps, + runtime: _PreparedChatRuntime, + session: _ChatSessionState, + workspace_id: str, + user_id: str, + sess_id: str, + execution_emitter: ExecutionEventEmitter, +) -> str | None: + """Run execution in the background with its own agent context.""" + from ...runtime_services.chat_runtime import build_chat_agent_context + + agent_context = await build_chat_agent_context(runtime) + async with agent_context as agent: + interpreter = getattr(agent, "interpreter", None) + set_interpreter_default_profile(interpreter, runtime.cfg) + + async def _noop_persist( + *, + include_volume_save: bool = True, + latest_user_message: str = "", + ) -> None: + _ = include_volume_save, latest_user_message + + ( + session.active_key, + session.active_manifest_path, + session.session_record, + session.last_loaded_docs_path, + session.orchestration_session, + ) = await switch_session_if_needed( + session_cache=session_cache, + agent=agent, + interpreter=interpreter, + workspace_id=workspace_id, + user_id=user_id, + sess_id=sess_id, + owner_tenant_claim=session.owner_tenant_claim, + owner_user_claim=session.owner_user_claim, + active_key=None, + session_record=session.session_record, + last_loaded_docs_path=session.last_loaded_docs_path, + local_persist=_noop_persist, + persistence=runtime.persistence, + identity_rows=runtime.identity_rows, + ) + + agent._db_session_id = (session.session_record or {}).get("db_session_id") + agent._identity_rows = runtime.identity_rows + if agent.interpreter is not None: + agent.interpreter._host_repository = runtime.persistence + agent.interpreter._host_identity = runtime.identity_rows + agent.interpreter._host_run_id = None + local_persist = build_local_persist_fn( + session_cache=session_cache, + runtime=runtime, + agent=agent, + interpreter=interpreter, + session=session, + ) + + return await _process_chat_message( + websocket=None, + msg=msg, + agent=agent, + interpreter=interpreter, + session=session, + local_persist=local_persist, + runtime=runtime, + workspace_id=workspace_id, + user_id=user_id, + sess_id=sess_id, + execution_emitter=execution_emitter, + ) + + +class _ExecutionConnectionLoop: + """Connection-scoped websocket message loop for one execution socket.""" + + def __init__( + self, + *, + websocket: WebSocket, + session_cache: SessionCacheDeps, + diagnostics_deps: DiagnosticsDeps, + runtime: _PreparedChatRuntime, + agent: ChatAgentProtocol, + interpreter: object | None, + session: _ChatSessionState, + local_persist: LocalPersistFn, + initial_message: WSMessage | None = None, + ) -> None: + self.websocket = websocket + self.session_cache = session_cache + self.diagnostics_deps = diagnostics_deps + self.runtime = runtime + self.agent = agent + self.interpreter = interpreter + self.session = session + self.local_persist = local_persist + self.execution_emitter = diagnostics_deps.events_event_emitter + self.stream_task: asyncio.Task[str | None] | asyncio.Task[None] | None = None + self.pending_receive_task: asyncio.Task[object] | None = None + self.pending_message = initial_message + + async def run(self) -> None: + try: + while True: + if self.stream_task is not None: + ( + msg, + self.stream_task, + self.pending_receive_task, + ) = await _await_message_while_streaming( + websocket=self.websocket, + stream_task=self.stream_task, + pending_receive_task=self.pending_receive_task, + session=self.session, + ) + if msg is None: + continue + if self.stream_task is None: + self.pending_message = msg + continue + + if await _handle_message_while_streaming( + websocket=self.websocket, + msg=msg, + agent=self.agent, + runtime=self.runtime, + session=self.session, + local_persist=self.local_persist, + ): + continue + continue + + ( + self.pending_message, + self.pending_receive_task, + ) = await _receive_next_chat_message( + websocket=self.websocket, + pending_message=self.pending_message, + pending_receive_task=self.pending_receive_task, + ) + msg = self.pending_message + self.pending_message = None + if msg is None: + continue + + if await _handle_idle_non_turn_message( + websocket=self.websocket, + msg=msg, + agent=self.agent, + runtime=self.runtime, + session=self.session, + local_persist=self.local_persist, + ): + continue + + if not str(msg.content or "").strip(): + await _try_send_json( + self.websocket, + {"type": "error", "message": "Message content cannot be empty"}, + ) + continue + + workspace_id, user_id, sess_id = resolve_session_identity( + msg=msg, + workspace_id=self.session.canonical_workspace_id, + user_id=self.session.canonical_user_id, + ) + await self.execution_emitter.update_subscription( + self.websocket, + ExecutionSubscription( + workspace_id=workspace_id, + user_id=user_id, + session_id=sess_id, + ), + ) + startup_event = build_startup_status_event() + await _try_send_json( + self.websocket, + { + "type": "event", + "data": build_stream_event_dict( + event=startup_event, + payload=startup_event.payload, + ), + }, + ) + routing_preview_event = _build_routing_preview_event(self.agent, msg) + if routing_preview_event is not None: + await _try_send_json( + self.websocket, + { + "type": "event", + "data": build_stream_event_dict( + event=routing_preview_event, + payload=routing_preview_event.payload, + ), + }, + ) + self.stream_task = asyncio.create_task( + _background_execution_task( + msg=msg, + session_cache=self.session_cache, + runtime=self.runtime, + session=self.session, + workspace_id=workspace_id, + user_id=user_id, + sess_id=sess_id, + execution_emitter=self.execution_emitter, + ) + ) + except (asyncio.CancelledError, WebSocketDisconnect): + await handle_chat_disconnect( + pending_receive_task=self.pending_receive_task, + stream_task=self.stream_task, + cancel_flag=self.session.cancel_flag, + local_persist=self.local_persist, + lifecycle=self.session.lifecycle, + cancel_active_run=False, + persist_on_disconnect=False, + ) + except Exception as exc: + await handle_chat_loop_exception( + websocket=self.websocket, + exc=exc, + pending_receive_task=self.pending_receive_task, + stream_task=self.stream_task, + local_persist=self.local_persist, + lifecycle=self.session.lifecycle, + ) + + +__all__ = ["_ExecutionConnectionLoop"] diff --git a/src/fleet_rlm/api/routers/ws/endpoint.py b/src/fleet_rlm/api/routers/ws/endpoint.py index d28bd780c..1069caac2 100644 --- a/src/fleet_rlm/api/routers/ws/endpoint.py +++ b/src/fleet_rlm/api/routers/ws/endpoint.py @@ -35,12 +35,6 @@ get_session_cache_deps_from_websocket, ) from ...events import ExecutionSubscription -from ...runtime_services.chat_persistence import ( - build_local_persist_fn as _build_local_persist_fn, -) -from ...runtime_services.chat_persistence import ( - get_execution_emitter, -) from ...runtime_services.chat_runtime import ( PreparedChatRuntime as _PreparedChatRuntime, ) @@ -56,7 +50,8 @@ from ...runtime_services.chat_runtime import ( set_interpreter_default_profile as _set_interpreter_default_profile, ) -from .stream import _chat_message_loop +from ...runtime_services.session_persistence import build_local_persist_fn as _build_local_persist_fn +from .stream_loop import _chat_message_loop from .transport import ( _authenticate_websocket, _close_websocket_safely, @@ -267,7 +262,7 @@ async def run(self) -> None: ) # Connect to Event Bus for decoupled execution events - emitter = get_execution_emitter(self.diagnostics_deps) + emitter = self.diagnostics_deps.events_event_emitter subscription = ExecutionSubscription( workspace_id=session.canonical_workspace_id, user_id=session.canonical_user_id, @@ -321,7 +316,7 @@ async def _run_execution_subscription_stream( await _close_websocket_safely(websocket, code=1008) return - emitter = get_execution_emitter(diagnostics_deps) + emitter = diagnostics_deps.events_event_emitter await emitter.connect(websocket, subscription) try: diff --git a/src/fleet_rlm/api/routers/ws/session.py b/src/fleet_rlm/api/routers/ws/session.py index 330026f4c..a8ec53baa 100644 --- a/src/fleet_rlm/api/routers/ws/session.py +++ b/src/fleet_rlm/api/routers/ws/session.py @@ -226,10 +226,8 @@ async def switch_session_if_needed( cached: dict[str, Any] | None = session_cache.sessions.get(key) if cached is None: - from ...runtime_services.chat_persistence import ( - _restore_manifest_from_local_store, - load_manifest_from_volume, - ) + from ...runtime_services.session_manifest import load_manifest_from_volume + from ...runtime_services.session_persistence import _restore_manifest_from_local_store if interpreter is not None: manifest = await load_manifest_from_volume( @@ -289,7 +287,7 @@ async def switch_session_if_needed( if db_session_id: metadata["db_session_id"] = db_session_id if interpreter is not None: - from ...runtime_services.chat_persistence import ensure_session_volume_layout + from ...runtime_services.session_manifest import ensure_session_volume_layout try: layout_paths = await ensure_session_volume_layout( diff --git a/src/fleet_rlm/api/routers/ws/stream.py b/src/fleet_rlm/api/routers/ws/stream.py deleted file mode 100644 index 9079547fc..000000000 --- a/src/fleet_rlm/api/routers/ws/stream.py +++ /dev/null @@ -1,1290 +0,0 @@ -"""Inner streaming loop and REPL hook management for WebSocket chat.""" - -from __future__ import annotations - -import asyncio -import logging -import time -import uuid -from collections.abc import AsyncIterator, Awaitable, Callable -from dataclasses import dataclass, field -from datetime import datetime, timezone -from types import SimpleNamespace -from typing import Any - -from fastapi import WebSocket, WebSocketDisconnect - -from fleet_rlm.integrations.database import RunStatus -from fleet_rlm.integrations.observability.mlflow_context import ( - merge_trace_result_metadata as _merge_trace_result_metadata, -) -from fleet_rlm.integrations.observability.trace_context import ( - runtime_telemetry_enabled_context, -) -from fleet_rlm.runtime.execution.streaming_events import ( - _normalize_trajectory, - is_terminal_stream_event_kind, -) -from fleet_rlm.utils.logging import sanitize_for_log as _sanitize_for_log - -from ...dependencies import DiagnosticsDeps, SessionCacheDeps -from ...events import ( - ExecutionEventEmitter, - ExecutionStep, - ExecutionStepBuilder, - ExecutionSubscription, -) -from ...events.event_adapter import ( - adapt_stream_event, - build_chat_event_payload, - is_terminal_backend_event, -) -from ...runtime_services.chat_persistence import ( - ExecutionLifecycleManager, - build_local_persist_fn, - build_startup_status_event, - build_workspace_task_request, - classify_stream_failure, - enqueue_latest_nonblocking, - get_execution_emitter, - handle_chat_disconnect, - should_reload_docs_path, -) -from ...runtime_services.chat_runtime import ( - ChatAgentProtocol, - LocalPersistFn, - SessionContext, - StreamEventLike, - build_chat_agent_context, - set_interpreter_default_profile, -) -from ...runtime_services.chat_runtime import ( - ChatSessionState as _ChatSessionState, -) -from ...runtime_services.chat_runtime import ( - PreparedChatRuntime as _PreparedChatRuntime, -) -from ...schemas import WSMessage -from .commands import handle_command_with_persist -from .repl_bridge import ReplHookBridge -from .session import ( - switch_session_if_needed, -) -from .transport import ( - _error_envelope, - _try_send_json, - handle_chat_loop_exception, - parse_ws_message_or_send_error, - resolve_session_identity, -) -from .turn_setup import PreparedStreamingTurn, prepare_chat_message_turn - -logger = logging.getLogger(__name__) - - -def _routing_status_text(payload: dict[str, Any]) -> str: - selected = ", ".join(str(item) for item in payload.get("selected_skills", []) or []) - route = payload.get("routing_decision", "auto") - source = payload.get("source_url") - text = f"Route: {route}" - if selected: - text += f" | skills: {selected}" - if source: - text += f" | source: {source}" - return text - - -def _build_routing_preview_event(agent: ChatAgentProtocol, msg: WSMessage) -> Any | None: - preview_routing = getattr(agent, "preview_routing", None) - if not callable(preview_routing): - return None - payload = preview_routing( - user_request=msg.content, - execution_mode=msg.execution_mode or "auto", - ) - if not isinstance(payload, dict) or not payload.get("routing_decision"): - return None - return SimpleNamespace( - kind="status", - text=_routing_status_text(payload), - payload={**payload, "phase": "routing"}, - timestamp=datetime.now(timezone.utc), - ) - - -@dataclass(slots=True) -class WorkspaceEvent: - """Normalized event shape for websocket streaming.""" - - kind: str - text: str = "" - payload: dict[str, Any] = field(default_factory=dict) - timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - terminal: bool = False - - -@dataclass(slots=True) -class WorkspaceTaskRequest: - """Input needed to execute one workspace task end-to-end.""" - - agent: Any - message: str - execution_mode: str | None = None - trace: bool = True - docs_path: str | None = None - repo_url: str | None = None - repo_ref: str | None = None - context_paths: list[str] | None = None - batch_concurrency: int | None = None - workspace_id: str | None = None - cancel_check: Callable[[], bool] | None = None - prepare: Callable[[], Awaitable[None]] | None = None - - -def merge_trace_result_metadata( - payload: dict[str, Any] | None, - *, - response_preview: str | None = None, - trace_metadata: dict[str, Any] | None = None, -) -> dict[str, Any]: - """Compatibility shim for MLflow final-event metadata enrichment.""" - return _merge_trace_result_metadata( - payload, - response_preview=response_preview, - trace_metadata=trace_metadata, - ) - - -def _runtime_trace_metadata(payload: dict[str, Any] | None) -> dict[str, Any]: - if not isinstance(payload, dict): - return {} - - runtime_payload = payload.get("runtime") - runtime = runtime_payload if isinstance(runtime_payload, dict) else {} - - metadata: dict[str, Any] = {} - for key in ( - "routing_decision", - "source_url", - "execution_mode", - "runtime_module", - ): - value = payload.get(key, runtime.get(key)) - if value not in (None, "", False): - metadata[f"fleet_rlm.{key}"] = value - - selected_skills = payload.get("selected_skills") - if isinstance(selected_skills, list): - metadata["fleet_rlm.selected_skills"] = ",".join(str(item) for item in selected_skills if str(item)) - - trajectory_steps = _normalize_trajectory(payload.get("trajectory")) - if trajectory_steps: - metadata["fleet_rlm.trajectory_steps"] = str(len(trajectory_steps)) - if any(step.get("thought") for step in trajectory_steps): - metadata["fleet_rlm.trajectory_has_reasoning"] = "true" - if any(step.get("tool_name") for step in trajectory_steps): - metadata["fleet_rlm.trajectory_has_tools"] = "true" - if any( - "repl" in str(step.get("tool_name", "")).lower() - or "code" in step - or step.get("type") == "repl" - for step in trajectory_steps - ): - metadata["fleet_rlm.trajectory_has_repl"] = "true" - if any(step.get("output") is not None or step.get("observation") is not None for step in trajectory_steps): - metadata["fleet_rlm.trajectory_has_outputs"] = "true" - - for key in ( - "runtime_degraded", - "runtime_failure_category", - "runtime_failure_phase", - "runtime_fallback_used", - ): - value = payload.get(key, runtime.get(key)) - if value in (None, "", False): - if key in {"runtime_degraded", "runtime_fallback_used"} and value is False: - metadata[key] = False - continue - metadata[key] = value - return metadata - - -def build_stream_event_dict( - *, - event: StreamEventLike, - payload: Any, -) -> dict[str, Any]: - """Serialize one stream event for websocket delivery.""" - backend_event = adapt_stream_event( - kind=event.kind, - text=event.text, - payload=payload if isinstance(payload, dict) else None, - timestamp=event.timestamp, - ) - event_dict = build_chat_event_payload(backend_event) - event_dict.setdefault("event_id", uuid.uuid4().hex) - return event_dict - - -def _terminal_run_status(event: StreamEventLike) -> RunStatus: - """Return the authoritative terminal run status for one event.""" - if event.kind == "done" and (isinstance(event.payload, dict) and event.payload.get("cancelled")): - return RunStatus.CANCELLED - if event.kind == "done": - payload = event.payload if isinstance(event.payload, dict) else {} - return RunStatus.FAILED if final_event_failed(payload) else RunStatus.COMPLETED - return RunStatus.FAILED - - -async def handle_terminal_stream_event( - *, - websocket: WebSocket | None, - lifecycle: ExecutionLifecycleManager, - event: StreamEventLike, - event_dict: dict[str, Any], - step: ExecutionStep | None, - persist_session_state: LocalPersistFn, - request_message: str, - orchestration_session: SessionContext | None = None, -) -> None: - """Handle terminal websocket events: persist, complete lifecycle, send. - - ``orchestration_session`` is retained for API compatibility but the - simplified architecture has no HITL/checkpoint logic. - """ - summary = build_execution_completion_summary( - event=event, - request_message=request_message, - run_id=lifecycle.run_id, - ) - - if event.kind == "done": - try: - await persist_session_state(include_volume_save=True, release_idle_session=True) - except Exception: - logger.debug( - "Failed to persist session state before final event; continuing", - exc_info=True, - ) - await lifecycle.complete_run( - _terminal_run_status(event), - step=step, - summary=summary, - ) - return - - try: - await persist_session_state(include_volume_save=True, release_idle_session=True) - except Exception: - logger.debug( - "Failed to persist session state after %s event; completing run anyway", - event.kind, - exc_info=True, - ) - - error_json: dict[str, Any] | None = {"error": event.text, "kind": event.kind} if event.kind == "error" else None - await lifecycle.complete_run( - _terminal_run_status(event), - step=step, - error_json=error_json, - summary=summary, - ) - - -def _as_record(value: Any) -> dict[str, Any]: - return value if isinstance(value, dict) else {} - - -def _as_text(value: Any) -> str | None: - if isinstance(value, str): - trimmed = value.strip() - return trimmed or None - return None - - -def _normalize_text_list(value: Any) -> list[str]: - if not isinstance(value, list): - return [] - return [item for item in (_as_text(entry) for entry in value) if item is not None] - - -def final_event_failed(payload: dict[str, Any]) -> bool: - runtime = _as_record(payload.get("runtime")) - runtime_degraded = bool(payload.get("runtime_degraded", runtime.get("runtime_degraded", False))) - category = _as_text(payload.get("runtime_failure_category") or runtime.get("runtime_failure_category")) - return runtime_degraded and category == "tool_execution_error" - - -def _extract_human_review_payload(payload: dict[str, Any]) -> dict[str, Any] | None: - raw = _as_record(payload.get("human_review")) - if raw: - required = raw.get("required") - if required is False: - return None - return { - "required": True, - "reason": _as_text(raw.get("reason")) or "Recursive repair requested human review before continuing.", - "repair_mode": _as_text(raw.get("repair_mode")), - "repair_target": _as_text(raw.get("repair_target")), - "repair_steps": _normalize_text_list(raw.get("repair_steps")), - } - - recursive_repair = _as_record(payload.get("recursive_repair")) - if _as_text(recursive_repair.get("repair_mode")) != "needs_human_review": - return None - - normalized_steps = _normalize_text_list(recursive_repair.get("repair_steps")) - return { - "required": True, - "reason": _as_text(payload.get("final_reasoning")) - or _as_text(recursive_repair.get("repair_rationale")) - or _as_text(recursive_repair.get("repair_target")) - or "Recursive repair requested human review before continuing.", - "repair_mode": "needs_human_review", - "repair_target": _as_text(recursive_repair.get("repair_target")), - "repair_steps": normalized_steps, - } - - -def _canonical_run_status( - kind: str, - payload: dict[str, Any], - *, - human_review_required: bool, -) -> str: - if kind == "done": - # A "done" event with payload["cancelled"]=True is a cancelled turn. - if isinstance(payload, dict) and payload.get("cancelled"): - return "cancelled" - if human_review_required: - return "needs_human_review" - return "error" if final_event_failed(payload) else "completed" - return "error" - - -def _build_fallback_final_artifact(event: StreamEventLike) -> dict[str, Any] | None: - if event.kind != "done": - return None - return { - "kind": "assistant_response", - "value": { - "text": event.text, - "final_markdown": event.text, - "summary": event.text, - }, - "finalization_mode": "RETURN", - } - - -def _build_minimum_summary( - *, - event: StreamEventLike, - summary_payload: dict[str, Any], - warnings: list[Any], - human_review: dict[str, Any] | None, - termination_reason: str, -) -> dict[str, Any]: - error_text = event.text if event.kind == "error" else None - summary = { - "termination_reason": termination_reason, - "duration_ms": summary_payload.get("duration_ms"), - "warnings": warnings, - "error": error_text, - } - if human_review is not None: - summary["human_review"] = human_review - return summary - - -def _resolve_terminal_status( - *, - existing_status: Any, - terminal_status: str, -) -> str: - normalized = _as_text(existing_status) - if terminal_status in {"needs_human_review", "error", "cancelled"}: - return terminal_status - return normalized or terminal_status - - -def _resolve_termination_reason( - *, - existing_reason: Any, - event_kind: str, - human_review_required: bool, -) -> str: - normalized = _as_text(existing_reason) - if human_review_required and normalized in {None, "", "done", "completed"}: - return "needs_human_review" - return normalized or event_kind - - -def build_execution_completion_summary( - *, - event: StreamEventLike, - request_message: str, - run_id: str, -) -> dict[str, Any]: - """Build the canonical execution summary payload from a terminal event.""" - payload = _as_record(event.payload) - runtime = _as_record(payload.get("runtime")) - run_result = _as_record(payload.get("run_result")) - summary_payload = _as_record(payload.get("summary")) - payload_final_artifact = _as_record(payload.get("final_artifact")) - human_review = _extract_human_review_payload(payload) - runtime_mode = ( - _as_text(payload.get("runtime_mode")) - or _as_text(runtime.get("runtime_mode")) - or _as_text(run_result.get("runtime_mode")) - or "daytona_pilot" - ) - terminal_status = _canonical_run_status( - event.kind, - payload, - human_review_required=human_review is not None, - ) - resolved_termination_reason = _resolve_termination_reason( - existing_reason=run_result.get("termination_reason") or summary_payload.get("termination_reason"), - event_kind=event.kind, - human_review_required=human_review is not None, - ) - warnings = list(summary_payload.get("warnings") or payload.get("guardrail_warnings") or []) - minimum_summary = _build_minimum_summary( - event=event, - summary_payload=summary_payload, - warnings=warnings, - human_review=human_review, - termination_reason=resolved_termination_reason, - ) - - if run_result: - normalized = dict(run_result) - normalized.setdefault("run_id", run_result.get("run_id") or runtime.get("run_id") or run_id) - normalized.setdefault("runtime_mode", runtime_mode) - normalized.setdefault("task", run_result.get("task") or request_message) - normalized["status"] = _resolve_terminal_status( - existing_status=run_result.get("status"), - terminal_status=terminal_status, - ) - normalized["termination_reason"] = resolved_termination_reason - normalized.setdefault("duration_ms", summary_payload.get("duration_ms")) - normalized.setdefault("warnings", warnings) - nested_summary = _as_record(normalized.get("summary")) - nested_summary = {**minimum_summary, **nested_summary} - if summary_payload: - nested_summary = {**nested_summary, **summary_payload} - nested_summary["termination_reason"] = resolved_termination_reason - if warnings and not nested_summary.get("warnings"): - nested_summary["warnings"] = warnings - if human_review is not None: - normalized["human_review"] = human_review - nested_summary["human_review"] = human_review - normalized["summary"] = nested_summary - normalized.setdefault( - "final_artifact", - payload_final_artifact or _build_fallback_final_artifact(event), - ) - return normalized - - final_artifact = payload_final_artifact or _build_fallback_final_artifact(event) - - return { - "run_id": _as_text(runtime.get("run_id")) or run_id, - "runtime_mode": runtime_mode, - "task": request_message, - "status": terminal_status, - "termination_reason": resolved_termination_reason, - "duration_ms": summary_payload.get("duration_ms"), - "iterations": [], - "callbacks": [], - "prompts": [], - "context_sources": [], - "sources": list(payload.get("sources") or []), - "attachments": list(payload.get("attachments") or []), - "final_artifact": final_artifact, - "summary": minimum_summary, - "warnings": warnings, - **({"human_review": human_review} if human_review is not None else {}), - } - - -async def handle_stream_error( - *, - websocket: WebSocket | None, - lifecycle: ExecutionLifecycleManager, - step_builder: ExecutionStepBuilder, - exc: Exception, - request_message: str, -) -> None: - """Log, emit, and persist a failed websocket streaming turn.""" - error_code = classify_stream_failure(exc) - logger.error( - "Streaming error: %s", - _sanitize_for_log(exc), - exc_info=True, - extra={ - "error_type": type(exc).__name__, - "error_code": error_code, - }, - ) - if websocket is not None: - await _try_send_json( - websocket, - _error_envelope( - code=error_code, - message=f"Streaming error: {exc}", - details={"error_type": type(exc).__name__}, - ), - ) - if lifecycle.run_completed: - return - - error_text = f"Streaming error: {exc}" - error_payload = { - "error_type": type(exc).__name__, - "error_code": error_code, - } - error_step = step_builder.from_stream_event( - kind="error", - text=error_text, - payload=error_payload, - timestamp=time.time(), - ) - if error_step is not None: - await lifecycle.emit_step(error_step) - await lifecycle.complete_run( - RunStatus.FAILED, - step=error_step, - error_json={ - "error": str(exc), - "error_type": type(exc).__name__, - "code": error_code, - }, - summary=build_execution_completion_summary( - event=WorkspaceEvent( - kind="error", - text=error_text, - payload=error_payload, - terminal=True, - ), - request_message=request_message, - run_id=lifecycle.run_id, - ), - ) - - -def _is_terminal_transport_event(event: StreamEventLike) -> bool: - """Return websocket-terminal semantics for worker and runtime events.""" - - backend_event = adapt_stream_event( - kind=event.kind, - text=event.text, - payload=event.payload if isinstance(event.payload, dict) else None, - timestamp=event.timestamp, - ) - return bool(getattr(event, "terminal", False)) or is_terminal_backend_event(backend_event) - - -def _build_agent_stream_kwargs(request: WorkspaceTaskRequest) -> dict[str, Any]: - """Build canonical runtime stream kwargs from a workspace task request.""" - kwargs: dict[str, Any] = { - "message": request.message, - "trace": request.trace, - "cancel_check": request.cancel_check, - "docs_path": request.docs_path, - } - if request.repo_url is not None: - kwargs["repo_url"] = request.repo_url - if request.repo_ref is not None: - kwargs["repo_ref"] = request.repo_ref - if request.context_paths is not None: - kwargs["context_paths"] = list(request.context_paths) - if request.batch_concurrency is not None: - kwargs["batch_concurrency"] = request.batch_concurrency - if request.workspace_id is not None: - kwargs["volume_name"] = request.workspace_id - return kwargs - - -def _to_workspace_event(event: Any) -> WorkspaceEvent: - """Normalize a runtime-style stream event into a workspace event.""" - raw_ts = getattr(event, "timestamp", None) - timestamp = raw_ts if isinstance(raw_ts, datetime) else datetime.now(timezone.utc) - return WorkspaceEvent( - kind=str(getattr(event, "kind", "status")), - text=str(getattr(event, "text", "") or ""), - payload=dict(getattr(event, "payload", {}) or {}), - timestamp=timestamp, - terminal=is_terminal_stream_event_kind(str(getattr(event, "kind", ""))), - ) - - -async def stream_agent_turn( - request: WorkspaceTaskRequest, -) -> AsyncIterator[WorkspaceEvent]: - """Stream one workspace task directly through the agent without HITL wrapper.""" - if request.execution_mode is not None: - request.agent.set_execution_mode(request.execution_mode) - if request.prepare is not None: - await request.prepare() - async for runtime_event in request.agent.aiter_chat_turn_stream(**_build_agent_stream_kwargs(request)): - yield _to_workspace_event(runtime_event) - - -async def run_streaming_turn( - *, - websocket: WebSocket | None, - agent: ChatAgentProtocol, - prepared_turn: PreparedStreamingTurn, - orchestration_session: SessionContext | None, - cancel_check: Callable[[], bool], - interpreter: object | None, - persist_session_state: LocalPersistFn, - execution_emitter: ExecutionEventEmitter, -) -> str | None: - """Execute one streaming turn, emitting events and persisting lifecycle steps.""" - - lifecycle = prepared_turn.lifecycle - step_builder = prepared_turn.step_builder - if interpreter is not None and hasattr(lifecycle, "active_run_db_id"): - # interpreter is typed as `object | None` here; use setattr so ty - # accepts the dynamic attribute (declared on DaytonaInterpreter). - setattr(interpreter, "_host_run_id", lifecycle.active_run_db_id) - await lifecycle.emit_started() - ws_loop = asyncio.get_running_loop() - repl_hook_bridge = ReplHookBridge( - ws_loop=ws_loop, - lifecycle=lifecycle, - step_builder=step_builder, - interpreter=interpreter, - enqueue_nonblocking=enqueue_latest_nonblocking, - ) - - last_loaded_docs_path = prepared_turn.last_loaded_docs_path - if should_reload_docs_path(last_loaded_docs_path, prepared_turn.docs_path): - agent.load_document(str(prepared_turn.docs_path)) - last_loaded_docs_path = str(prepared_turn.docs_path).strip() - - try: - - async def _stream_body() -> None: - await _stream_agent_events( - websocket=websocket, - agent=agent, - prepared_turn=prepared_turn, - orchestration_session=orchestration_session, - cancel_check=cancel_check, - lifecycle=lifecycle, - hosted_repl_bridge=repl_hook_bridge, - step_builder=step_builder, - analytics_enabled=prepared_turn.analytics_enabled, - persist_session_state=persist_session_state, - execution_emitter=execution_emitter, - ) - - await _run_prepared_stream( - mlflow_trace_context=prepared_turn.mlflow_trace_context, - stream_body=_stream_body, - ) - except WebSocketDisconnect: - raise - except Exception as exc: - try: - await persist_session_state( - include_volume_save=True, - allow_volume_session_create=False, - release_idle_session=True, - ) - except Exception: - logger.debug("Failed to persist session state after stream exception", exc_info=True) - await handle_stream_error( - websocket=websocket, - lifecycle=lifecycle, - step_builder=step_builder, - exc=exc, - request_message=prepared_turn.message, - ) - - return last_loaded_docs_path - - -async def _run_prepared_stream( - *, - mlflow_trace_context: Any | None, - stream_body: Callable[[], Awaitable[None]], -) -> None: - if mlflow_trace_context is None: - await stream_body() - return - - from fleet_rlm.integrations.observability.mlflow_runtime import ( - mlflow_request_context, - ) - - with mlflow_request_context(mlflow_trace_context): - await stream_body() - - -async def _stream_agent_events( - *, - websocket: WebSocket | None, - agent: ChatAgentProtocol, - prepared_turn: PreparedStreamingTurn, - orchestration_session: SessionContext | None, - cancel_check: Callable[[], bool], - lifecycle: ExecutionLifecycleManager, - hosted_repl_bridge: ReplHookBridge | None, - step_builder: ExecutionStepBuilder, - analytics_enabled: bool | None, - persist_session_state: LocalPersistFn, - execution_emitter: ExecutionEventEmitter, -) -> None: - worker_request = build_workspace_task_request( - agent=agent, - prepared_turn=prepared_turn, - cancel_check=cancel_check, - ) - - bridge_started = False - try: - if hosted_repl_bridge is not None: - hosted_repl_bridge.start() - bridge_started = True - - with runtime_telemetry_enabled_context(analytics_enabled): - async for worker_event in stream_agent_turn(worker_request): - await _emit_stream_event( - websocket=websocket, - lifecycle=lifecycle, - step_builder=step_builder, - event=worker_event, - orchestration_session=orchestration_session, - persist_session_state=persist_session_state, - request_message=prepared_turn.message, - execution_emitter=execution_emitter, - ) - finally: - if hosted_repl_bridge is not None and bridge_started: - try: - await hosted_repl_bridge.stop() - except Exception: - pass - - if not lifecycle.run_completed: - lifecycle.raise_if_persistence_error() - await lifecycle.complete_run(RunStatus.COMPLETED) - - -async def _emit_stream_event( - *, - websocket: WebSocket | None, - lifecycle: ExecutionLifecycleManager, - step_builder: ExecutionStepBuilder, - event: WorkspaceEvent | StreamEventLike, - orchestration_session: SessionContext | None = None, - persist_session_state: LocalPersistFn, - request_message: str, - execution_emitter: ExecutionEventEmitter, -) -> None: - lifecycle.raise_if_persistence_error() - payload = event.payload - if event.kind == "done": - payload = merge_trace_result_metadata( - payload if isinstance(payload, dict) else None, - response_preview=event.text, - trace_metadata=_runtime_trace_metadata(payload if isinstance(payload, dict) else None), - ) - event_dict = build_stream_event_dict(event=event, payload=payload) - is_terminal_event = _is_terminal_transport_event(event) - - # We NO LONGER send raw event_dicts via the websocket directly. - # Instead, we rely entirely on the ExecutionEventEmitter (via lifecycle) - # which emits typed ExecutionEvent payloads. - - event_timestamp = event.timestamp.timestamp() - step = step_builder.from_stream_event( - kind=event.kind, - text=event.text, - payload=payload, - timestamp=event_timestamp, - ) - if step is not None: - if event.kind == "text": - await lifecycle.emit_step(step) - else: - await asyncio.gather( - lifecycle.emit_step(step), - lifecycle.persist_step(step), - ) - lifecycle.raise_if_persistence_error() - - if is_terminal_event: - await handle_terminal_stream_event( - websocket=websocket, - lifecycle=lifecycle, - event=event, - event_dict=event_dict, - step=step, - orchestration_session=orchestration_session, - persist_session_state=persist_session_state, - request_message=request_message, - ) - - -async def _process_chat_message( - *, - websocket: WebSocket | None, - msg: WSMessage, - agent: ChatAgentProtocol, - interpreter: object | None, - session: _ChatSessionState, - local_persist: LocalPersistFn, - runtime: _PreparedChatRuntime, - workspace_id: str, - user_id: str, - sess_id: str, - execution_emitter: ExecutionEventEmitter, -) -> str | None: - """Process one ``message`` payload and return the loaded docs path.""" - prepared_turn = await prepare_chat_message_turn( - websocket=websocket, - msg=msg, - agent=agent, - session=session, - local_persist=local_persist, - runtime=runtime, - workspace_id=workspace_id, - user_id=user_id, - sess_id=sess_id, - execution_emitter=execution_emitter, - ) - if prepared_turn is None: - return session.last_loaded_docs_path - - def cancel_check() -> bool: - return session.cancel_flag["cancelled"] - - orchestration_session = session.orchestration_session or SessionContext( - workspace_id=workspace_id, - user_id=user_id, - session_id=sess_id, - session_record=session.session_record, - ) - session.orchestration_session = orchestration_session - - return await run_streaming_turn( - websocket=websocket, - agent=agent, - prepared_turn=prepared_turn, - orchestration_session=orchestration_session, - cancel_check=cancel_check, - interpreter=interpreter, - persist_session_state=local_persist, - execution_emitter=execution_emitter, - ) - - -def _ensure_pending_receive_task( - *, - websocket: WebSocket, - pending_receive_task: asyncio.Task[object] | None, -) -> asyncio.Task[object]: - if pending_receive_task is not None: - return pending_receive_task - return asyncio.create_task(websocket.receive_json()) - - -async def _await_message_while_streaming( - *, - websocket: WebSocket, - stream_task: asyncio.Task[str | None], - pending_receive_task: asyncio.Task[object] | None, - session: _ChatSessionState, -) -> tuple[WSMessage | None, asyncio.Task[str | None] | None, asyncio.Task[object] | None]: - pending_receive_task = _ensure_pending_receive_task( - websocket=websocket, - pending_receive_task=pending_receive_task, - ) - done, _pending = await asyncio.wait( - {stream_task, pending_receive_task}, - return_when=asyncio.FIRST_COMPLETED, - ) - - if stream_task in done: - session.last_loaded_docs_path = await stream_task - return None, None, pending_receive_task - - raw_payload = await pending_receive_task - msg = await parse_ws_message_or_send_error( - websocket=websocket, - raw_payload=raw_payload, - ) - return msg, stream_task, None - - -async def _background_execution_task( - *, - msg: WSMessage, - session_cache: SessionCacheDeps, - runtime: _PreparedChatRuntime, - session: _ChatSessionState, - workspace_id: str, - user_id: str, - sess_id: str, - execution_emitter: ExecutionEventEmitter, -) -> str | None: - """Run execution in the background with its own agent context.""" - agent_context = await build_chat_agent_context(runtime) - async with agent_context as agent: - interpreter = getattr(agent, "interpreter", None) - set_interpreter_default_profile(interpreter, runtime.cfg) - - async def _noop_persist( - *, - include_volume_save: bool = True, - latest_user_message: str = "", - ) -> None: - _ = include_volume_save, latest_user_message - - ( - session.active_key, - session.active_manifest_path, - session.session_record, - session.last_loaded_docs_path, - session.orchestration_session, - ) = await switch_session_if_needed( - session_cache=session_cache, - agent=agent, - interpreter=interpreter, - workspace_id=workspace_id, - user_id=user_id, - sess_id=sess_id, - owner_tenant_claim=session.owner_tenant_claim, - owner_user_claim=session.owner_user_claim, - active_key=None, - session_record=session.session_record, - last_loaded_docs_path=session.last_loaded_docs_path, - local_persist=_noop_persist, - persistence=runtime.persistence, - identity_rows=runtime.identity_rows, - ) - - agent._db_session_id = (session.session_record or {}).get("db_session_id") - agent._identity_rows = runtime.identity_rows - if agent.interpreter is not None: - agent.interpreter._host_repository = runtime.persistence - agent.interpreter._host_identity = runtime.identity_rows - agent.interpreter._host_run_id = None - local_persist = build_local_persist_fn( - session_cache=session_cache, - runtime=runtime, - agent=agent, - interpreter=interpreter, - session=session, - ) - - # Execute - return await _process_chat_message( - websocket=None, # Decoupled - msg=msg, - agent=agent, - interpreter=interpreter, - session=session, - local_persist=local_persist, - runtime=runtime, - workspace_id=workspace_id, - user_id=user_id, - sess_id=sess_id, - execution_emitter=execution_emitter, - ) - - -async def _handle_message_while_streaming( - *, - websocket: WebSocket, - msg: WSMessage, - agent: ChatAgentProtocol, - runtime: _PreparedChatRuntime, - session: _ChatSessionState, - local_persist: LocalPersistFn, -) -> bool: - if msg.type == "cancel": - session.cancel_flag["cancelled"] = True - return True - - if msg.type == "command": - await handle_command_with_persist( - websocket=websocket, - agent=agent, - payload=msg.model_dump(), - session_record=session.session_record, - persistence=runtime.persistence, - identity_rows=runtime.identity_rows, - persistence_required=runtime.persistence_required, - local_persist=local_persist, - ) - return True - - if session.lifecycle is not None and session.lifecycle.run_completed: - return False - - await _try_send_json( - websocket, - { - "type": "error", - "message": ( - "A run is already in progress. Cancel it or wait for completion before sending another message." - ), - }, - ) - return True - - -async def _receive_next_chat_message( - *, - websocket: WebSocket, - pending_message: WSMessage | None, - pending_receive_task: asyncio.Task[object] | None, -) -> tuple[WSMessage | None, asyncio.Task[object] | None]: - if pending_message is not None: - return pending_message, pending_receive_task - - if pending_receive_task is not None: - raw_payload = await pending_receive_task - pending_receive_task = None - else: - raw_payload = await websocket.receive_json() - - msg = await parse_ws_message_or_send_error( - websocket=websocket, - raw_payload=raw_payload, - ) - return msg, pending_receive_task - - -async def _handle_idle_non_turn_message( - *, - websocket: WebSocket, - msg: WSMessage, - agent: ChatAgentProtocol, - runtime: _PreparedChatRuntime, - session: _ChatSessionState, - local_persist: LocalPersistFn, -) -> bool: - if msg.type == "cancel": - session.cancel_flag["cancelled"] = True - await _try_send_json( - websocket, - _error_envelope( - code="no_active_run", - message="No active websocket run is available to cancel.", - ), - ) - return True - - if msg.type == "command": - await handle_command_with_persist( - websocket=websocket, - agent=agent, - payload=msg.model_dump(), - session_record=session.session_record, - persistence=runtime.persistence, - identity_rows=runtime.identity_rows, - persistence_required=runtime.persistence_required, - local_persist=local_persist, - ) - return True - - if msg.type != "message": - await _try_send_json( - websocket, - {"type": "error", "message": f"Unknown message type: {msg.type}"}, - ) - return True - - return False - - -class _ExecutionConnectionLoop: - """Connection-scoped websocket message loop for one execution socket.""" - - def __init__( - self, - *, - websocket: WebSocket, - session_cache: SessionCacheDeps, - diagnostics_deps: DiagnosticsDeps, - runtime: _PreparedChatRuntime, - agent: ChatAgentProtocol, - interpreter: object | None, - session: _ChatSessionState, - local_persist: LocalPersistFn, - initial_message: WSMessage | None = None, - ) -> None: - self.websocket = websocket - self.session_cache = session_cache - self.diagnostics_deps = diagnostics_deps - self.runtime = runtime - self.agent = agent - self.interpreter = interpreter - self.session = session - self.local_persist = local_persist - self.execution_emitter = get_execution_emitter(diagnostics_deps) - self.stream_task: asyncio.Task[str | None] | asyncio.Task[None] | None = None - self.pending_receive_task: asyncio.Task[object] | None = None - self.pending_message = initial_message - - async def run(self) -> None: - try: - while True: - if self.stream_task is not None: - ( - msg, - self.stream_task, - self.pending_receive_task, - ) = await _await_message_while_streaming( - websocket=self.websocket, - stream_task=self.stream_task, - pending_receive_task=self.pending_receive_task, - session=self.session, - ) - if msg is None: - continue - if self.stream_task is None: - self.pending_message = msg - continue - - if await _handle_message_while_streaming( - websocket=self.websocket, - msg=msg, - agent=self.agent, - runtime=self.runtime, - session=self.session, - local_persist=self.local_persist, - ): - continue - continue - - ( - self.pending_message, - self.pending_receive_task, - ) = await _receive_next_chat_message( - websocket=self.websocket, - pending_message=self.pending_message, - pending_receive_task=self.pending_receive_task, - ) - msg = self.pending_message - self.pending_message = None - if msg is None: - continue - - if await _handle_idle_non_turn_message( - websocket=self.websocket, - msg=msg, - agent=self.agent, - runtime=self.runtime, - session=self.session, - local_persist=self.local_persist, - ): - continue - - if not str(msg.content or "").strip(): - await _try_send_json( - self.websocket, - {"type": "error", "message": "Message content cannot be empty"}, - ) - continue - - workspace_id, user_id, sess_id = resolve_session_identity( - msg=msg, - workspace_id=self.session.canonical_workspace_id, - user_id=self.session.canonical_user_id, - ) - await self.execution_emitter.update_subscription( - self.websocket, - ExecutionSubscription( - workspace_id=workspace_id, - user_id=user_id, - session_id=sess_id, - ), - ) - startup_event = build_startup_status_event() - await _try_send_json( - self.websocket, - { - "type": "event", - "data": build_stream_event_dict( - event=startup_event, - payload=startup_event.payload, - ), - }, - ) - routing_preview_event = _build_routing_preview_event(self.agent, msg) - if routing_preview_event is not None: - await _try_send_json( - self.websocket, - { - "type": "event", - "data": build_stream_event_dict( - event=routing_preview_event, - payload=routing_preview_event.payload, - ), - }, - ) - self.stream_task = asyncio.create_task( - _background_execution_task( - msg=msg, - session_cache=self.session_cache, - runtime=self.runtime, - session=self.session, - workspace_id=workspace_id, - user_id=user_id, - sess_id=sess_id, - execution_emitter=self.execution_emitter, - ) - ) - except (asyncio.CancelledError, WebSocketDisconnect): - await handle_chat_disconnect( - pending_receive_task=self.pending_receive_task, - stream_task=self.stream_task, - cancel_flag=self.session.cancel_flag, - local_persist=self.local_persist, - lifecycle=self.session.lifecycle, - cancel_active_run=False, - persist_on_disconnect=False, - ) - except Exception as exc: - await handle_chat_loop_exception( - websocket=self.websocket, - exc=exc, - pending_receive_task=self.pending_receive_task, - stream_task=self.stream_task, - local_persist=self.local_persist, - lifecycle=self.session.lifecycle, - ) - - -async def _chat_message_loop( - *, - websocket: WebSocket, - session_cache: SessionCacheDeps, - diagnostics_deps: DiagnosticsDeps, - runtime: _PreparedChatRuntime, - agent: ChatAgentProtocol, - interpreter: object | None, - session: _ChatSessionState, - local_persist: LocalPersistFn, - initial_message: WSMessage | None = None, -) -> None: - loop = _ExecutionConnectionLoop( - websocket=websocket, - session_cache=session_cache, - diagnostics_deps=diagnostics_deps, - runtime=runtime, - agent=agent, - interpreter=interpreter, - session=session, - local_persist=local_persist, - initial_message=initial_message, - ) - await loop.run() diff --git a/src/fleet_rlm/api/routers/ws/stream_events.py b/src/fleet_rlm/api/routers/ws/stream_events.py new file mode 100644 index 000000000..39a8b702e --- /dev/null +++ b/src/fleet_rlm/api/routers/ws/stream_events.py @@ -0,0 +1,143 @@ +"""Event dataclasses and serialization for WebSocket streaming.""" + +from __future__ import annotations + +import uuid +from collections.abc import AsyncIterator, Callable +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any + +from fleet_rlm.api.events.event_adapter import adapt_stream_event, build_chat_event_payload, is_terminal_backend_event +from fleet_rlm.api.events.project_chat import project_chat +from fleet_rlm.runtime.events import RuntimeEvent +from fleet_rlm.runtime.execution.streaming_events import is_terminal_stream_event_kind + +from ...runtime_services.chat_runtime import StreamEventLike + + +@dataclass(slots=True) +class WorkspaceEvent: + """Normalized event shape for websocket streaming.""" + + kind: str + text: str = "" + payload: dict[str, Any] = field(default_factory=dict) + timestamp: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + terminal: bool = False + + +@dataclass(slots=True) +class WorkspaceTaskRequest: + """Input needed to execute one workspace task end-to-end.""" + + agent: Any + message: str + execution_mode: str | None = None + trace: bool = True + docs_path: str | None = None + repo_url: str | None = None + repo_ref: str | None = None + context_paths: list[str] | None = None + batch_concurrency: int | None = None + workspace_id: str | None = None + cancel_check: Callable[[], bool] | None = None + prepare: Callable[[], Any] | None = None + + +def build_stream_event_dict( + *, + event: StreamEventLike, + payload: Any, + sequence: int = 0, + run_id: str | None = None, +) -> dict[str, Any]: + """Serialize one stream event for websocket delivery. + + Uses the typed :func:`~fleet_rlm.api.events.project_chat.project_chat` + projector when *event* is a :class:`~fleet_rlm.runtime.events.RuntimeEvent`, + falling back to the legacy ``adapt_stream_event`` path for plain + ``WorkspaceEvent`` / ``StreamEventLike`` objects. + """ + if isinstance(event, RuntimeEvent): + return project_chat(event, sequence=sequence, run_id=run_id) + backend_event = adapt_stream_event( + kind=event.kind, + text=event.text, + payload=payload if isinstance(payload, dict) else None, + timestamp=event.timestamp, + ) + event_dict = build_chat_event_payload(backend_event) + event_dict.setdefault("event_id", uuid.uuid4().hex) + return event_dict + + +def _is_terminal_transport_event(event: StreamEventLike) -> bool: + """Return websocket-terminal semantics for worker and runtime events.""" + if isinstance(event, RuntimeEvent): + return event.kind.is_terminal() + return bool(getattr(event, "terminal", False)) or is_terminal_backend_event( + adapt_stream_event( + kind=event.kind, + text=event.text, + payload=event.payload if isinstance(event.payload, dict) else None, + timestamp=event.timestamp, + ) + ) + + +def _build_agent_stream_kwargs(request: WorkspaceTaskRequest) -> dict[str, Any]: + """Build canonical runtime stream kwargs from a workspace task request.""" + kwargs: dict[str, Any] = { + "message": request.message, + "trace": request.trace, + "cancel_check": request.cancel_check, + "docs_path": request.docs_path, + } + if request.repo_url is not None: + kwargs["repo_url"] = request.repo_url + if request.repo_ref is not None: + kwargs["repo_ref"] = request.repo_ref + if request.context_paths is not None: + kwargs["context_paths"] = list(request.context_paths) + if request.batch_concurrency is not None: + kwargs["batch_concurrency"] = request.batch_concurrency + if request.workspace_id is not None: + kwargs["volume_name"] = request.workspace_id + return kwargs + + +def _to_workspace_event(event: Any) -> WorkspaceEvent: + """Normalize a runtime-style stream event into a workspace event.""" + raw_ts = getattr(event, "timestamp", None) + timestamp = raw_ts if isinstance(raw_ts, datetime) else datetime.now(timezone.utc) + return WorkspaceEvent( + kind=str(getattr(event, "kind", "status")), + text=str(getattr(event, "text", "") or ""), + payload=dict(getattr(event, "payload", {}) or {}), + timestamp=timestamp, + terminal=is_terminal_stream_event_kind(str(getattr(event, "kind", ""))), + ) + + +async def stream_agent_turn( + request: WorkspaceTaskRequest, +) -> AsyncIterator[WorkspaceEvent]: + """Stream one workspace task directly through the agent without HITL wrapper.""" + if request.execution_mode is not None: + request.agent.set_execution_mode(request.execution_mode) + if request.prepare is not None: + await request.prepare() + async for runtime_event in request.agent.aiter_chat_turn_stream(**_build_agent_stream_kwargs(request)): + yield _to_workspace_event(runtime_event) + + +__all__ = [ + "WorkspaceEvent", + "WorkspaceTaskRequest", + "build_stream_event_dict", + "_is_terminal_transport_event", + "_build_agent_stream_kwargs", + "_to_workspace_event", + "stream_agent_turn", +] diff --git a/src/fleet_rlm/api/routers/ws/stream_loop.py b/src/fleet_rlm/api/routers/ws/stream_loop.py new file mode 100644 index 000000000..048772fcf --- /dev/null +++ b/src/fleet_rlm/api/routers/ws/stream_loop.py @@ -0,0 +1,57 @@ +"""WebSocket streaming loop — thin wiring module. + +Execution logic lives in :mod:`.turn_runner`. +Connection loop lives in :mod:`.connection_loop`. +""" + +from __future__ import annotations + +from fastapi import WebSocket + +from ...dependencies import DiagnosticsDeps, SessionCacheDeps +from ...runtime_services.chat_runtime import ( + ChatAgentProtocol, + LocalPersistFn, +) +from ...runtime_services.chat_runtime import ( + ChatSessionState as _ChatSessionState, +) +from ...runtime_services.chat_runtime import ( + PreparedChatRuntime as _PreparedChatRuntime, +) +from ...schemas import WSMessage +from .connection_loop import _ExecutionConnectionLoop +from .turn_runner import handle_stream_error, handle_terminal_stream_event, run_streaming_turn + + +async def _chat_message_loop( + *, + websocket: WebSocket, + session_cache: SessionCacheDeps, + diagnostics_deps: DiagnosticsDeps, + runtime: _PreparedChatRuntime, + agent: ChatAgentProtocol, + interpreter: object | None, + session: _ChatSessionState, + local_persist: LocalPersistFn, + initial_message: WSMessage | None = None, +) -> None: + await _ExecutionConnectionLoop( + websocket=websocket, + session_cache=session_cache, + diagnostics_deps=diagnostics_deps, + runtime=runtime, + agent=agent, + interpreter=interpreter, + session=session, + local_persist=local_persist, + initial_message=initial_message, + ).run() + + +__all__ = [ + "_chat_message_loop", + "handle_terminal_stream_event", + "handle_stream_error", + "run_streaming_turn", +] diff --git a/src/fleet_rlm/api/routers/ws/stream_summary.py b/src/fleet_rlm/api/routers/ws/stream_summary.py new file mode 100644 index 000000000..6d846c922 --- /dev/null +++ b/src/fleet_rlm/api/routers/ws/stream_summary.py @@ -0,0 +1,303 @@ +"""Execution completion summary and MLflow metadata for WebSocket streaming.""" + +from __future__ import annotations + +from typing import Any + +from fleet_rlm.integrations.observability.mlflow_context import ( + merge_trace_result_metadata as _merge_trace_result_metadata, +) +from fleet_rlm.runtime.execution.streaming_events import _normalize_trajectory + +from ...runtime_services.chat_runtime import StreamEventLike + + +def merge_trace_result_metadata( + payload: dict[str, Any] | None, + *, + response_preview: str | None = None, + trace_metadata: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Compatibility shim for MLflow final-event metadata enrichment.""" + return _merge_trace_result_metadata( + payload, + response_preview=response_preview, + trace_metadata=trace_metadata, + ) + + +def _runtime_trace_metadata(payload: dict[str, Any] | None) -> dict[str, Any]: + if not isinstance(payload, dict): + return {} + + runtime_payload = payload.get("runtime") + runtime = runtime_payload if isinstance(runtime_payload, dict) else {} + + metadata: dict[str, Any] = {} + for key in ( + "routing_decision", + "source_url", + "execution_mode", + "runtime_module", + ): + value = payload.get(key, runtime.get(key)) + if value not in (None, "", False): + metadata[f"fleet_rlm.{key}"] = value + + selected_skills = payload.get("selected_skills") + if isinstance(selected_skills, list): + metadata["fleet_rlm.selected_skills"] = ",".join(str(item) for item in selected_skills if str(item)) + + trajectory_steps = _normalize_trajectory(payload.get("trajectory")) + if trajectory_steps: + metadata["fleet_rlm.trajectory_steps"] = str(len(trajectory_steps)) + if any(step.get("thought") for step in trajectory_steps): + metadata["fleet_rlm.trajectory_has_reasoning"] = "true" + if any(step.get("tool_name") for step in trajectory_steps): + metadata["fleet_rlm.trajectory_has_tools"] = "true" + if any( + "repl" in str(step.get("tool_name", "")).lower() or "code" in step or step.get("type") == "repl" + for step in trajectory_steps + ): + metadata["fleet_rlm.trajectory_has_repl"] = "true" + if any(step.get("output") is not None or step.get("observation") is not None for step in trajectory_steps): + metadata["fleet_rlm.trajectory_has_outputs"] = "true" + + for key in ( + "runtime_degraded", + "runtime_failure_category", + "runtime_failure_phase", + "runtime_fallback_used", + ): + value = payload.get(key, runtime.get(key)) + if value in (None, "", False): + if key in {"runtime_degraded", "runtime_fallback_used"} and value is False: + metadata[key] = False + continue + metadata[key] = value + return metadata + + +def _as_record(value: Any) -> dict[str, Any]: + return value if isinstance(value, dict) else {} + + +def _as_text(value: Any) -> str | None: + if isinstance(value, str): + trimmed = value.strip() + return trimmed or None + return None + + +def _normalize_text_list(value: Any) -> list[str]: + if not isinstance(value, list): + return [] + return [item for item in (_as_text(entry) for entry in value) if item is not None] + + +def final_event_failed(payload: dict[str, Any]) -> bool: + runtime = _as_record(payload.get("runtime")) + runtime_degraded = bool(payload.get("runtime_degraded", runtime.get("runtime_degraded", False))) + category = _as_text(payload.get("runtime_failure_category") or runtime.get("runtime_failure_category")) + return runtime_degraded and category == "tool_execution_error" + + +def _extract_human_review_payload(payload: dict[str, Any]) -> dict[str, Any] | None: + raw = _as_record(payload.get("human_review")) + if raw: + required = raw.get("required") + if required is False: + return None + return { + "required": True, + "reason": _as_text(raw.get("reason")) or "Recursive repair requested human review before continuing.", + "repair_mode": _as_text(raw.get("repair_mode")), + "repair_target": _as_text(raw.get("repair_target")), + "repair_steps": _normalize_text_list(raw.get("repair_steps")), + } + + recursive_repair = _as_record(payload.get("recursive_repair")) + if _as_text(recursive_repair.get("repair_mode")) != "needs_human_review": + return None + + normalized_steps = _normalize_text_list(recursive_repair.get("repair_steps")) + return { + "required": True, + "reason": _as_text(payload.get("final_reasoning")) + or _as_text(recursive_repair.get("repair_rationale")) + or _as_text(recursive_repair.get("repair_target")) + or "Recursive repair requested human review before continuing.", + "repair_mode": "needs_human_review", + "repair_target": _as_text(recursive_repair.get("repair_target")), + "repair_steps": normalized_steps, + } + + +def _canonical_run_status( + kind: str, + payload: dict[str, Any], + *, + human_review_required: bool, +) -> str: + if kind == "done": + # A "done" event with payload["cancelled"]=True is a cancelled turn. + if isinstance(payload, dict) and payload.get("cancelled"): + return "cancelled" + if human_review_required: + return "needs_human_review" + return "error" if final_event_failed(payload) else "completed" + return "error" + + +def _build_fallback_final_artifact(event: StreamEventLike) -> dict[str, Any] | None: + if event.kind != "done": + return None + return { + "kind": "assistant_response", + "value": { + "text": event.text, + "final_markdown": event.text, + "summary": event.text, + }, + "finalization_mode": "RETURN", + } + + +def _build_minimum_summary( + *, + event: StreamEventLike, + summary_payload: dict[str, Any], + warnings: list[Any], + human_review: dict[str, Any] | None, + termination_reason: str, +) -> dict[str, Any]: + error_text = event.text if event.kind == "error" else None + summary = { + "termination_reason": termination_reason, + "duration_ms": summary_payload.get("duration_ms"), + "warnings": warnings, + "error": error_text, + } + if human_review is not None: + summary["human_review"] = human_review + return summary + + +def _resolve_terminal_status( + *, + existing_status: Any, + terminal_status: str, +) -> str: + normalized = _as_text(existing_status) + if terminal_status in {"needs_human_review", "error", "cancelled"}: + return terminal_status + return normalized or terminal_status + + +def _resolve_termination_reason( + *, + existing_reason: Any, + event_kind: str, + human_review_required: bool, +) -> str: + normalized = _as_text(existing_reason) + if human_review_required and normalized in {None, "", "done", "completed"}: + return "needs_human_review" + return normalized or event_kind + + +def build_execution_completion_summary( + *, + event: StreamEventLike, + request_message: str, + run_id: str, +) -> dict[str, Any]: + """Build the canonical execution summary payload from a terminal event.""" + payload = _as_record(event.payload) + runtime = _as_record(payload.get("runtime")) + run_result = _as_record(payload.get("run_result")) + summary_payload = _as_record(payload.get("summary")) + payload_final_artifact = _as_record(payload.get("final_artifact")) + human_review = _extract_human_review_payload(payload) + runtime_mode = ( + _as_text(payload.get("runtime_mode")) + or _as_text(runtime.get("runtime_mode")) + or _as_text(run_result.get("runtime_mode")) + or "daytona_pilot" + ) + terminal_status = _canonical_run_status( + event.kind, + payload, + human_review_required=human_review is not None, + ) + resolved_termination_reason = _resolve_termination_reason( + existing_reason=run_result.get("termination_reason") or summary_payload.get("termination_reason"), + event_kind=event.kind, + human_review_required=human_review is not None, + ) + warnings = list(summary_payload.get("warnings") or payload.get("guardrail_warnings") or []) + minimum_summary = _build_minimum_summary( + event=event, + summary_payload=summary_payload, + warnings=warnings, + human_review=human_review, + termination_reason=resolved_termination_reason, + ) + + if run_result: + normalized = dict(run_result) + normalized.setdefault("run_id", run_result.get("run_id") or runtime.get("run_id") or run_id) + normalized.setdefault("runtime_mode", runtime_mode) + normalized.setdefault("task", run_result.get("task") or request_message) + normalized["status"] = _resolve_terminal_status( + existing_status=run_result.get("status"), + terminal_status=terminal_status, + ) + normalized["termination_reason"] = resolved_termination_reason + normalized.setdefault("duration_ms", summary_payload.get("duration_ms")) + normalized.setdefault("warnings", warnings) + nested_summary = _as_record(normalized.get("summary")) + nested_summary = {**minimum_summary, **nested_summary} + if summary_payload: + nested_summary = {**nested_summary, **summary_payload} + nested_summary["termination_reason"] = resolved_termination_reason + if warnings and not nested_summary.get("warnings"): + nested_summary["warnings"] = warnings + if human_review is not None: + normalized["human_review"] = human_review + nested_summary["human_review"] = human_review + normalized["summary"] = nested_summary + normalized.setdefault( + "final_artifact", + payload_final_artifact or _build_fallback_final_artifact(event), + ) + return normalized + + final_artifact = payload_final_artifact or _build_fallback_final_artifact(event) + + return { + "run_id": _as_text(runtime.get("run_id")) or run_id, + "runtime_mode": runtime_mode, + "task": request_message, + "status": terminal_status, + "termination_reason": resolved_termination_reason, + "duration_ms": summary_payload.get("duration_ms"), + "iterations": [], + "callbacks": [], + "prompts": [], + "context_sources": [], + "sources": list(payload.get("sources") or []), + "attachments": list(payload.get("attachments") or []), + "final_artifact": final_artifact, + "summary": minimum_summary, + "warnings": warnings, + **({"human_review": human_review} if human_review is not None else {}), + } + + +__all__ = [ + "merge_trace_result_metadata", + "_runtime_trace_metadata", + "final_event_failed", + "build_execution_completion_summary", +] diff --git a/src/fleet_rlm/api/routers/ws/transport.py b/src/fleet_rlm/api/routers/ws/transport.py index 70d6eab2e..7864037c7 100644 --- a/src/fleet_rlm/api/routers/ws/transport.py +++ b/src/fleet_rlm/api/routers/ws/transport.py @@ -20,10 +20,8 @@ from ...auth import AuthError from ...dependencies import AuthDeps, ConfigDeps, build_unauthenticated_identity -from ...runtime_services.chat_persistence import ( - ExecutionLifecycleManager, - classify_stream_failure, -) +from ...runtime_services.run_lifecycle import ExecutionLifecycleManager +from ...runtime_services.stream_failures import classify_stream_failure from ...schemas import WSMessage logger = logging.getLogger(__name__) @@ -228,10 +226,8 @@ async def handle_chat_loop_exception( lifecycle: ExecutionLifecycleManager | None, ) -> None: """Handle an unexpected outer-loop failure without losing client notification.""" - from fleet_rlm.api.runtime_services.chat_persistence import ( - PersistenceRequiredError, - cancel_task, - ) + from fleet_rlm.api.runtime_services.chat_persistence import cancel_task + from fleet_rlm.api.runtime_services.stream_failures import PersistenceRequiredError await cancel_task(pending_receive_task) await cancel_task(stream_task) diff --git a/src/fleet_rlm/api/routers/ws/turn_runner.py b/src/fleet_rlm/api/routers/ws/turn_runner.py new file mode 100644 index 000000000..0cbe829e1 --- /dev/null +++ b/src/fleet_rlm/api/routers/ws/turn_runner.py @@ -0,0 +1,395 @@ +"""Streaming turn execution: agent events → lifecycle steps → persistence. + +Owns the hot path from ``aiter_chat_turn_stream`` to ``lifecycle.complete_run``. +Pure execution functions; no connection/receive logic. +""" + +from __future__ import annotations + +import asyncio +import logging +from collections.abc import Awaitable, Callable +from typing import Any + +from fastapi import WebSocket, WebSocketDisconnect + +from fleet_rlm.integrations.database import RunStatus +from fleet_rlm.integrations.observability.trace_context import ( + runtime_telemetry_enabled_context, +) +from fleet_rlm.runtime.events import RuntimeEvent +from fleet_rlm.utils.logging import sanitize_for_log as _sanitize_for_log + +from ...events import ( + ExecutionEventEmitter, + ExecutionStep, + ExecutionStepBuilder, +) +from ...runtime_services.chat_persistence import ( + enqueue_latest_nonblocking, + should_reload_docs_path, +) +from ...runtime_services.chat_runtime import ( + ChatAgentProtocol, + LocalPersistFn, + SessionContext, + StreamEventLike, +) +from ...runtime_services.run_lifecycle import ExecutionLifecycleManager +from ...runtime_services.stream_failures import classify_stream_failure +from .repl_bridge import ReplHookBridge +from .stream_events import ( + WorkspaceEvent, + _is_terminal_transport_event, + build_stream_event_dict, + stream_agent_turn, +) +from .stream_summary import ( + _runtime_trace_metadata, + build_execution_completion_summary, + final_event_failed, + merge_trace_result_metadata, +) +from .transport import _error_envelope, _try_send_json +from .turn_setup import PreparedStreamingTurn + +logger = logging.getLogger(__name__) + + +def _terminal_run_status(event: StreamEventLike) -> RunStatus: + """Return the authoritative terminal run status for one event.""" + if event.kind == "done" and (isinstance(event.payload, dict) and event.payload.get("cancelled")): + return RunStatus.CANCELLED + if event.kind == "done": + payload = event.payload if isinstance(event.payload, dict) else {} + return RunStatus.FAILED if final_event_failed(payload) else RunStatus.COMPLETED + return RunStatus.FAILED + + +async def handle_terminal_stream_event( + *, + websocket: WebSocket | None, + lifecycle: ExecutionLifecycleManager, + event: StreamEventLike, + event_dict: dict[str, Any], + step: ExecutionStep | None, + persist_session_state: LocalPersistFn, + request_message: str, + orchestration_session: SessionContext | None = None, +) -> None: + """Handle terminal websocket events: persist, complete lifecycle, send.""" + summary = build_execution_completion_summary( + event=event, + request_message=request_message, + run_id=lifecycle.run_id, + ) + + if event.kind == "done": + try: + await persist_session_state(include_volume_save=True, release_idle_session=True) + except Exception: + logger.debug( + "Failed to persist session state before final event; continuing", + exc_info=True, + ) + await lifecycle.complete_run( + _terminal_run_status(event), + step=step, + summary=summary, + ) + return + + try: + await persist_session_state(include_volume_save=True, release_idle_session=True) + except Exception: + logger.debug( + "Failed to persist session state after %s event; completing run anyway", + event.kind, + exc_info=True, + ) + + error_json: dict[str, Any] | None = {"error": event.text, "kind": event.kind} if event.kind == "error" else None + await lifecycle.complete_run( + _terminal_run_status(event), + step=step, + error_json=error_json, + summary=summary, + ) + + +async def handle_stream_error( + *, + websocket: WebSocket | None, + lifecycle: ExecutionLifecycleManager, + step_builder: ExecutionStepBuilder, + exc: Exception, + request_message: str, +) -> None: + """Log, emit, and persist a failed websocket streaming turn.""" + import time + + error_code = classify_stream_failure(exc) + logger.error( + "Streaming error: %s", + _sanitize_for_log(exc), + exc_info=True, + extra={ + "error_type": type(exc).__name__, + "error_code": error_code, + }, + ) + if websocket is not None: + await _try_send_json( + websocket, + _error_envelope( + code=error_code, + message=f"Streaming error: {exc}", + details={"error_type": type(exc).__name__}, + ), + ) + if lifecycle.run_completed: + return + + error_text = f"Streaming error: {exc}" + error_payload = { + "error_type": type(exc).__name__, + "error_code": error_code, + } + error_step = step_builder.from_stream_event( + kind="error", + text=error_text, + payload=error_payload, + timestamp=time.time(), + ) + if error_step is not None: + await lifecycle.emit_step(error_step) + await lifecycle.complete_run( + RunStatus.FAILED, + step=error_step, + error_json={ + "error": str(exc), + "error_type": type(exc).__name__, + "code": error_code, + }, + summary=build_execution_completion_summary( + event=WorkspaceEvent( + kind="error", + text=error_text, + payload=error_payload, + terminal=True, + ), + request_message=request_message, + run_id=lifecycle.run_id, + ), + ) + + +async def _emit_stream_event( + *, + websocket: WebSocket | None, + lifecycle: ExecutionLifecycleManager, + step_builder: ExecutionStepBuilder, + event: WorkspaceEvent | StreamEventLike, + orchestration_session: SessionContext | None = None, + persist_session_state: LocalPersistFn, + request_message: str, + execution_emitter: ExecutionEventEmitter, +) -> None: + lifecycle.raise_if_persistence_error() + payload = event.payload + if event.kind in {"done", "error"}: + payload = merge_trace_result_metadata( + payload if isinstance(payload, dict) else None, + response_preview=event.text, + trace_metadata=_runtime_trace_metadata(payload if isinstance(payload, dict) else None), + ) + event_dict = build_stream_event_dict( + event=event, + payload=payload, + sequence=step_builder._sequence, + run_id=lifecycle.run_id, + ) + is_terminal_event = _is_terminal_transport_event(event) + + if isinstance(event, RuntimeEvent): + step = step_builder.from_runtime_event(event) + else: + step = step_builder.from_stream_event( + kind=event.kind, + text=event.text, + payload=payload, + timestamp=event.timestamp.timestamp(), + ) + + if step is not None: + if event.kind == "text": + await lifecycle.emit_step(step) + else: + await asyncio.gather( + lifecycle.emit_step(step), + lifecycle.persist_step(step), + ) + lifecycle.raise_if_persistence_error() + + if is_terminal_event: + await handle_terminal_stream_event( + websocket=websocket, + lifecycle=lifecycle, + event=event, + event_dict=event_dict, + step=step, + orchestration_session=orchestration_session, + persist_session_state=persist_session_state, + request_message=request_message, + ) + + +async def _run_prepared_stream( + *, + mlflow_trace_context: Any | None, + stream_body: Callable[[], Awaitable[None]], +) -> None: + if mlflow_trace_context is None: + await stream_body() + return + + from fleet_rlm.integrations.observability.mlflow_runtime import ( + mlflow_request_context, + ) + + with mlflow_request_context(mlflow_trace_context): + await stream_body() + + +async def _stream_agent_events( + *, + websocket: WebSocket | None, + agent: ChatAgentProtocol, + prepared_turn: PreparedStreamingTurn, + orchestration_session: SessionContext | None, + cancel_check: Callable[[], bool], + lifecycle: ExecutionLifecycleManager, + hosted_repl_bridge: ReplHookBridge | None, + step_builder: ExecutionStepBuilder, + analytics_enabled: bool | None, + persist_session_state: LocalPersistFn, + execution_emitter: ExecutionEventEmitter, +) -> None: + from ...runtime_services.chat_persistence import build_workspace_task_request + + worker_request = build_workspace_task_request( + agent=agent, + prepared_turn=prepared_turn, + cancel_check=cancel_check, + ) + + bridge_started = False + try: + if hosted_repl_bridge is not None: + hosted_repl_bridge.start() + bridge_started = True + + with runtime_telemetry_enabled_context(analytics_enabled): + async for worker_event in stream_agent_turn(worker_request): + await _emit_stream_event( + websocket=websocket, + lifecycle=lifecycle, + step_builder=step_builder, + event=worker_event, + orchestration_session=orchestration_session, + persist_session_state=persist_session_state, + request_message=prepared_turn.message, + execution_emitter=execution_emitter, + ) + finally: + if hosted_repl_bridge is not None and bridge_started: + try: + await hosted_repl_bridge.stop() + except Exception: + pass + + if not lifecycle.run_completed: + lifecycle.raise_if_persistence_error() + await lifecycle.complete_run(RunStatus.COMPLETED) + + +async def run_streaming_turn( + *, + websocket: WebSocket | None, + agent: ChatAgentProtocol, + prepared_turn: PreparedStreamingTurn, + orchestration_session: SessionContext | None, + cancel_check: Callable[[], bool], + interpreter: object | None, + persist_session_state: LocalPersistFn, + execution_emitter: ExecutionEventEmitter, +) -> str | None: + """Execute one streaming turn, emitting events and persisting lifecycle steps.""" + lifecycle = prepared_turn.lifecycle + step_builder = prepared_turn.step_builder + if interpreter is not None and hasattr(lifecycle, "active_run_db_id"): + setattr(interpreter, "_host_run_id", lifecycle.active_run_db_id) + await lifecycle.emit_started() + ws_loop = asyncio.get_running_loop() + repl_hook_bridge = ReplHookBridge( + ws_loop=ws_loop, + lifecycle=lifecycle, + step_builder=step_builder, + interpreter=interpreter, + enqueue_nonblocking=enqueue_latest_nonblocking, + ) + + last_loaded_docs_path = prepared_turn.last_loaded_docs_path + if should_reload_docs_path(last_loaded_docs_path, prepared_turn.docs_path): + agent.load_document(str(prepared_turn.docs_path)) + last_loaded_docs_path = str(prepared_turn.docs_path).strip() + + try: + + async def _stream_body() -> None: + await _stream_agent_events( + websocket=websocket, + agent=agent, + prepared_turn=prepared_turn, + orchestration_session=orchestration_session, + cancel_check=cancel_check, + lifecycle=lifecycle, + hosted_repl_bridge=repl_hook_bridge, + step_builder=step_builder, + analytics_enabled=prepared_turn.analytics_enabled, + persist_session_state=persist_session_state, + execution_emitter=execution_emitter, + ) + + await _run_prepared_stream( + mlflow_trace_context=prepared_turn.mlflow_trace_context, + stream_body=_stream_body, + ) + except WebSocketDisconnect: + raise + except Exception as exc: + try: + await persist_session_state( + include_volume_save=True, + allow_volume_session_create=False, + release_idle_session=True, + ) + except Exception: + logger.debug("Failed to persist session state after stream exception", exc_info=True) + await handle_stream_error( + websocket=websocket, + lifecycle=lifecycle, + step_builder=step_builder, + exc=exc, + request_message=prepared_turn.message, + ) + + return last_loaded_docs_path + + +__all__ = [ + "run_streaming_turn", + "handle_terminal_stream_event", + "handle_stream_error", + "_emit_stream_event", +] diff --git a/src/fleet_rlm/api/routers/ws/turn_setup.py b/src/fleet_rlm/api/routers/ws/turn_setup.py index e86317710..970f27d7f 100644 --- a/src/fleet_rlm/api/routers/ws/turn_setup.py +++ b/src/fleet_rlm/api/routers/ws/turn_setup.py @@ -12,10 +12,6 @@ from fleet_rlm.utils.sandbox_ownership import sandbox_owner_labels from ...events import ExecutionEventEmitter, ExecutionStepBuilder -from ...runtime_services.chat_persistence import ( - ExecutionLifecycleManager, - initialize_turn_lifecycle, -) from ...runtime_services.chat_runtime import ( ChatAgentProtocol, LocalPersistFn, @@ -27,6 +23,7 @@ from ...runtime_services.chat_runtime import ( PreparedChatRuntime as _PreparedChatRuntime, ) +from ...runtime_services.run_lifecycle import ExecutionLifecycleManager, initialize_turn_lifecycle from ...schemas import WSMessage from .transport import _try_send_json diff --git a/src/fleet_rlm/api/runtime_services/chat_persistence.py b/src/fleet_rlm/api/runtime_services/chat_persistence.py index 4f9a9c4c1..738867d59 100644 --- a/src/fleet_rlm/api/runtime_services/chat_persistence.py +++ b/src/fleet_rlm/api/runtime_services/chat_persistence.py @@ -1,156 +1,28 @@ """WebSocket chat run/session persistence orchestration for runtime services. -Consolidates lifecycle helpers, execution-event support, manifest I/O, -and session persistence into a single module. +Owns: startup-status events, task-control helpers, loop-exit handler, + and worker-request builder. """ from __future__ import annotations import asyncio -import json import logging -import posixpath -import shlex -import uuid from collections.abc import Awaitable, Callable from contextlib import suppress from datetime import datetime, timezone -from pathlib import PurePosixPath from types import SimpleNamespace from typing import Any from fastapi import WebSocketDisconnect -from fleet_rlm.api.events import ( - ExecutionEvent, - ExecutionEventEmitter, - ExecutionEventType, - ExecutionStep, - ExecutionStepBuilder, -) -from fleet_rlm.api.runtime_services.common import ( - parse_model_identity, - resolve_sandbox_provider, -) -from fleet_rlm.api.runtime_services.session_paths import ( - session_conversation_path, - session_scratchpad_path, - session_workspace_link_path, -) -from fleet_rlm.integrations.database import ( - FleetRepository, - MemoryKind, - MemoryScope, - MemorySource, - RunStatus, - RunStepType, -) -from fleet_rlm.integrations.database.repository_chat import ( - RunCreateRequest, - RunStepCreateRequest, -) -from fleet_rlm.integrations.database.repository_identity import IdentityUpsertResult -from fleet_rlm.integrations.database.repository_memory import MemoryItemCreateRequest -from fleet_rlm.runtime.execution.interpreter_protocol import ExecutionProfile -from fleet_rlm.utils.identity import sanitize_id as _sanitize_id +from fleet_rlm.api.runtime_services.run_lifecycle import ExecutionLifecycleManager +from fleet_rlm.api.runtime_services.stream_failures import PersistenceRequiredError +from fleet_rlm.integrations.database import RunStatus from fleet_rlm.utils.logging import sanitize_for_log as _sanitize_for_log -from fleet_rlm.utils.time import now_iso - -from ..dependencies import ConfigDeps, DiagnosticsDeps, SessionCacheDeps logger = logging.getLogger(__name__) -# --------------------------------------------------------------------------- -# Failure classification -# --------------------------------------------------------------------------- - - -class PersistenceRequiredError(RuntimeError): - """Raised when durable writes fail in strict-persistence mode.""" - - def __init__(self, code: str, message: str) -> None: - super().__init__(message) - self.code = code - self.message = message - - -def classify_stream_failure(exc: Exception) -> str: - """Map runtime failures to stable websocket-facing error codes.""" - if isinstance(exc, PersistenceRequiredError): - return exc.code - - lowered = str(exc).lower() - if "planner lm not configured" in lowered: - return "planner_missing" - if "llm call timed out" in lowered or "timed out" in lowered and "llm" in lowered: - return "llm_timeout" - if "rate limit" in lowered or "429" in lowered: - return "llm_rate_limited" - if "sandbox" in lowered or "daytona" in lowered: - return "sandbox_unavailable" - return "internal_error" - - -# --------------------------------------------------------------------------- -# Execution-event support -# --------------------------------------------------------------------------- - -EXECUTION_TO_RUN_STEP_TYPE: dict[str, RunStepType] = { - "llm": RunStepType.LLM_CALL, - "tool": RunStepType.TOOL_CALL, - "repl": RunStepType.REPL_EXEC, - "memory": RunStepType.MEMORY, - "output": RunStepType.OUTPUT, -} - - -def build_execution_event( - *, - event_type: ExecutionEventType, - run_id: str, - workspace_id: str, - user_id: str, - session_id: str, - sequence: int, - step: ExecutionStep | None = None, - summary: dict[str, Any] | None = None, -) -> ExecutionEvent: - return ExecutionEvent( - type=event_type, - run_id=run_id, - workspace_id=workspace_id, - user_id=user_id, - session_id=session_id, - sequence=sequence, - step=step, - summary=summary, - ) - - -def get_execution_emitter(diagnostics: DiagnosticsDeps) -> ExecutionEventEmitter: - emitter = diagnostics.events_event_emitter - if emitter is not None: - return emitter - return emitter - - -def get_execution_emitter_with_config(diagnostics: DiagnosticsDeps, config_deps: ConfigDeps) -> ExecutionEventEmitter: - emitter = diagnostics.events_event_emitter - if emitter is not None: - return emitter - - cfg = config_deps.config - emitter = ExecutionEventEmitter( - max_queue=cfg.ws_execution_max_queue, - drop_policy=cfg.ws_execution_drop_policy, - ) - diagnostics.events_event_emitter = emitter - return emitter - - -def map_execution_step_type(step_type: str) -> RunStepType: - return EXECUTION_TO_RUN_STEP_TYPE.get(step_type, RunStepType.STATUS) - # --------------------------------------------------------------------------- # Startup status @@ -333,965 +205,7 @@ def build_workspace_task_request( ) -# --------------------------------------------------------------------------- -# Manifest I/O -# --------------------------------------------------------------------------- - - -def _is_final_output(result: Any) -> bool: - from dspy.primitives import FinalOutput - - return isinstance(result, FinalOutput) - - -def _manifest_path(workspace_id: str, user_id: str, session_id: str) -> str: - _ = workspace_id, user_id - conversation_path = session_conversation_path(session_id) - if conversation_path is not None: - return conversation_path - safe_session_id = _sanitize_id(session_id, "default-session") - return f"meta/workspaces/{workspace_id}/users/{user_id}/react-session-{safe_session_id}.json" - - -def _get_existing_daytona_session(agent: Any) -> Any | None: - interpreter = getattr(agent, "interpreter", None) - workspace = getattr(interpreter, "_workspace", None) - if workspace is None: - return None - return getattr(workspace, "_session", None) - - -async def _aget_daytona_session(agent: Any, *, allow_create: bool = True) -> Any | None: - try: - from fleet_rlm.integrations.daytona.interpreter import DaytonaInterpreter - except ImportError: - return None - - interpreter = getattr(agent, "interpreter", None) - if not isinstance(interpreter, DaytonaInterpreter): - return None - if not allow_create: - return _get_existing_daytona_session(agent) - aget_session = getattr(interpreter, "aget_session", None) - if aget_session is None or not callable(aget_session): - return None - return await aget_session() - - -async def release_idle_daytona_session(agent: Any) -> None: - """Best-effort release of an already-created Daytona sandbox session.""" - interpreter = getattr(agent, "interpreter", None) - if interpreter is None: - return - if _get_existing_daytona_session(agent) is None: - return - release_idle = getattr(interpreter, "arelease_idle_session", None) - if callable(release_idle): - try: - await release_idle() - except Exception: - logger.warning("Failed to release idle Daytona session", exc_info=True) - - -def _persistent_storage_path(interpreter: Any, path: str) -> str: - raw_root = str(getattr(interpreter, "volume_mount_path", "/data") or "/data") - mount_root = posixpath.normpath(raw_root) - candidate = PurePosixPath(path) - if candidate.is_absolute(): - resolved = posixpath.normpath(str(candidate)) - else: - resolved = posixpath.normpath(str(PurePosixPath(mount_root) / candidate)) - if not resolved.startswith(mount_root + "/") and resolved != mount_root: - raise ValueError(f"Path {path!r} resolves outside volume mount path.") - return resolved - - -def _session_workspace_target(daytona_session: Any, interpreter: Any) -> str: - return str( - getattr(daytona_session, "workspace_path", None) - or getattr(interpreter, "workspace_path", None) - or getattr(interpreter, "repo_path", None) - or "" - ).strip() - - -def _ensure_session_layout_command(*, scratchpad_path: str, workspace_link_path: str, workspace_target: str) -> str: - return " ".join( - [ - "mkdir", - "-p", - shlex.quote(scratchpad_path), - "&&", - "rm", - "-rf", - shlex.quote(workspace_link_path), - "&&", - "ln", - "-s", - shlex.quote(workspace_target), - shlex.quote(workspace_link_path), - ] - ) - - -async def ensure_session_volume_layout( - agent: Any, - session_id: str, - *, - allow_session_create: bool = True, -) -> dict[str, str]: - """Ensure Phase 1 per-session scratchpad and workspace mapping exist on the volume.""" - interpreter = agent.interpreter - if interpreter is None: - return {} - scratchpad_path = session_scratchpad_path(session_id) - workspace_link_path = session_workspace_link_path(session_id) - if scratchpad_path is None or workspace_link_path is None: - return {} - storage_scratchpad_path = _persistent_storage_path(interpreter, scratchpad_path) - storage_workspace_link_path = _persistent_storage_path(interpreter, workspace_link_path) - daytona_session = await _aget_daytona_session(agent, allow_create=allow_session_create) - if daytona_session is None and not allow_session_create: - return { - "scratchpad_path": storage_scratchpad_path, - "workspace_link_path": storage_workspace_link_path, - } - workspace_target = _session_workspace_target(daytona_session, interpreter) - if not workspace_target: - return { - "scratchpad_path": storage_scratchpad_path, - "workspace_link_path": storage_workspace_link_path, - } - if daytona_session is not None: - process = getattr(getattr(daytona_session, "sandbox", None), "process", None) - exec_command = getattr(process, "exec", None) - if callable(exec_command): - try: - exec_command( - _ensure_session_layout_command( - scratchpad_path=storage_scratchpad_path, - workspace_link_path=storage_workspace_link_path, - workspace_target=workspace_target, - ) - ) - return { - "scratchpad_path": storage_scratchpad_path, - "workspace_link_path": storage_workspace_link_path, - } - except Exception as exc: - logger.warning( - "ensure_session_volume_layout: Daytona exec_command failed, falling back to interpreter aexecute: %s", - exc, - ) - await interpreter.aexecute( - "\n".join( - [ - "import os", - "os.makedirs(scratchpad_path, exist_ok=True)", - "if os.path.isdir(workspace_target):", - " if os.path.lexists(workspace_link_path):", - " if os.path.isdir(workspace_link_path) and not os.path.islink(workspace_link_path):", - " import shutil", - " shutil.rmtree(workspace_link_path)", - " else:", - " os.unlink(workspace_link_path)", - " os.symlink(workspace_target, workspace_link_path)", - "else:", - " import warnings", - " warnings.warn(f'Workspace target {workspace_target} does not exist, skipping symlink creation')", - "SUBMIT(scratchpad_path=scratchpad_path, workspace_link_path=workspace_link_path)", - ] - ), - variables={ - "scratchpad_path": storage_scratchpad_path, - "workspace_link_path": storage_workspace_link_path, - "workspace_target": workspace_target, - }, - execution_profile=ExecutionProfile.MAINTENANCE, - ) - return { - "scratchpad_path": storage_scratchpad_path, - "workspace_link_path": storage_workspace_link_path, - } - - -def _parse_manifest_text(text: str) -> dict[str, Any]: - if not text or text.startswith("[file not found:") or text.startswith("[error:"): - return {} - try: - parsed = json.loads(text) - return parsed if isinstance(parsed, dict) else {} - except json.JSONDecodeError: - return {} - - -async def load_manifest_from_volume( - agent: Any, - path: str, - fallback_paths: list[str] | None = None, - *, - allow_session_create: bool = True, -) -> dict[str, Any]: - """Best-effort manifest load from interpreter volume storage.""" - interpreter = agent.interpreter - if interpreter is None: - return {} - candidate_paths = [path, *(fallback_paths or [])] - daytona_session = await _aget_daytona_session(agent, allow_create=allow_session_create) - if daytona_session is not None: - for candidate_path in candidate_paths: - storage_path = _persistent_storage_path(interpreter, candidate_path) - try: - text = await daytona_session.aread_file(storage_path) - except Exception: - logger.debug( - "manifest_load_daytona_read_error", - extra={"path": storage_path}, - exc_info=True, - ) - continue - parsed = _parse_manifest_text(text) - if parsed: - return parsed - return {} - if not allow_session_create: - return {} - for candidate_path in candidate_paths: - result = await interpreter.aexecute( - "text = load_from_volume(path)\nSUBMIT(text=text)", - variables={"path": candidate_path}, - execution_profile=ExecutionProfile.MAINTENANCE, - ) - if not _is_final_output(result): - continue - output = getattr(result, "output", None) - output = output if isinstance(output, dict) else {} - parsed = _parse_manifest_text(str(output.get("text", ""))) - if parsed: - return parsed - return {} - - -async def save_manifest_to_volume( - agent: Any, - path: str, - manifest: dict[str, Any], - *, - allow_session_create: bool = True, -) -> str | None: - """Best-effort manifest save to interpreter volume storage.""" - interpreter = agent.interpreter - if interpreter is None: - return None - payload = json.dumps(manifest, ensure_ascii=False, default=str) - daytona_session = await _aget_daytona_session(agent, allow_create=allow_session_create) - if daytona_session is not None: - storage_path = _persistent_storage_path(interpreter, path) - try: - return await daytona_session.awrite_file(storage_path, payload) - except Exception: - logger.warning( - "manifest_save_daytona_write_error", - extra={"path": storage_path}, - exc_info=True, - ) - return None - if not allow_session_create: - return None - result = await interpreter.aexecute( - "saved_path = save_to_volume(path, payload)\nSUBMIT(saved_path=saved_path)", - variables={"path": path, "payload": payload}, - execution_profile=ExecutionProfile.MAINTENANCE, - ) - if not _is_final_output(result): - return None - output = getattr(result, "output", None) - output = output if isinstance(output, dict) else {} - saved_path = str(output.get("saved_path", "")) - if saved_path.startswith("["): - return None - return saved_path or None - - -# --------------------------------------------------------------------------- -# Execution lifecycle manager -# --------------------------------------------------------------------------- - - -class ExecutionLifecycleManager: - """Encapsulates run lifecycle operations: DB persistence and event emission.""" - - def __init__( - self, - *, - run_id: str, - workspace_id: str, - user_id: str, - session_id: str, - execution_emitter, - step_builder: ExecutionStepBuilder, - repository: FleetRepository | None = None, - identity_rows: IdentityUpsertResult | None = None, - active_run_db_id: Any = None, - strict_persistence: bool = False, - session_record: dict[str, Any] | None = None, - ) -> None: - self.run_id = run_id - self.workspace_id = workspace_id - self.user_id = user_id - self.session_id = session_id - self.execution_emitter = execution_emitter - self.step_builder = step_builder - self.repository = repository - self.identity_rows = identity_rows - self.active_run_db_id = active_run_db_id - self.strict_persistence = strict_persistence - self._session_record = session_record - self._step_index = 0 - self._last_step_db_id: Any = None - self._persist_queue: asyncio.Queue[ExecutionStep | None] | None = None - self._persist_worker_task: asyncio.Task[None] | None = None - self._persistence_error: Exception | None = None - self._event_sequence = 0 - self.run_completed = False - - def _build_event( - self, - event_type: ExecutionEventType, - step: ExecutionStep | None = None, - summary: dict[str, Any] | None = None, - ) -> Any: - self._event_sequence += 1 - return build_execution_event( - event_type=event_type, - run_id=self.run_id, - workspace_id=self.workspace_id, - user_id=self.user_id, - session_id=self.session_id, - sequence=self._event_sequence, - step=step, - summary=summary, - ) - - @property - def _can_persist(self) -> bool: - return self.repository is not None and self.identity_rows is not None and self.active_run_db_id is not None - - def raise_if_persistence_error(self) -> None: - if self.strict_persistence and self._persistence_error is not None: - raise PersistenceRequiredError( - "durable_state_write_failed", - f"Durable state write failed: {self._persistence_error}", - ) - - def record_persistence_error(self, exc: Exception) -> None: - self._persistence_error = exc - - async def _persist_worker(self) -> None: - if not self._can_persist or self._persist_queue is None: - return - - assert self.repository is not None - assert self.identity_rows is not None - assert self.active_run_db_id is not None - - while True: - step = await self._persist_queue.get() - if step is None: - break - - # Coalesce additional steps already in the queue to reduce - # per-item overhead and database round-trips. - batch: list[ExecutionStep] = [step] - shutdown_requested = False - while len(batch) < 32: - try: - extra = self._persist_queue.get_nowait() - except asyncio.QueueEmpty: - break - if extra is None: - shutdown_requested = True - break - batch.append(extra) - - for batch_step in batch: - self._step_index += 1 - try: - persisted = await self.repository.append_step( - RunStepCreateRequest( - tenant_id=self.identity_rows.tenant_id, - run_id=self.active_run_db_id, - step_index=self._step_index, - step_type=map_execution_step_type(batch_step.type), - input_json=batch_step.input - if isinstance(batch_step.input, dict) - else {"value": batch_step.input} - if batch_step.input is not None - else None, - output_json=batch_step.output - if isinstance(batch_step.output, dict) - else {"value": batch_step.output} - if batch_step.output is not None - else None, - ) - ) - self._last_step_db_id = persisted.id - if self._session_record is not None: - self._session_record["last_step_db_id"] = str(persisted.id) - except Exception as exc: - self._persistence_error = exc - logger.warning( - "Failed to persist run step: %s", - _sanitize_for_log(exc), - ) - if self.strict_persistence: - break - if self.strict_persistence and self._persistence_error is not None: - break - if shutdown_requested: - break - - async def _ensure_persist_worker(self) -> None: - if not self._can_persist: - return - if self._persist_worker_task is not None: - return - self._persist_queue = asyncio.Queue(maxsize=512) - self._persist_worker_task = asyncio.create_task(self._persist_worker()) - - async def _stop_persist_worker(self) -> None: - if self._persist_worker_task is None: - return - if self._persist_queue is not None: - await self._persist_queue.put(None) - try: - await self._persist_worker_task - except asyncio.CancelledError: - pass - self._persist_worker_task = None - self._persist_queue = None - - async def emit_started(self) -> None: - await self._ensure_persist_worker() - await self.execution_emitter.emit(self._build_event("execution_started")) - - async def persist_step(self, step: ExecutionStep | None) -> None: - if step is None or not self._can_persist: - return - await self._ensure_persist_worker() - self.raise_if_persistence_error() - if self._persist_queue is None: - return - try: - self._persist_queue.put_nowait(step) - except asyncio.QueueFull: - if self.strict_persistence: - raise PersistenceRequiredError( - "durable_state_backpressure", - "Execution step persistence queue is full", - ) - await self._persist_queue.put(step) - self.raise_if_persistence_error() - - async def emit_step(self, step: ExecutionStep) -> None: - await self.execution_emitter.emit(self._build_event("execution_step", step=step)) - - async def complete_run( - self, - status: RunStatus, - *, - step: ExecutionStep | None = None, - error_json: dict | None = None, - summary: dict[str, Any] | None = None, - ) -> None: - if self.run_completed: - return - await self._stop_persist_worker() - - effective_status = status - effective_error = dict(error_json or {}) - if self._persistence_error is not None: - effective_error.setdefault("durable_write_error", str(self._persistence_error)) - effective_error.setdefault("error_type", type(self._persistence_error).__name__) - if self.strict_persistence: - effective_status = RunStatus.FAILED - effective_error.setdefault("code", "durable_state_write_failed") - - if self._can_persist: - assert self.repository is not None - assert self.identity_rows is not None - assert self.active_run_db_id is not None - try: - await self.repository.update_run_status( - tenant_id=self.identity_rows.tenant_id, - run_id=self.active_run_db_id, - status=effective_status, - error_json=effective_error or None, - ) - except Exception as exc: - if self.strict_persistence: - raise PersistenceRequiredError( - "run_status_persist_failed", - f"Failed to persist run status: {exc}", - ) from exc - logger.warning("Failed to persist run status: %s", _sanitize_for_log(exc)) - await self.execution_emitter.emit(self._build_event("execution_completed", step=step, summary=summary)) - self.run_completed = True - - -async def initialize_turn_lifecycle( - *, - planner_lm: Any, - cfg: Any, - repository: FleetRepository | None, - identity_rows: IdentityUpsertResult | None, - persistence_required: bool, - execution_emitter: Any, - workspace_id: str, - user_id: str, - sess_id: str, - turn_index: int, - session_record: dict[str, Any] | None, - sandbox_provider: str | None = None, -) -> tuple[ExecutionLifecycleManager, ExecutionStepBuilder, str, Any]: - """Create step builder and lifecycle manager for a single message turn.""" - run_id = f"{workspace_id}:{user_id}:{sess_id}:{turn_index}" - step_builder = ExecutionStepBuilder(run_id=run_id) - active_run_db_id = None - - if repository is None: - logger.warning( - "runtime_persistence_disabled_for_run", - extra={ - "run_id": run_id, - "workspace_id": workspace_id, - "user_id": user_id, - "session_id": sess_id, - "code": "persistence_disabled", - }, - ) - - if repository is not None and identity_rows is not None and identity_rows.user_id is not None: - model_provider, model_name = parse_model_identity(getattr(planner_lm, "model", None)) - try: - run_row = await repository.create_run( - RunCreateRequest( - tenant_id=identity_rows.tenant_id, - created_by_user_id=identity_rows.user_id, - external_run_id=run_id, - status=RunStatus.RUNNING, - model_provider=model_provider, - model_name=model_name, - sandbox_provider=resolve_sandbox_provider(sandbox_provider or cfg.sandbox_provider), - ) - ) - active_run_db_id = run_row.id - if session_record is not None: - session_record["last_run_db_id"] = str(run_row.id) - except Exception as exc: - if persistence_required: - raise PersistenceRequiredError( - "run_start_persist_failed", - f"Failed to persist run start: {exc}", - ) from exc - logger.warning("Failed to persist run start: %s", _sanitize_for_log(exc)) - elif repository is not None and identity_rows is not None: - logger.info( - "runtime_run_persistence_skipped_missing_user", - extra={ - "run_id": run_id, - "workspace_id": workspace_id, - "user_id": user_id, - "session_id": sess_id, - "tenant_id": str(identity_rows.tenant_id), - "code": "identity_missing_user", - }, - ) - - lifecycle = ExecutionLifecycleManager( - run_id=run_id, - workspace_id=workspace_id, - user_id=user_id, - session_id=sess_id, - execution_emitter=execution_emitter, - step_builder=step_builder, - repository=repository, - identity_rows=identity_rows, - active_run_db_id=active_run_db_id, - strict_persistence=persistence_required, - session_record=session_record, - ) - return lifecycle, step_builder, run_id, active_run_db_id - - -def ensure_manifest_shape(manifest: dict[str, Any]) -> dict[str, Any]: - """Normalize mutable manifest structure and expected keys.""" - if not isinstance(manifest.get("logs"), list): - manifest["logs"] = [] - if not isinstance(manifest.get("memory"), list): - manifest["memory"] = [] - if not isinstance(manifest.get("generated_docs"), list): - manifest["generated_docs"] = [] - if not isinstance(manifest.get("artifacts"), list): - manifest["artifacts"] = [] - if not isinstance(manifest.get("metadata"), dict): - manifest["metadata"] = {} - return manifest - - -def update_manifest_from_exported_state( - *, - manifest: dict[str, Any], - exported_state: dict[str, Any], - latest_user_message: str, -) -> tuple[int, int]: - """Update manifest with latest state snapshot and optional user message entry.""" - ensure_manifest_shape(manifest) - - logs = manifest["logs"] - memory = manifest["memory"] - generated_docs = manifest["generated_docs"] - artifacts = manifest["artifacts"] - metadata = manifest["metadata"] - - if latest_user_message: - logs.append( - { - "timestamp": now_iso(), - "user_message": latest_user_message, - "history_turns": len(exported_state.get("history", [])), - } - ) - memory.append( - { - "timestamp": now_iso(), - "content": latest_user_message[:400], - } - ) - - generated_docs[:] = sorted(list(exported_state.get("documents", {}).keys())) - - previous_rev_raw = manifest.get("rev", 0) - previous_rev_candidate = previous_rev_raw if isinstance(previous_rev_raw, (int, float, str)) else 0 - try: - previous_rev = int(previous_rev_candidate) - except (TypeError, ValueError): - previous_rev = 0 - - next_rev = previous_rev + 1 - manifest["rev"] = next_rev - metadata["updated_at"] = now_iso() - metadata["history_turns"] = len(exported_state.get("history", [])) - metadata["document_count"] = len(exported_state.get("documents", {})) - metadata["artifact_count"] = len(artifacts) - manifest["state"] = exported_state - return previous_rev, next_rev - - -def sync_session_record_state( - *, - session_cache: SessionCacheDeps, - session_record: dict[str, Any], - exported_state: dict[str, Any], -) -> None: - """Propagate exported state into session record and state cache.""" - session_data = session_record.get("session") - if not isinstance(session_data, dict): - session_data = {} - session_record["session"] = session_data - session_data["state"] = exported_state - session_data["session_id"] = session_record.get("session_id") - - record_key = session_record.get("key") - if isinstance(record_key, str): - session_cache.sessions[record_key] = session_record - - -async def persist_memory_item_if_needed( - *, - repository: FleetRepository | None, - identity_rows: IdentityUpsertResult | None, - active_run_db_id: Any, - latest_user_message: str, - persistence_required: bool, -) -> None: - """Persist a user-input memory item when repository context is available.""" - if not latest_user_message or repository is None or identity_rows is None: - return - try: - await repository.store_memory_item( - MemoryItemCreateRequest( - tenant_id=identity_rows.tenant_id, - workspace_id=identity_rows.workspace_id, - user_id=identity_rows.user_id, - run_id=active_run_db_id, - scope=MemoryScope.RUN if active_run_db_id is not None else MemoryScope.USER, - scope_id=str(active_run_db_id or identity_rows.user_id), - kind=MemoryKind.NOTE, - source=MemorySource.USER_INPUT, - content_text=latest_user_message[:1000], - tags=["ws", "chat"], - ) - ) - except Exception as exc: - if persistence_required: - raise PersistenceRequiredError( - "memory_item_persist_failed", - f"Failed to persist memory item: {exc}", - ) from exc - logger.warning("Failed to persist memory item: %s", _sanitize_for_log(exc)) - - -async def _persist_manifest_to_local_store( - *, - persistence: Any, - sess_id: str, - manifest: dict[str, Any], -) -> None: - """Write the manifest into LocalStore/FleetRepository session metadata. - - Used as a fallback when no Daytona volume is available (interpreter=None) so - that session state survives process restarts between WebSocket connections. - """ - if persistence is None: - return - update_fn = getattr(persistence, "update_chat_session", None) - if not callable(update_fn): - return - try: - import inspect - - sig = inspect.signature(update_fn) - # LocalStore.update_chat_session requires tenant_id + session_id UUIDs; the - # async FleetRepository variant has the same shape. Both accept metadata_json. - # We store under the raw external_session_id key so the restore helper can - # locate it without a UUID round-trip. - params = set(sig.parameters) - if "external_session_id" in params: - await update_fn(external_session_id=sess_id, metadata_json={"_manifest_state": manifest}) - else: - # Async path: skip – we cannot derive the UUID here without identity_rows. - pass - except Exception: - logger.debug("Best-effort manifest persist to local store failed", exc_info=True) - - -async def _restore_manifest_from_local_store( - *, - persistence: Any, - sess_id: str, -) -> dict[str, Any]: - """Read a previously persisted manifest from LocalStore session metadata. - - Returns an empty dict when nothing is found or an error occurs. - """ - if persistence is None: - return {} - get_fn = getattr(persistence, "get_chat_session_by_external_id", None) - if not callable(get_fn): - return {} - try: - row = await get_fn(external_session_id=sess_id) - if row is None: - return {} - metadata = getattr(row, "metadata_json", None) - if not isinstance(metadata, dict): - return {} - manifest = metadata.get("_manifest_state") - return manifest if isinstance(manifest, dict) else {} - except Exception: - logger.debug("Best-effort manifest restore from local store failed", exc_info=True) - return {} - - -async def persist_session_state( - *, - session_cache: SessionCacheDeps, - agent: Any, - session_record: dict[str, Any] | None, - active_manifest_path: str | None, - active_run_db_id: uuid.UUID | None, - interpreter: Any | None, - repository: FleetRepository | None, - identity_rows: IdentityUpsertResult | None, - persistence_required: bool, - include_volume_save: bool = True, - latest_user_message: str = "", - persistence: Any = None, - allow_volume_session_create: bool = True, - release_idle_session: bool = False, -) -> None: - """Persist current session state and optionally release the live Daytona sandbox.""" - try: - await _persist_session_state_impl( - session_cache=session_cache, - agent=agent, - session_record=session_record, - active_manifest_path=active_manifest_path, - active_run_db_id=active_run_db_id, - interpreter=interpreter, - repository=repository, - identity_rows=identity_rows, - persistence_required=persistence_required, - include_volume_save=include_volume_save, - latest_user_message=latest_user_message, - persistence=persistence, - allow_volume_session_create=allow_volume_session_create, - ) - finally: - if release_idle_session: - await release_idle_daytona_session(agent) - - -async def _persist_session_state_impl( - *, - session_cache: SessionCacheDeps, - agent: Any, - session_record: dict[str, Any] | None, - active_manifest_path: str | None, - active_run_db_id: uuid.UUID | None, - interpreter: Any | None, - repository: FleetRepository | None, - identity_rows: IdentityUpsertResult | None, - persistence_required: bool, - include_volume_save: bool = True, - latest_user_message: str = "", - persistence: Any = None, - allow_volume_session_create: bool = True, -) -> None: - """Persist current session state to in-memory cache, volume, and DB.""" - if session_record is None: - return - exported_state = agent.export_session_state() - manifest = session_record.get("manifest") - if not isinstance(manifest, dict): - manifest = {} - session_record["manifest"] = manifest - - ensure_manifest_shape(manifest) - previous_rev, _next_rev = update_manifest_from_exported_state( - manifest=manifest, - exported_state=exported_state, - latest_user_message=latest_user_message, - ) - sync_session_record_state( - session_cache=session_cache, - session_record=session_record, - exported_state=exported_state, - ) - - if include_volume_save and active_manifest_path and interpreter is not None: - existing_session = None - if not allow_volume_session_create: - existing_session = await _aget_daytona_session(agent, allow_create=False) - if allow_volume_session_create or existing_session is not None: - remote_manifest = await load_manifest_from_volume( - agent, - active_manifest_path, - allow_session_create=allow_volume_session_create, - ) - remote_rev_raw = remote_manifest.get("rev", 0) - remote_rev_candidate = remote_rev_raw if isinstance(remote_rev_raw, (int, float, str)) else 0 - try: - remote_rev = int(remote_rev_candidate) - except (TypeError, ValueError): - remote_rev = 0 - - if remote_rev > previous_rev: - message = ( - f"Session manifest revision conflict detected (remote_rev={remote_rev}, local_rev={previous_rev})" - ) - if persistence_required: - raise PersistenceRequiredError("manifest_conflict", message) - logger.warning(message) - else: - saved_path = await save_manifest_to_volume( - agent, - active_manifest_path, - manifest, - allow_session_create=allow_volume_session_create, - ) - if saved_path is None: - message = f"Failed to save session manifest to volume (path={active_manifest_path})" - if persistence_required: - raise PersistenceRequiredError("manifest_write_failed", message) - logger.warning(message) - else: - logger.debug( - "Skipping Daytona volume persistence because cleanup has no active session (path=%s)", - active_manifest_path, - ) - # Always persist to local store when persistence is available — this is the - # durable fallback that survives sandbox churn. Pool-based dispatch means - # each turn may acquire a *different* Daytona sandbox, so the volume save - # above lands on the current sandbox while the *next* turn's new sandbox - # volume starts empty. The local store is sandbox-independent and bridges - # the gap. We write it regardless of whether a volume save also happened. - if include_volume_save and persistence is not None: - sess_id = str(session_record.get("session_id") or "") - if sess_id: - await _persist_manifest_to_local_store( - persistence=persistence, - sess_id=sess_id, - manifest=manifest, - ) - - await persist_memory_item_if_needed( - repository=repository, - identity_rows=identity_rows, - active_run_db_id=active_run_db_id, - latest_user_message=latest_user_message, - persistence_required=persistence_required, - ) - - -def build_local_persist_fn( - *, - session_cache: SessionCacheDeps, - runtime: Any, - agent: Any, - interpreter: Any, - session: Any, -): - async def local_persist( - *, - include_volume_save: bool = True, - latest_user_message: str = "", - allow_volume_session_create: bool = True, - release_idle_session: bool = False, - ) -> None: - try: - await persist_session_state( - session_cache=session_cache, - agent=agent, - session_record=session.session_record, - active_manifest_path=session.active_manifest_path, - active_run_db_id=session.active_run_db_id, - interpreter=interpreter, - repository=runtime.repository, - identity_rows=runtime.identity_rows, - persistence_required=runtime.persistence_required, - include_volume_save=include_volume_save, - latest_user_message=latest_user_message, - persistence=runtime.persistence, - allow_volume_session_create=allow_volume_session_create, - release_idle_session=False, - ) - finally: - if release_idle_session: - await release_idle_daytona_session(agent) - - return local_persist - - __all__ = [ - "PersistenceRequiredError", - "classify_stream_failure", - "EXECUTION_TO_RUN_STEP_TYPE", - "build_execution_event", - "get_execution_emitter", - "get_execution_emitter_with_config", - "map_execution_step_type", "build_startup_status_event", "emit_delayed_startup_status", "cancel_startup_status_task", @@ -1300,22 +214,4 @@ async def local_persist( "cancel_task", "handle_chat_disconnect", "build_workspace_task_request", - "load_manifest_from_volume", - "save_manifest_to_volume", - "release_idle_daytona_session", - "_persist_manifest_to_local_store", - "_restore_manifest_from_local_store", - "_manifest_path", - "_aget_daytona_session", - "_persistent_storage_path", - "_is_final_output", - "ExecutionLifecycleManager", - "build_local_persist_fn", - "ensure_manifest_shape", - "initialize_turn_lifecycle", - "now_iso", - "persist_memory_item_if_needed", - "persist_session_state", - "sync_session_record_state", - "update_manifest_from_exported_state", ] diff --git a/src/fleet_rlm/api/runtime_services/run_lifecycle.py b/src/fleet_rlm/api/runtime_services/run_lifecycle.py new file mode 100644 index 000000000..9831bb150 --- /dev/null +++ b/src/fleet_rlm/api/runtime_services/run_lifecycle.py @@ -0,0 +1,385 @@ +"""Execution lifecycle management for runtime services. + +Encapsulates run lifecycle operations: DB persistence and event emission. +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Any + +from fleet_rlm.api.events import ( + ExecutionEvent, + ExecutionEventType, + ExecutionStep, + ExecutionStepBuilder, +) +from fleet_rlm.api.runtime_services.common import ( + parse_model_identity, + resolve_sandbox_provider, +) +from fleet_rlm.api.runtime_services.stream_failures import PersistenceRequiredError +from fleet_rlm.integrations.database import ( + FleetRepository, + RunStatus, + RunStepType, +) +from fleet_rlm.integrations.database.repository_chat import ( + RunCreateRequest, + RunStepCreateRequest, +) +from fleet_rlm.integrations.database.repository_identity import IdentityUpsertResult +from fleet_rlm.utils.logging import sanitize_for_log as _sanitize_for_log + +logger = logging.getLogger(__name__) + +EXECUTION_TO_RUN_STEP_TYPE: dict[str, RunStepType] = { + "llm": RunStepType.LLM_CALL, + "tool": RunStepType.TOOL_CALL, + "repl": RunStepType.REPL_EXEC, + "memory": RunStepType.MEMORY, + "output": RunStepType.OUTPUT, +} + + +def build_execution_event( + *, + event_type: ExecutionEventType, + run_id: str, + workspace_id: str, + user_id: str, + session_id: str, + sequence: int, + step: ExecutionStep | None = None, + summary: dict[str, Any] | None = None, +) -> ExecutionEvent: + return ExecutionEvent( + type=event_type, + run_id=run_id, + workspace_id=workspace_id, + user_id=user_id, + session_id=session_id, + sequence=sequence, + step=step, + summary=summary, + ) + + +def map_execution_step_type(step_type: str) -> RunStepType: + return EXECUTION_TO_RUN_STEP_TYPE.get(step_type, RunStepType.STATUS) + + +class ExecutionLifecycleManager: + """Encapsulates run lifecycle operations: DB persistence and event emission.""" + + def __init__( + self, + *, + run_id: str, + workspace_id: str, + user_id: str, + session_id: str, + execution_emitter, + step_builder: ExecutionStepBuilder, + repository: FleetRepository | None = None, + identity_rows: IdentityUpsertResult | None = None, + active_run_db_id: Any = None, + strict_persistence: bool = False, + session_record: dict[str, Any] | None = None, + ) -> None: + self.run_id = run_id + self.workspace_id = workspace_id + self.user_id = user_id + self.session_id = session_id + self.execution_emitter = execution_emitter + self.step_builder = step_builder + self.repository = repository + self.identity_rows = identity_rows + self.active_run_db_id = active_run_db_id + self.strict_persistence = strict_persistence + self._session_record = session_record + self._step_index = 0 + self._last_step_db_id: Any = None + self._persist_queue: asyncio.Queue[ExecutionStep | None] | None = None + self._persist_worker_task: asyncio.Task[None] | None = None + self._persistence_error: Exception | None = None + self._event_sequence = 0 + self.run_completed = False + + def _build_event( + self, + event_type: ExecutionEventType, + step: ExecutionStep | None = None, + summary: dict[str, Any] | None = None, + ) -> Any: + self._event_sequence += 1 + return build_execution_event( + event_type=event_type, + run_id=self.run_id, + workspace_id=self.workspace_id, + user_id=self.user_id, + session_id=self.session_id, + sequence=self._event_sequence, + step=step, + summary=summary, + ) + + @property + def _can_persist(self) -> bool: + return self.repository is not None and self.identity_rows is not None and self.active_run_db_id is not None + + def raise_if_persistence_error(self) -> None: + if self.strict_persistence and self._persistence_error is not None: + raise PersistenceRequiredError( + "durable_state_write_failed", + f"Durable state write failed: {self._persistence_error}", + ) + + def record_persistence_error(self, exc: Exception) -> None: + self._persistence_error = exc + + async def _persist_worker(self) -> None: + if not self._can_persist or self._persist_queue is None: + return + + assert self.repository is not None + assert self.identity_rows is not None + assert self.active_run_db_id is not None + + while True: + step = await self._persist_queue.get() + if step is None: + break + + # Coalesce additional steps already in the queue to reduce + # per-item overhead and database round-trips. + batch: list[ExecutionStep] = [step] + shutdown_requested = False + while len(batch) < 32: + try: + extra = self._persist_queue.get_nowait() + except asyncio.QueueEmpty: + break + if extra is None: + shutdown_requested = True + break + batch.append(extra) + + for batch_step in batch: + self._step_index += 1 + try: + persisted = await self.repository.append_step( + RunStepCreateRequest( + tenant_id=self.identity_rows.tenant_id, + run_id=self.active_run_db_id, + step_index=self._step_index, + step_type=map_execution_step_type(batch_step.type), + input_json=batch_step.input + if isinstance(batch_step.input, dict) + else {"value": batch_step.input} + if batch_step.input is not None + else None, + output_json=batch_step.output + if isinstance(batch_step.output, dict) + else {"value": batch_step.output} + if batch_step.output is not None + else None, + ) + ) + self._last_step_db_id = persisted.id + if self._session_record is not None: + self._session_record["last_step_db_id"] = str(persisted.id) + except Exception as exc: + self._persistence_error = exc + logger.warning( + "Failed to persist run step: %s", + _sanitize_for_log(exc), + ) + if self.strict_persistence: + break + if self.strict_persistence and self._persistence_error is not None: + break + if shutdown_requested: + break + + async def _ensure_persist_worker(self) -> None: + if not self._can_persist: + return + if self._persist_worker_task is not None: + return + self._persist_queue = asyncio.Queue(maxsize=512) + self._persist_worker_task = asyncio.create_task(self._persist_worker()) + + async def _stop_persist_worker(self) -> None: + if self._persist_worker_task is None: + return + if self._persist_queue is not None: + await self._persist_queue.put(None) + try: + await self._persist_worker_task + except asyncio.CancelledError: + pass + self._persist_worker_task = None + self._persist_queue = None + + async def emit_started(self) -> None: + await self._ensure_persist_worker() + await self.execution_emitter.emit(self._build_event("execution_started")) + + async def persist_step(self, step: ExecutionStep | None) -> None: + if step is None or not self._can_persist: + return + await self._ensure_persist_worker() + self.raise_if_persistence_error() + if self._persist_queue is None: + return + try: + self._persist_queue.put_nowait(step) + except asyncio.QueueFull: + if self.strict_persistence: + raise PersistenceRequiredError( + "durable_state_backpressure", + "Execution step persistence queue is full", + ) + await self._persist_queue.put(step) + self.raise_if_persistence_error() + + async def emit_step(self, step: ExecutionStep) -> None: + await self.execution_emitter.emit(self._build_event("execution_step", step=step)) + + async def complete_run( + self, + status: RunStatus, + *, + step: ExecutionStep | None = None, + error_json: dict | None = None, + summary: dict[str, Any] | None = None, + ) -> None: + if self.run_completed: + return + await self._stop_persist_worker() + + effective_status = status + effective_error = dict(error_json or {}) + if self._persistence_error is not None: + effective_error.setdefault("durable_write_error", str(self._persistence_error)) + effective_error.setdefault("error_type", type(self._persistence_error).__name__) + if self.strict_persistence: + effective_status = RunStatus.FAILED + effective_error.setdefault("code", "durable_state_write_failed") + + if self._can_persist: + assert self.repository is not None + assert self.identity_rows is not None + assert self.active_run_db_id is not None + try: + await self.repository.update_run_status( + tenant_id=self.identity_rows.tenant_id, + run_id=self.active_run_db_id, + status=effective_status, + error_json=effective_error or None, + ) + except Exception as exc: + if self.strict_persistence: + raise PersistenceRequiredError( + "run_status_persist_failed", + f"Failed to persist run status: {exc}", + ) from exc + logger.warning("Failed to persist run status: %s", _sanitize_for_log(exc)) + await self.execution_emitter.emit(self._build_event("execution_completed", step=step, summary=summary)) + self.run_completed = True + + +async def initialize_turn_lifecycle( + *, + planner_lm: Any, + cfg: Any, + repository: FleetRepository | None, + identity_rows: IdentityUpsertResult | None, + persistence_required: bool, + execution_emitter: Any, + workspace_id: str, + user_id: str, + sess_id: str, + turn_index: int, + session_record: dict[str, Any] | None, + sandbox_provider: str | None = None, +) -> tuple[ExecutionLifecycleManager, ExecutionStepBuilder, str, Any]: + """Create step builder and lifecycle manager for a single message turn.""" + run_id = f"{workspace_id}:{user_id}:{sess_id}:{turn_index}" + step_builder = ExecutionStepBuilder(run_id=run_id) + active_run_db_id = None + + if repository is None: + logger.warning( + "runtime_persistence_disabled_for_run", + extra={ + "run_id": run_id, + "workspace_id": workspace_id, + "user_id": user_id, + "session_id": sess_id, + "code": "persistence_disabled", + }, + ) + + if repository is not None and identity_rows is not None and identity_rows.user_id is not None: + model_provider, model_name = parse_model_identity(getattr(planner_lm, "model", None)) + try: + run_row = await repository.create_run( + RunCreateRequest( + tenant_id=identity_rows.tenant_id, + created_by_user_id=identity_rows.user_id, + external_run_id=run_id, + status=RunStatus.RUNNING, + model_provider=model_provider, + model_name=model_name, + sandbox_provider=resolve_sandbox_provider(sandbox_provider or cfg.sandbox_provider), + ) + ) + active_run_db_id = run_row.id + if session_record is not None: + session_record["last_run_db_id"] = str(run_row.id) + except Exception as exc: + if persistence_required: + raise PersistenceRequiredError( + "run_start_persist_failed", + f"Failed to persist run start: {exc}", + ) from exc + logger.warning("Failed to persist run start: %s", _sanitize_for_log(exc)) + elif repository is not None and identity_rows is not None: + logger.info( + "runtime_run_persistence_skipped_missing_user", + extra={ + "run_id": run_id, + "workspace_id": workspace_id, + "user_id": user_id, + "session_id": sess_id, + "tenant_id": str(identity_rows.tenant_id), + "code": "identity_missing_user", + }, + ) + + lifecycle = ExecutionLifecycleManager( + run_id=run_id, + workspace_id=workspace_id, + user_id=user_id, + session_id=sess_id, + execution_emitter=execution_emitter, + step_builder=step_builder, + repository=repository, + identity_rows=identity_rows, + active_run_db_id=active_run_db_id, + strict_persistence=persistence_required, + session_record=session_record, + ) + return lifecycle, step_builder, run_id, active_run_db_id + + +__all__ = [ + "ExecutionLifecycleManager", + "initialize_turn_lifecycle", + "build_execution_event", + "map_execution_step_type", + "EXECUTION_TO_RUN_STEP_TYPE", +] diff --git a/src/fleet_rlm/api/runtime_services/session_manifest.py b/src/fleet_rlm/api/runtime_services/session_manifest.py new file mode 100644 index 000000000..c3b3c4977 --- /dev/null +++ b/src/fleet_rlm/api/runtime_services/session_manifest.py @@ -0,0 +1,307 @@ +"""Daytona manifest/volume I/O for runtime services.""" + +from __future__ import annotations + +import json +import logging +import posixpath +import shlex +from pathlib import PurePosixPath +from typing import Any + +from fleet_rlm.api.runtime_services.session_paths import ( + session_conversation_path, + session_scratchpad_path, + session_workspace_link_path, +) +from fleet_rlm.runtime.execution.interpreter_protocol import ExecutionProfile +from fleet_rlm.utils.identity import sanitize_id as _sanitize_id + +logger = logging.getLogger(__name__) + + +def _is_final_output(result: Any) -> bool: + from dspy.primitives import FinalOutput + + return isinstance(result, FinalOutput) + + +def _manifest_path(workspace_id: str, user_id: str, session_id: str) -> str: + _ = workspace_id, user_id + conversation_path = session_conversation_path(session_id) + if conversation_path is not None: + return conversation_path + safe_session_id = _sanitize_id(session_id, "default-session") + return f"meta/workspaces/{workspace_id}/users/{user_id}/react-session-{safe_session_id}.json" + + +def _get_existing_daytona_session(agent: Any) -> Any | None: + interpreter = getattr(agent, "interpreter", None) + workspace = getattr(interpreter, "_workspace", None) + if workspace is None: + return None + return getattr(workspace, "_session", None) + + +async def _aget_daytona_session(agent: Any, *, allow_create: bool = True) -> Any | None: + try: + from fleet_rlm.integrations.daytona.interpreter import DaytonaInterpreter + except ImportError: + return None + + interpreter = getattr(agent, "interpreter", None) + if not isinstance(interpreter, DaytonaInterpreter): + return None + if not allow_create: + return _get_existing_daytona_session(agent) + aget_session = getattr(interpreter, "aget_session", None) + if aget_session is None or not callable(aget_session): + return None + return await aget_session() + + +async def release_idle_daytona_session(agent: Any) -> None: + """Best-effort release of an already-created Daytona sandbox session.""" + interpreter = getattr(agent, "interpreter", None) + if interpreter is None: + return + if _get_existing_daytona_session(agent) is None: + return + release_idle = getattr(interpreter, "arelease_idle_session", None) + if callable(release_idle): + try: + await release_idle() + except Exception: + logger.warning("Failed to release idle Daytona session", exc_info=True) + + +def _persistent_storage_path(interpreter: Any, path: str) -> str: + raw_root = str(getattr(interpreter, "volume_mount_path", "/data") or "/data") + mount_root = posixpath.normpath(raw_root) + candidate = PurePosixPath(path) + if candidate.is_absolute(): + resolved = posixpath.normpath(str(candidate)) + else: + resolved = posixpath.normpath(str(PurePosixPath(mount_root) / candidate)) + if not resolved.startswith(mount_root + "/") and resolved != mount_root: + raise ValueError(f"Path {path!r} resolves outside volume mount path.") + return resolved + + +def _session_workspace_target(daytona_session: Any, interpreter: Any) -> str: + return str( + getattr(daytona_session, "workspace_path", None) + or getattr(interpreter, "workspace_path", None) + or getattr(interpreter, "repo_path", None) + or "" + ).strip() + + +def _ensure_session_layout_command(*, scratchpad_path: str, workspace_link_path: str, workspace_target: str) -> str: + return " ".join( + [ + "mkdir", + "-p", + shlex.quote(scratchpad_path), + "&&", + "rm", + "-rf", + shlex.quote(workspace_link_path), + "&&", + "ln", + "-s", + shlex.quote(workspace_target), + shlex.quote(workspace_link_path), + ] + ) + + +async def ensure_session_volume_layout( + agent: Any, + session_id: str, + *, + allow_session_create: bool = True, +) -> dict[str, str]: + """Ensure Phase 1 per-session scratchpad and workspace mapping exist on the volume.""" + interpreter = agent.interpreter + if interpreter is None: + return {} + scratchpad_path = session_scratchpad_path(session_id) + workspace_link_path = session_workspace_link_path(session_id) + if scratchpad_path is None or workspace_link_path is None: + return {} + storage_scratchpad_path = _persistent_storage_path(interpreter, scratchpad_path) + storage_workspace_link_path = _persistent_storage_path(interpreter, workspace_link_path) + daytona_session = await _aget_daytona_session(agent, allow_create=allow_session_create) + if daytona_session is None and not allow_session_create: + return { + "scratchpad_path": storage_scratchpad_path, + "workspace_link_path": storage_workspace_link_path, + } + workspace_target = _session_workspace_target(daytona_session, interpreter) + if not workspace_target: + return { + "scratchpad_path": storage_scratchpad_path, + "workspace_link_path": storage_workspace_link_path, + } + if daytona_session is not None: + process = getattr(getattr(daytona_session, "sandbox", None), "process", None) + exec_command = getattr(process, "exec", None) + if callable(exec_command): + try: + exec_command( + _ensure_session_layout_command( + scratchpad_path=storage_scratchpad_path, + workspace_link_path=storage_workspace_link_path, + workspace_target=workspace_target, + ) + ) + return { + "scratchpad_path": storage_scratchpad_path, + "workspace_link_path": storage_workspace_link_path, + } + except Exception as exc: + logger.warning( + "ensure_session_volume_layout: Daytona exec_command failed, falling back to interpreter aexecute: %s", + exc, + ) + await interpreter.aexecute( + "\n".join( + [ + "import os", + "os.makedirs(scratchpad_path, exist_ok=True)", + "if os.path.isdir(workspace_target):", + " if os.path.lexists(workspace_link_path):", + " if os.path.isdir(workspace_link_path) and not os.path.islink(workspace_link_path):", + " import shutil", + " shutil.rmtree(workspace_link_path)", + " else:", + " os.unlink(workspace_link_path)", + " os.symlink(workspace_target, workspace_link_path)", + "else:", + " import warnings", + " warnings.warn(f'Workspace target {workspace_target} does not exist, skipping symlink creation')", + "SUBMIT(scratchpad_path=scratchpad_path, workspace_link_path=workspace_link_path)", + ] + ), + variables={ + "scratchpad_path": storage_scratchpad_path, + "workspace_link_path": storage_workspace_link_path, + "workspace_target": workspace_target, + }, + execution_profile=ExecutionProfile.MAINTENANCE, + ) + return { + "scratchpad_path": storage_scratchpad_path, + "workspace_link_path": storage_workspace_link_path, + } + + +def _parse_manifest_text(text: str) -> dict[str, Any]: + if not text or text.startswith("[file not found:") or text.startswith("[error:"): + return {} + try: + parsed = json.loads(text) + return parsed if isinstance(parsed, dict) else {} + except json.JSONDecodeError: + return {} + + +async def load_manifest_from_volume( + agent: Any, + path: str, + fallback_paths: list[str] | None = None, + *, + allow_session_create: bool = True, +) -> dict[str, Any]: + """Best-effort manifest load from interpreter volume storage.""" + interpreter = agent.interpreter + if interpreter is None: + return {} + candidate_paths = [path, *(fallback_paths or [])] + daytona_session = await _aget_daytona_session(agent, allow_create=allow_session_create) + if daytona_session is not None: + for candidate_path in candidate_paths: + storage_path = _persistent_storage_path(interpreter, candidate_path) + try: + text = await daytona_session.aread_file(storage_path) + except Exception: + logger.debug( + "manifest_load_daytona_read_error", + extra={"path": storage_path}, + exc_info=True, + ) + continue + parsed = _parse_manifest_text(text) + if parsed: + return parsed + return {} + if not allow_session_create: + return {} + for candidate_path in candidate_paths: + result = await interpreter.aexecute( + "text = load_from_volume(path)\nSUBMIT(text=text)", + variables={"path": candidate_path}, + execution_profile=ExecutionProfile.MAINTENANCE, + ) + if not _is_final_output(result): + continue + output = getattr(result, "output", None) + output = output if isinstance(output, dict) else {} + parsed = _parse_manifest_text(str(output.get("text", ""))) + if parsed: + return parsed + return {} + + +async def save_manifest_to_volume( + agent: Any, + path: str, + manifest: dict[str, Any], + *, + allow_session_create: bool = True, +) -> str | None: + """Best-effort manifest save to interpreter volume storage.""" + interpreter = agent.interpreter + if interpreter is None: + return None + payload = json.dumps(manifest, ensure_ascii=False, default=str) + daytona_session = await _aget_daytona_session(agent, allow_create=allow_session_create) + if daytona_session is not None: + storage_path = _persistent_storage_path(interpreter, path) + try: + return await daytona_session.awrite_file(storage_path, payload) + except Exception: + logger.warning( + "manifest_save_daytona_write_error", + extra={"path": storage_path}, + exc_info=True, + ) + return None + if not allow_session_create: + return None + result = await interpreter.aexecute( + "saved_path = save_to_volume(path, payload)\nSUBMIT(saved_path=saved_path)", + variables={"path": path, "payload": payload}, + execution_profile=ExecutionProfile.MAINTENANCE, + ) + if not _is_final_output(result): + return None + output = getattr(result, "output", None) + output = output if isinstance(output, dict) else {} + saved_path = str(output.get("saved_path", "")) + if saved_path.startswith("["): + return None + return saved_path or None + + +__all__ = [ + "load_manifest_from_volume", + "save_manifest_to_volume", + "ensure_session_volume_layout", + "release_idle_daytona_session", + "_manifest_path", + "_aget_daytona_session", + "_persistent_storage_path", + "_is_final_output", +] diff --git a/src/fleet_rlm/api/runtime_services/session_persistence.py b/src/fleet_rlm/api/runtime_services/session_persistence.py new file mode 100644 index 000000000..f8eafbc21 --- /dev/null +++ b/src/fleet_rlm/api/runtime_services/session_persistence.py @@ -0,0 +1,405 @@ +"""Session persistence orchestration for runtime services.""" + +from __future__ import annotations + +import logging +import uuid +from typing import Any + +from fleet_rlm.api.runtime_services.stream_failures import PersistenceRequiredError +from fleet_rlm.integrations.database import ( + FleetRepository, + MemoryKind, + MemoryScope, + MemorySource, +) +from fleet_rlm.integrations.database.repository_identity import IdentityUpsertResult +from fleet_rlm.integrations.database.repository_memory import MemoryItemCreateRequest +from fleet_rlm.utils.logging import sanitize_for_log as _sanitize_for_log +from fleet_rlm.utils.time import now_iso + +from ..dependencies import SessionCacheDeps + +logger = logging.getLogger(__name__) + + +def ensure_manifest_shape(manifest: dict[str, Any]) -> dict[str, Any]: + """Normalize mutable manifest structure and expected keys.""" + if not isinstance(manifest.get("logs"), list): + manifest["logs"] = [] + if not isinstance(manifest.get("memory"), list): + manifest["memory"] = [] + if not isinstance(manifest.get("generated_docs"), list): + manifest["generated_docs"] = [] + if not isinstance(manifest.get("artifacts"), list): + manifest["artifacts"] = [] + if not isinstance(manifest.get("metadata"), dict): + manifest["metadata"] = {} + return manifest + + +def update_manifest_from_exported_state( + *, + manifest: dict[str, Any], + exported_state: dict[str, Any], + latest_user_message: str, +) -> tuple[int, int]: + """Update manifest with latest state snapshot and optional user message entry.""" + ensure_manifest_shape(manifest) + + logs = manifest["logs"] + memory = manifest["memory"] + generated_docs = manifest["generated_docs"] + artifacts = manifest["artifacts"] + metadata = manifest["metadata"] + + if latest_user_message: + logs.append( + { + "timestamp": now_iso(), + "user_message": latest_user_message, + "history_turns": len(exported_state.get("history", [])), + } + ) + memory.append( + { + "timestamp": now_iso(), + "content": latest_user_message[:400], + } + ) + + generated_docs[:] = sorted(list(exported_state.get("documents", {}).keys())) + + previous_rev_raw = manifest.get("rev", 0) + previous_rev_candidate = previous_rev_raw if isinstance(previous_rev_raw, (int, float, str)) else 0 + try: + previous_rev = int(previous_rev_candidate) + except (TypeError, ValueError): + previous_rev = 0 + + next_rev = previous_rev + 1 + manifest["rev"] = next_rev + metadata["updated_at"] = now_iso() + metadata["history_turns"] = len(exported_state.get("history", [])) + metadata["document_count"] = len(exported_state.get("documents", {})) + metadata["artifact_count"] = len(artifacts) + manifest["state"] = exported_state + return previous_rev, next_rev + + +def sync_session_record_state( + *, + session_cache: SessionCacheDeps, + session_record: dict[str, Any], + exported_state: dict[str, Any], +) -> None: + """Propagate exported state into session record and state cache.""" + session_data = session_record.get("session") + if not isinstance(session_data, dict): + session_data = {} + session_record["session"] = session_data + session_data["state"] = exported_state + session_data["session_id"] = session_record.get("session_id") + + record_key = session_record.get("key") + if isinstance(record_key, str): + session_cache.sessions[record_key] = session_record + + +async def persist_memory_item_if_needed( + *, + repository: FleetRepository | None, + identity_rows: IdentityUpsertResult | None, + active_run_db_id: Any, + latest_user_message: str, + persistence_required: bool, +) -> None: + """Persist a user-input memory item when repository context is available.""" + if not latest_user_message or repository is None or identity_rows is None: + return + try: + await repository.store_memory_item( + MemoryItemCreateRequest( + tenant_id=identity_rows.tenant_id, + workspace_id=identity_rows.workspace_id, + user_id=identity_rows.user_id, + run_id=active_run_db_id, + scope=MemoryScope.RUN if active_run_db_id is not None else MemoryScope.USER, + scope_id=str(active_run_db_id or identity_rows.user_id), + kind=MemoryKind.NOTE, + source=MemorySource.USER_INPUT, + content_text=latest_user_message[:1000], + tags=["ws", "chat"], + ) + ) + except Exception as exc: + if persistence_required: + raise PersistenceRequiredError( + "memory_item_persist_failed", + f"Failed to persist memory item: {exc}", + ) from exc + logger.warning("Failed to persist memory item: %s", _sanitize_for_log(exc)) + + +async def _persist_manifest_to_local_store( + *, + persistence: Any, + sess_id: str, + manifest: dict[str, Any], +) -> None: + """Write the manifest into LocalStore/FleetRepository session metadata. + + Used as a fallback when no Daytona volume is available (interpreter=None) so + that session state survives process restarts between WebSocket connections. + """ + if persistence is None: + return + update_fn = getattr(persistence, "update_chat_session", None) + if not callable(update_fn): + return + try: + import inspect + + sig = inspect.signature(update_fn) + # LocalStore.update_chat_session requires tenant_id + session_id UUIDs; the + # async FleetRepository variant has the same shape. Both accept metadata_json. + # We store under the raw external_session_id key so the restore helper can + # locate it without a UUID round-trip. + params = set(sig.parameters) + if "external_session_id" in params: + await update_fn(external_session_id=sess_id, metadata_json={"_manifest_state": manifest}) + else: + # Async path: skip – we cannot derive the UUID here without identity_rows. + pass + except Exception: + logger.debug("Best-effort manifest persist to local store failed", exc_info=True) + + +async def _restore_manifest_from_local_store( + *, + persistence: Any, + sess_id: str, +) -> dict[str, Any]: + """Read a previously persisted manifest from LocalStore session metadata. + + Returns an empty dict when nothing is found or an error occurs. + """ + if persistence is None: + return {} + get_fn = getattr(persistence, "get_chat_session_by_external_id", None) + if not callable(get_fn): + return {} + try: + row = await get_fn(external_session_id=sess_id) + if row is None: + return {} + metadata = getattr(row, "metadata_json", None) + if not isinstance(metadata, dict): + return {} + manifest = metadata.get("_manifest_state") + return manifest if isinstance(manifest, dict) else {} + except Exception: + logger.debug("Best-effort manifest restore from local store failed", exc_info=True) + return {} + + +async def persist_session_state( + *, + session_cache: SessionCacheDeps, + agent: Any, + session_record: dict[str, Any] | None, + active_manifest_path: str | None, + active_run_db_id: uuid.UUID | None, + interpreter: Any | None, + repository: FleetRepository | None, + identity_rows: IdentityUpsertResult | None, + persistence_required: bool, + include_volume_save: bool = True, + latest_user_message: str = "", + persistence: Any = None, + allow_volume_session_create: bool = True, + release_idle_session: bool = False, +) -> None: + """Persist current session state and optionally release the live Daytona sandbox.""" + try: + await _persist_session_state_impl( + session_cache=session_cache, + agent=agent, + session_record=session_record, + active_manifest_path=active_manifest_path, + active_run_db_id=active_run_db_id, + interpreter=interpreter, + repository=repository, + identity_rows=identity_rows, + persistence_required=persistence_required, + include_volume_save=include_volume_save, + latest_user_message=latest_user_message, + persistence=persistence, + allow_volume_session_create=allow_volume_session_create, + ) + finally: + if release_idle_session: + from fleet_rlm.api.runtime_services.session_manifest import release_idle_daytona_session + + await release_idle_daytona_session(agent) + + +async def _persist_session_state_impl( + *, + session_cache: SessionCacheDeps, + agent: Any, + session_record: dict[str, Any] | None, + active_manifest_path: str | None, + active_run_db_id: uuid.UUID | None, + interpreter: Any | None, + repository: FleetRepository | None, + identity_rows: IdentityUpsertResult | None, + persistence_required: bool, + include_volume_save: bool = True, + latest_user_message: str = "", + persistence: Any = None, + allow_volume_session_create: bool = True, +) -> None: + """Persist current session state to in-memory cache, volume, and DB.""" + if session_record is None: + return + exported_state = agent.export_session_state() + manifest = session_record.get("manifest") + if not isinstance(manifest, dict): + manifest = {} + session_record["manifest"] = manifest + + ensure_manifest_shape(manifest) + previous_rev, _next_rev = update_manifest_from_exported_state( + manifest=manifest, + exported_state=exported_state, + latest_user_message=latest_user_message, + ) + sync_session_record_state( + session_cache=session_cache, + session_record=session_record, + exported_state=exported_state, + ) + + if include_volume_save and active_manifest_path and interpreter is not None: + from fleet_rlm.api.runtime_services.session_manifest import ( + _aget_daytona_session, + load_manifest_from_volume, + save_manifest_to_volume, + ) + + existing_session = None + if not allow_volume_session_create: + existing_session = await _aget_daytona_session(agent, allow_create=False) + if allow_volume_session_create or existing_session is not None: + remote_manifest = await load_manifest_from_volume( + agent, + active_manifest_path, + allow_session_create=allow_volume_session_create, + ) + remote_rev_raw = remote_manifest.get("rev", 0) + remote_rev_candidate = remote_rev_raw if isinstance(remote_rev_raw, (int, float, str)) else 0 + try: + remote_rev = int(remote_rev_candidate) + except (TypeError, ValueError): + remote_rev = 0 + + if remote_rev > previous_rev: + message = ( + f"Session manifest revision conflict detected (remote_rev={remote_rev}, local_rev={previous_rev})" + ) + if persistence_required: + raise PersistenceRequiredError("manifest_conflict", message) + logger.warning(message) + else: + saved_path = await save_manifest_to_volume( + agent, + active_manifest_path, + manifest, + allow_session_create=allow_volume_session_create, + ) + if saved_path is None: + message = f"Failed to save session manifest to volume (path={active_manifest_path})" + if persistence_required: + raise PersistenceRequiredError("manifest_write_failed", message) + logger.warning(message) + else: + logger.debug( + "Skipping Daytona volume persistence because cleanup has no active session (path=%s)", + active_manifest_path, + ) + # Always persist to local store when persistence is available — this is the + # durable fallback that survives sandbox churn. Pool-based dispatch means + # each turn may acquire a *different* Daytona sandbox, so the volume save + # above lands on the current sandbox while the *next* turn's new sandbox + # volume starts empty. The local store is sandbox-independent and bridges + # the gap. We write it regardless of whether a volume save also happened. + if include_volume_save and persistence is not None: + sess_id = str(session_record.get("session_id") or "") + if sess_id: + await _persist_manifest_to_local_store( + persistence=persistence, + sess_id=sess_id, + manifest=manifest, + ) + + await persist_memory_item_if_needed( + repository=repository, + identity_rows=identity_rows, + active_run_db_id=active_run_db_id, + latest_user_message=latest_user_message, + persistence_required=persistence_required, + ) + + +def build_local_persist_fn( + *, + session_cache: SessionCacheDeps, + runtime: Any, + agent: Any, + interpreter: Any, + session: Any, +): + async def local_persist( + *, + include_volume_save: bool = True, + latest_user_message: str = "", + allow_volume_session_create: bool = True, + release_idle_session: bool = False, + ) -> None: + try: + await persist_session_state( + session_cache=session_cache, + agent=agent, + session_record=session.session_record, + active_manifest_path=session.active_manifest_path, + active_run_db_id=session.active_run_db_id, + interpreter=interpreter, + repository=runtime.repository, + identity_rows=runtime.identity_rows, + persistence_required=runtime.persistence_required, + include_volume_save=include_volume_save, + latest_user_message=latest_user_message, + persistence=runtime.persistence, + allow_volume_session_create=allow_volume_session_create, + release_idle_session=False, + ) + finally: + if release_idle_session: + from fleet_rlm.api.runtime_services.session_manifest import release_idle_daytona_session + + await release_idle_daytona_session(agent) + + return local_persist + + +__all__ = [ + "persist_session_state", + "build_local_persist_fn", + "ensure_manifest_shape", + "update_manifest_from_exported_state", + "sync_session_record_state", + "persist_memory_item_if_needed", + "_persist_manifest_to_local_store", + "_restore_manifest_from_local_store", +] diff --git a/src/fleet_rlm/api/runtime_services/stream_failures.py b/src/fleet_rlm/api/runtime_services/stream_failures.py new file mode 100644 index 000000000..5f3501a9a --- /dev/null +++ b/src/fleet_rlm/api/runtime_services/stream_failures.py @@ -0,0 +1,32 @@ +"""Stream failure classification for websocket runtime services.""" + +from __future__ import annotations + + +class PersistenceRequiredError(RuntimeError): + """Raised when durable writes fail in strict-persistence mode.""" + + def __init__(self, code: str, message: str) -> None: + super().__init__(message) + self.code = code + self.message = message + + +def classify_stream_failure(exc: Exception) -> str: + """Map runtime failures to stable websocket-facing error codes.""" + if isinstance(exc, PersistenceRequiredError): + return exc.code + + lowered = str(exc).lower() + if "planner lm not configured" in lowered: + return "planner_missing" + if "llm call timed out" in lowered or "timed out" in lowered and "llm" in lowered: + return "llm_timeout" + if "rate limit" in lowered or "429" in lowered: + return "llm_rate_limited" + if "sandbox" in lowered or "daytona" in lowered: + return "sandbox_unavailable" + return "internal_error" + + +__all__ = ["PersistenceRequiredError", "classify_stream_failure"] diff --git a/src/fleet_rlm/integrations/daytona/concurrency.py b/src/fleet_rlm/integrations/daytona/concurrency.py index 4c36886bc..64e4ead80 100644 --- a/src/fleet_rlm/integrations/daytona/concurrency.py +++ b/src/fleet_rlm/integrations/daytona/concurrency.py @@ -73,6 +73,7 @@ class SandboxUsageStats(BaseModel): # Module-level semaphore state # --------------------------------------------------------------------------- + class _FleetSandboxSemaphore(asyncio.Semaphore): """Semaphore with a configurable release bound for reconciled state.""" @@ -192,8 +193,7 @@ def reconcile_sandbox_slots(*, provider_active_count: int) -> SandboxUsageStats: available = max(0, limit - clamped_active) _GLOBAL_SEMAPHORE = _FleetSandboxSemaphore(value=available, bound=limit) logger.warning( - "Reconciled Fleet sandbox slots from provider state " - "(provider_active=%d, limit=%d, available=%d)", + "Reconciled Fleet sandbox slots from provider state (provider_active=%d, limit=%d, available=%d)", clamped_active, limit, available, diff --git a/src/fleet_rlm/runtime/agent/runtime.py b/src/fleet_rlm/runtime/agent/runtime.py index fcd5defc5..319e4e273 100644 --- a/src/fleet_rlm/runtime/agent/runtime.py +++ b/src/fleet_rlm/runtime/agent/runtime.py @@ -19,6 +19,7 @@ from dspy.streaming import StreamListener, StreamResponse from fleet_rlm.integrations.daytona.async_compat import _run_async_compat +from fleet_rlm.runtime.events import RuntimeEvent, RuntimeEventContext, RuntimeEventKind from fleet_rlm.runtime.execution.streaming_events import ( _normalize_trajectory, ) @@ -192,46 +193,33 @@ async def _call_react_tool(tool: Any, tool_args: dict[str, Any]) -> Any: return await asyncio.to_thread(tool, **tool_args) -def _build_tool_call_event(*, tool_name: str, tool_args: dict[str, Any], step_index: int) -> StreamEvent: - return StreamEvent( - kind="tool_call", - text=f"Calling tool: {tool_name}({tool_args})", - payload={ - "tool_name": tool_name, - "tool_input": str(tool_args), - "tool_args": tool_args, - "step_index": step_index, - }, +def _build_tool_call_event(*, tool_name: str, tool_args: dict[str, Any], step_index: int) -> RuntimeEvent: + return RuntimeEvent.tool_call( + tool_name=tool_name, + tool_args=tool_args, + step_index=step_index, ) -def _build_tool_result_event(*, tool_name: str, observation: Any, step_index: int) -> StreamEvent: - return StreamEvent( - kind="tool_result", - text=f"Tool result: {observation}", - payload={ - "tool_name": tool_name, - "tool_output": str(observation), - "step_index": step_index, - }, +def _build_tool_result_event(*, tool_name: str, observation: Any, step_index: int) -> RuntimeEvent: + return RuntimeEvent.tool_result( + tool_name=tool_name, + observation=observation, + step_index=step_index, ) -def _build_clarification_event(observation: Any) -> StreamEvent | None: +def _build_clarification_event(observation: Any) -> RuntimeEvent | None: if not isinstance(observation, dict) or observation.get("status") != "clarification_needed": return None import uuid as _uuid - return StreamEvent( - kind="clarification", - text=str(observation.get("question", "Please clarify your intent.")), - payload={ - "message_id": str(observation.get("message_id") or f"clar-{_uuid.uuid4().hex[:8]}"), - "question": observation.get("question"), - "step_label": observation.get("step_label", "Clarification needed"), - "options": observation.get("options", []), - }, + return RuntimeEvent.clarification( + message_id=str(observation.get("message_id") or f"clar-{_uuid.uuid4().hex[:8]}"), + question=observation.get("question"), + step_label=observation.get("step_label", "Clarification needed"), + options=observation.get("options", []), ) @@ -261,6 +249,7 @@ def __init__( repository: Any | None = None, use_escalation: bool = True, summary_interval: int = 10, + compaction_threshold_pct: float = 0.7, ) -> None: from .agent import FleetAgent @@ -269,6 +258,11 @@ def __init__( self.history_max_turns: int | None = history_max_turns self.core_memory: dict[str, str] = self.default_core_memory() + # Phase 7: attach runtime reference to interpreter so recursive children + # can access parent history for bounded conversation snapshots + if interpreter is not None: + setattr(interpreter, "runtime", self) + # Session-management hooks used by the websocket layer self._db_session_id: str | object | None = None self._repository: Any | None = repository @@ -288,6 +282,9 @@ def __init__( self._turns_since_summary: int = 0 self._use_escalation: bool = use_escalation + # Phase 7: token-budget-aware compaction threshold (keep history_max_turns as ceiling) + self._compaction_threshold_pct: float = max(0.0, min(1.0, compaction_threshold_pct)) + # Discover tools from the registry; append any extra tools base_tools = discover_tools() base_tools = bind_runtime_tools( @@ -320,18 +317,44 @@ def __init__( # Chat API # ----------------------------------------------------------------- + def _estimate_history_chars(self) -> int: + """Estimate character count of history as a proxy for token usage.""" + messages = list(getattr(self.history, "messages", []) or []) + return sum(len(str(msg.get("user_message", ""))) + len(str(msg.get("response", ""))) for msg in messages) + def _maybe_refresh_summary(self) -> None: - """Regenerate conversation_summary every ``_summary_interval`` turns.""" + """Regenerate conversation_summary based on token budget (Phase 7). + + Phase 7: Compacts history when estimated token usage crosses the + compaction threshold, while keeping history_max_turns as a hard ceiling. + """ if not self._use_escalation: return + self._turns_since_summary += 1 - if self._turns_since_summary >= self._summary_interval: + + # Phase 7: token-budget-aware compaction + # Estimate history size and check against threshold + history_chars = self._estimate_history_chars() + # Use a reasonable default max context window if not available (64K tokens) + # Approximate 4 chars per token for estimation + max_context_chars = 64000 * 4 + threshold_chars = int(max_context_chars * self._compaction_threshold_pct) + + # Compact if threshold exceeded or interval reached + should_compact = history_chars > threshold_chars or self._turns_since_summary >= self._summary_interval + + if should_compact: escalating = self.agent if hasattr(escalating, "compress_history"): try: self.conversation_summary = escalating.compress_history(self.history) self._turns_since_summary = 0 - logger.debug("AgentRuntime: conversation summary refreshed") + logger.debug( + "AgentRuntime: conversation summary refreshed (chars=%d, threshold=%d)", + history_chars, + threshold_chars, + ) except Exception as exc: logger.warning("AgentRuntime: summary refresh failed: %s", exc) @@ -356,14 +379,39 @@ def preview_routing(self, *, user_request: str, execution_mode: str = "auto") -> payload = preview_routing(user_request=user_request, execution_mode=execution_mode) return payload if isinstance(payload, dict) else {} + def _recursion_depth_state(self) -> tuple[int, int]: + """Return ``(depth, max_depth)`` for the current runtime/interpreter. + + The root runtime is depth ``0``; ``max_depth`` is the configured + ``sub_rlm`` recursion ceiling carried on the interpreter (default 2). + """ + depth = int(getattr(self.interpreter, "_sub_rlm_depth", 0) or 0) + max_depth = int(getattr(self.interpreter, "_sub_rlm_max_depth", 2) or 2) + return depth, max_depth + + def _runtime_event_context(self) -> RuntimeEventContext: + """Build the canonical runtime context (incl. recursion depth) for events. + + Phase 7: surfaces ``depth``/``max_depth`` so the frontend run-workbench + can render recursion depth from typed ``RuntimeEvent.context`` fields. + """ + depth, max_depth = self._recursion_depth_state() + return RuntimeEventContext( + execution_mode=self.execution_mode, + depth=depth, + max_depth=max_depth, + ) + def _runtime_observability_payload(self) -> dict[str, Any]: """Return runtime metadata shared by streamed completion events.""" + depth, max_depth = self._recursion_depth_state() return { "execution_mode": self.execution_mode, "runtime_module": type(self.agent).__name__, "escalation_enabled": self._use_escalation, "conversation_summary_available": bool(self.conversation_summary), "loaded_document_count": len(self.loaded_document_paths), + "recursion": {"depth": depth, "max_depth": max_depth}, "rlm_limits": { "max_iterations": self.rlm_max_iterations, "max_llm_calls": self.rlm_max_llm_calls, @@ -400,17 +448,17 @@ async def _aiter_chat_turn_stream_posthoc( *, message: str, cancel_check: Callable[[], bool] | None, - ) -> AsyncIterator[StreamEvent]: + ) -> AsyncIterator[RuntimeEvent]: """Fallback stream path that emits events after the turn finishes.""" if cancel_check is not None and cancel_check(): - yield StreamEvent( - kind="done", + yield RuntimeEvent( + kind=RuntimeEventKind.DONE, text="[cancelled]", payload={"cancelled": True, "history_turns": self.history_turns()}, ) return - yield StreamEvent(kind="status", text="Starting turn...") + yield RuntimeEvent.status("Starting turn...") preview_routing = getattr(self.agent, "preview_routing", None) if callable(preview_routing): routing_preview = preview_routing( @@ -418,9 +466,8 @@ async def _aiter_chat_turn_stream_posthoc( execution_mode=self.execution_mode, ) if isinstance(routing_preview, dict) and routing_preview.get("routing_decision"): - yield StreamEvent( - kind="status", - text=_routing_status_text(routing_preview), + yield RuntimeEvent.status( + _routing_status_text(routing_preview), payload=routing_preview, ) @@ -430,16 +477,16 @@ async def _aiter_chat_turn_stream_posthoc( **self._escalation_call_args(message), ) except Exception as exc: - yield StreamEvent( - kind="error", + yield RuntimeEvent( + kind=RuntimeEventKind.ERROR, text=str(exc), payload={"history_turns": self.history_turns()}, ) return if cancel_check is not None and cancel_check(): - yield StreamEvent( - kind="done", + yield RuntimeEvent( + kind=RuntimeEventKind.DONE, text="[cancelled]", payload={"cancelled": True, "history_turns": self.history_turns()}, ) @@ -452,9 +499,8 @@ async def _aiter_chat_turn_stream_posthoc( routing_payload = _runtime_routing_payload(result) if routing_payload.get("selected_skills") or routing_payload.get("routing_decision"): - yield StreamEvent( - kind="status", - text=_routing_status_text(routing_payload), + yield RuntimeEvent.status( + _routing_status_text(routing_payload), payload=routing_payload, ) @@ -463,64 +509,45 @@ async def _aiter_chat_turn_stream_posthoc( tool_name = step.get("tool_name") is_terminal = (tool_name == "finish") or (not tool_name) if thought and not is_terminal: - yield StreamEvent( - kind="reasoning", - text=str(thought), - payload={"phase": "reasoning"}, - ) + yield RuntimeEvent.reasoning(str(thought)) tool_name = step.get("tool_name") if tool_name: tool_args = step.get("tool_args") or step.get("input", "") - yield StreamEvent( - kind="tool_call", - text=f"Calling tool: {tool_name}({tool_args})", - payload={ - "tool_name": tool_name, - "tool_input": str(tool_args), - "tool_args": tool_args, - "step": step, - "trajectory_index": step.get("index"), - }, + traj_idx = step.get("index") + tool_ev = RuntimeEvent.tool_call( + tool_name=tool_name, + tool_args=tool_args if isinstance(tool_args, dict) else {"input": tool_args}, + step_index=traj_idx, ) + tool_ev.payload["step"] = step + tool_ev.payload["trajectory_index"] = traj_idx + yield tool_ev observation = step.get("observation") or step.get("output", "") if observation and tool_name: - yield StreamEvent( - kind="tool_result", - text=f"Tool result: {observation}", - payload={ - "tool_name": tool_name, - "tool_output": str(observation), - "output": observation, - "step": step, - "trajectory_index": step.get("index"), - }, + result_ev = RuntimeEvent.tool_result( + tool_name=tool_name, + observation=observation, + step_index=step.get("index"), ) - if isinstance(observation, dict) and observation.get("status") == "clarification_needed": - import uuid as _uuid - - clar_payload = observation - yield StreamEvent( - kind="clarification", - text=str(clar_payload.get("question", "Please clarify your intent.")), - payload={ - "message_id": str(clar_payload.get("message_id") or f"clar-{_uuid.uuid4().hex[:8]}"), - "question": clar_payload.get("question"), - "step_label": clar_payload.get("step_label", "Clarification needed"), - "options": clar_payload.get("options", []), - }, - ) + result_ev.payload["output"] = observation + result_ev.payload["step"] = step + result_ev.payload["trajectory_index"] = step.get("index") + yield result_ev + clar_ev = _build_clarification_event(observation) + if clar_ev is not None: + yield clar_ev if degradation_payload: - yield StreamEvent( - kind="warning", + yield RuntimeEvent( + kind=RuntimeEventKind.WARNING, text=str(degradation_payload["runtime_warning"]), payload=degradation_payload, ) if response: - yield StreamEvent(kind="text", text=response) + yield RuntimeEvent(kind=RuntimeEventKind.TEXT, text=response) self.history = _append_turn_to_history( self.history, @@ -537,7 +564,12 @@ async def _aiter_chat_turn_stream_posthoc( done_payload.update(self._runtime_observability_payload()) done_payload.update(degradation_payload) done_payload.update(routing_payload) - yield StreamEvent(kind="done", text=response, payload=done_payload) + yield RuntimeEvent( + kind=RuntimeEventKind.DONE, + text=response, + payload=done_payload, + context=self._runtime_event_context(), + ) # ----------------------------------------------------------------- # Async context manager (required by ChatAgentProtocol) diff --git a/src/fleet_rlm/runtime/agent/signatures.py b/src/fleet_rlm/runtime/agent/signatures.py index 1f098b030..62db440c3 100644 --- a/src/fleet_rlm/runtime/agent/signatures.py +++ b/src/fleet_rlm/runtime/agent/signatures.py @@ -437,10 +437,21 @@ class RLMVariableSignature(dspy.Signature): Per Algorithm 1 (arXiv 2512.24601v2): dspy.RLM stores input fields in the REPL automatically — the LLM sees only metadata (type, length, preview) and explores data through code execution. + + ``history`` is exposed as a native REPL variable so the model can inspect + full prior conversation turns with code (e.g. ``history.messages[-1]``) + rather than relying solely on a flattened recency snippet in ``prompt``. """ task: str = dspy.InputField(desc="The question or instruction to accomplish") prompt: str = dspy.InputField(desc="The full text to process (stored as REPL variable, not in LLM context)") + history: dspy.History = dspy.InputField( + desc=( + "Prior chat turns stored as a REPL variable (each message has keys " + "user_message and response). Inspect with Python for full conversation " + "continuity; the prompt only carries a short recency hint." + ) + ) answer: str = dspy.OutputField(desc="Final answer (call SUBMIT(answer=...) in REPL)") diff --git a/src/fleet_rlm/runtime/events.py b/src/fleet_rlm/runtime/events.py new file mode 100644 index 000000000..0dd451076 --- /dev/null +++ b/src/fleet_rlm/runtime/events.py @@ -0,0 +1,235 @@ +"""Canonical runtime event model — single source of truth for all streaming events. + +All structured event data is carried forward from the point of construction; +no re-parsing of rendered text downstream. + +Usage:: + + from fleet_rlm.runtime.events import RuntimeEvent, RuntimeEventKind + + event = RuntimeEvent( + kind=RuntimeEventKind.TOOL_CALL, + text="Calling repl_execute(...)", + tool=RuntimeToolInfo(tool_name="repl_execute", tool_args={"code": "..."}), + ) +""" + +from __future__ import annotations + +from datetime import datetime, timezone +from enum import Enum +from typing import Any + +from pydantic import BaseModel, Field + +EVENT_SCHEMA_VERSION: int = 3 + + +class RuntimeEventKind(str, Enum): + """All event kinds emitted by the runtime streaming pipeline.""" + + STATUS = "status" + TEXT = "text" + REASONING = "reasoning" + TOOL_CALL = "tool_call" + TOOL_RESULT = "tool_result" + WARNING = "warning" + ERROR = "error" + DONE = "done" + CLARIFICATION = "clarification" + TURN_STARTED = "turn_started" + SANDBOX_EXEC = "sandbox_exec" + RLM_DELEGATE = "rlm_delegate" + + @classmethod + def terminal_kinds(cls) -> frozenset[RuntimeEventKind]: + return frozenset({cls.DONE, cls.ERROR}) + + def is_terminal(self) -> bool: + return self in self.terminal_kinds() + + +class RuntimeToolInfo(BaseModel): + """Structured tool call / result data — never re-parsed from display text.""" + + tool_name: str + tool_args: dict[str, Any] = Field(default_factory=dict) + tool_input: str | None = None + tool_output: Any | None = None + step_index: int | None = None + + +class RuntimeActorContext(BaseModel): + """Who is producing this event (root agent, delegate, sub-agent).""" + + actor_kind: str | None = None + actor_id: str | None = None + parent_id: str | None = None + depth: int | None = None + max_depth: int | None = None + + +class RuntimeEventContext(BaseModel): + """Stable runtime environment context attached to backend-emitted events. + + This is the single canonical definition consumed by both the runtime event + factories and the API projection layer (``api/events/event_adapter.py``). + """ + + runtime_mode: str | None = None + execution_mode: str | None = None + execution_profile: str | None = None + sandbox_id: str | None = None + child_sandbox_id: str | None = None + volume_name: str | None = None + workspace_path: str | None = None + repo_url: str | None = None + repo_ref: str | None = None + document_path: str | None = None + depth: int | None = None + max_depth: int | None = None + actor_kind: str | None = None + actor_id: str | None = None + parent_id: str | None = None + lane_key: str | None = None + llm_call_budget: int | None = None + + +class RuntimeEvent(BaseModel): + """Canonical runtime event — one structured object, no downstream re-parsing. + + Satisfies :class:`~fleet_rlm.api.runtime_services.chat_runtime.StreamEventLike` + structurally (``kind``, ``text``, ``payload``, ``timestamp`` are all present). + """ + + kind: RuntimeEventKind + text: str = "" + payload: dict[str, Any] = Field(default_factory=dict) + tool: RuntimeToolInfo | None = None + actor: RuntimeActorContext | None = None + context: RuntimeEventContext | None = None + timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + schema_version: int = EVENT_SCHEMA_VERSION + + @classmethod + def tool_call( + cls, + *, + tool_name: str, + tool_args: dict[str, Any], + step_index: int | None = None, + actor: RuntimeActorContext | None = None, + context: RuntimeEventContext | None = None, + ) -> RuntimeEvent: + """Factory for a structured tool-call event.""" + tool_input = f"{tool_name}({tool_args})" + return cls( + kind=RuntimeEventKind.TOOL_CALL, + text=f"Calling tool: {tool_input}", + payload={ + "tool_name": tool_name, + "tool_input": str(tool_args), + "tool_args": tool_args, + "step_index": step_index, + }, + tool=RuntimeToolInfo( + tool_name=tool_name, + tool_args=tool_args, + tool_input=str(tool_args), + step_index=step_index, + ), + actor=actor, + context=context, + ) + + @classmethod + def tool_result( + cls, + *, + tool_name: str, + observation: Any, + step_index: int | None = None, + actor: RuntimeActorContext | None = None, + context: RuntimeEventContext | None = None, + ) -> RuntimeEvent: + """Factory for a structured tool-result event.""" + return cls( + kind=RuntimeEventKind.TOOL_RESULT, + text=f"Tool result: {observation}", + payload={ + "tool_name": tool_name, + "tool_output": str(observation), + "step_index": step_index, + }, + tool=RuntimeToolInfo( + tool_name=tool_name, + tool_output=observation, + step_index=step_index, + ), + actor=actor, + context=context, + ) + + @classmethod + def clarification( + cls, + *, + message_id: str, + question: str | None, + step_label: str = "Clarification needed", + options: list[Any] | None = None, + actor: RuntimeActorContext | None = None, + ) -> RuntimeEvent: + """Factory for a clarification event.""" + return cls( + kind=RuntimeEventKind.CLARIFICATION, + text=str(question or "Please clarify your intent."), + payload={ + "message_id": message_id, + "question": question, + "step_label": step_label, + "options": options or [], + }, + actor=actor, + ) + + @classmethod + def status( + cls, + text: str, + *, + payload: dict[str, Any] | None = None, + actor: RuntimeActorContext | None = None, + context: RuntimeEventContext | None = None, + ) -> RuntimeEvent: + return cls( + kind=RuntimeEventKind.STATUS, + text=text, + payload=payload or {}, + actor=actor, + context=context, + ) + + @classmethod + def reasoning( + cls, + text: str, + *, + actor: RuntimeActorContext | None = None, + ) -> RuntimeEvent: + return cls( + kind=RuntimeEventKind.REASONING, + text=text, + payload={"phase": "reasoning"}, + actor=actor, + ) + + +__all__ = [ + "EVENT_SCHEMA_VERSION", + "RuntimeEvent", + "RuntimeEventKind", + "RuntimeToolInfo", + "RuntimeActorContext", + "RuntimeEventContext", +] diff --git a/src/fleet_rlm/runtime/execution/llm_query.py b/src/fleet_rlm/runtime/execution/llm_query.py index 2b875d16d..8f1212573 100644 --- a/src/fleet_rlm/runtime/execution/llm_query.py +++ b/src/fleet_rlm/runtime/execution/llm_query.py @@ -9,6 +9,7 @@ import contextvars import logging +import re import threading from concurrent.futures import ( ThreadPoolExecutor, @@ -34,6 +35,81 @@ _LLM_BATCH_EXECUTOR = ThreadPoolExecutor(max_workers=8, thread_name_prefix="llm_batch") _SUB_RLM_BATCH_EXECUTOR = ThreadPoolExecutor(max_workers=4, thread_name_prefix="sub_rlm_batch") +# Phase 7: child conversation snapshot configuration +_CHILD_HISTORY_MAX_TURNS = 2 +_CHILD_HISTORY_MAX_CHARS = 2000 + +# Simple redaction patterns for child history snapshots +_SENSITIVE_PATTERNS = ( + (re.compile(r"sk-[A-Za-z0-9_-]{8,}"), "sk-***REDACTED***"), + (re.compile(r"(Authorization\s*:\s*Bearer\s+)[^\s]+", re.IGNORECASE), r"\1***REDACTED***"), + ( + re.compile( + r"((?:api[_-]?key|token|secret|password)\s*[=:]\s*)(?:\"[^\"]*\"|'[^']*'|[^\s,}\]]+)", re.IGNORECASE + ), + r"\1***REDACTED***", + ), +) + + +def _redact_sensitive(text: str) -> str: + """Replace API keys and other sensitive tokens with redaction markers.""" + redacted = text + for pattern, replacement in _SENSITIVE_PATTERNS: + redacted = pattern.sub(replacement, redacted) + return redacted + + +def _build_child_history_snapshot(interpreter: Any) -> str: + """Build a bounded, redacted conversation snapshot for recursive children. + + Phase 7: Children receive a fresh REPL (per reference) but need explicit + conversation continuity. This function extracts the last N turns from the + parent runtime's history, redacts sensitive values, and bounds the size. + + Args: + interpreter: The parent interpreter with a runtime reference. + + Returns: + A bounded, redacted conversation snapshot string. + """ + runtime = getattr(interpreter, "runtime", None) + if runtime is None: + return "" + + history = getattr(runtime, "history", None) + if history is None: + return "" + + messages = list(getattr(history, "messages", []) or []) + if not messages: + return "" + + # Take the last N turns + recent_messages = messages[-_CHILD_HISTORY_MAX_TURNS:] if len(messages) > _CHILD_HISTORY_MAX_TURNS else messages + + # Format each turn + turn_parts = [] + for msg in recent_messages: + if isinstance(msg, dict): + user_msg = str(msg.get("user_message", "")) + response = str(msg.get("response", "")) + if user_msg or response: + turn_parts.append(f"User: {user_msg}\nAssistant: {response}") + + snapshot = "\n\n".join(turn_parts) + if not snapshot: + return "" + + # Redact sensitive values + snapshot = _redact_sensitive(snapshot) + + # Truncate to max chars + if len(snapshot) > _CHILD_HISTORY_MAX_CHARS: + snapshot = snapshot[:_CHILD_HISTORY_MAX_CHARS] + "...[truncated]" + + return snapshot + class LLMQueryMixin: """Mixin providing LLM query tools for recursive LLM calls. @@ -245,22 +321,32 @@ def sub_rlm(self, prompt: str, context: str = "") -> str: Each call spawns a child dspy.RLM with its own REPL, sharing the parent's LLM budget. + Phase 7: When max recursion depth is reached, this falls back to + a single LLM call (llm_query) instead of raising an error. This + ensures graceful degradation when the recursion ceiling is hit. + Args: prompt: Task for the child RLM to solve. context: Optional supporting context string. Returns: - The child RLM's answer as a string. + The child RLM's answer as a string (or llm_query fallback at max depth). Raises: - RuntimeError: If recursion depth exceeded or LLM budget exhausted. + RuntimeError: If LLM budget exhausted. """ if not prompt: raise ValueError("sub_rlm prompt cannot be empty") if self._sub_rlm_depth >= self._sub_rlm_max_depth: - raise RuntimeError( - f"sub_rlm max recursion depth ({self._sub_rlm_max_depth}) reached. Cannot recurse further." + # Phase 7: fallback to llm_query at max depth instead of raising + logger.info( + "sub_rlm max recursion depth (%s) reached; falling back to llm_query", + self._sub_rlm_max_depth, ) + full_prompt = prompt + if context: + full_prompt = f"{context}\n\n{prompt}" + return self.llm_query(full_prompt) return self._execute_sub_rlm(prompt, context) def sub_rlm_batched(self, prompts: list[str], context: str = "") -> list[str]: @@ -269,6 +355,9 @@ def sub_rlm_batched(self, prompts: list[str], context: str = "") -> list[str]: Equivalent to ``[sub_rlm(p, context) for p in prompts]`` but runs concurrently using a thread pool. + Phase 7: When max recursion depth is reached, this falls back to + llm_query_batched instead of raising an error. + Args: prompts: List of task prompts for child RLMs. context: Shared context string for all children. @@ -277,12 +366,18 @@ def sub_rlm_batched(self, prompts: list[str], context: str = "") -> list[str]: List of answer strings, one per prompt (same order). Raises: - RuntimeError: On depth/budget violations or child failures. + RuntimeError: On budget violations or child failures. """ if not prompts: return [] if self._sub_rlm_depth >= self._sub_rlm_max_depth: - raise RuntimeError(f"sub_rlm max recursion depth ({self._sub_rlm_max_depth}) reached.") + # Phase 7: fallback to llm_query_batched at max depth + logger.info( + "sub_rlm_batched max recursion depth (%s) reached; falling back to llm_query_batched", + self._sub_rlm_max_depth, + ) + full_prompts = [f"{context}\n\n{p}" if context else p for p in prompts] + return self.llm_query_batched(full_prompts) leases = self._sub_rlm_budget_leases(len(prompts)) results: dict[int, str] = {} @@ -333,6 +428,12 @@ def _execute_sub_rlm( self._install_child_budget_lease(child, child_budget) max_iterations = max(1, min(getattr(self, "rlm_max_iterations", 30), child_budget)) + # Phase 7: build bounded conversation snapshot for child context + history_snapshot = _build_child_history_snapshot(self) + full_context = context + if history_snapshot: + full_context = f"{history_snapshot}\n\n{context}" if context else history_snapshot + child_module = build_recursive_subquery_rlm( interpreter=child, max_iterations=max_iterations, @@ -348,7 +449,7 @@ def _execute_sub_rlm( sandbox_id = getattr(session, "sandbox_id", None) if isinstance(metadata, dict) and sandbox_id: metadata.setdefault("child_sandbox_id", sandbox_id) - prediction = child_module(prompt=prompt, context=context or "") + prediction = child_module(prompt=prompt, context=full_context or "") return _validated_child_answer(prediction) except Exception as exc: logger.warning("sub_rlm child failed: %s", exc, exc_info=True) diff --git a/src/fleet_rlm/runtime/execution/streaming_events.py b/src/fleet_rlm/runtime/execution/streaming_events.py index e41f973d6..2fbbcd93b 100644 --- a/src/fleet_rlm/runtime/execution/streaming_events.py +++ b/src/fleet_rlm/runtime/execution/streaming_events.py @@ -1,26 +1,23 @@ -"""Event construction, status parsing, citation handling, and payload building +"""Citation handling, trajectory normalisation, and terminal-event helpers for the RLM ReAct chat agent streaming pipeline. -Event construction logic kept separate from stream orchestration in -:mod:`fleet_rlm.api.routers.ws.stream`. +Event *construction* lives in :mod:`fleet_rlm.runtime.events`. +This module retains only final-payload assembly (citations/attachments/sources) +and the DSPy ``ReActStatusProvider`` status hook. """ from __future__ import annotations import json -import re -from typing import Any, Literal +from typing import Any from urllib.parse import urlparse import dspy from dspy.streaming.messages import StatusMessageProvider -from fleet_rlm.runtime.schemas import StreamEvent +from fleet_rlm.runtime.events import EVENT_SCHEMA_VERSION from fleet_rlm.utils.preview import head_tail_preview -# Pre-compiled regexes for hot-path status parsing -_CALLING_TOOL_RE = re.compile(r"^Calling tool:\s*(.+)$") - # Soft content cap for trajectory step outputs crossing the websocket boundary. # Individual steps can carry multi-KB observations (grep hits, long file reads) # that bloat chat payloads without improving UX. See DSPy 3.2.0 PR #9282. @@ -42,83 +39,41 @@ def is_terminal_stream_event_kind(kind: str) -> bool: # ═══════════════════════════════════════════════════════════════════════ -# Status parsing and tool/HITL event helpers (was streaming_status.py) +# DSPy status hook (structured data embedded in RuntimeEvent at build time) # ═══════════════════════════════════════════════════════════════════════ -ToolEventKind = Literal["tool_call"] - - -def parse_tool_call_status(message: str) -> str | None: - match = _CALLING_TOOL_RE.match(message.strip()) - if not match: - return None - return f"tool call: {match.group(1).strip()}" - - -def parse_tool_call_payload(message: str) -> dict[str, Any] | None: - match = _CALLING_TOOL_RE.match(message.strip()) - if not match: - return None - - raw_call = match.group(1).strip() - tool_name = raw_call.split("(", 1)[0].strip() if raw_call else "" - args_snippet = "" - if "(" in raw_call: - args_snippet = raw_call.split("(", 1)[1].rsplit(")", 1)[0].strip() - - payload: dict[str, Any] = {"raw_status": message, "raw_call": raw_call} - if tool_name: - payload["tool_name"] = tool_name - if args_snippet: - payload["tool_args"] = args_snippet - payload["tool_input"] = args_snippet - return payload - - -def parse_tool_result_status(message: str) -> str | None: - stripped = message.strip() - if stripped == "Tool finished.": - return "tool result: finished" - if stripped.startswith("Tool result:"): - return "tool result: completed" - return None - - -def parse_tool_result_payload(message: str, *, tool_name: str | None) -> dict[str, Any] | None: - stripped = message.strip() - if stripped != "Tool finished." and not stripped.startswith("Tool result:"): - return None - - payload: dict[str, Any] = {"raw_status": message} - if tool_name: - payload["tool_name"] = tool_name - if stripped.startswith("Tool result:"): - result_text = stripped.removeprefix("Tool result:").strip() - if result_text: - payload["tool_output"] = result_text - return payload - class ReActStatusProvider(StatusMessageProvider): - """Concise status messaging for streamed ReAct sessions.""" + """Concise status hook for streamed ReAct sessions. - def tool_start_status_message(self, instance: Any, inputs: dict[str, Any]): + Returns human-readable status strings required by the DSPy + ``StatusMessageProvider`` interface. Structured tool/actor data + is carried in :class:`~fleet_rlm.runtime.events.RuntimeEvent` objects + built by the event factories in :mod:`fleet_rlm.runtime.events`. + """ + + def tool_start_status_message(self, instance: Any, inputs: dict[str, Any]) -> str: return f"Calling tool: {instance.name}" - def tool_end_status_message(self, outputs: Any): + def tool_end_status_message(self, outputs: Any) -> str | None: return "Tool finished." - def module_start_status_message(self, instance: Any, inputs: dict[str, Any]): + def module_start_status_message(self, instance: Any, inputs: dict[str, Any]) -> str | None: return f"Running module: {instance.__class__.__name__}" - def module_end_status_message(self, outputs: Any): + def module_end_status_message(self, outputs: Any) -> str | None: return None +# ═══════════════════════════════════════════════════════════════════════ +# HITL event helper +# ═══════════════════════════════════════════════════════════════════════ + + def try_parse_hitl_request( tool_name: str | None, payload: dict[str, Any], -) -> StreamEvent | None: +) -> dict[str, Any] | None: if not tool_name: return None @@ -138,40 +93,34 @@ def try_parse_hitl_request( if data and isinstance(data, dict): questions = data.get("questions", []) if questions: - return StreamEvent( - kind="status", - text="The agent has some questions for you.", - payload={ + return { + "kind": "status", + "text": "The agent has some questions for you.", + "payload": { "options": questions, "source": "clarification_questions", "requires_response": True, }, - ) + } if tool_name == "memory_action_intent": if data and isinstance(data, dict) and data.get("requires_confirmation"): - return StreamEvent( - kind="status", - text="This memory action requires confirmation.", - payload={ + return { + "kind": "status", + "text": "This memory action requires confirmation.", + "payload": { "action": data.get("intent"), "source": "memory_action_intent", "requires_response": True, }, - ) + } return None -def classify_tool_event_kind(tool_name: str | None) -> ToolEventKind: - return "tool_call" - - # ═══════════════════════════════════════════════════════════════════════ # Citation and final-payload helpers (was streaming_citations.py) # ═══════════════════════════════════════════════════════════════════════ - -STREAM_EVENT_SCHEMA_VERSION = 2 _ALLOWED_EXTERNAL_URL_SCHEMES = frozenset({"http", "https"}) @@ -625,7 +574,7 @@ def _build_final_payload( ] payload: dict[str, Any] = { - "schema_version": STREAM_EVENT_SCHEMA_VERSION, + "schema_version": EVENT_SCHEMA_VERSION, "trajectory": trajectory, "history_turns": history_turns, "guardrail_warnings": guardrail_warnings, diff --git a/src/fleet_rlm/runtime/modules/escalating.py b/src/fleet_rlm/runtime/modules/escalating.py index 9e7dc5584..5ba986254 100644 --- a/src/fleet_rlm/runtime/modules/escalating.py +++ b/src/fleet_rlm/runtime/modules/escalating.py @@ -417,13 +417,18 @@ def _run_rlm( call_kwargs: dict[str, Any] = { "task": user_request, "prompt": prompt, + # Phase 7: expose structured history as a native REPL variable on + # the heavy RLM path (both RLMVariableSignature and + # RLMLargeDocSignature declare a ``history`` input field), so the + # model can inspect full prior turns with code rather than relying + # solely on the flattened recency snippet embedded in ``prompt``. + "history": history, } if url_document_mode: fetched = self._fetch_url_document(source_url=source_url) call_kwargs["source_url"] = fetched.source_url call_kwargs["document_text"] = fetched.document_text call_kwargs["source_metadata"] = fetched.source_metadata - call_kwargs["history"] = history try: result = rlm(**call_kwargs) _prediction_set(result, "selected_skills", selected_skills or []) diff --git a/src/fleet_rlm/runtime/tools/rlm_delegate.py b/src/fleet_rlm/runtime/tools/rlm_delegate.py index 90545230c..d603d9ac2 100644 --- a/src/fleet_rlm/runtime/tools/rlm_delegate.py +++ b/src/fleet_rlm/runtime/tools/rlm_delegate.py @@ -222,30 +222,6 @@ def _run_delegate_child( llm_budget: int, ) -> dict[str, Any]: """Build, run, validate, and clean up one delegated child RLM.""" - # Fast-path: solve sentiment-classification tasks locally when context - # is structured JSON reviews. These tasks have deterministic rules - # (contains positive/negative sentiment words) and can be computed - # directly without the full child sandbox + RLM round-trip. - local_answer = _try_solve_classification_locally(query, context) - if local_answer is not None: - logger.info("delegate_to_rlm: classification task solved locally: %s", local_answer) - return {"status": "ok", "answer": local_answer} - - # Fast-path: solve log-line extraction tasks locally via regex scan. - # These tasks ask "how many log lines have level '' AND service ''?" - # on plain-text log data — pure string matching, no LLM needed. - extraction_answer = _try_solve_extraction_locally(query, context) - if extraction_answer is not None: - return {"status": "ok", "answer": extraction_answer} - - # Fast-path: solve category-counting tasks locally via JSON scan. - # These tasks ask "How many items are in the '' category?" - # on JSON product lists — pure field-match counting, no LLM needed. - counting_answer = _try_solve_counting_locally(query, context) - if counting_answer is not None: - logger.info("delegate_to_rlm: counting task solved locally: %s", counting_answer) - return {"status": "ok", "answer": counting_answer} - child = None started_at = time.time() try: @@ -285,8 +261,7 @@ def _run_delegate_child( "delegate_to_rlm: running child RLM with isolation=%s", getattr(child, "child_isolation_metadata", {}), ) - effective_query = _augment_classification_query(query) - prediction = _run_with_delegate_adapter(rlm, interpreter, prompt=effective_query, context=resolved_context) + prediction = _run_with_delegate_adapter(rlm, interpreter, prompt=query, context=resolved_context) raw_answer = getattr(prediction, "answer", None) answer = "" if raw_answer is None else str(raw_answer) @@ -437,176 +412,6 @@ def _is_broker_failure(value: Any) -> bool: return contains_marker(value, _BROKER_ERROR_MARKER) -# Sentiment classification word sets. The OOLONG benchmark task says -# "contains words LIKE ..." giving 6 examples per polarity; the ground -# truth uses these extended sets which include all synonyms present in the -# generated review data. -_POSITIVE_SENTIMENT_WORDS: frozenset[str] = frozenset( - { - "excellent", - "great", - "wonderful", - "fantastic", - "love", - "amazing", - "delighted", - "impressed", - "outstanding", - "perfect", - "superb", - "thrilled", - } -) -_NEGATIVE_SENTIMENT_WORDS: frozenset[str] = frozenset( - { - "terrible", - "awful", - "horrible", - "worst", - "hate", - "disappointing", - "broken", - "frustrated", - "angry", - "useless", - "regret", - "defective", - } -) - -# Regex to detect classification-style queries asking for sentiment counts -_CLASSIFICATION_QUERY_RE = re.compile( - r"classify each review as positive.*negative.*neutral", - re.IGNORECASE, -) - -# Regex to detect log-line extraction queries: "how many log lines have level '' -# AND service ''?" This pattern covers all 10 OOLONG extraction tasks. -_EXTRACTION_QUERY_RE = re.compile( - r"how many log lines have level '(\w[\w-]*)' and service '([\w-]+)'", - re.IGNORECASE, -) - -# Regex to detect counting-style queries: "how many items are in the '' category" -_COUNTING_QUERY_RE = re.compile( - r"how many items are in the '(\w+)' category", - re.IGNORECASE, -) - - -def _try_solve_classification_locally(query: str, context: str) -> str | None: - """Attempt to solve a sentiment-classification task via direct computation. - - For classification tasks where the context is a JSON list of reviews and - the query asks to classify each as positive/negative/neutral based on - sentiment words, we compute the answer directly by checking word presence. - - Returns the formatted "positive=N negative=M neutral=K" string if - solvable, None otherwise. - """ - if not _CLASSIFICATION_QUERY_RE.search(query): - return None - - import json as _json - - # Parse context as JSON list of review objects - try: - data = _json.loads(context.strip()) - except (ValueError, TypeError): - return None - - if not isinstance(data, list) or not data: - return None - - # Verify structure: items should have 'text' field - if not isinstance(data[0], dict) or "text" not in data[0]: - return None - - pos_count = 0 - neg_count = 0 - neu_count = 0 - for item in data: - if not isinstance(item, dict): - continue - text_lower = str(item.get("text", "")).lower() - words = set(re.findall(r"\w+", text_lower)) - has_positive = bool(words & _POSITIVE_SENTIMENT_WORDS) - has_negative = bool(words & _NEGATIVE_SENTIMENT_WORDS) - if has_positive and not has_negative: - pos_count += 1 - elif has_negative and not has_positive: - neg_count += 1 - else: - neu_count += 1 - - return f"positive={pos_count} negative={neg_count} neutral={neu_count}" - - -def _try_solve_extraction_locally(query: str, context: str) -> str | None: - """Attempt to solve a log-line extraction task via direct regex scan. - - For extraction tasks where the context is plain-text log data (one line per - log entry in 'timestamp [LEVEL] service: message' format) and the query asks - "How many log lines have level '' AND service ''?", we count - matching lines directly without any LLM round-trip. - - Returns the count as a plain string (e.g. "4"), or None if the query does - not match the expected pattern. - """ - match = _EXTRACTION_QUERY_RE.search(query) - if match is None: - return None - - level = match.group(1) - service = match.group(2) - - level_token = f"[{level}]".lower() - service_token = f" {service}:".lower() - count = sum(1 for line in context.splitlines() if level_token in line.lower() and service_token in line.lower()) - logger.info( - "delegate_to_rlm: extraction task solved locally: level=%s service=%s count=%d", - level, - service, - count, - ) - return str(count) - - -def _try_solve_counting_locally(query: str, context: str) -> str | None: - """Attempt to solve a category-counting task via direct computation. - - For counting tasks where the context is a JSON list of product items with - a 'category' field and the query asks "How many items are in the '' - category?", we compute the count directly without spinning up a child - sandbox or LLM. - - Returns str(count) if solvable, None otherwise. - """ - match = _COUNTING_QUERY_RE.search(query) - if not match: - return None - - import json as _json - - category = match.group(1) - - # Parse context as JSON list of product objects - try: - data = _json.loads(context.strip()) - except (ValueError, TypeError): - return None - - if not isinstance(data, list) or not data: - return None - - # Verify structure: items should have 'category' field - if not isinstance(data[0], dict) or "category" not in data[0]: - return None - - count = sum(1 for item in data if isinstance(item, dict) and item.get("category", "").lower() == category.lower()) - return str(count) - - def _resolve_delegate_sub_lm(child: Any, parent: Any) -> Any | None: """Resolve the sub_lm for a delegate child RLM. @@ -667,49 +472,6 @@ def _ensure_dspy_lm_configured(sub_lm: Any) -> None: logger.info("delegate_to_rlm: configured dspy.settings.lm from resolved sub_lm") -def _augment_classification_query(query: str) -> str: - """Reinforce output format for classification-style queries. - - Classification tasks expect a specific key=value format (e.g. - "positive=86 negative=66 neutral=57"). When the query already - specifies such a format, append an explicit instruction to the child - RLM ensuring it returns ONLY the formatted string via SUBMIT(). - - This does NOT alter extraction queries (single-number answers) or - queries that don't mention a key=value output pattern. - """ - # Detect classification pattern: query mentions "key=N" format with - # multiple categories separated by spaces - _KV_PATTERN = re.compile( - r"\b(\w+=\s*[A-Z])\b.*\b(\w+=\s*[A-Z])\b", - re.IGNORECASE, - ) - # More specific: looks for patterns like "positive=N negative=M neutral=K" - # or "category1=N category2=M" in the query's format instruction - _MULTI_KV_FORMAT = re.compile( - r"(\w+)=([A-Z_]\w*)\s+(\w+)=([A-Z_]\w*)", - re.IGNORECASE, - ) - if not _MULTI_KV_FORMAT.search(query): - return query - - # Extract the category names from the format pattern - format_match = _MULTI_KV_FORMAT.findall(query) - if not format_match: - return query - - # Build format reinforcement suffix - suffix = ( - "\n\nCRITICAL OUTPUT FORMAT: Your SUBMIT(answer=...) must contain ONLY " - "the counts in the exact format shown above (e.g. key1=N key2=M key3=K). " - "Do NOT include any explanation, prose, or extra text in the answer. " - "Do NOT wrap in quotes or add punctuation beyond the key=value pairs. " - "The answer string must match the pattern: word=number word=number ... " - "with single spaces between pairs." - ) - return query + suffix - - def _resolve_delegate_context( *, child: Any, diff --git a/tests/contracts/test_golden_payloads.py b/tests/contracts/test_golden_payloads.py new file mode 100644 index 000000000..88716f3e5 --- /dev/null +++ b/tests/contracts/test_golden_payloads.py @@ -0,0 +1,193 @@ +"""Golden-payload tests for Phase 0 safety net. + +Captures every event kind emitted on both websockets for a representative turn: +- chat + tool + repl + delegate + done/error + cancelled + +This serves as a regression oracle before any event contract refactoring. +""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any + +import pytest + +# Event kinds we expect to capture during a representative turn +EXPECTED_EVENT_KINDS = { + "turn_started", + "status", + "reasoning", + "tool_call", + "tool_result", + "sandbox_exec", + "rlm_delegate", + "text", + "done", + "error", + "warning", +} + +# Output directory for golden payloads +GOLDEN_PAYLOADS_DIR = Path(__file__).parent / "golden_payloads" + + +def _collect_all_events(websocket: Any, timeout_seconds: int = 30) -> list[dict[str, Any]]: + """Collect all events from a websocket until done/error or timeout.""" + events = [] + import time + + start = time.time() + while time.time() - start < timeout_seconds: + try: + payload = websocket.receive_json(timeout=1.0) + events.append(payload) + if payload.get("kind") in ("done", "error"): + break + except Exception: + # Timeout or disconnect - stop collecting + break + return events + + +@pytest.mark.skipif( + not GOLDEN_PAYLOADS_DIR.exists(), + reason="Golden payloads directory exists - run once to capture baseline", +) +def test_capture_chat_websocket_golden_payloads(no_db_client, auth_headers: dict[str, str]) -> None: + """Capture all events from /api/v1/ws/execution for a representative turn.""" + # Setup: ensure LM is configured + no_db_client.app.state.lm_deps.planner_lm = object() + + GOLDEN_PAYLOADS_DIR.mkdir(exist_ok=True) + + with no_db_client.websocket_connect("/api/v1/ws/execution") as websocket: + # Send a simple message that should trigger multiple event kinds + websocket.send_json( + { + "type": "message", + "content": "What is 2 + 2?", + "session_id": "golden-test-session", + } + ) + + # Collect all events + events = _collect_all_events(websocket) + + # Store as golden payload + output_file = GOLDEN_PAYLOADS_DIR / "chat_websocket_events.json" + with open(output_file, "w") as f: + json.dump(events, f, indent=2) + + # Verify we captured expected event kinds + captured_kinds = {event.get("kind") for event in events} + assert captured_kinds & EXPECTED_EVENT_KINDS, f"Missing expected event kinds. Captured: {captured_kinds}" + + +@pytest.mark.skipif( + not GOLDEN_PAYLOADS_DIR.exists(), + reason="Golden payloads directory exists - run once to capture baseline", +) +def test_capture_passive_events_websocket_golden_payloads(no_db_client, auth_headers: dict[str, str]) -> None: + """Capture all events from /api/v1/ws/execution/events for a representative turn.""" + # Setup: ensure LM is configured + no_db_client.app.state.lm_deps.planner_lm = object() + + GOLDEN_PAYLOADS_DIR.mkdir(exist_ok=True) + + # First, run a turn on the chat websocket to generate events + with no_db_client.websocket_connect("/api/v1/ws/execution") as chat_ws: + chat_ws.send_json( + { + "type": "message", + "content": "What is 2 + 2?", + "session_id": "golden-test-session-passive", + } + ) + _collect_all_events(chat_ws) + + # Then connect to passive events stream and capture + with no_db_client.websocket_connect( + "/api/v1/ws/execution/events?session_id=golden-test-session-passive" + ) as passive_ws: + events = _collect_all_events(passive_ws) + + # Store as golden payload + output_file = GOLDEN_PAYLOADS_DIR / "passive_events_websocket_events.json" + with open(output_file, "w") as f: + json.dump(events, f, indent=2) + + # Verify we captured execution events + captured_kinds = {event.get("kind") for event in events} + assert captured_kinds & {"execution_started", "execution_step", "execution_completed"} + + +@pytest.mark.skipif( + GOLDEN_PAYLOADS_DIR.exists(), + reason="Golden payloads already captured - use for regression testing", +) +def test_regression_chat_websocket_events(no_db_client, auth_headers: dict[str, str]) -> None: + """Regression test: compare current events against golden payload.""" + no_db_client.app.state.lm_deps.planner_lm = object() + + golden_file = GOLDEN_PAYLOADS_DIR / "chat_websocket_events.json" + assert golden_file.exists(), "Run golden payload capture first" + + with open(golden_file) as f: + golden_events = json.load(f) + + with no_db_client.websocket_connect("/api/v1/ws/execution") as websocket: + websocket.send_json( + { + "type": "message", + "content": "What is 2 + 2?", + "session_id": "regression-test-session", + } + ) + + current_events = _collect_all_events(websocket) + + # Compare event kinds (structure may change, but kinds should match) + golden_kinds = {event.get("kind") for event in golden_events} + current_kinds = {event.get("kind") for event in current_events} + + assert golden_kinds == current_kinds, f"Event kinds changed. Golden: {golden_kinds}, Current: {current_kinds}" + + +@pytest.mark.skipif( + GOLDEN_PAYLOADS_DIR.exists(), + reason="Golden payloads already captured - use for regression testing", +) +def test_regression_passive_events_websocket_events(no_db_client, auth_headers: dict[str, str]) -> None: + """Regression test: compare passive events against golden payload.""" + no_db_client.app.state.lm_deps.planner_lm = object() + + golden_file = GOLDEN_PAYLOADS_DIR / "passive_events_websocket_events.json" + assert golden_file.exists(), "Run golden payload capture first" + + with open(golden_file) as f: + golden_events = json.load(f) + + # Run a turn + with no_db_client.websocket_connect("/api/v1/ws/execution") as chat_ws: + chat_ws.send_json( + { + "type": "message", + "content": "What is 2 + 2?", + "session_id": "regression-test-session-passive", + } + ) + _collect_all_events(chat_ws) + + # Capture passive events + with no_db_client.websocket_connect( + "/api/v1/ws/execution/events?session_id=regression-test-session-passive" + ) as passive_ws: + current_events = _collect_all_events(passive_ws) + + # Compare event kinds + golden_kinds = {event.get("kind") for event in golden_events} + current_kinds = {event.get("kind") for event in current_events} + + assert golden_kinds == current_kinds, f"Event kinds changed. Golden: {golden_kinds}, Current: {current_kinds}" diff --git a/tests/unit/api/test_bootstrap.py b/tests/unit/api/test_bootstrap.py index 41dc2342e..023904f60 100644 --- a/tests/unit/api/test_bootstrap.py +++ b/tests/unit/api/test_bootstrap.py @@ -19,13 +19,13 @@ def test_build_server_state_creates_ready_compatible_state(clean_runtime_env): state = bootstrap_module.build_server_state(cfg) - assert state.config is cfg - assert isinstance(state.auth_provider, auth_module.DevAuthProvider) - assert state.sessions == {} - assert state.optional_service_status["planner_lm"] == "pending" + assert state.config_deps.config is cfg + assert isinstance(state.auth_deps.auth_provider, auth_module.DevAuthProvider) + assert state.session_cache_deps.sessions == {} + assert state.diagnostics_deps.optional_service_status["planner_lm"] == "pending" assert state.is_ready is False - state.planner_lm = object() + state.lm_deps.planner_lm = object() assert state.is_ready is True diff --git a/tests/unit/api/test_chat_persistence.py b/tests/unit/api/test_chat_persistence.py index ab6e71f49..c9cd61d80 100644 --- a/tests/unit/api/test_chat_persistence.py +++ b/tests/unit/api/test_chat_persistence.py @@ -10,7 +10,7 @@ @pytest.mark.asyncio async def test_load_manifest_from_volume_reads_current_conversation_path(monkeypatch: pytest.MonkeyPatch) -> None: - from fleet_rlm.api.runtime_services.chat_persistence import load_manifest_from_volume + from fleet_rlm.api.runtime_services.session_manifest import load_manifest_from_volume from fleet_rlm.integrations.daytona.interpreter import DaytonaInterpreter reads: list[str] = [] @@ -40,7 +40,7 @@ async def fake_get_session(self): @pytest.mark.asyncio async def test_save_manifest_to_volume_writes_phase_one_conversation_path(monkeypatch: pytest.MonkeyPatch) -> None: - from fleet_rlm.api.runtime_services.chat_persistence import save_manifest_to_volume + from fleet_rlm.api.runtime_services.session_manifest import save_manifest_to_volume from fleet_rlm.integrations.daytona.interpreter import DaytonaInterpreter writes: list[tuple[str, str]] = [] @@ -70,7 +70,7 @@ async def fake_get_session(self): async def test_manifest_volume_io_does_not_create_daytona_session_when_disallowed( monkeypatch: pytest.MonkeyPatch, ) -> None: - from fleet_rlm.api.runtime_services.chat_persistence import load_manifest_from_volume, save_manifest_to_volume + from fleet_rlm.api.runtime_services.session_manifest import load_manifest_from_volume, save_manifest_to_volume from fleet_rlm.integrations.daytona.interpreter import DaytonaInterpreter async def forbidden_get_session(self): @@ -108,7 +108,7 @@ async def test_persist_session_state_skips_volume_without_creating_cleanup_sessi monkeypatch: pytest.MonkeyPatch, ) -> None: from fleet_rlm.api.dependencies import SessionCacheDeps - from fleet_rlm.api.runtime_services.chat_persistence import persist_session_state + from fleet_rlm.api.runtime_services.session_persistence import persist_session_state from fleet_rlm.integrations.daytona.interpreter import DaytonaInterpreter async def forbidden_get_session(self): @@ -157,7 +157,7 @@ async def forbidden_execute(*args: object, **kwargs: object): async def test_manifest_volume_io_uses_existing_daytona_session_without_creating( monkeypatch: pytest.MonkeyPatch, ) -> None: - from fleet_rlm.api.runtime_services.chat_persistence import load_manifest_from_volume, save_manifest_to_volume + from fleet_rlm.api.runtime_services.session_manifest import load_manifest_from_volume, save_manifest_to_volume from fleet_rlm.integrations.daytona.interpreter import DaytonaInterpreter class FakeDaytonaSession: @@ -299,7 +299,7 @@ async def test_persist_session_state_releases_idle_daytona_session_after_save( monkeypatch: pytest.MonkeyPatch, ) -> None: from fleet_rlm.api.dependencies import SessionCacheDeps - from fleet_rlm.api.runtime_services.chat_persistence import persist_session_state + from fleet_rlm.api.runtime_services.session_persistence import persist_session_state class FakeSession: async def aread_file(self, path: str) -> str: @@ -339,7 +339,7 @@ async def fake_get_daytona_session(agent: Any, allow_create: bool = True) -> Any return session monkeypatch.setattr( - "fleet_rlm.api.runtime_services.chat_persistence._aget_daytona_session", + "fleet_rlm.api.runtime_services.session_manifest._aget_daytona_session", fake_get_daytona_session, ) @@ -366,7 +366,8 @@ async def test_persist_session_state_releases_idle_daytona_session_after_save_fa monkeypatch: pytest.MonkeyPatch, ) -> None: from fleet_rlm.api.dependencies import SessionCacheDeps - from fleet_rlm.api.runtime_services.chat_persistence import PersistenceRequiredError, persist_session_state + from fleet_rlm.api.runtime_services.session_persistence import persist_session_state + from fleet_rlm.api.runtime_services.stream_failures import PersistenceRequiredError class FakeSession: async def aread_file(self, path: str) -> str: @@ -402,7 +403,7 @@ async def fake_get_daytona_session(agent: Any, allow_create: bool = True) -> Any return session monkeypatch.setattr( - "fleet_rlm.api.runtime_services.chat_persistence._aget_daytona_session", + "fleet_rlm.api.runtime_services.session_manifest._aget_daytona_session", fake_get_daytona_session, ) @@ -429,7 +430,7 @@ async def fake_get_daytona_session(agent: Any, allow_create: bool = True) -> Any async def test_ensure_session_volume_layout_creates_scratchpad_and_workspace_link( monkeypatch: pytest.MonkeyPatch, ) -> None: - from fleet_rlm.api.runtime_services.chat_persistence import ensure_session_volume_layout + from fleet_rlm.api.runtime_services.session_manifest import ensure_session_volume_layout from fleet_rlm.integrations.daytona.interpreter import DaytonaInterpreter commands: list[str] = [] @@ -477,11 +478,11 @@ async def fake_load_manifest(*args: Any, **kwargs: Any) -> dict[str, Any]: return {"rev": 0, "state": {}} monkeypatch.setattr( - "fleet_rlm.api.runtime_services.chat_persistence.ensure_session_volume_layout", + "fleet_rlm.api.runtime_services.session_manifest.ensure_session_volume_layout", fake_layout, ) monkeypatch.setattr( - "fleet_rlm.api.runtime_services.chat_persistence.load_manifest_from_volume", + "fleet_rlm.api.runtime_services.session_manifest.load_manifest_from_volume", fake_load_manifest, ) diff --git a/tests/unit/api/test_dependencies.py b/tests/unit/api/test_dependencies.py index d4e0f5c6f..ebfc4761e 100644 --- a/tests/unit/api/test_dependencies.py +++ b/tests/unit/api/test_dependencies.py @@ -51,13 +51,13 @@ def test_compose_server_state_preserves_dependency_slices(clean_runtime_env): diagnostics_deps, ) - assert state.config is cfg - assert state.sessions == {"owner:abc:__default__": {"history": []}} - assert state.auth_provider is auth_deps.auth_provider - assert state.local_store is persistence_deps.local_store + assert state.config_deps.config is cfg + assert state.session_cache_deps.sessions == {"owner:abc:__default__": {"history": []}} + assert state.auth_deps.auth_provider is auth_deps.auth_provider + assert state.persistence_deps.local_store is persistence_deps.local_store assert state.is_ready is False - state.planner_lm = object() + state.lm_deps.planner_lm = object() assert state.is_ready is True diff --git a/tests/unit/api/test_events.py b/tests/unit/api/test_events.py index 48cf4216c..a06db0445 100644 --- a/tests/unit/api/test_events.py +++ b/tests/unit/api/test_events.py @@ -133,7 +133,7 @@ def test_summarize_code_for_event_returns_stable_preview(monkeypatch): def test_startup_status_projects_to_canonical_execution_started_frame(): persistence_module = importlib.import_module("fleet_rlm.api.runtime_services.chat_persistence") - stream_module = importlib.import_module("fleet_rlm.api.routers.ws.stream") + stream_module = importlib.import_module("fleet_rlm.api.routers.ws.stream_events") event = persistence_module.build_startup_status_event() frame = stream_module.build_stream_event_dict(event=event, payload=event.payload) @@ -161,7 +161,7 @@ def test_backend_status_projects_to_canonical_execution_step_frame(): def test_runtime_trace_metadata_counts_structured_rlm_trajectory(): - stream_module = importlib.import_module("fleet_rlm.api.routers.ws.stream") + stream_module = importlib.import_module("fleet_rlm.api.routers.ws.stream_summary") metadata = stream_module._runtime_trace_metadata( { diff --git a/tests/unit/runtime/test_escalating_module.py b/tests/unit/runtime/test_escalating_module.py index 90734881e..30e50b8fe 100644 --- a/tests/unit/runtime/test_escalating_module.py +++ b/tests/unit/runtime/test_escalating_module.py @@ -252,13 +252,9 @@ def test_url_document_analysis_passes_fetched_doc_as_rlm_variables( "source_type": "html", } assert "# DSPy docs\nRLM details" not in call_kwargs["prompt"] - assert call_kwargs["prompt"].startswith( - "Task:\nanalyze https://dspy.ai and provide documentation notes" - ) + assert call_kwargs["prompt"].startswith("Task:\nanalyze https://dspy.ai and provide documentation notes") assert "URL document variables" in call_kwargs["prompt"] - assert call_kwargs["prompt"].endswith( - "Repeat task:\nanalyze https://dspy.ai and provide documentation notes" - ) + assert call_kwargs["prompt"].endswith("Repeat task:\nanalyze https://dspy.ai and provide documentation notes") def test_tools_only_does_not_auto_route_url_to_rlm(self) -> None: module = _make_module() diff --git a/tests/unit/runtime/test_tools.py b/tests/unit/runtime/test_tools.py index ca0d0e3c9..36ef16e7e 100644 --- a/tests/unit/runtime/test_tools.py +++ b/tests/unit/runtime/test_tools.py @@ -124,22 +124,6 @@ def write(self, chunk: bytes) -> None: assert events == ["close:42", f"unlink:{tmp_path}"] -def test_log_extraction_fast_path_is_case_insensitive() -> None: - from fleet_rlm.runtime.tools.rlm_delegate import _try_solve_extraction_locally - - context = "\n".join( - [ - "2026-01-01T00:00:00Z [INFO] Api: first", - "2026-01-01T00:00:01Z [info] api: second", - "2026-01-01T00:00:02Z [ERROR] api: ignored", - ] - ) - - answer = _try_solve_extraction_locally("How many log lines have level 'info' AND service 'api'?", context) - - assert answer == "2" - - def test_bind_runtime_tools_binds_memory_tools_and_skips_interpreter_only_without_interpreter() -> None: from fleet_rlm.runtime.tools.binding import bind_runtime_tools from fleet_rlm.runtime.tools.rlm_delegate import delegate_to_rlm diff --git a/uv.lock b/uv.lock index 87b3a0cc3..2e07aca34 100644 --- a/uv.lock +++ b/uv.lock @@ -818,7 +818,7 @@ wheels = [ [[package]] name = "daytona" -version = "0.181.0" +version = "0.184.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiofiles" }, @@ -842,14 +842,14 @@ dependencies = [ { name = "urllib3" }, { name = "wsproto" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ea/65/7fccf3cb10f56e280c6dabbc60d95a2be9b8bfbadb60e81b23b37cb68a9e/daytona-0.181.0.tar.gz", hash = "sha256:e3ccc19fa953999491c566e777399759de3d3d0307f8623b07072f873d80dc1c", size = 141485, upload-time = "2026-05-25T12:08:35.149Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ac/6d/482ed38e868e5d6904961bf8c0317a94981bdd2b165f2611acb4f7094357/daytona-0.184.0.tar.gz", hash = "sha256:38a8de4a2daf34cb5940590db4d0f816ee8354d9919ca8ab1051df4a38f2107c", size = 141946, upload-time = "2026-06-03T14:46:56.231Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/de/bc/9df0fcfe90539b9dd5282f10a79dcad49a5a9ea52701063df75d342b2336/daytona-0.181.0-py3-none-any.whl", hash = "sha256:dd8de608b13e347f520b14f50096b2720472eb52400d61679463f868cb4b1b06", size = 172024, upload-time = "2026-05-25T12:08:33.229Z" }, + { url = "https://files.pythonhosted.org/packages/10/da/5ff417bbf0ca7d480ee88c310372fec3ab039a66878c27ecab97203d00c0/daytona-0.184.0-py3-none-any.whl", hash = "sha256:630df33c62aebae2ae34679eff79198806014567850c98dcd8d45accdcc5437a", size = 172586, upload-time = "2026-06-03T14:46:54.467Z" }, ] [[package]] name = "daytona-api-client" -version = "0.181.0" +version = "0.184.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pydantic" }, @@ -857,14 +857,14 @@ dependencies = [ { name = "typing-extensions" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/94/5e/749630746844995d18091e9ea6e754acaa6a48226055c7417a610f9ec7b6/daytona_api_client-0.181.0.tar.gz", hash = "sha256:1960a536dc3e0c0569a2c4d67804df29fd90c78458cd527889d5d203013bb872", size = 145263, upload-time = "2026-05-25T12:07:25.506Z" } +sdist = { url = "https://files.pythonhosted.org/packages/73/63/cdb3ea6dd35fd2091a7e1f988ce1d3194d9bec1accfa6dbbf63ad424e671/daytona_api_client-0.184.0.tar.gz", hash = "sha256:3cdb111cf2a21be13cc648e444162c094d45f80d1eeb638b671fc854be307c23", size = 148346, upload-time = "2026-06-03T14:46:09.888Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b6/32/a66885d9d5b93790558dbbada8c2eac96a7fc93016686704468035cbfde8/daytona_api_client-0.181.0-py3-none-any.whl", hash = "sha256:46a8119c76c2e14f704d4f019eff8aa71f21be2deb58a1aa3bc53485a3aa50f9", size = 402306, upload-time = "2026-05-25T12:07:23.742Z" }, + { url = "https://files.pythonhosted.org/packages/88/65/44993d1bb3cb1f5f294ac97bb3246346c14bf7ba3e79cdd4e92c7df0c6df/daytona_api_client-0.184.0-py3-none-any.whl", hash = "sha256:8fec5062d757fb533e82e9bf295fc1625ae294b4a154c24081bbd2bf30bcb5dd", size = 407600, upload-time = "2026-06-03T14:46:08.086Z" }, ] [[package]] name = "daytona-api-client-async" -version = "0.181.0" +version = "0.184.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -873,14 +873,14 @@ dependencies = [ { name = "python-dateutil" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/54/66/a08b5a5574bf63ff3f8ef81ab6cd838aa3e1cf54a61c484e7ae3000f449d/daytona_api_client_async-0.181.0.tar.gz", hash = "sha256:10ec1a623c5232203f202e0f0c9ef2e0720d4a90fba21c0acb28af28a3857fbc", size = 145706, upload-time = "2026-05-25T12:07:53.582Z" } +sdist = { url = "https://files.pythonhosted.org/packages/df/2f/caba484cbac693d269b417df73237f117882e2e28d1af4c3878e568bf148/daytona_api_client_async-0.184.0.tar.gz", hash = "sha256:a3cb0ebe5eec2c824b7848d4189a035a13cd84de1e0083c3f0c033f5445f3b71", size = 149131, upload-time = "2026-06-03T14:46:19.64Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/08/ea/17d2ce0b9bd8eb74f7e8782be89e5c0997cc21f54a72e1ea90220e6a79eb/daytona_api_client_async-0.181.0-py3-none-any.whl", hash = "sha256:0e4e20ca9ce3ac559545e2fd52084748df2485f19660e8b73b665795db03c428", size = 405527, upload-time = "2026-05-25T12:07:51.971Z" }, + { url = "https://files.pythonhosted.org/packages/8e/fc/c6e304db0bb382a0842cda913474ba42060acbf4f6cc91d8708f7997e337/daytona_api_client_async-0.184.0-py3-none-any.whl", hash = "sha256:19f268e6fb4d4190e7a4ea2ad3f4600bf14c25ab8f4997390d8d07d4a281de61", size = 410889, upload-time = "2026-06-03T14:46:18.051Z" }, ] [[package]] name = "daytona-toolbox-api-client" -version = "0.181.0" +version = "0.184.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pydantic" }, @@ -888,14 +888,14 @@ dependencies = [ { name = "typing-extensions" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6d/61/c50f1be20c05587ef31710d9895367d5ff1a30e586794623f55abbd14bac/daytona_toolbox_api_client-0.181.0.tar.gz", hash = "sha256:4a0f0512f58ec9e417db679761c149bf255d8a028ee116bde454e89299f11ccf", size = 77741, upload-time = "2026-05-25T12:07:45.186Z" } +sdist = { url = "https://files.pythonhosted.org/packages/26/8f/78b0c06e91065e573ce1f2f4c6b7941066be939de46128a7e97a30b2e4a7/daytona_toolbox_api_client-0.184.0.tar.gz", hash = "sha256:c325a050ca45e7832c7d05127260634ed0742d8bc75b2d9de65847153fba70e6", size = 77893, upload-time = "2026-06-03T14:45:52.776Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c1/1f/7444791b67da8202672a0e3669f8395755562eb51821d9074c24490f0b0f/daytona_toolbox_api_client-0.181.0-py3-none-any.whl", hash = "sha256:2fb30715be14390872b03d4f78f3e27610b3e5a139b6f422cc568c25f7ac0a9c", size = 205625, upload-time = "2026-05-25T12:07:43.745Z" }, + { url = "https://files.pythonhosted.org/packages/ae/f9/311c8a6d8fdbb48793c1d5504fa212b0325047887e1c8c8dc2560687cff6/daytona_toolbox_api_client-0.184.0-py3-none-any.whl", hash = "sha256:05b6b179018247bf3e7de667b049b0d7bc0e5e83f58b498e0371988e61b3041e", size = 205757, upload-time = "2026-06-03T14:45:51.343Z" }, ] [[package]] name = "daytona-toolbox-api-client-async" -version = "0.181.0" +version = "0.184.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -904,9 +904,9 @@ dependencies = [ { name = "python-dateutil" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ae/bc/9cb238b936eeee714b66deccfb194de0f3a7be2aa1705004b073754a0680/daytona_toolbox_api_client_async-0.181.0.tar.gz", hash = "sha256:42ff1c2f1282d463ee9b3daadf1b76df4fd6a6755f36ed9d6dbded7af759423b", size = 71355, upload-time = "2026-05-25T12:07:59.195Z" } +sdist = { url = "https://files.pythonhosted.org/packages/6b/cc/d10617b9bf1eae7f61870d0f98d5f22e6a56b0777ee9643427be5d8ec488/daytona_toolbox_api_client_async-0.184.0.tar.gz", hash = "sha256:495526bbd8de1d1ce2d86ecdb561efd5c4e7e86007abf4bc550ecc0c5efaf2e9", size = 71490, upload-time = "2026-06-03T14:46:20.57Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d5/d6/a74d206fb08807a8bce22b254daf9431510e5d56ef5f658fb9441f1fb60a/daytona_toolbox_api_client_async-0.181.0-py3-none-any.whl", hash = "sha256:f3df2185db9f55769af2eee052153afe858769e5ed8df706a5b78b6764460914", size = 203903, upload-time = "2026-05-25T12:07:57.853Z" }, + { url = "https://files.pythonhosted.org/packages/db/71/d6e98539e8a855bc63a55ce60f554f0b0063e9275424fea0476eb46f3a2a/daytona_toolbox_api_client_async-0.184.0-py3-none-any.whl", hash = "sha256:b971bc3475c0f658c38a86de40a5d4d31394bf60ad4b4c7f8b7693a5a56647b1", size = 204036, upload-time = "2026-06-03T14:46:18.771Z" }, ] [[package]] @@ -1116,7 +1116,7 @@ wheels = [ [[package]] name = "fastapi" -version = "0.136.1" +version = "0.136.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "annotated-doc" }, @@ -1125,9 +1125,9 @@ dependencies = [ { name = "typing-extensions" }, { name = "typing-inspection" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/5d/45/c130091c2dfa061bbfe3150f2a5091ef1adf149f2a8d2ae769ecaf6e99a2/fastapi-0.136.1.tar.gz", hash = "sha256:7af665ad7acfa0a3baf8983d393b6b471b9da10ede59c60045f49fbc89a0fa7f", size = 397448, upload-time = "2026-04-23T16:49:44.046Z" } +sdist = { url = "https://files.pythonhosted.org/packages/81/2d/ff8d91d7b564d464629a0fd50a4489c97fcb836ac230bf3a7269232a9b1f/fastapi-0.136.3.tar.gz", hash = "sha256:e487fae93ad408e6f47641ee4dfe389864fd7bec92e547ea8498fc13f43e83ab", size = 396410, upload-time = "2026-05-23T18:53:15.192Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5a/ff/2e4eca3ade2c22fe1dea7043b8ee9dabe47753349eb1b56a202de8af6349/fastapi-0.136.1-py3-none-any.whl", hash = "sha256:a6e9d7eeada96c93a4d69cb03836b44fa34e2854accb7244a1ece36cd4781c3f", size = 117683, upload-time = "2026-04-23T16:49:42.437Z" }, + { url = "https://files.pythonhosted.org/packages/e0/82/45359b62a067409bd929ae8a56b8ed13e5a8c8a61194b3c236920999ab83/fastapi-0.136.3-py3-none-any.whl", hash = "sha256:3d2a69bdf04b7e9f3afa292c3bc7a98816bbfafa10bc9b45f3f3700d2f761620", size = 117481, upload-time = "2026-05-23T18:53:16.924Z" }, ] [package.optional-dependencies] @@ -1470,12 +1470,12 @@ requires-dist = [ { name = "asyncpg", specifier = ">=0.31.0,<1" }, { name = "boto3", specifier = ">=1.43.4" }, { name = "comm", marker = "extra == 'evaluation'", specifier = "==0.2.3" }, - { name = "daytona", specifier = ">=0.181.0,<1" }, + { name = "daytona", specifier = ">=0.184.0,<1" }, { name = "debugpy", marker = "extra == 'evaluation'", specifier = "==1.8.20" }, { name = "decorator", marker = "extra == 'evaluation'", specifier = "==5.2.1" }, { name = "dspy", extras = ["optuna"], specifier = "==3.2.1" }, { name = "executing", marker = "extra == 'evaluation'", specifier = "==2.2.1" }, - { name = "fastapi", extras = ["standard"], specifier = "==0.136.1" }, + { name = "fastapi", extras = ["standard"], specifier = "==0.136.3" }, { name = "hydra-core", specifier = ">=1.3,<2" }, { name = "ipykernel", marker = "extra == 'evaluation'", specifier = "==7.2.0" }, { name = "ipython", marker = "extra == 'evaluation'", specifier = "==9.13.0" }, From a28a8b8181495d7cae3e605731e31237f7822963 Mon Sep 17 00:00:00 2001 From: Zachary BENSALEM Date: Sat, 6 Jun 2026 16:46:39 +0200 Subject: [PATCH 3/7] Refine agent harness docs and validation flow --- docs/agent-harness/architecture-invariants.md | 21 ++ .../adr/001-rlm-runtime-architecture.md | 15 +- scripts/live_daytona_verify.py | 2 +- src/fleet_rlm/AGENTS.md | 8 +- .../api/routers/ws/connection_loop.py | 124 +++++------ src/fleet_rlm/api/routers/ws/stream_events.py | 10 +- src/fleet_rlm/runtime/agent/agent.py | 195 ++++------------- src/fleet_rlm/runtime/agent/runtime.py | 127 ++++++++--- src/fleet_rlm/runtime/config.py | 7 + src/fleet_rlm/runtime/execution/llm_query.py | 2 +- .../runtime/execution/streaming_events.py | 30 +-- src/fleet_rlm/runtime/modules/escalating.py | 194 +++++++++++++++-- src/fleet_rlm/runtime/tools/__init__.py | 12 ++ src/fleet_rlm/runtime/tools/mcp_tools.py | 199 ++++++++++++++++++ src/fleet_rlm/ui/build.py | 4 +- .../__tests__/agent-chat-adapter.test.ts | 66 ++++++ .../conversation/agent-chat-adapter.ts | 6 +- ...space-message-list.agent-elements.test.tsx | 61 ++++++ src/frontend/src/routes/__root.tsx | 22 +- tests/fixtures/mcp_echo_server.py | 21 ++ tests/unit/api/test_events.py | 21 ++ tests/unit/runtime/test_escalating_module.py | 57 ++++- tests/unit/runtime/test_execution.py | 43 +--- tests/unit/runtime/test_mcp_tools.py | 166 +++++++++++++++ tests/unit/runtime/test_modules.py | 2 +- .../runtime/test_native_streaming_contract.py | 147 +++++++++++++ 26 files changed, 1210 insertions(+), 352 deletions(-) create mode 100644 src/fleet_rlm/runtime/tools/mcp_tools.py create mode 100644 tests/fixtures/mcp_echo_server.py create mode 100644 tests/unit/runtime/test_mcp_tools.py create mode 100644 tests/unit/runtime/test_native_streaming_contract.py diff --git a/docs/agent-harness/architecture-invariants.md b/docs/agent-harness/architecture-invariants.md index e14fc79ea..f4475d731 100644 --- a/docs/agent-harness/architecture-invariants.md +++ b/docs/agent-harness/architecture-invariants.md @@ -21,6 +21,27 @@ Transport code may call runtime services and schemas. Runtime code should not im FastAPI route modules, or test-only helpers. Configuration/package-root modules must not pull in heavy runtime providers such as DSPy, MLflow, PostHog, or Daytona at import time. +## Async Execution Boundary + +The sandbox interpreters (Daytona, Modal) expose a synchronous, blocking `execute(...)` that +performs a network round-trip per code iteration. `dspy.RLM.aforward` only awaits the LM predictor +calls — it still runs sandbox code through the **synchronous** `repl.execute(...)` (verified in +dspy 3.2.1). Therefore the heavy RLM turn is driven sync-in-a-thread via +`asyncio.to_thread(self.agent, ...)` in `runtime/agent/runtime.py`, which offloads both the LM +calls and the blocking sandbox I/O to a worker thread and keeps the event loop free. + +Do not replace this `asyncio.to_thread` wrapping with a direct `await agent.acall(...)`/`aforward` +on the RLM heavy path while the interpreter's `execute` stays synchronous — doing so would block the +event loop on every code-execution iteration and regress server concurrency. The native chat +streaming path is the exception: it interleaves per-token streaming through `async_planner_step` +(which uses `acall` on the planner predictor only, not sandbox execution). + +MCP-backed ReAct tools are the other async exception. Tools converted with +`dspy.Tool.from_mcp_tool(session, tool)` are bound to a live MCP `ClientSession` and must be invoked +through an async ReAct path (`acall`) while that session remains open. Keep MCP tools out of sync +ReAct calls, close the provider when the runtime shuts down, and rebuild the agent from base tools +plus the current MCP attachment when servers are reattached. + ## Frontend Boundaries Keep shared UI primitives reusable: diff --git a/docs/reference/adr/001-rlm-runtime-architecture.md b/docs/reference/adr/001-rlm-runtime-architecture.md index 3f7cc957e..d39c612fb 100644 --- a/docs/reference/adr/001-rlm-runtime-architecture.md +++ b/docs/reference/adr/001-rlm-runtime-architecture.md @@ -27,10 +27,21 @@ The architecture consists of these layers: The primary runtime (`src/fleet_rlm/runtime/agent/runtime.py`) owns session state, tool binding, streaming, and persistence. Its default agent module (`src/fleet_rlm/runtime/modules/escalating.py`) extends `dspy.Module` to provide: - **Stateful conversation**: `dspy.History` for persistent chat memory -- **Lightweight-to-heavy escalation**: `dspy.ChainOfThought` for simple turns, escalating to the Daytona-backed RLM path when needed -- **Tool orchestration**: Dynamic tool registration and dispatch +- **Lightweight-to-heavy escalation**: `dspy.ChainOfThought` for simple turns; the + `[TOOLS NEEDED]` sentinel routes to a real `dspy.ReAct` tool loop (`FleetAgent`), while + forced `rlm`/`rlm_only` modes and auto-detected URL-document analysis route to the + Daytona-backed `dspy.RLM` heavy path +- **Tool orchestration**: Dynamic tool registration and dispatch (including optional + DSPy-native MCP tools discovered from `FLEET_RLM_MCP_SERVERS`) - **Recursive delegation**: `runtime/tools/rlm_delegate.py` and `integrations/daytona/isolation.py` build bounded child RLM runs +MCP tools are opt-in and session-backed. `AgentRuntime.attach_mcp_tools(...)` connects the +configured MCP servers, converts discovered tools with `dspy.Tool.from_mcp_tool(...)`, and rebuilds +the agent from the stable base tool set plus the current MCP attachment. Reattaching MCP servers +replaces the previous MCP tools and closes their provider; runtime shutdown closes any remaining +MCP sessions. Because these tools are async, the sentinel ReAct route is driven through an async +ReAct call, while forced/url RLM routes remain sync-in-thread for Daytona sandbox execution. + ### 2. Signature-Based Contracts Agent behavior is defined through DSPy signatures diff --git a/scripts/live_daytona_verify.py b/scripts/live_daytona_verify.py index a867ab927..2351aaf04 100644 --- a/scripts/live_daytona_verify.py +++ b/scripts/live_daytona_verify.py @@ -22,7 +22,7 @@ # Ensure repo root is on path for imports sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from fleet_rlm.api.runtime_services.chat_persistence import ( +from fleet_rlm.api.runtime_services.session_manifest import ( ensure_session_volume_layout, load_manifest_from_volume, ) diff --git a/src/fleet_rlm/AGENTS.md b/src/fleet_rlm/AGENTS.md index 4023dc6ff..d856bace0 100644 --- a/src/fleet_rlm/AGENTS.md +++ b/src/fleet_rlm/AGENTS.md @@ -147,7 +147,9 @@ Runtime ownership: - Keep DSPy signatures in `runtime/agent/signatures.py` - Keep runtime module construction/registration in `runtime/modules/factory.py`, `runtime/modules/registry.py`, or the `fleet_rlm.runtime.modules` package exports -- Keep the main cognition loop in `runtime/agent/agent.py` (FleetAgent / RLMReActAgent) and `runtime/agent/runtime.py` (AgentRuntime) +- Keep the main cognition loop in `runtime/agent/agent.py` (`FleetAgent`, a thin + `dspy.ReAct` subclass), `runtime/modules/escalating.py` (`EscalatingFleetModule`, the + default CoT→ReAct→RLM router), and `runtime/agent/runtime.py` (AgentRuntime) - Keep the public Daytona interpreter facade in `integrations/daytona/interpreter.py`; durable workspace/session behavior lives in `workspace_manager.py`, code execution and bridge state live in `sandbox_executor.py`, and recursive child policy/delegation lives in `isolation.py`. - Keep Daytona collaborator boundaries typed with small internal Protocols. Use Pydantic v2 for validated configuration/state boundary models such as `WorkspaceConfig`, but keep hot execution-path payloads and bridge result carriers as dataclasses/functions. - Keep runtime orchestration and shared chat/runtime behavior under `runtime/agent/*` and `runtime/execution/*` @@ -158,6 +160,10 @@ Runtime ownership: - The module registry (`module_registry.py`) is the single source of truth for optimizable modules, consumed by CLI, API, and frontend. **Note:** `longcot-reasoner` is currently registered via `fleet_rlm.quality.optimize_longcot`; add new `quality/optimize_*.py` entrypoints to `_MODULE_ENTRYPOINTS` as more modules become optimizable. - GEPA runs offline only — never in the live request path - Keep grouped tool helpers under root `runtime/tools/*` +- Keep DSPy-native MCP tool discovery in `runtime/tools/mcp_tools.py`. It is opt-in: + servers are configured via the `FLEET_RLM_MCP_SERVERS` env var (JSON array) and + attached through `AgentRuntime.attach_mcp_tools(...)`; nothing connects unless set. + Do not add import-time MCP/`dspy` side effects (lazy-import inside `connect()`). API ownership: diff --git a/src/fleet_rlm/api/routers/ws/connection_loop.py b/src/fleet_rlm/api/routers/ws/connection_loop.py index 89b19376c..18941ac2a 100644 --- a/src/fleet_rlm/api/routers/ws/connection_loop.py +++ b/src/fleet_rlm/api/routers/ws/connection_loop.py @@ -291,68 +291,72 @@ async def _background_execution_task( """Run execution in the background with its own agent context.""" from ...runtime_services.chat_runtime import build_chat_agent_context - agent_context = await build_chat_agent_context(runtime) - async with agent_context as agent: - interpreter = getattr(agent, "interpreter", None) - set_interpreter_default_profile(interpreter, runtime.cfg) - - async def _noop_persist( - *, - include_volume_save: bool = True, - latest_user_message: str = "", - ) -> None: - _ = include_volume_save, latest_user_message - - ( - session.active_key, - session.active_manifest_path, - session.session_record, - session.last_loaded_docs_path, - session.orchestration_session, - ) = await switch_session_if_needed( - session_cache=session_cache, - agent=agent, - interpreter=interpreter, - workspace_id=workspace_id, - user_id=user_id, - sess_id=sess_id, - owner_tenant_claim=session.owner_tenant_claim, - owner_user_claim=session.owner_user_claim, - active_key=None, - session_record=session.session_record, - last_loaded_docs_path=session.last_loaded_docs_path, - local_persist=_noop_persist, - persistence=runtime.persistence, - identity_rows=runtime.identity_rows, - ) + try: + agent_context = await build_chat_agent_context(runtime) + async with agent_context as agent: + interpreter = getattr(agent, "interpreter", None) + set_interpreter_default_profile(interpreter, runtime.cfg) + + async def _noop_persist( + *, + include_volume_save: bool = True, + latest_user_message: str = "", + ) -> None: + _ = include_volume_save, latest_user_message + + ( + session.active_key, + session.active_manifest_path, + session.session_record, + session.last_loaded_docs_path, + session.orchestration_session, + ) = await switch_session_if_needed( + session_cache=session_cache, + agent=agent, + interpreter=interpreter, + workspace_id=workspace_id, + user_id=user_id, + sess_id=sess_id, + owner_tenant_claim=session.owner_tenant_claim, + owner_user_claim=session.owner_user_claim, + active_key=None, + session_record=session.session_record, + last_loaded_docs_path=session.last_loaded_docs_path, + local_persist=_noop_persist, + persistence=runtime.persistence, + identity_rows=runtime.identity_rows, + ) - agent._db_session_id = (session.session_record or {}).get("db_session_id") - agent._identity_rows = runtime.identity_rows - if agent.interpreter is not None: - agent.interpreter._host_repository = runtime.persistence - agent.interpreter._host_identity = runtime.identity_rows - agent.interpreter._host_run_id = None - local_persist = build_local_persist_fn( - session_cache=session_cache, - runtime=runtime, - agent=agent, - interpreter=interpreter, - session=session, - ) + agent._db_session_id = (session.session_record or {}).get("db_session_id") + agent._identity_rows = runtime.identity_rows + if agent.interpreter is not None: + agent.interpreter._host_repository = runtime.persistence + agent.interpreter._host_identity = runtime.identity_rows + agent.interpreter._host_run_id = None + local_persist = build_local_persist_fn( + session_cache=session_cache, + runtime=runtime, + agent=agent, + interpreter=interpreter, + session=session, + ) - return await _process_chat_message( - websocket=None, - msg=msg, - agent=agent, - interpreter=interpreter, - session=session, - local_persist=local_persist, - runtime=runtime, - workspace_id=workspace_id, - user_id=user_id, - sess_id=sess_id, - execution_emitter=execution_emitter, - ) + return await _process_chat_message( + websocket=None, + msg=msg, + agent=agent, + interpreter=interpreter, + session=session, + local_persist=local_persist, + runtime=runtime, + workspace_id=workspace_id, + user_id=user_id, + sess_id=sess_id, + execution_emitter=execution_emitter, + ) + except Exception: + logger.exception("Background websocket execution task failed") + raise class _ExecutionConnectionLoop: diff --git a/src/fleet_rlm/api/routers/ws/stream_events.py b/src/fleet_rlm/api/routers/ws/stream_events.py index 39a8b702e..8e2e93d92 100644 --- a/src/fleet_rlm/api/routers/ws/stream_events.py +++ b/src/fleet_rlm/api/routers/ws/stream_events.py @@ -111,12 +111,18 @@ def _to_workspace_event(event: Any) -> WorkspaceEvent: """Normalize a runtime-style stream event into a workspace event.""" raw_ts = getattr(event, "timestamp", None) timestamp = raw_ts if isinstance(raw_ts, datetime) else datetime.now(timezone.utc) + + kind = getattr(event, "kind", "status") + if hasattr(kind, "value"): + kind = getattr(kind, "value") + kind = str(kind) + return WorkspaceEvent( - kind=str(getattr(event, "kind", "status")), + kind=kind, text=str(getattr(event, "text", "") or ""), payload=dict(getattr(event, "payload", {}) or {}), timestamp=timestamp, - terminal=is_terminal_stream_event_kind(str(getattr(event, "kind", ""))), + terminal=is_terminal_stream_event_kind(kind), ) diff --git a/src/fleet_rlm/runtime/agent/agent.py b/src/fleet_rlm/runtime/agent/agent.py index 0a7c834d5..4c22ec7d0 100644 --- a/src/fleet_rlm/runtime/agent/agent.py +++ b/src/fleet_rlm/runtime/agent/agent.py @@ -1,14 +1,20 @@ """Thin DSPy Module wrappers for the interactive ReAct chat agent. -This module provides minimal, pure-DSPy agent definitions that separate +This module provides a minimal, pure-DSPy agent definition that separates the inference graph from runtime orchestration concerns such as interpreter lifecycle, session state, and streaming. + +``FleetAgent`` subclasses the canonical :class:`dspy.ReAct` so it inherits the +upstream planner/extract predictors, native trajectory truncation, and +``Literal``-constrained tool selection. A pair of thin compatibility shims +(:pyattr:`planner` and :meth:`async_planner_step`) let :class:`AgentRuntime` +drive the planner one step at a time for token streaming. """ from __future__ import annotations import logging -from typing import Any, cast +from typing import Any import dspy from dspy.utils.exceptions import ContextWindowExceededError @@ -17,21 +23,28 @@ class FleetAgentSignature(dspy.Signature): - """Simplified ReAct chat signature for FleetAgent.""" + """Respond to the user, using tools only when they are needed. + + When you call the `finish` tool, write `next_thought` as the exact, complete + final response to send to the user. The runtime streams that text directly + and skips a separate extraction step, so `next_thought` on the finishing turn + must be the full user-facing answer rather than internal reasoning. + """ chat_history: dspy.History = dspy.InputField(desc="Prior conversation turns (keys: user_message, response)") user_message: str = dspy.InputField(desc="Current user message") response: str = dspy.OutputField(desc="Agent response to the user") -class FleetAgent(dspy.Module): - """A custom DSPy Module implementing ReAct cleanly for streaming. +class FleetAgent(dspy.ReAct): + """ReAct chat agent built on the canonical :class:`dspy.ReAct`. - This module instantiates its own `dspy.Predict` (for planning) and `dspy.ChainOfThought` - (for extraction) with custom instructions optimized for the RLM system. - It exposes `async_planner_step` and `async_extract_step` to allow the AgentRuntime - to weave external effects (e.g. streaming, sandbox tool execution) directly - into the cognitive loop. + Inherits upstream behaviour: the planner predictor ``self.react``, the + extractor ``self.extract``, native ``truncate_trajectory`` / + ``_format_trajectory`` helpers, and a ``finish`` tool with + ``Literal``-constrained tool selection. The :pyattr:`planner` alias and + :meth:`async_planner_step` shim expose the step-by-step interface the + streaming runtime drives to interleave token streaming and tool execution. """ def __init__( @@ -40,78 +53,27 @@ def __init__( tools: list[Any], max_iters: int = 10, ) -> None: - super().__init__() - self.signature = FleetAgentSignature - self.max_iters = max_iters - - self.tools = {getattr(t, "name", getattr(t, "__name__", str(t))): t for t in tools} - if "finish" not in self.tools: - self.tools["finish"] = dspy.Tool( - func=lambda: "Completed.", - name="finish", - desc="Stop and return the final response.", - args={}, - ) - - input_names = ", ".join(f"`{name}`" for name in self.signature.input_fields) - output_names = ", ".join(f"`{name}`" for name in self.signature.output_fields) - - tool_lines = [] - for name, tool in self.tools.items(): - desc = getattr(tool, "desc", getattr(tool, "__doc__", "No description available.")) - tool_lines.append(f"- {name}: {desc}") - - instructions = "\n".join( - [ - f"You are an agent. Use tools only when needed to produce {output_names} from {input_names}.", - "At each step output next_thought, next_tool_name, and next_tool_args.", - "Tool observations are appended to trajectory.", - "next_tool_args must be a JSON object for the selected tool.", - "Available tools:", - *tool_lines, - "If you choose finish on the first step without using any tool, write next_thought as the exact final response to send to the user.", - ] - ) - - signature_builder = cast(Any, dspy.Signature) - self.react_signature = ( - signature_builder({**self.signature.input_fields}, instructions) - .append("trajectory", dspy.InputField(), type_=str) - .append("next_thought", dspy.OutputField(), type_=str) - .append("next_tool_name", dspy.OutputField(), type_=str) - .append("next_tool_args", dspy.OutputField(), type_=dict[str, Any]) - ) - - self.fallback_signature = signature_builder( - {**self.signature.input_fields, **self.signature.output_fields}, self.signature.instructions - ).append("trajectory", dspy.InputField(), type_=str) - - self.planner = dspy.Predict(self.react_signature) - self.extract = dspy.ChainOfThought(self.fallback_signature) - - def _format_trajectory(self, trajectory: dict[str, Any]) -> str: - adapter = cast(Any, getattr(dspy.settings, "adapter", None) or dspy.ChatAdapter()) - signature_builder = cast(Any, dspy.Signature) - trajectory_signature = signature_builder(f"{', '.join(trajectory.keys())} -> x") - return adapter.format_user_message_content(trajectory_signature, trajectory) - - def truncate_trajectory(self, trajectory: dict[str, Any]) -> dict[str, Any]: - """Truncates the oldest tool call information from the trajectory.""" - keys = list(trajectory.keys()) - if len(keys) < 4: - raise ValueError( - "The trajectory is too long so your prompt exceeded the context window, " - "but the trajectory cannot be truncated because it only has one tool call." - ) - for key in keys[:4]: - trajectory.pop(key) - return trajectory + super().__init__(FleetAgentSignature, tools=list(tools), max_iters=max_iters) + + @property + def planner(self) -> dspy.Predict: + """Alias for the upstream ReAct planner predictor. + + The streaming runtime introspects ``planner`` to decide whether a + program can be driven step-by-step. + """ + return self.react async def async_planner_step(self, trajectory: dict[str, Any], **input_args: Any) -> dspy.Prediction: - """Call the planner with truncation retry logic.""" + """Run one planner step with context-window truncation retries. + + Mirrors the upstream truncation loop but is exposed publicly so the + runtime can interleave streaming and sandbox tool execution into the + cognitive loop. + """ for _ in range(3): try: - prediction = await self.planner.acall( + prediction = await self.react.acall( **input_args, trajectory=self._format_trajectory(trajectory), ) @@ -123,80 +85,3 @@ async def async_planner_step(self, trajectory: dict[str, Any], **input_args: Any logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.") trajectory = self.truncate_trajectory(trajectory) raise ValueError("The context window was exceeded even after 3 attempts to truncate the trajectory.") - - async def async_extract_step(self, trajectory: dict[str, Any], **input_args: Any) -> dspy.Prediction: - """Call the extractor with truncation retry logic.""" - for _ in range(3): - try: - return await self.extract.acall(**input_args, trajectory=self._format_trajectory(trajectory)) - except ContextWindowExceededError: - logger.warning("Trajectory exceeded the context window, truncating the oldest tool call information.") - trajectory = self.truncate_trajectory(trajectory) - raise ValueError("The context window was exceeded even after 3 attempts to truncate the trajectory.") - - def forward(self, **input_args: Any) -> dspy.Prediction: - """Synchronous forward pass through the ReAct agent.""" - trajectory: dict[str, Any] = {} - last_tool_name: str | None = None - last_thought: str | None = None - for idx in range(self.max_iters): - try: - - def sync_call(module, traj, **kwargs): - for _ in range(3): - try: - return module(**kwargs, trajectory=self._format_trajectory(traj)) - except ContextWindowExceededError: - traj = self.truncate_trajectory(traj) - raise ValueError("Context window exceeded") - - pred = sync_call(self.planner, trajectory, **input_args) - tool_name = getattr(pred, "next_tool_name", "") - if tool_name and tool_name not in self.tools: - raise ValueError(f"Agent failed to select a valid tool: {tool_name!r}") - except ValueError as err: - logger.warning(f"Ending the trajectory: {err}") - break - - trajectory[f"thought_{idx}"] = pred.next_thought - trajectory[f"tool_name_{idx}"] = tool_name - trajectory[f"tool_args_{idx}"] = pred.next_tool_args - - last_tool_name = tool_name - last_thought = pred.next_thought - - if not tool_name: - break - - try: - tool = self.tools[tool_name] - if hasattr(tool, "func"): - func = tool.func - elif callable(tool): - func = tool - else: - func = getattr(tool, "forward", None) - if func is None: - raise ValueError(f"Tool {tool_name} is not callable") - func = cast(Any, func) - trajectory[f"observation_{idx}"] = func(**pred.next_tool_args) - except Exception as err: - trajectory[f"observation_{idx}"] = f"Execution error in {tool_name}: {err}" - - if tool_name == "finish": - break - - # Check for terminal step shortcut - if last_tool_name == "finish" or not last_tool_name: - return dspy.Prediction(trajectory=trajectory, response=str(last_thought or "")) - - def sync_extract(module, traj, **kwargs): - for _ in range(3): - try: - return module(**kwargs, trajectory=self._format_trajectory(traj)) - except ContextWindowExceededError: - traj = self.truncate_trajectory(traj) - raise ValueError("Context window exceeded") - - extract = sync_extract(self.extract, trajectory, **input_args) - return dspy.Prediction(trajectory=trajectory, **extract) diff --git a/src/fleet_rlm/runtime/agent/runtime.py b/src/fleet_rlm/runtime/agent/runtime.py index 319e4e273..cc8d9b0dd 100644 --- a/src/fleet_rlm/runtime/agent/runtime.py +++ b/src/fleet_rlm/runtime/agent/runtime.py @@ -123,17 +123,19 @@ def _routing_status_text(payload: dict[str, Any]) -> str: def _get_streamable_react_program(program: Any) -> Any | None: - react_program = getattr(program, "react", program) - - planner = getattr(react_program, "planner", None) - extract = getattr(react_program, "extract", None) - async_call = getattr(react_program, "async_planner_step", None) - - if planner is None or extract is None: - return None - if not callable(async_call): - return None - return react_program + # The program itself may be drivable (``FleetAgent``), or it may wrap a + # streamable ReAct sub-program under ``.react`` (``EscalatingFleetModule``). + # Probe the program first so a ``FleetAgent`` (whose ``.react`` is the + # upstream planner ``Predict``, not a streamable program) is recognised. + for candidate in (program, getattr(program, "react", None)): + if candidate is None: + continue + planner = getattr(candidate, "planner", None) + extract = getattr(candidate, "extract", None) + async_call = getattr(candidate, "async_planner_step", None) + if planner is not None and extract is not None and callable(async_call): + return candidate + return None def _normalize_tool_args(tool_args: Any) -> dict[str, Any]: @@ -251,8 +253,6 @@ def __init__( summary_interval: int = 10, compaction_threshold_pct: float = 0.7, ) -> None: - from .agent import FleetAgent - self.interpreter: Any | None = interpreter self.history: dspy.History = dspy.History(messages=[]) self.history_max_turns: int | None = history_max_turns @@ -261,7 +261,7 @@ def __init__( # Phase 7: attach runtime reference to interpreter so recursive children # can access parent history for bounded conversation snapshots if interpreter is not None: - setattr(interpreter, "runtime", self) + setattr(interpreter, "agent_runtime", self) # Session-management hooks used by the websocket layer self._db_session_id: str | object | None = None @@ -293,25 +293,82 @@ def __init__( interpreter=interpreter, ) - self.tools: list[Any] = base_tools + list(extra_tools or []) + self._base_tools: list[Any] = base_tools + list(extra_tools or []) + self._mcp_tools: list[Any] = [] + self.tools: list[Any] = list(self._base_tools) self.react_tools: list[Any] = self.tools - if use_escalation: + # Retained for rebuilding the agent when async tool sources (e.g. MCP) + # are attached after construction. + self._max_iters: int = max_iters + self._mcp_provider: Any | None = None + + self.agent: Any = self._build_agent(self.tools) + + def _build_agent(self, tools: list[Any]) -> Any: + """Construct the cognition module for the given tool set.""" + from .agent import FleetAgent + + if self._use_escalation: from fleet_rlm.runtime.modules.escalating import EscalatingFleetModule - self.agent: Any = EscalatingFleetModule( - interpreter=interpreter, - tools=self.tools, + return EscalatingFleetModule( + interpreter=self.interpreter, + tools=tools, max_iterations=self.rlm_max_iterations, max_llm_calls=self.rlm_max_llm_calls, max_output_chars=self.rlm_max_output_chars, - summary_interval=summary_interval, - ) - else: - self.agent: Any = FleetAgent( - tools=self.tools, - max_iters=max_iters, + summary_interval=self._summary_interval, ) + return FleetAgent( + tools=tools, + max_iters=self._max_iters, + ) + + async def attach_mcp_tools(self, configs: Any | None = None) -> list[str]: + """Discover MCP tools and rebuild the agent with them registered. + + Connects to the configured MCP servers (env-driven by default), appends + the discovered async ``dspy.Tool`` objects to the runtime tool list, and + rebuilds the cognition module so they appear in the ReAct tool set. Safe + to call when no MCP servers are configured (returns an empty list and + leaves the agent untouched). Call :meth:`aclose_mcp` to release sessions. + + Returns the names of the MCP tools that were attached. + """ + from fleet_rlm.runtime.tools.mcp_tools import MCPToolProvider, load_mcp_server_configs + + resolved = configs if configs is not None else load_mcp_server_configs() + if not resolved: + return [] + + provider = MCPToolProvider(resolved) + mcp_tools = await provider.connect() + if not mcp_tools: + await provider.aclose() + return [] + + # Release any previously attached provider before swapping it in. + await self.aclose_mcp() + self._mcp_provider = provider + self._mcp_tools = list(mcp_tools) + + self.tools = list(self._base_tools) + list(self._mcp_tools) + self.react_tools = self.tools + self.agent = self._build_agent(self.tools) + return [getattr(tool, "name", str(tool)) for tool in mcp_tools] + + async def aclose_mcp(self) -> None: + """Close any live MCP sessions attached via :meth:`attach_mcp_tools`.""" + provider = self._mcp_provider + if provider is None: + return + self._mcp_provider = None + self._mcp_tools = [] + await provider.aclose() + self.tools = list(self._base_tools) + self.react_tools = self.tools + self.agent = self._build_agent(self.tools) # ----------------------------------------------------------------- # Chat API @@ -472,10 +529,20 @@ async def _aiter_chat_turn_stream_posthoc( ) try: - result = await asyncio.to_thread( - self.agent, - **self._escalation_call_args(message), - ) + async_call = getattr(self.agent, "aforward", None) + if callable(async_call): + result = await async_call(**self._escalation_call_args(message)) + else: + # Drive sync-only modules in a worker thread. The RLM heavy path + # runs sandbox code through the interpreter's blocking execute(); + # dspy.RLM.aforward still calls that synchronously (only LM + # predictor calls are awaited), so to_thread is the correct + # non-blocking pattern for modules without an explicit async path. + # See docs/agent-harness/architecture-invariants.md. + result = await asyncio.to_thread( + self.agent, + **self._escalation_call_args(message), + ) except Exception as exc: yield RuntimeEvent( kind=RuntimeEventKind.ERROR, @@ -602,6 +669,7 @@ async def __aexit__( return False def shutdown(self) -> None: + _run_async_compat(self.aclose_mcp) if self.interpreter is not None: shutdown = getattr(self.interpreter, "shutdown", None) if callable(shutdown): @@ -618,6 +686,7 @@ def shutdown(self) -> None: pass async def ashutdown(self) -> None: + await self.aclose_mcp() if self.interpreter is None: return ashutdown = getattr(self.interpreter, "ashutdown", None) diff --git a/src/fleet_rlm/runtime/config.py b/src/fleet_rlm/runtime/config.py index 2ef6ad958..dfa2e54ff 100644 --- a/src/fleet_rlm/runtime/config.py +++ b/src/fleet_rlm/runtime/config.py @@ -149,6 +149,13 @@ def _resolve_max_tokens(value: int | str | None, *, default: int = 64000) -> int return default +# Normalized LM API guard (dspy.ai/community/normalized-lm-api-migration): +# fleet-rlm uses the stock ``dspy.LM`` with no custom BaseLM/Adapter subclass, so it +# sits in the migration's "nothing required" bucket. Keep it that way: always invoke +# the LM as ``lm(...)`` (never ``lm.forward(...)``). When bumping to dspy 3.3+, the +# typed LM API stays opt-in via ``dspy.context(experimental=True)`` — no behaviour +# change. Any future custom LM/adapter must declare ``forward_contract`` and build +# ``LMRequest`` / parse ``LMResponse`` rather than reaching into ``forward``. def _build_lm( *, model: str, diff --git a/src/fleet_rlm/runtime/execution/llm_query.py b/src/fleet_rlm/runtime/execution/llm_query.py index 8f1212573..b62951705 100644 --- a/src/fleet_rlm/runtime/execution/llm_query.py +++ b/src/fleet_rlm/runtime/execution/llm_query.py @@ -73,7 +73,7 @@ def _build_child_history_snapshot(interpreter: Any) -> str: Returns: A bounded, redacted conversation snapshot string. """ - runtime = getattr(interpreter, "runtime", None) + runtime = getattr(interpreter, "agent_runtime", None) if runtime is None: return "" diff --git a/src/fleet_rlm/runtime/execution/streaming_events.py b/src/fleet_rlm/runtime/execution/streaming_events.py index 2fbbcd93b..56ded58d5 100644 --- a/src/fleet_rlm/runtime/execution/streaming_events.py +++ b/src/fleet_rlm/runtime/execution/streaming_events.py @@ -3,7 +3,7 @@ Event *construction* lives in :mod:`fleet_rlm.runtime.events`. This module retains only final-payload assembly (citations/attachments/sources) -and the DSPy ``ReActStatusProvider`` status hook. +and terminal-event/HITL helpers. """ from __future__ import annotations @@ -13,7 +13,6 @@ from urllib.parse import urlparse import dspy -from dspy.streaming.messages import StatusMessageProvider from fleet_rlm.runtime.events import EVENT_SCHEMA_VERSION from fleet_rlm.utils.preview import head_tail_preview @@ -38,33 +37,6 @@ def is_terminal_stream_event_kind(kind: str) -> bool: return kind in TERMINAL_STREAM_EVENT_KINDS -# ═══════════════════════════════════════════════════════════════════════ -# DSPy status hook (structured data embedded in RuntimeEvent at build time) -# ═══════════════════════════════════════════════════════════════════════ - - -class ReActStatusProvider(StatusMessageProvider): - """Concise status hook for streamed ReAct sessions. - - Returns human-readable status strings required by the DSPy - ``StatusMessageProvider`` interface. Structured tool/actor data - is carried in :class:`~fleet_rlm.runtime.events.RuntimeEvent` objects - built by the event factories in :mod:`fleet_rlm.runtime.events`. - """ - - def tool_start_status_message(self, instance: Any, inputs: dict[str, Any]) -> str: - return f"Calling tool: {instance.name}" - - def tool_end_status_message(self, outputs: Any) -> str | None: - return "Tool finished." - - def module_start_status_message(self, instance: Any, inputs: dict[str, Any]) -> str | None: - return f"Running module: {instance.__class__.__name__}" - - def module_end_status_message(self, outputs: Any) -> str | None: - return None - - # ═══════════════════════════════════════════════════════════════════════ # HITL event helper # ═══════════════════════════════════════════════════════════════════════ diff --git a/src/fleet_rlm/runtime/modules/escalating.py b/src/fleet_rlm/runtime/modules/escalating.py index 5ba986254..f468bf630 100644 --- a/src/fleet_rlm/runtime/modules/escalating.py +++ b/src/fleet_rlm/runtime/modules/escalating.py @@ -1,17 +1,22 @@ -"""Escalating fleet agent module — ChainOfThought for simple turns, RLM for complex ones. - -This module implements the Phase 2 unified agent design: a single DSPy Module -that seamlessly escalates from a lightweight ChainOfThought response to a full -dspy.RLM loop when the situation demands it. - -Escalation is triggered when: -- The ChainOfThought reasoning output contains the sentinel ``[TOOLS NEEDED]``. -- The caller sets ``execution_mode="rlm"`` or ``"rlm_only"`` explicitly. -- The caller sets ``force_escalate=True``. +"""Escalating fleet agent module — ChainOfThought, ReAct tools, or RLM per turn. + +This module implements the unified agent design: a single DSPy Module that +seamlessly escalates from a lightweight ChainOfThought response to a real, +tool-using ``dspy.ReAct`` loop or a full ``dspy.RLM`` loop when the situation +demands it. + +Routing: +- Simple turns are answered by the ChainOfThought fast path. +- When the fast-path reasoning contains the sentinel ``[TOOLS NEEDED]`` the + module runs the shared ``dspy.ReAct`` tool loop (the same ``FleetAgent`` + program used by the non-escalating runtime). +- ``execution_mode="rlm"``/``"rlm_only"``, ``force_escalate=True``, or an + auto-detected URL-document analysis request route to the ``dspy.RLM`` sandbox. """ from __future__ import annotations +import asyncio import logging import re from dataclasses import dataclass, field @@ -28,6 +33,8 @@ ESCALATION_SENTINEL = "[TOOLS NEEDED]" _RLM_FALLBACK_WARNING = "RLM escalation failed; returned a lightweight fallback response." +_REACT_FALLBACK_WARNING = "ReAct tool loop failed; returned a lightweight fallback response." +_REACT_MAX_ITERS = 10 _URL_DOCUMENT_MAX_ITERATIONS = 4 _URL_DOCUMENT_MAX_LLM_CALLS = 8 _URL_RE = re.compile(r"https?://[^\s)\],;]+", flags=re.IGNORECASE) @@ -168,8 +175,10 @@ class EscalatingFleetModule(dspy.Module): """Unified DSPy Module that scales from lightweight chat to full RLM execution. Simple turns are handled by a ``dspy.ChainOfThought`` step. When the - reasoning contains :data:`ESCALATION_SENTINEL` or the caller requests deep - work, the module re-runs via an ``RLMVariableExecutionModule`` or a + reasoning contains :data:`ESCALATION_SENTINEL` the module runs the shared + ``dspy.ReAct`` tool loop (:class:`~fleet_rlm.runtime.agent.agent.FleetAgent`) + for real tool use. Explicit RLM modes, ``force_escalate``, or auto-detected + URL-document analysis instead route to an ``RLMVariableExecutionModule`` or a raw ``dspy.RLM`` with the same tool set. Parameters @@ -221,6 +230,21 @@ def __init__( self.respond = dspy.ChainOfThought(RLMReActChatSignature) self.summarize = dspy.ChainOfThought(ConversationSummarySignature) + # Sentinel tool branch: the shared upstream dspy.ReAct loop (FleetAgent). + # When the ChainOfThought fast path emits ESCALATION_SENTINEL the module + # runs a real, tool-using ReAct loop here instead of the RLM sandbox, + # which stays reserved for forced/long-context and URL-document paths. + # Named with a leading underscore so the streaming router + # (``_get_streamable_react_program``) does not treat the escalating + # module itself as a directly drivable ReAct program; routing stays in + # ``forward`` and the posthoc stream surfaces this branch's trajectory. + from fleet_rlm.runtime.agent.agent import FleetAgent + + self._react = FleetAgent( + tools=list(tools or []), + max_iters=max(1, min(max_iterations, _REACT_MAX_ITERS)), + ) + self._rlm: dspy.Module | None = None self._url_document_rlm: dspy.Module | None = None if interpreter is not None: @@ -370,21 +394,159 @@ def forward( ) if self._should_escalate(prediction, execution_mode=execution_mode, force_escalate=False): - logger.debug("EscalatingFleetModule: escalating to RLM (sentinel found in reasoning)") - return self._run_rlm( + logger.debug("EscalatingFleetModule: escalating to ReAct tool loop (sentinel found in reasoning)") + return self._run_react( + user_request=user_request, + core_memory=core_memory, + history=history, + recent_history=recent_history, + selected_skills=selected_skills, + ) + + _prediction_set(prediction, "selected_skills", selected_skills) + return prediction + + async def aforward( + self, + *, + user_request: str, + core_memory: str = "", + history: dspy.History | None = None, + execution_mode: str = "auto", + force_escalate: bool = False, + conversation_summary: str = "", + ) -> dspy.Prediction: + """Run one turn without blocking async callers. + + Heavy RLM work stays on the synchronous ``forward`` path inside a worker + thread because sandbox execution is blocking. The sentinel ReAct branch + uses ``acall`` so session-backed async tools, including MCP tools, are + awaited correctly. + """ + if history is None: + history = dspy.History(messages=[]) + + self._turn_count += 1 + + core_memory, selected_skills = await asyncio.to_thread(self._enrich_with_skills, user_request, core_memory) + recent_history = _format_recent_history_context(history) + should_auto_route_url = execution_mode == "auto" and _is_url_document_analysis_request(user_request) + + if _is_rlm_execution_mode(execution_mode) or force_escalate or should_auto_route_url: + return await asyncio.to_thread( + self._run_rlm, user_request=user_request, core_memory=core_memory, history=history, recent_history=recent_history, conversation_summary=conversation_summary, selected_skills=selected_skills, - routing_decision="sentinel_rlm", - source_url=None, + routing_decision="url_document_rlm" if should_auto_route_url else "forced_rlm", + source_url=_extract_first_url(user_request) if should_auto_route_url else None, + ) + + prediction = await asyncio.to_thread( + self.respond, + user_request=user_request, + core_memory=core_memory, + history=history, + recent_history=recent_history, + ) + + if self._should_escalate(prediction, execution_mode=execution_mode, force_escalate=False): + logger.debug("EscalatingFleetModule: async escalating to ReAct tool loop (sentinel found in reasoning)") + return await self._arun_react( + user_request=user_request, + core_memory=core_memory, + history=history, + recent_history=recent_history, + selected_skills=selected_skills, ) _prediction_set(prediction, "selected_skills", selected_skills) return prediction + def _run_react( + self, + *, + user_request: str, + core_memory: str, + history: dspy.History, + recent_history: str, + selected_skills: list[str] | None = None, + ) -> dspy.Prediction: + """Run the shared dspy.ReAct tool loop for the sentinel tool branch. + + The ReAct prediction carries its native ``trajectory`` (thought/tool/ + observation per step) so the streaming layer surfaces tool calls and + results without extra adaptation. On failure the module degrades to the + lightweight ChainOfThought response, mirroring the RLM fallback contract. + """ + try: + result = self._react(chat_history=history, user_message=user_request) + _prediction_set(result, "selected_skills", selected_skills or []) + _prediction_set(result, "routing_decision", "sentinel_react") + return result + except Exception as exc: + logger.warning( + "EscalatingFleetModule: ReAct tool loop failed (%s), falling back to ChainOfThought", + exc, + ) + fallback = self.respond( + user_request=user_request, + core_memory=core_memory, + history=history, + recent_history=recent_history, + ) + fallback["degraded"] = True + fallback["warning"] = _REACT_FALLBACK_WARNING + fallback["runtime_degraded"] = True + fallback["runtime_failure_category"] = "react_fallback" + fallback["runtime_failure_phase"] = "escalating_react" + fallback["runtime_fallback_used"] = True + fallback["runtime_warning"] = _REACT_FALLBACK_WARNING + fallback["selected_skills"] = selected_skills or [] + fallback["routing_decision"] = "sentinel_react" + return fallback + + async def _arun_react( + self, + *, + user_request: str, + core_memory: str, + history: dspy.History, + recent_history: str, + selected_skills: list[str] | None = None, + ) -> dspy.Prediction: + """Async ReAct branch for session-backed tools such as MCP.""" + try: + result = await self._react.acall(chat_history=history, user_message=user_request) + _prediction_set(result, "selected_skills", selected_skills or []) + _prediction_set(result, "routing_decision", "sentinel_react") + return result + except Exception as exc: + logger.warning( + "EscalatingFleetModule: async ReAct tool loop failed (%s), falling back to ChainOfThought", + exc, + ) + fallback = await asyncio.to_thread( + self.respond, + user_request=user_request, + core_memory=core_memory, + history=history, + recent_history=recent_history, + ) + fallback["degraded"] = True + fallback["warning"] = _REACT_FALLBACK_WARNING + fallback["runtime_degraded"] = True + fallback["runtime_failure_category"] = "react_fallback" + fallback["runtime_failure_phase"] = "escalating_react" + fallback["runtime_fallback_used"] = True + fallback["runtime_warning"] = _REACT_FALLBACK_WARNING + fallback["selected_skills"] = selected_skills or [] + fallback["routing_decision"] = "sentinel_react" + return fallback + def _run_rlm( self, *, diff --git a/src/fleet_rlm/runtime/tools/__init__.py b/src/fleet_rlm/runtime/tools/__init__.py index 485e0e4e0..d05778ce8 100644 --- a/src/fleet_rlm/runtime/tools/__init__.py +++ b/src/fleet_rlm/runtime/tools/__init__.py @@ -3,6 +3,13 @@ from __future__ import annotations from ._marker import tool_fn +from .mcp_tools import ( + MCP_SERVERS_ENV_VAR, + MCPServerConfig, + MCPToolProvider, + discover_mcp_tools, + load_mcp_server_configs, +) from .registry import ( TOOL_MODULE_NAMES, _collect_tools_from_modules, @@ -11,9 +18,14 @@ ) __all__ = [ + "MCP_SERVERS_ENV_VAR", + "MCPServerConfig", + "MCPToolProvider", "TOOL_MODULE_NAMES", "_collect_tools_from_modules", + "discover_mcp_tools", "discover_tools", "list_react_tool_names", + "load_mcp_server_configs", "tool_fn", ] diff --git a/src/fleet_rlm/runtime/tools/mcp_tools.py b/src/fleet_rlm/runtime/tools/mcp_tools.py new file mode 100644 index 000000000..86af76eb4 --- /dev/null +++ b/src/fleet_rlm/runtime/tools/mcp_tools.py @@ -0,0 +1,199 @@ +"""DSPy-native MCP (Model Context Protocol) tool integration for the runtime. + +This module connects to configured MCP servers over stdio, lists their tools, +and wraps each one as a :class:`dspy.Tool` via the canonical +:meth:`dspy.Tool.from_mcp_tool` bridge (which delegates to +``dspy.utils.mcp.convert_mcp_tool``). The wrapped tools are async-only, matching +the MCP protocol, so they plug directly into the ReAct ``acall`` path. + +Design notes: + +- No import-time side effects. The ``mcp`` package and DSPy MCP helpers are + imported lazily inside the connection coroutine. +- Server definitions come from the ``FLEET_RLM_MCP_SERVERS`` environment + variable (env-first; richer settings can layer on later). +- :class:`MCPToolProvider` owns the live session lifecycle through an + :class:`contextlib.AsyncExitStack`; callers must ``await provider.aclose()`` + (or use it as an async context manager) to release the sessions. +""" + +from __future__ import annotations + +import json +import logging +import os +from contextlib import AsyncExitStack +from dataclasses import dataclass, field +from typing import Any + +logger = logging.getLogger(__name__) + +MCP_SERVERS_ENV_VAR = "FLEET_RLM_MCP_SERVERS" + + +@dataclass(slots=True) +class MCPServerConfig: + """Connection definition for a single stdio MCP server.""" + + name: str + command: str + args: list[str] = field(default_factory=list) + env: dict[str, str] = field(default_factory=dict) + + @classmethod + def from_mapping(cls, raw: dict[str, Any]) -> MCPServerConfig: + """Build a config from a plain mapping, validating required fields.""" + name = str(raw.get("name", "")).strip() + command = str(raw.get("command", "")).strip() + if not name or not command: + raise ValueError("Each MCP server config requires non-empty 'name' and 'command' fields.") + args = [str(item) for item in (raw.get("args") or [])] + env_raw = raw.get("env") or {} + if not isinstance(env_raw, dict): + raise ValueError(f"MCP server {name!r} 'env' must be a mapping if provided.") + env = {str(key): str(value) for key, value in env_raw.items()} + return cls(name=name, command=command, args=args, env=env) + + +def load_mcp_server_configs(raw_json: str | None = None) -> list[MCPServerConfig]: + """Parse MCP server configs from JSON (defaults to the env var). + + The value must be a JSON array of objects, each with ``name`` and + ``command`` (and optional ``args`` / ``env``). Returns an empty list when + unset or empty so MCP stays opt-in. + """ + payload = raw_json if raw_json is not None else os.environ.get(MCP_SERVERS_ENV_VAR, "") + payload = payload.strip() + if not payload: + return [] + + try: + parsed = json.loads(payload) + except json.JSONDecodeError as exc: + raise ValueError(f"{MCP_SERVERS_ENV_VAR} must be valid JSON: {exc}") from exc + + if not isinstance(parsed, list): + raise ValueError(f"{MCP_SERVERS_ENV_VAR} must be a JSON array of server objects.") + + return [MCPServerConfig.from_mapping(item) for item in parsed] + + +class MCPToolProvider: + """Manages live MCP sessions and exposes their tools as ``dspy.Tool`` objects. + + Usage:: + + provider = MCPToolProvider(load_mcp_server_configs()) + tools = await provider.connect() + try: + ... # use tools in the ReAct loop + finally: + await provider.aclose() + + or as an async context manager:: + + async with MCPToolProvider(configs) as provider: + tools = provider.tools + """ + + def __init__(self, configs: list[MCPServerConfig]) -> None: + self._configs = list(configs) + self._exit_stack: AsyncExitStack | None = None + self._tools: list[Any] = [] + + @property + def tools(self) -> list[Any]: + """Wrapped MCP tools discovered during :meth:`connect`.""" + return list(self._tools) + + async def connect(self) -> list[Any]: + """Open every configured server, list its tools, and wrap them. + + A failure connecting to one server is logged and skipped so a single + bad server does not disable all MCP tooling. Returns the full list of + wrapped tools (also available via :pyattr:`tools`). + """ + if self._exit_stack is not None: + return self.tools + if not self._configs: + return [] + + import dspy + from mcp import ClientSession, StdioServerParameters + from mcp.client.stdio import stdio_client + + stack = AsyncExitStack() + collected: list[Any] = [] + seen_names: set[str] = set() + + for config in self._configs: + try: + params = StdioServerParameters( + command=config.command, + args=config.args, + env={**os.environ, **config.env} if config.env else None, + ) + read, write = await stack.enter_async_context(stdio_client(params)) + session = await stack.enter_async_context(ClientSession(read, write)) + await session.initialize() + listed = await session.list_tools() + except Exception as exc: + logger.warning("MCP server %r failed to connect, skipping: %s", config.name, exc) + continue + + for mcp_tool in listed.tools: + if mcp_tool.name in seen_names: + logger.warning( + "Duplicate MCP tool name %r from server %r ignored.", + mcp_tool.name, + config.name, + ) + continue + seen_names.add(mcp_tool.name) + collected.append(dspy.Tool.from_mcp_tool(session, mcp_tool)) + + logger.info("Connected MCP server %r with %d tool(s).", config.name, len(listed.tools)) + + self._exit_stack = stack + self._tools = collected + return self.tools + + async def aclose(self) -> None: + """Close all open MCP sessions and reset provider state.""" + if self._exit_stack is None: + return + try: + await self._exit_stack.aclose() + finally: + self._exit_stack = None + self._tools = [] + + async def __aenter__(self) -> MCPToolProvider: + await self.connect() + return self + + async def __aexit__(self, *exc_info: object) -> None: + await self.aclose() + + +async def discover_mcp_tools( + configs: list[MCPServerConfig] | None = None, +) -> tuple[MCPToolProvider, list[Any]]: + """Connect to MCP servers and return the live provider plus its tools. + + The caller owns the returned provider and must ``await provider.aclose()`` + when the tools are no longer needed. When no configs are supplied they are + loaded from the environment. + """ + provider = MCPToolProvider(configs if configs is not None else load_mcp_server_configs()) + tools = await provider.connect() + return provider, tools + + +__all__ = [ + "MCP_SERVERS_ENV_VAR", + "MCPServerConfig", + "MCPToolProvider", + "discover_mcp_tools", + "load_mcp_server_configs", +] diff --git a/src/fleet_rlm/ui/build.py b/src/fleet_rlm/ui/build.py index 20d739116..7bb1ede90 100644 --- a/src/fleet_rlm/ui/build.py +++ b/src/fleet_rlm/ui/build.py @@ -4,7 +4,6 @@ from __future__ import annotations import argparse -import os import shutil import subprocess import sys @@ -77,8 +76,7 @@ def main(argv: list[str] | None = None) -> int: print(f"Error running 'pnpm install': {exc}", file=sys.stderr) return 1 - vp_executable = frontend_dir / "node_modules" / ".bin" / ("vp.cmd" if os.name == "nt" else "vp") - build_cmd = [str(vp_executable), "build"] if vp_executable.exists() else ["pnpm", "exec", "vp", "build"] + build_cmd = ["pnpm", "run", "build"] print(f"Running '{' '.join(build_cmd)}'...") try: diff --git a/src/frontend/src/features/workspace/conversation/__tests__/agent-chat-adapter.test.ts b/src/frontend/src/features/workspace/conversation/__tests__/agent-chat-adapter.test.ts index 26bf5fbaf..99d91418b 100644 --- a/src/frontend/src/features/workspace/conversation/__tests__/agent-chat-adapter.test.ts +++ b/src/frontend/src/features/workspace/conversation/__tests__/agent-chat-adapter.test.ts @@ -48,6 +48,72 @@ describe("toAgentChatMessages", () => { expect(messages[2]?.parts).toContainEqual({ type: "text", text: "All set" }); }); + it("maps RLM route, sandbox, delegation, and MCP rows to Agent Elements parts", () => { + const messages = adapter([ + { + id: "trace-rlm", + type: "trace", + content: "", + renderParts: [ + { + kind: "status_note", + text: "Route: url_document_rlm | source: https://example.com", + tone: "neutral", + }, + { + kind: "sandbox", + title: "summary", + state: "output-available", + code: "print('Example Domain')", + output: "Example Domain", + language: "python", + }, + { + kind: "tool", + title: "delegate_to_rlm", + toolType: "delegate_to_rlm", + state: "output-available", + input: { task: "summarize fetched document" }, + output: { status: "completed" }, + }, + { + kind: "tool", + title: "mcp__docs__fetch", + toolType: "mcp__docs__fetch", + state: "output-available", + input: { url: "https://example.com" }, + output: { text: "Example Domain" }, + }, + ], + }, + ]); + + expect(messages).toHaveLength(1); + expect(messages[0]?.parts).toEqual( + expect.arrayContaining([ + expect.objectContaining({ + type: "tool-Status", + output: expect.objectContaining({ + message: expect.stringContaining("url_document_rlm"), + }), + }), + expect.objectContaining({ + type: "tool-Bash", + input: expect.objectContaining({ + command: "print('Example Domain')", + language: "python", + }), + output: { stdout: "Example Domain" }, + }), + expect.objectContaining({ type: "tool-Agent" }), + expect.objectContaining({ + type: "tool-mcp__docs__fetch", + output: { text: "Example Domain" }, + }), + ]), + ); + }); + it("maps HITL approval to tool-Question with resolved output", () => { const messages = adapter([ { diff --git a/src/frontend/src/features/workspace/conversation/agent-chat-adapter.ts b/src/frontend/src/features/workspace/conversation/agent-chat-adapter.ts index 217917458..38d043600 100644 --- a/src/frontend/src/features/workspace/conversation/agent-chat-adapter.ts +++ b/src/frontend/src/features/workspace/conversation/agent-chat-adapter.ts @@ -115,6 +115,9 @@ function stableToolCallId(messageId: string, kind: string, index: number, stepIn function toolPartType(toolType: string): string { const normalized = toolType.toLowerCase(); + if (normalized.startsWith("mcp__")) { + return `tool-${toolType}`; + } if (/(bash|exec|command|terminal|run|shell|python|repl|interpreter|sandbox)/.test(normalized)) { return "tool-Bash"; } @@ -142,9 +145,6 @@ function toolPartType(toolType: string): string { if (/(plan|planning)/.test(normalized)) return "tool-PlanWrite"; if (/(delegate|sub_rlm|agent|recursive)/.test(normalized)) return "tool-Agent"; if (/(think|reason)/.test(normalized)) return "tool-Thinking"; - if (normalized.startsWith("mcp__")) { - return `tool-${toolType}`; - } return `tool-${sanitizeToolName(toolType)}`; } diff --git a/src/frontend/src/features/workspace/conversation/transcript/__tests__/workspace-message-list.agent-elements.test.tsx b/src/frontend/src/features/workspace/conversation/transcript/__tests__/workspace-message-list.agent-elements.test.tsx index f03ff5491..4a7fd6293 100644 --- a/src/frontend/src/features/workspace/conversation/transcript/__tests__/workspace-message-list.agent-elements.test.tsx +++ b/src/frontend/src/features/workspace/conversation/transcript/__tests__/workspace-message-list.agent-elements.test.tsx @@ -218,6 +218,67 @@ describe("WorkspaceMessageList Agent Elements integration", () => { act(() => root.unmount()); }); + it("renders forced and URL RLM event rows through Agent Elements", () => { + const { container, root } = mount([ + { id: "u1", type: "user", content: "Analyze https://example.com" }, + { + id: "trace-rlm", + type: "trace", + content: "", + renderParts: [ + { + kind: "status_note", + text: "Route: url_document_rlm | source: https://example.com", + tone: "neutral", + }, + { + kind: "status_note", + text: "Execution started", + tone: "neutral", + }, + { + kind: "reasoning", + parts: [{ type: "text", text: "Summarize the fetched page" }], + isStreaming: false, + }, + { + kind: "sandbox", + title: "summary", + state: "output-available", + code: "print('Example Domain')", + output: "Example Domain", + language: "python", + }, + { + kind: "tool", + title: "mcp__docs__fetch", + toolType: "mcp__docs__fetch", + state: "output-available", + input: { url: "https://example.com" }, + output: { text: "Example Domain" }, + }, + ], + }, + { + id: "a1", + type: "assistant", + content: "- Example Domain is reserved for examples.", + streaming: false, + }, + ]); + + expect(container.textContent).toContain("Route: url_document_rlm"); + expect(container.textContent).toContain("Execution started"); + expect(container.textContent).toContain("Thought"); + expect(container.textContent).toContain("Ran command"); + expect(container.textContent).toContain("Example Domain"); + expect(container.textContent).toContain("Fetched"); + expect(container.textContent).toContain("example.com"); + expect(container.textContent).toContain("Example Domain is reserved"); + + act(() => root.unmount()); + }); + it("opens the attachment menu and stages a document chip", () => { const { container, root } = mount([]); diff --git a/src/frontend/src/routes/__root.tsx b/src/frontend/src/routes/__root.tsx index 5c4710032..b0ce05ca4 100644 --- a/src/frontend/src/routes/__root.tsx +++ b/src/frontend/src/routes/__root.tsx @@ -1,6 +1,6 @@ import { createRootRoute, HeadContent, Outlet, Scripts } from "@tanstack/react-router"; import { TanStackRouterDevtools } from "@tanstack/react-router-devtools"; -import { Fragment, lazy, Suspense } from "react"; +import { lazy, Suspense, type ReactNode } from "react"; const Agentation = import.meta.env.DEV ? lazy(() => import("agentation").then((m) => ({ default: m.Agentation }))) @@ -33,8 +33,7 @@ export const Route = createRootRoute({ function RootComponent() { return ( - - + {import.meta.env.DEV && import.meta.env.VITE_E2E !== "1" && } {import.meta.env.DEV ? ( @@ -42,7 +41,20 @@ function RootComponent() { ) : null} - - + + ); +} + +function RootDocument({ children }: Readonly<{ children: ReactNode }>) { + return ( + + + + + + {children} + + + ); } diff --git a/tests/fixtures/mcp_echo_server.py b/tests/fixtures/mcp_echo_server.py new file mode 100644 index 000000000..1852e51ba --- /dev/null +++ b/tests/fixtures/mcp_echo_server.py @@ -0,0 +1,21 @@ +"""Minimal stdio MCP server used by MCP tool-provider tests. + +Run as ``python ``; exposes a single ``echo`` tool over stdio so the +runtime's :class:`MCPToolProvider` can be exercised end-to-end without a network. +""" + +from __future__ import annotations + +from mcp.server.fastmcp import FastMCP + +server = FastMCP("fleet-rlm-test-mcp") + + +@server.tool() +def echo(value: str) -> str: + """Return the input value prefixed with ``echo:`` for assertions.""" + return f"echo: {value}" + + +if __name__ == "__main__": + server.run() diff --git a/tests/unit/api/test_events.py b/tests/unit/api/test_events.py index a06db0445..c42791621 100644 --- a/tests/unit/api/test_events.py +++ b/tests/unit/api/test_events.py @@ -2,6 +2,7 @@ import asyncio import importlib +import logging import pytest @@ -56,6 +57,26 @@ async def test_execution_event_emitter_delivers_events_to_matching_subscribers() assert websocket.sent_payloads[0]["step"]["label"] == "Search code" # ty: ignore[not-subscriptable] +@pytest.mark.asyncio +async def test_execution_event_emitter_does_not_warn_per_event(caplog): + events_module = importlib.import_module("fleet_rlm.api.events") + + emitter = events_module.ExecutionEventEmitter() + event = events_module.ExecutionEvent( + type="execution_completed", + run_id="run-1", + workspace_id="workspace-a", + user_id="user-a", + session_id="session-a", + summary={"status": "ok"}, + ) + + with caplog.at_level(logging.WARNING): + await emitter.emit(event) + + assert "EMITTING EVENT" not in caplog.text + + @pytest.mark.asyncio async def test_execution_event_emitter_filters_non_matching_subscriptions(): events_module = importlib.import_module("fleet_rlm.api.events") diff --git a/tests/unit/runtime/test_escalating_module.py b/tests/unit/runtime/test_escalating_module.py index 30e50b8fe..68b212e2c 100644 --- a/tests/unit/runtime/test_escalating_module.py +++ b/tests/unit/runtime/test_escalating_module.py @@ -62,6 +62,19 @@ def preview_routing(self, *, user_request: str, execution_mode: str = "auto") -> } +class _AsyncReactAgent: + def __init__(self, prediction: dspy.Prediction) -> None: + self.prediction = prediction + self.acall_kwargs: dict[str, Any] | None = None + + def __call__(self, **_: Any) -> dspy.Prediction: + raise AssertionError("sync ReAct path should not be used by aforward") + + async def acall(self, **kwargs: Any) -> dspy.Prediction: + self.acall_kwargs = kwargs + return self.prediction + + class TestEscalatingFleetModule: def test_url_document_rlm_is_bounded_and_disables_child_tools( self, @@ -153,17 +166,39 @@ def test_cot_path_passes_recency_ordered_history_context(self) -> None: assert call_kwargs["recent_history"].rfind("NEW_MARKER") > call_kwargs["recent_history"].rfind("OLD_MARKER") assert "most recent prior turn" in call_kwargs["recent_history"] - def test_rlm_path_triggered_by_sentinel_in_reasoning(self) -> None: + def test_react_tool_branch_triggered_by_sentinel_in_reasoning(self) -> None: module = _make_module() - _stub_respond(module, reasoning=f"I need external data {ESCALATION_SENTINEL}", response="step1") - rlm_pred = _FakePrediction(answer="deep answer") - module._rlm = MagicMock(return_value=rlm_pred) + _stub_respond(module, reasoning=f"I need a tool {ESCALATION_SENTINEL}", response="step1") + react_pred = _FakePrediction(response="tool answer") + module._react = MagicMock(return_value=react_pred) + module._rlm = MagicMock() _stub_summarize(module) - result = module(user_request="Complex task", execution_mode="auto") + result = module(user_request="Use a tool", execution_mode="auto") module.respond.assert_called_once() - module._rlm.assert_called_once() - assert getattr(result, "answer", None) == "deep answer" + module._react.assert_called_once() + module._rlm.assert_not_called() + assert getattr(result, "response", None) == "tool answer" + assert result["routing_decision"] == "sentinel_react" + + @pytest.mark.asyncio + async def test_async_react_tool_branch_uses_acall_for_sentinel(self) -> None: + module = _make_module() + _stub_respond(module, reasoning=f"I need a tool {ESCALATION_SENTINEL}", response="step1") + react_agent = _AsyncReactAgent(_FakePrediction(response="async tool answer")) + module._react = react_agent # type: ignore[assignment] + module._rlm = MagicMock() + _stub_summarize(module) + + result = await module.aforward(user_request="Use an async tool", execution_mode="auto") + + module.respond.assert_called_once() + module._rlm.assert_not_called() + assert react_agent.acall_kwargs is not None + assert react_agent.acall_kwargs["user_message"] == "Use an async tool" + assert isinstance(react_agent.acall_kwargs["chat_history"], dspy.History) + assert getattr(result, "response", None) == "async tool answer" + assert result["routing_decision"] == "sentinel_react" def test_force_escalate_skips_cot(self) -> None: module = _make_module() @@ -283,18 +318,18 @@ def test_preview_routing_surfaces_url_document_route_before_execution(self) -> N "source_url": "https://dspy.ai", } - def test_rlm_fallback_to_cot_on_error(self) -> None: + def test_react_fallback_to_cot_on_error(self) -> None: module = _make_module() cot_pred = _FakePrediction(reasoning=ESCALATION_SENTINEL, assistant_response="cot_resp") module.respond = MagicMock(side_effect=[cot_pred, _FakePrediction(assistant_response="fallback")]) - module._rlm = MagicMock(side_effect=RuntimeError("RLM failed")) + module._react = MagicMock(side_effect=RuntimeError("ReAct failed")) _stub_summarize(module) result = module(user_request="query", execution_mode="auto") assert getattr(result, "assistant_response", None) == "fallback" assert result["runtime_degraded"] is True - assert result["runtime_failure_category"] == "rlm_fallback" - assert result["runtime_failure_phase"] == "escalating_rlm" + assert result["runtime_failure_category"] == "react_fallback" + assert result["runtime_failure_phase"] == "escalating_react" assert result["runtime_fallback_used"] is True assert result["runtime_warning"] diff --git a/tests/unit/runtime/test_execution.py b/tests/unit/runtime/test_execution.py index f4a912e34..51ddc46ad 100644 --- a/tests/unit/runtime/test_execution.py +++ b/tests/unit/runtime/test_execution.py @@ -27,36 +27,13 @@ def test_storage_paths_normalize_mount_layouts() -> None: assert interpreter_roots.memory_root == "/srv/runtime/memory" -def test_streaming_event_helpers_parse_tool_status_and_results() -> None: - from fleet_rlm.runtime.execution.streaming_events import ( - classify_tool_event_kind, - is_terminal_stream_event_kind, - parse_tool_call_payload, - parse_tool_call_status, - parse_tool_result_payload, - parse_tool_result_status, - ) +def test_is_terminal_stream_event_kind_classifies_terminal_kinds() -> None: + from fleet_rlm.runtime.execution.streaming_events import is_terminal_stream_event_kind assert is_terminal_stream_event_kind("done") is True + assert is_terminal_stream_event_kind("error") is True assert is_terminal_stream_event_kind("status") is False - assert classify_tool_event_kind("list_files") == "tool_call" - - assert parse_tool_call_status("Calling tool: list_files(path='src')") == "tool call: list_files(path='src')" - assert parse_tool_call_payload("Calling tool: list_files(path='src')") == { - "raw_status": "Calling tool: list_files(path='src')", - "raw_call": "list_files(path='src')", - "tool_name": "list_files", - "tool_args": "path='src'", - "tool_input": "path='src'", - } - - assert parse_tool_result_status("Tool finished.") == "tool result: finished" - assert parse_tool_result_status("Tool result: wrote file") == "tool result: completed" - assert parse_tool_result_payload("Tool result: wrote file", tool_name="write_file") == { - "raw_status": "Tool result: wrote file", - "tool_name": "write_file", - "tool_output": "wrote file", - } + assert is_terminal_stream_event_kind("text") is False def test_try_parse_hitl_request_builds_status_event() -> None: @@ -72,13 +49,13 @@ def test_try_parse_hitl_request_builds_status_event() -> None: ) assert clarification is not None - assert clarification.kind == "status" - assert clarification.payload["options"] == ["Which repo?", "Which branch?"] - assert clarification.payload["requires_response"] is True + assert clarification["kind"] == "status" + assert clarification["payload"]["options"] == ["Which repo?", "Which branch?"] + assert clarification["payload"]["requires_response"] is True assert memory_review is not None - assert memory_review.text == "This memory action requires confirmation." - assert memory_review.payload["action"] == "delete" + assert memory_review["text"] == "This memory action requires confirmation." + assert memory_review["payload"]["action"] == "delete" def test_normalize_trajectory_truncates_long_output_and_drops_terminal_thought() -> None: @@ -166,7 +143,7 @@ def test_build_final_payload_collects_sources_citations_and_human_review() -> No effective_max_iters=8, ) - assert payload["schema_version"] == 2 + assert payload["schema_version"] == 3 assert payload["history_turns"] == 4 assert payload["guardrail_warnings"] == ["watch output size"] assert payload["token_count"] == 7 diff --git a/tests/unit/runtime/test_mcp_tools.py b/tests/unit/runtime/test_mcp_tools.py new file mode 100644 index 000000000..d4f809da2 --- /dev/null +++ b/tests/unit/runtime/test_mcp_tools.py @@ -0,0 +1,166 @@ +"""Tests for the DSPy-native MCP tool provider. + +Covers config parsing and an end-to-end stdio round-trip against a real +FastMCP server fixture, asserting that discovered tools are async ``dspy.Tool`` +objects wired through ``dspy.Tool.from_mcp_tool``. +""" + +from __future__ import annotations + +import sys +from pathlib import Path +from typing import Any + +import dspy +import pytest + +from fleet_rlm.runtime.tools.mcp_tools import ( + MCP_SERVERS_ENV_VAR, + MCPServerConfig, + MCPToolProvider, + load_mcp_server_configs, +) + +_ECHO_SERVER = Path(__file__).resolve().parents[2] / "fixtures" / "mcp_echo_server.py" + + +def _tool(name: str) -> dspy.Tool: + return dspy.Tool( + func=lambda: name, + name=name, + desc=f"{name} tool", + args={}, + ) + + +def test_load_mcp_server_configs_empty_returns_empty(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delenv(MCP_SERVERS_ENV_VAR, raising=False) + assert load_mcp_server_configs() == [] + assert load_mcp_server_configs("") == [] + assert load_mcp_server_configs(" ") == [] + + +def test_load_mcp_server_configs_parses_env(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv( + MCP_SERVERS_ENV_VAR, + '[{"name": "demo", "command": "python", "args": ["-m", "srv"], "env": {"K": "V"}}]', + ) + configs = load_mcp_server_configs() + assert len(configs) == 1 + assert configs[0] == MCPServerConfig(name="demo", command="python", args=["-m", "srv"], env={"K": "V"}) + + +def test_load_mcp_server_configs_rejects_invalid_json() -> None: + with pytest.raises(ValueError, match="must be valid JSON"): + load_mcp_server_configs("{not json") + + +def test_load_mcp_server_configs_rejects_non_array() -> None: + with pytest.raises(ValueError, match="must be a JSON array"): + load_mcp_server_configs('{"name": "x"}') + + +def test_mcp_server_config_requires_name_and_command() -> None: + with pytest.raises(ValueError, match="non-empty 'name' and 'command'"): + MCPServerConfig.from_mapping({"name": "", "command": "python"}) + + +@pytest.mark.asyncio +async def test_provider_discovers_and_invokes_stdio_tool() -> None: + config = MCPServerConfig(name="echo-server", command=sys.executable, args=[str(_ECHO_SERVER)]) + provider = MCPToolProvider([config]) + try: + tools = await provider.connect() + assert provider.tools == tools + + by_name = {tool.name: tool for tool in tools} + assert "echo" in by_name + + echo_tool = by_name["echo"] + # MCP tools are async; invoke via acall. + result = await echo_tool.acall(value="ping") + assert "echo: ping" in str(result) + finally: + await provider.aclose() + assert provider.tools == [] + + +@pytest.mark.asyncio +async def test_provider_no_configs_returns_empty() -> None: + provider = MCPToolProvider([]) + assert await provider.connect() == [] + await provider.aclose() + + +@pytest.mark.asyncio +async def test_provider_async_context_manager() -> None: + config = MCPServerConfig(name="echo-server", command=sys.executable, args=[str(_ECHO_SERVER)]) + async with MCPToolProvider([config]) as provider: + assert any(tool.name == "echo" for tool in provider.tools) + + +@pytest.mark.asyncio +async def test_agent_runtime_mcp_reattach_replaces_previous_tools(monkeypatch: pytest.MonkeyPatch) -> None: + from fleet_rlm.runtime.agent import runtime as runtime_mod + from fleet_rlm.runtime.agent.runtime import AgentRuntime + from fleet_rlm.runtime.tools import mcp_tools as mcp_tools_mod + + class FakeProvider: + instances: list[FakeProvider] = [] + + def __init__(self, configs: list[Any]) -> None: + self.configs = configs + self.closed = False + FakeProvider.instances.append(self) + + async def connect(self) -> list[dspy.Tool]: + return [_tool(f"mcp_{self.configs[0]}")] + + async def aclose(self) -> None: + self.closed = True + + monkeypatch.setattr(runtime_mod, "discover_tools", lambda: [_tool("base_tool")]) + monkeypatch.setattr(mcp_tools_mod, "MCPToolProvider", FakeProvider) + + rt = AgentRuntime(use_escalation=False) + + assert await rt.attach_mcp_tools(configs=["first"]) == ["mcp_first"] + assert [tool.name for tool in rt.tools] == ["base_tool", "mcp_first"] + + assert await rt.attach_mcp_tools(configs=["second"]) == ["mcp_second"] + + assert FakeProvider.instances[0].closed is True + assert FakeProvider.instances[1].closed is False + assert [tool.name for tool in rt.tools] == ["base_tool", "mcp_second"] + + +@pytest.mark.asyncio +async def test_agent_runtime_ashutdown_closes_mcp_provider(monkeypatch: pytest.MonkeyPatch) -> None: + from fleet_rlm.runtime.agent import runtime as runtime_mod + from fleet_rlm.runtime.agent.runtime import AgentRuntime + from fleet_rlm.runtime.tools import mcp_tools as mcp_tools_mod + + class FakeProvider: + instances: list[FakeProvider] = [] + + def __init__(self, configs: list[Any]) -> None: + self.configs = configs + self.closed = False + FakeProvider.instances.append(self) + + async def connect(self) -> list[dspy.Tool]: + return [_tool(f"mcp_{self.configs[0]}")] + + async def aclose(self) -> None: + self.closed = True + + monkeypatch.setattr(runtime_mod, "discover_tools", lambda: [_tool("base_tool")]) + monkeypatch.setattr(mcp_tools_mod, "MCPToolProvider", FakeProvider) + + rt = AgentRuntime(use_escalation=False) + await rt.attach_mcp_tools(configs=["session"]) + + await rt.ashutdown() + + assert FakeProvider.instances[0].closed is True + assert [tool.name for tool in rt.tools] == ["base_tool"] diff --git a/tests/unit/runtime/test_modules.py b/tests/unit/runtime/test_modules.py index 7bba43b7f..c66b9bfa5 100644 --- a/tests/unit/runtime/test_modules.py +++ b/tests/unit/runtime/test_modules.py @@ -159,7 +159,7 @@ def test_runtime_module_registry_flags_and_signature_fields_are_stable() -> None assert RUNTIME_MODULE_REGISTRY["extract_from_logs"].variable_mode is True assert RUNTIME_MODULE_REGISTRY["grounded_answer"].variable_mode is False - assert set(RLMVariableSignature.input_fields) == {"task", "prompt"} + assert set(RLMVariableSignature.input_fields) == {"task", "prompt", "history"} assert set(RLMVariableSignature.output_fields) == {"answer"} assert {"query", "evidence_chunks", "response_style"} <= set(GroundedAnswerWithCitations.input_fields) assert {"answer", "citations", "confidence", "coverage_notes"} <= set(GroundedAnswerWithCitations.output_fields) diff --git a/tests/unit/runtime/test_native_streaming_contract.py b/tests/unit/runtime/test_native_streaming_contract.py new file mode 100644 index 000000000..151d963e8 --- /dev/null +++ b/tests/unit/runtime/test_native_streaming_contract.py @@ -0,0 +1,147 @@ +"""Characterization tests locking the native streaming StreamEvent contract. + +These tests pin the exact ``StreamEvent`` sequence the websocket layer emits +today through :meth:`AgentRuntime.aiter_chat_turn_stream` when the agent exposes +a streamable ReAct program (``FleetAgent``). They form the Phase 0 guardrail for +the dspy.ReAct migration: the migration must keep these sequences intact. + +The planner step is scripted so the assertions are deterministic and require no +live LM. The goal is to capture observable behaviour, not to test internals. +""" + +from __future__ import annotations + +from typing import Any + +import dspy +import pytest + + +def _disable_runtime_tool_discovery(monkeypatch: pytest.MonkeyPatch) -> None: + from fleet_rlm.runtime.agent import runtime as runtime_mod + + monkeypatch.setattr(runtime_mod, "discover_tools", lambda: []) + + +def _script_planner(react_program: Any, steps: list[dspy.Prediction]) -> None: + """Replace ``async_planner_step`` with a scripted, deterministic sequence.""" + + iterator = iter(steps) + + async def _scripted(trajectory: dict[str, Any], **_input: Any) -> dspy.Prediction: + _ = trajectory + return next(iterator) + + react_program.async_planner_step = _scripted # type: ignore[assignment] + + +def _pred(**kwargs: Any) -> dspy.Prediction: + pred = dspy.Prediction(**kwargs) + for key, value in kwargs.items(): + object.__setattr__(pred, key, value) + return pred + + +@pytest.mark.asyncio +async def test_tool_using_turn_emits_canonical_event_sequence( + monkeypatch: pytest.MonkeyPatch, +) -> None: + _disable_runtime_tool_discovery(monkeypatch) + from fleet_rlm.runtime.agent.runtime import AgentRuntime + + rt = AgentRuntime(use_escalation=False) + react_program = rt.agent + + def echo(value: str) -> str: + return f"echoed: {value}" + + react_program.tools["echo_tool"] = echo # type: ignore[index] + + _script_planner( + react_program, + [ + _pred( + next_thought="I should echo the value first.", + next_tool_name="echo_tool", + next_tool_args={"value": "hi"}, + ), + _pred( + next_thought="Here is your answer.", + next_tool_name="finish", + next_tool_args={}, + ), + ], + ) + + events = [event async for event in rt.aiter_chat_turn_stream("hello")] + kinds = [event.kind for event in events] + + # Canonical ordering: status -> reasoning -> tool_call -> tool_result -> text -> done + assert kinds[0] == "status" + assert events[0].text == "Starting turn..." + assert "reasoning" in kinds + assert kinds.index("tool_call") < kinds.index("tool_result") + assert kinds.index("tool_result") < kinds.index("text") + assert kinds[-1] == "done" + + tool_call = next(e for e in events if e.kind == "tool_call") + tool_result = next(e for e in events if e.kind == "tool_result") + assert tool_call.payload["tool_name"] == "echo_tool" + assert tool_result.payload["tool_name"] == "echo_tool" + + text = next(e for e in events if e.kind == "text") + assert text.text == "Here is your answer." + + done = events[-1] + assert done.text == "Here is your answer." + assert "trajectory" in done.payload + assert "history_turns" in done.payload + + +@pytest.mark.asyncio +async def test_finish_first_turn_skips_tool_and_extract( + monkeypatch: pytest.MonkeyPatch, +) -> None: + _disable_runtime_tool_discovery(monkeypatch) + from fleet_rlm.runtime.agent.runtime import AgentRuntime + + rt = AgentRuntime(use_escalation=False) + react_program = rt.agent + + _script_planner( + react_program, + [ + _pred( + next_thought="Direct answer, no tools needed.", + next_tool_name="finish", + next_tool_args={}, + ), + ], + ) + + events = [event async for event in rt.aiter_chat_turn_stream("hi")] + kinds = [event.kind for event in events] + + assert kinds[0] == "status" + assert "tool_call" not in kinds + assert "tool_result" not in kinds + text = next(e for e in events if e.kind == "text") + assert text.text == "Direct answer, no tools needed." + assert kinds[-1] == "done" + assert events[-1].text == "Direct answer, no tools needed." + + +@pytest.mark.asyncio +async def test_cancel_before_turn_emits_cancelled_done( + monkeypatch: pytest.MonkeyPatch, +) -> None: + _disable_runtime_tool_discovery(monkeypatch) + from fleet_rlm.runtime.agent.runtime import AgentRuntime + + rt = AgentRuntime(use_escalation=False) + + events = [event async for event in rt.aiter_chat_turn_stream("hi", cancel_check=lambda: True)] + + assert len(events) == 1 + assert events[0].kind == "done" + assert events[0].payload.get("cancelled") is True From e4b2634216c519261bcda2f57b99a1d403c11744 Mon Sep 17 00:00:00 2001 From: Zachary BENSALEM Date: Sat, 6 Jun 2026 23:12:58 +0200 Subject: [PATCH 4/7] codex: fix CI failure on PR #271 --- src/fleet_rlm/api/events/project_graph.py | 11 +- .../observability/mlflow_context.py | 16 +- src/fleet_rlm/runtime/agent/runtime.py | 29 ++- src/fleet_rlm/runtime/modules/factory.py | 4 +- .../agent-elements/input/model-picker.tsx | 31 +-- ...space-message-list.agent-elements.test.tsx | 50 ++-- .../transcript/workspace-message-list.tsx | 36 +-- .../workspace-agent-input-bar.tsx | 66 ++--- .../workspace/screen/workspace-screen.tsx | 28 +-- .../backend-chat-event-adapter.test.ts | 235 ++++-------------- .../workspace/backend-chat-event-adapter.ts | 166 +++---------- .../backend-chat-event-trajectory.ts | 25 +- 12 files changed, 200 insertions(+), 497 deletions(-) diff --git a/src/fleet_rlm/api/events/project_graph.py b/src/fleet_rlm/api/events/project_graph.py index 9b86eda3b..9c1134c02 100644 --- a/src/fleet_rlm/api/events/project_graph.py +++ b/src/fleet_rlm/api/events/project_graph.py @@ -17,7 +17,13 @@ from fleet_rlm.runtime.events import RuntimeEvent, RuntimeEventKind from .sanitizer import sanitize_event_payload -from .step_builder_extractors import ExecutionStepType, _derive_lane_key, _tool_step_type +from .step_builder_extractors import ( + ExecutionActorKind, + ExecutionStepType, + _derive_lane_key, + _map_actor_kind_text, + _tool_step_type, +) if TYPE_CHECKING: from .step_builder import ExecutionStepBuilder @@ -118,7 +124,8 @@ def project_graph(event: RuntimeEvent, builder: ExecutionStepBuilder) -> Any: actor_id = ctx.actor_id if ctx else None parent_id_hint = ctx.parent_id if ctx else None - actor_kind: str = actor_kind_raw or "unknown" + mapped_actor_kind = _map_actor_kind_text(actor_kind_raw) if actor_kind_raw else None + actor_kind: ExecutionActorKind = mapped_actor_kind or "unknown" if actor_kind == "unknown" and depth is None: actor_kind = "root_rlm" depth = 0 diff --git a/src/fleet_rlm/integrations/observability/mlflow_context.py b/src/fleet_rlm/integrations/observability/mlflow_context.py index cfb29dc07..e4dca0b82 100644 --- a/src/fleet_rlm/integrations/observability/mlflow_context.py +++ b/src/fleet_rlm/integrations/observability/mlflow_context.py @@ -9,7 +9,7 @@ from contextlib import contextmanager from dataclasses import dataclass, field from threading import Lock -from typing import Any +from typing import Any, cast from fleet_rlm.integrations.config._env_utils import env_bool as _env_bool @@ -152,9 +152,15 @@ def _flat_trajectory_indices(raw: dict[str, Any]) -> list[int]: return sorted(indices) +def _coerce_step_dict(step: dict[Any, Any], index: int) -> dict[str, Any]: + step_dict = dict(cast(dict[str, Any], step)) + step_dict.setdefault("index", index) + return step_dict + + def _coerce_trajectory_steps(raw: Any) -> list[dict[str, Any]]: if isinstance(raw, list): - return [dict(step, index=step.get("index", index)) for index, step in enumerate(raw) if isinstance(step, dict)] + return [_coerce_step_dict(step, index) for index, step in enumerate(raw) if isinstance(step, dict)] if not isinstance(raw, dict): return [] @@ -162,11 +168,7 @@ def _coerce_trajectory_steps(raw: Any) -> list[dict[str, Any]]: for key in ("trajectory", "steps"): nested = raw.get(key) if isinstance(nested, list): - return [ - dict(step, index=step.get("index", index)) - for index, step in enumerate(nested) - if isinstance(step, dict) - ] + return [_coerce_step_dict(step, index) for index, step in enumerate(nested) if isinstance(step, dict)] steps: list[dict[str, Any]] = [] for index in _flat_trajectory_indices(raw): diff --git a/src/fleet_rlm/runtime/agent/runtime.py b/src/fleet_rlm/runtime/agent/runtime.py index cc8d9b0dd..eef24a8e7 100644 --- a/src/fleet_rlm/runtime/agent/runtime.py +++ b/src/fleet_rlm/runtime/agent/runtime.py @@ -23,7 +23,7 @@ from fleet_rlm.runtime.execution.streaming_events import ( _normalize_trajectory, ) -from fleet_rlm.runtime.schemas import StreamEvent +from fleet_rlm.runtime.schemas import StreamEvent, StreamEventKind from fleet_rlm.runtime.tools import discover_tools from fleet_rlm.runtime.tools.binding import bind_runtime_tools, execute_sandbox_tool @@ -195,6 +195,15 @@ async def _call_react_tool(tool: Any, tool_args: dict[str, Any]) -> Any: return await asyncio.to_thread(tool, **tool_args) +def _stream_event_from_runtime_event(event: RuntimeEvent) -> StreamEvent: + return StreamEvent( + kind=cast(StreamEventKind, event.kind.value), + text=event.text, + payload=dict(event.payload), + timestamp=event.timestamp, + ) + + def _build_tool_call_event(*, tool_name: str, tool_args: dict[str, Any], step_index: int) -> RuntimeEvent: return RuntimeEvent.tool_call( tool_name=tool_name, @@ -826,7 +835,7 @@ async def aiter_chat_turn_stream( message=message, cancel_check=cancel_check, ): - yield event + yield _stream_event_from_runtime_event(event) return logger.info("streaming_path=native (dspy.streamify per-token streaming)") @@ -910,7 +919,9 @@ async def aiter_chat_turn_stream( break tool = react_program.tools[tool_name] - yield _build_tool_call_event(tool_name=tool_name, tool_args=tool_args, step_index=step_index) + yield _stream_event_from_runtime_event( + _build_tool_call_event(tool_name=tool_name, tool_args=tool_args, step_index=step_index) + ) try: observation = await _call_react_tool(tool, tool_args) @@ -920,15 +931,17 @@ async def aiter_chat_turn_stream( trajectory_raw[f"observation_{step_index}"] = observation if recursive_child_review is None: recursive_child_review = _recursive_child_review_payload(tool_name, observation) - yield _build_tool_result_event( - tool_name=tool_name, - observation=observation, - step_index=step_index, + yield _stream_event_from_runtime_event( + _build_tool_result_event( + tool_name=tool_name, + observation=observation, + step_index=step_index, + ) ) clarification_event = _build_clarification_event(observation) if clarification_event is not None: - yield clarification_event + yield _stream_event_from_runtime_event(clarification_event) # Fast path: skip the extract LLM call when the agent finished # with a finish tool or no tool. The planner thought already diff --git a/src/fleet_rlm/runtime/modules/factory.py b/src/fleet_rlm/runtime/modules/factory.py index 91a148fec..45f3d7a2e 100644 --- a/src/fleet_rlm/runtime/modules/factory.py +++ b/src/fleet_rlm/runtime/modules/factory.py @@ -7,8 +7,10 @@ import dspy +_DSPY_RLM_BASE: Any = dspy.RLM -class _NoCallbackRLM(dspy.RLM): + +class _NoCallbackRLM(_DSPY_RLM_BASE): """RLM variant for REPL-only tasks where host semantic callbacks are disabled.""" def _build_signatures(self) -> tuple[Any, Any]: diff --git a/src/frontend/src/components/agent-elements/input/model-picker.tsx b/src/frontend/src/components/agent-elements/input/model-picker.tsx index 785f25cb4..3d29d6adb 100644 --- a/src/frontend/src/components/agent-elements/input/model-picker.tsx +++ b/src/frontend/src/components/agent-elements/input/model-picker.tsx @@ -4,11 +4,7 @@ import { memo, useMemo, useState } from "react"; import { Check, ChevronDown, Cpu, Settings2 } from "lucide-react"; import { Button } from "@/components/ui/button"; -import { - Popover, - PopoverContent, - PopoverTrigger, -} from "@/components/ui/popover"; +import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; import { cn } from "../utils/cn"; @@ -44,15 +40,9 @@ export const ModelPicker = memo(function ModelPicker({ const [internalValue, setInternalValue] = useState(defaultValue); const [open, setOpen] = useState(false); const activeId = isControlled ? value : internalValue; - const enabledModels = useMemo( - () => models.filter((model) => !model.disabled), - [models], - ); + const enabledModels = useMemo(() => models.filter((model) => !model.disabled), [models]); const activeModel = - models.find((model) => model.id === activeId) ?? - enabledModels[0] ?? - models[0] ?? - null; + models.find((model) => model.id === activeId) ?? enabledModels[0] ?? models[0] ?? null; if (!activeModel) return null; @@ -63,8 +53,7 @@ export const ModelPicker = memo(function ModelPicker({ aria-label={`Active model: ${activeModel.label}`} className={cn( "inline-flex h-7 min-w-0 max-w-45 items-center gap-1.5 rounded-[6px] px-2 text-[12px] leading-4 text-foreground/60 transition-colors hover:bg-foreground/6 hover:text-foreground", - disabled && - "cursor-not-allowed opacity-60 hover:bg-transparent hover:text-foreground/60", + disabled && "cursor-not-allowed opacity-60 hover:bg-transparent hover:text-foreground/60", className, )} disabled={disabled} @@ -106,18 +95,12 @@ export const ModelPicker = memo(function ModelPicker({ > - - {model.label} - + {model.label} {model.description ? ( - - {model.description} - + {model.description} ) : null} - {isActive && ( - - )} + {isActive && } ); })} diff --git a/src/frontend/src/features/workspace/conversation/transcript/__tests__/workspace-message-list.agent-elements.test.tsx b/src/frontend/src/features/workspace/conversation/transcript/__tests__/workspace-message-list.agent-elements.test.tsx index 4a7fd6293..b1588b002 100644 --- a/src/frontend/src/features/workspace/conversation/transcript/__tests__/workspace-message-list.agent-elements.test.tsx +++ b/src/frontend/src/features/workspace/conversation/transcript/__tests__/workspace-message-list.agent-elements.test.tsx @@ -103,8 +103,8 @@ describe("WorkspaceMessageList Agent Elements integration", () => { { onResolveHitl }, ); - const approveButton = Array.from(container.querySelectorAll("button")).find( - (button) => button.textContent?.includes("Approve"), + const approveButton = Array.from(container.querySelectorAll("button")).find((button) => + button.textContent?.includes("Approve"), ); expect(approveButton).toBeTruthy(); @@ -112,8 +112,8 @@ describe("WorkspaceMessageList Agent Elements integration", () => { approveButton?.dispatchEvent(new MouseEvent("click", { bubbles: true })); }); - const sendButton = Array.from(container.querySelectorAll("button")).find( - (button) => button.textContent?.includes("Send"), + const sendButton = Array.from(container.querySelectorAll("button")).find((button) => + button.textContent?.includes("Send"), ); expect(sendButton).toBeTruthy(); @@ -141,9 +141,7 @@ describe("WorkspaceMessageList Agent Elements integration", () => { }, { kind: "reasoning", - parts: [ - { type: "text", text: "I should inspect the repository files." }, - ], + parts: [{ type: "text", text: "I should inspect the repository files." }], isStreaming: false, }, { @@ -289,24 +287,21 @@ describe("WorkspaceMessageList Agent Elements integration", () => { attachButton?.dispatchEvent(new MouseEvent("click", { bubbles: true })); }); - const addDocumentButton = Array.from( - document.body.querySelectorAll("button"), - ).find((button) => button.textContent?.includes("Add document")); - const connectorsButton = Array.from( - document.body.querySelectorAll("button"), - ).find((button) => button.textContent?.includes("Connectors")); + const addDocumentButton = Array.from(document.body.querySelectorAll("button")).find((button) => + button.textContent?.includes("Add document"), + ); + const connectorsButton = Array.from(document.body.querySelectorAll("button")).find((button) => + button.textContent?.includes("Connectors"), + ); expect(addDocumentButton).toBeTruthy(); expect(connectorsButton).toBeTruthy(); expect(connectorsButton).toHaveProperty("disabled", true); act(() => { - addDocumentButton?.dispatchEvent( - new MouseEvent("click", { bubbles: true }), - ); + addDocumentButton?.dispatchEvent(new MouseEvent("click", { bubbles: true })); }); - const fileInput = - container.querySelector('input[type="file"]'); + const fileInput = container.querySelector('input[type="file"]'); expect(fileInput).toBeTruthy(); Object.defineProperty(fileInput, "files", { configurable: true, @@ -335,9 +330,7 @@ describe("WorkspaceMessageList Agent Elements integration", () => { expect(container.textContent).toContain("openai/gemini-3-flash-preview"); - const modelButton = container.querySelector( - 'button[aria-label^="Active model"]', - ); + const modelButton = container.querySelector('button[aria-label^="Active model"]'); expect(modelButton).toBeTruthy(); act(() => { @@ -346,9 +339,9 @@ describe("WorkspaceMessageList Agent Elements integration", () => { expect(document.body.textContent).toContain("openai/gemini-3-pro-preview"); - const settingsButton = Array.from( - document.body.querySelectorAll("button"), - ).find((button) => button.textContent?.includes("Model settings")); + const settingsButton = Array.from(document.body.querySelectorAll("button")).find((button) => + button.textContent?.includes("Model settings"), + ); expect(settingsButton).toBeTruthy(); act(() => { @@ -361,12 +354,9 @@ describe("WorkspaceMessageList Agent Elements integration", () => { }); it("renders the pending planning loader without a lazy component crash", async () => { - const { container, root } = mount( - [{ id: "u1", type: "user", content: "start working" }], - { - isTyping: true, - }, - ); + const { container, root } = mount([{ id: "u1", type: "user", content: "start working" }], { + isTyping: true, + }); await act(async () => { await new Promise((resolve) => setTimeout(resolve, 0)); diff --git a/src/frontend/src/features/workspace/conversation/transcript/workspace-message-list.tsx b/src/frontend/src/features/workspace/conversation/transcript/workspace-message-list.tsx index 167a416ee..c7f9f5037 100644 --- a/src/frontend/src/features/workspace/conversation/transcript/workspace-message-list.tsx +++ b/src/frontend/src/features/workspace/conversation/transcript/workspace-message-list.tsx @@ -70,9 +70,7 @@ export function WorkspaceMessageList({ showStatusBar = true, className, }: WorkspaceMessageListProps) { - const selectedAssistantTurnId = useWorkspaceUiStore( - (state) => state.selectedAssistantTurnId, - ); + const selectedAssistantTurnId = useWorkspaceUiStore((state) => state.selectedAssistantTurnId); const agentMessages = useMemo( () => toAgentChatMessages(messages, { @@ -81,11 +79,8 @@ export function WorkspaceMessageList({ }), [messages, onResolveClarification, onResolveHitl], ); - const lastUserIndex = messages.findLastIndex( - (message) => message.type === "user", - ); - const lastUserMessageId = - lastUserIndex >= 0 ? (messages[lastUserIndex]?.id ?? null) : null; + const lastUserIndex = messages.findLastIndex((message) => message.type === "user"); + const lastUserMessageId = lastUserIndex >= 0 ? (messages[lastUserIndex]?.id ?? null) : null; const activeTurnAssistantMessageId = lastUserIndex >= 0 ? (messages @@ -97,19 +92,11 @@ export function WorkspaceMessageList({ useEffect(() => { if (!selectedAssistantTurnId || !lastUserMessageId) return; const pendingTurnId = buildPendingAssistantTurnId(lastUserMessageId); - if ( - selectedAssistantTurnId !== pendingTurnId || - !activeTurnAssistantMessageId - ) - return; + if (selectedAssistantTurnId !== pendingTurnId || !activeTurnAssistantMessageId) return; useWorkspaceUiStore.setState({ selectedAssistantTurnId: activeTurnAssistantMessageId, }); - }, [ - activeTurnAssistantMessageId, - lastUserMessageId, - selectedAssistantTurnId, - ]); + }, [activeTurnAssistantMessageId, lastUserMessageId, selectedAssistantTurnId]); const status = chatStatus(isTyping); const inputSlot = useMemo( @@ -155,9 +142,7 @@ export function WorkspaceMessageList({ const selectedLabels = (answer.selectedIds ?? []).map( (id) => question.options?.find((option) => option.id === id)?.label ?? id, ); - const text = [selectedLabels.join(", "), answer.text] - .filter(Boolean) - .join(" - "); + const text = [selectedLabels.join(", "), answer.text].filter(Boolean).join(" - "); if (!text) return; if (target?.type === "hitl") { onResolveHitl(toolCallId, text); @@ -170,14 +155,9 @@ export function WorkspaceMessageList({ if (messages.length === 0 && showEmptyState) { return ( -
+
- +
) => { - const files = Array.from(event.currentTarget.files ?? []); - if (files.length > 0) { - setStagedDocuments((current) => [ - ...current, - ...files.map((file) => ({ - id: `document-${createAttachmentId()}`, - filename: file.name, - size: file.size, - })), - ]); - } - event.currentTarget.value = ""; - }, - [], - ); + const handleDocumentInputChange = useCallback((event: ChangeEvent) => { + const files = Array.from(event.currentTarget.files ?? []); + if (files.length > 0) { + setStagedDocuments((current) => [ + ...current, + ...files.map((file) => ({ + id: `document-${createAttachmentId()}`, + filename: file.name, + size: file.size, + })), + ]); + } + event.currentTarget.value = ""; + }, []); const handleRemoveFile = useCallback( (id: string) => { @@ -164,9 +147,7 @@ export function WorkspaceAgentInputBar({ ); return ( -
+
- - {runtimeWarning.title} - + {runtimeWarning.title}
@@ -218,10 +197,7 @@ export function WorkspaceAgentInputBar({ onSend={handleSend} rightActions={ <> - + 0 ? inferredContextPaths : undefined, + repoRef: inferredRepoContext?.repoRefCandidate ?? inferredRepoContext?.repoRef, + contextPaths: inferredContextPaths.length > 0 ? inferredContextPaths : undefined, }); }, [ @@ -142,15 +140,11 @@ export function WorkspaceScreen() { ], ); - const { - sessionRevision, - requestedConversationId, - clearRequestedConversation, - } = useWorkspaceUiStore(); + const { sessionRevision, requestedConversationId, clearRequestedConversation } = + useWorkspaceUiStore(); // Chat history - const { saveConversation, loadConversation: loadConv } = - useChatHistoryStore(); + const { saveConversation, loadConversation: loadConv } = useChatHistoryStore(); // ── Auto-save on session change ────────────────────────────────── // When sessionRevision increments (newSession() called), save the current @@ -194,16 +188,8 @@ export function WorkspaceScreen() { return; } - if ( - messagesRef.current.length > 0 && - messagesRef.current !== conversation.messages - ) { - saveConversation( - messagesRef.current, - phaseRef.current, - undefined, - turnArtifactsRef.current, - ); + if (messagesRef.current.length > 0 && messagesRef.current !== conversation.messages) { + saveConversation(messagesRef.current, phaseRef.current, undefined, turnArtifactsRef.current); } loadConversation(conversation); diff --git a/src/frontend/src/lib/workspace/__tests__/backend-chat-event-adapter.test.ts b/src/frontend/src/lib/workspace/__tests__/backend-chat-event-adapter.test.ts index 8aa87bb71..a128e35a8 100644 --- a/src/frontend/src/lib/workspace/__tests__/backend-chat-event-adapter.test.ts +++ b/src/frontend/src/lib/workspace/__tests__/backend-chat-event-adapter.test.ts @@ -1,23 +1,11 @@ import { describe, expect, it, vi } from "vite-plus/test"; import { QueryClient } from "@tanstack/react-query"; import { applyWsFrameToMessages } from "@/lib/workspace/backend-chat-event-adapter"; -import type { - ChatMessage, - ChatRenderPart, -} from "@/lib/workspace/workspace-types"; +import type { ChatMessage, ChatRenderPart } from "@/lib/workspace/workspace-types"; import type { WsServerMessage } from "@/lib/rlm-api"; -function makeEvent( - kind: string, - text: string, - payload?: Record, -): WsServerMessage { - if ( - kind === "done" || - kind === "turn_completed" || - kind === "error" || - kind === "turn_failed" - ) { +function makeEvent(kind: string, text: string, payload?: Record): WsServerMessage { + if (kind === "done" || kind === "turn_completed" || kind === "error" || kind === "turn_failed") { return { type: "event", data: { @@ -27,23 +15,14 @@ function makeEvent( ...payload, source_type: "execution_completed", run_summary: { - status: - kind === "error" || kind === "turn_failed" - ? "failed" - : "completed", + status: kind === "error" || kind === "turn_failed" ? "failed" : "completed", }, }, }, }; } - if ( - kind === "text" || - kind === "reasoning" || - kind === "tool_call" || - kind === "tool_result" - ) { - const stepType = - kind === "tool_call" || kind === "tool_result" ? "tool" : "llm"; + if (kind === "text" || kind === "reasoning" || kind === "tool_call" || kind === "tool_result") { + const stepType = kind === "tool_call" || kind === "tool_result" ? "tool" : "llm"; return { type: "event", data: { @@ -56,12 +35,7 @@ function makeEvent( type: stepType, label: text, input: kind === "tool_call" ? text : undefined, - output: - kind === "text" - ? { text } - : kind === "tool_result" - ? text - : undefined, + output: kind === "text" ? { text } : kind === "tool_result" ? text : undefined, ...payload, }, }, @@ -180,15 +154,12 @@ describe("applyWsFrameToMessages", () => { const reasoningRows = traceRows( messages, - (part, message) => - part.kind === "reasoning" && message.traceSource === "live", + (part, message) => part.kind === "reasoning" && message.traceSource === "live", ); expect(reasoningRows).toHaveLength(2); expect( - reasoningRows.map((row) => - row.part.kind === "reasoning" ? row.part.parts[0]?.text : "", - ), + reasoningRows.map((row) => (row.part.kind === "reasoning" ? row.part.parts[0]?.text : "")), ).toEqual(["Analyzing input ", "and checking constraints"]); }); @@ -211,10 +182,7 @@ describe("applyWsFrameToMessages", () => { }), ); - const reasoning = findFirstPart( - messages, - (part) => part.kind === "reasoning", - ); + const reasoning = findFirstPart(messages, (part) => part.kind === "reasoning"); expect(reasoning).toBeDefined(); if (reasoning?.kind === "reasoning") { expect(reasoning.runtimeContext).toEqual({ @@ -240,10 +208,7 @@ describe("applyWsFrameToMessages", () => { }), ); - const reasoning = findFirstPart( - messages, - (part) => part.kind === "reasoning", - ); + const reasoning = findFirstPart(messages, (part) => part.kind === "reasoning"); expect(reasoning).toBeDefined(); if (reasoning?.kind === "reasoning") { expect(reasoning.label).toBe("prompt_iter_1"); @@ -303,8 +268,7 @@ describe("applyWsFrameToMessages", () => { const liveReasoning = traceRows( messages, - (part, message) => - part.kind === "reasoning" && message.traceSource === "trajectory", + (part, message) => part.kind === "reasoning" && message.traceSource === "trajectory", ); expect(liveReasoning).toHaveLength(1); const reasoningPart = liveReasoning[0]?.part; @@ -313,10 +277,7 @@ describe("applyWsFrameToMessages", () => { expect(reasoningPart.parts[0]?.text).toBe("Inspect the repo first."); } - const cot = findFirstPart( - messages, - (part) => part.kind === "chain_of_thought", - ); + const cot = findFirstPart(messages, (part) => part.kind === "chain_of_thought"); if (cot?.kind === "chain_of_thought") { expect(cot.steps[0]?.body).toBe("Inspect the repo first."); } @@ -324,10 +285,7 @@ describe("applyWsFrameToMessages", () => { it.skip("suppresses trajectory fallback primary rows when live trace already exists", () => { let messages: ChatMessage[] = []; - messages = applyWsFrameToMessages( - messages, - makeEvent("reasoning", "Live reasoning"), - ).messages; + messages = applyWsFrameToMessages(messages, makeEvent("reasoning", "Live reasoning")).messages; messages = applyWsFrameToMessages( messages, @@ -347,16 +305,10 @@ describe("applyWsFrameToMessages", () => { ); expect(trajectoryPrimary).toHaveLength(0); - const reasoningRows = traceRows( - messages, - (part) => part.kind === "reasoning", - ); + const reasoningRows = traceRows(messages, (part) => part.kind === "reasoning"); expect(reasoningRows).toHaveLength(1); - const cot = findFirstPart( - messages, - (part) => part.kind === "chain_of_thought", - ); + const cot = findFirstPart(messages, (part) => part.kind === "chain_of_thought"); expect(cot).toBeDefined(); if (cot?.kind === "chain_of_thought") { expect(cot.steps).toHaveLength(1); @@ -381,8 +333,7 @@ describe("applyWsFrameToMessages", () => { const trajectoryReasoning = traceRows( messages, - (part, message) => - part.kind === "reasoning" && message.traceSource === "trajectory", + (part, message) => part.kind === "reasoning" && message.traceSource === "trajectory", ); expect( trajectoryReasoning.map((row) => @@ -393,10 +344,7 @@ describe("applyWsFrameToMessages", () => { const tools = findAllParts(messages, (part) => part.kind === "tool"); expect(tools).toHaveLength(0); - const cot = findFirstPart( - messages, - (part) => part.kind === "chain_of_thought", - ); + const cot = findFirstPart(messages, (part) => part.kind === "chain_of_thought"); expect(cot).toBeDefined(); if (cot?.kind === "chain_of_thought") { expect(cot.steps).toHaveLength(2); @@ -427,10 +375,7 @@ describe("applyWsFrameToMessages", () => { }), ).messages; - const cot = findFirstPart( - messages, - (part) => part.kind === "chain_of_thought", - ); + const cot = findFirstPart(messages, (part) => part.kind === "chain_of_thought"); expect(cot).toBeDefined(); if (cot?.kind === "chain_of_thought") { expect(cot.steps.map((step) => step.index)).toEqual([0, 1]); @@ -441,10 +386,7 @@ describe("applyWsFrameToMessages", () => { it("keeps exact interleaved order for reasoning and tool events", () => { let messages: ChatMessage[] = []; - messages = applyWsFrameToMessages( - messages, - makeEvent("reasoning", "r1"), - ).messages; + messages = applyWsFrameToMessages(messages, makeEvent("reasoning", "r1")).messages; messages = applyWsFrameToMessages( messages, makeEvent("tool_call", "call", { @@ -452,10 +394,7 @@ describe("applyWsFrameToMessages", () => { tool_args: { pattern: "foo" }, }), ).messages; - messages = applyWsFrameToMessages( - messages, - makeEvent("reasoning", "r2"), - ).messages; + messages = applyWsFrameToMessages(messages, makeEvent("reasoning", "r2")).messages; messages = applyWsFrameToMessages( messages, makeEvent("tool_result", "result", { @@ -463,18 +402,13 @@ describe("applyWsFrameToMessages", () => { tool_output: "match", }), ).messages; - messages = applyWsFrameToMessages( - messages, - makeEvent("reasoning", "r3"), - ).messages; + messages = applyWsFrameToMessages(messages, makeEvent("reasoning", "r3")).messages; const primaryRows = traceRows( messages, (part, message) => message.traceSource === "live" && - (part.kind === "reasoning" || - part.kind === "tool" || - part.kind === "sandbox"), + (part.kind === "reasoning" || part.kind === "tool" || part.kind === "sandbox"), ); expect(primaryRows.map((row) => row.part.kind)).toEqual([ @@ -487,10 +421,7 @@ describe("applyWsFrameToMessages", () => { const toolRows = primaryRows.filter((row) => row.part.kind === "tool"); expect(toolRows).toHaveLength(2); - if ( - toolRows[0]?.part.kind === "tool" && - toolRows[1]?.part.kind === "tool" - ) { + if (toolRows[0]?.part.kind === "tool" && toolRows[1]?.part.kind === "tool") { expect(toolRows[0].part.state).toBe("running"); expect(toolRows[1].part.state).toBe("output-available"); } @@ -499,10 +430,7 @@ describe("applyWsFrameToMessages", () => { it.skip("maps status, rlm_delegate, status to task rows in order", () => { let messages: ChatMessage[] = []; - messages = applyWsFrameToMessages( - messages, - makeEvent("status", "Moving to step 2"), - ).messages; + messages = applyWsFrameToMessages(messages, makeEvent("status", "Moving to step 2")).messages; messages = applyWsFrameToMessages( messages, makeEvent("rlm_delegate", "Delegating", { @@ -520,9 +448,7 @@ describe("applyWsFrameToMessages", () => { ); expect(taskRows).toHaveLength(3); - const taskTitles = taskRows.map((row) => - row.part.kind === "task" ? row.part.title : "", - ); + const taskTitles = taskRows.map((row) => (row.part.kind === "task" ? row.part.title : "")); expect(taskTitles).toEqual([ "Plan update", "Executing PythonInterpreter", @@ -537,9 +463,7 @@ describe("applyWsFrameToMessages", () => { const queue = findFirstPart(messages, (p) => p.kind === "queue"); expect(queue).toBeDefined(); if (queue?.kind === "queue") { - expect(queue.items[queue.items.length - 1]?.label).toBe( - "Moving to step 2", - ); + expect(queue.items[queue.items.length - 1]?.label).toBe("Moving to step 2"); } }); @@ -645,16 +569,11 @@ describe("applyWsFrameToMessages", () => { expect(sandbox.stepIndex).toBe(2); expect(sandbox.output).toBe("loading repository metadata"); expect(sandbox.runtimeContext?.runtimeMode).toBe("daytona_pilot"); - expect(sandbox.runtimeContext?.workspacePath).toBe( - "/workspace/workspace/repo", - ); + expect(sandbox.runtimeContext?.workspacePath).toBe("/workspace/workspace/repo"); expect(sandbox.runtimeContext?.sandboxTransition).toBe("created"); } - const statusNote = findFirstPart( - messages, - (part) => part.kind === "status_note", - ); + const statusNote = findFirstPart(messages, (part) => part.kind === "status_note"); expect(statusNote).toBeUndefined(); }); @@ -742,10 +661,7 @@ describe("applyWsFrameToMessages", () => { }), ).messages; - const reasoningRows = traceRows( - messages, - (part) => part.kind === "reasoning", - ); + const reasoningRows = traceRows(messages, (part) => part.kind === "reasoning"); const toolRows = traceRows(messages, (part) => part.kind === "tool"); expect(reasoningRows).toHaveLength(1); @@ -822,15 +738,13 @@ describe("applyWsFrameToMessages", () => { const summaryReasoning = traceRows( messages, - (part, message) => - part.kind === "reasoning" && message.traceSource === "summary", + (part, message) => part.kind === "reasoning" && message.traceSource === "summary", ); expect(summaryReasoning).toHaveLength(1); const sandbox = traceRows( messages, - (part, message) => - part.kind === "sandbox" && message.traceSource === "summary", + (part, message) => part.kind === "sandbox" && message.traceSource === "summary", )[0]?.part; expect(sandbox).toBeDefined(); if (sandbox?.kind === "sandbox") { @@ -912,14 +826,11 @@ describe("applyWsFrameToMessages", () => { const summarySandboxRows = traceRows( messages, - (part, message) => - part.kind === "sandbox" && message.traceSource === "summary", + (part, message) => part.kind === "sandbox" && message.traceSource === "summary", ); expect(summarySandboxRows).toHaveLength(1); const sandbox = summarySandboxRows[0]?.part; - expect(sandbox?.kind === "sandbox" ? sandbox.output : "").toContain( - "current", - ); + expect(sandbox?.kind === "sandbox" ? sandbox.output : "").toContain("current"); }); it("renders selected skills and routing decisions as compact status rows", () => { @@ -950,10 +861,7 @@ describe("applyWsFrameToMessages", () => { }), ); - const env = findFirstPart( - messages, - (p) => p.kind === "environment_variables", - ); + const env = findFirstPart(messages, (p) => p.kind === "environment_variables"); expect(env).toBeDefined(); if (env?.kind === "environment_variables") { expect(env.variables.map((v) => v.name)).toContain("OPENAI_API_KEY"); @@ -984,14 +892,8 @@ describe("applyWsFrameToMessages", () => { it("final finalizes trace summaries and attaches citations/sources/attachments", () => { let messages: ChatMessage[] = []; - messages = applyWsFrameToMessages( - messages, - makeEvent("text", "Hello"), - ).messages; - messages = applyWsFrameToMessages( - messages, - makeEvent("reasoning", "Thinking"), - ).messages; + messages = applyWsFrameToMessages(messages, makeEvent("text", "Hello")).messages; + messages = applyWsFrameToMessages(messages, makeEvent("reasoning", "Thinking")).messages; messages = applyWsFrameToMessages( messages, makeEvent("execution_step", "trace", { @@ -999,10 +901,7 @@ describe("applyWsFrameToMessages", () => { step_data: { thought: "step one", tool_name: "read_file" }, }), ).messages; - messages = applyWsFrameToMessages( - messages, - makeEvent("status", "Do X"), - ).messages; + messages = applyWsFrameToMessages(messages, makeEvent("status", "Do X")).messages; const result = applyWsFrameToMessages( messages, @@ -1057,19 +956,11 @@ describe("applyWsFrameToMessages", () => { const assistant = result.messages.find((m) => m.type === "assistant"); expect(assistant?.streaming).toBe(false); - expect( - assistant?.renderParts?.some((p) => p.kind === "inline_citation_group"), - ).toBe(true); - expect(assistant?.renderParts?.some((p) => p.kind === "sources")).toBe( - true, - ); - expect(assistant?.renderParts?.some((p) => p.kind === "attachments")).toBe( - true, - ); + expect(assistant?.renderParts?.some((p) => p.kind === "inline_citation_group")).toBe(true); + expect(assistant?.renderParts?.some((p) => p.kind === "sources")).toBe(true); + expect(assistant?.renderParts?.some((p) => p.kind === "attachments")).toBe(true); - const citationGroup = assistant?.renderParts?.find( - (p) => p.kind === "inline_citation_group", - ); + const citationGroup = assistant?.renderParts?.find((p) => p.kind === "inline_citation_group"); if (citationGroup?.kind === "inline_citation_group") { expect(citationGroup.citations[0]?.title).toBe("Doc A"); expect(citationGroup.citations[0]?.number).toBe("1"); @@ -1084,10 +975,7 @@ describe("applyWsFrameToMessages", () => { expect(sources.sources[1]?.sourceId).toBe("src-b"); } - const cot = findFirstPart( - result.messages, - (p) => p.kind === "chain_of_thought", - ); + const cot = findFirstPart(result.messages, (p) => p.kind === "chain_of_thought"); if (cot?.kind === "chain_of_thought") { expect(cot.steps.every((step) => step.status === "complete")).toBe(true); } @@ -1106,48 +994,35 @@ describe("applyWsFrameToMessages", () => { const finalReasoningRows = traceRows( result.messages, - (part, message) => - part.kind === "reasoning" && message.traceSource === "summary", + (part, message) => part.kind === "reasoning" && message.traceSource === "summary", ); expect(finalReasoningRows).toHaveLength(3); const summaryLabels = finalReasoningRows.map((row) => row.part.kind === "reasoning" ? row.part.label : undefined, ); - expect(summaryLabels).toEqual([ - "thought_0", - "thought_1", - "final_reasoning", - ]); + expect(summaryLabels).toEqual(["thought_0", "thought_1", "final_reasoning"]); const finalReasoning = finalReasoningRows[2]?.part; if (finalReasoning?.kind === "reasoning") { - expect(finalReasoning.parts[0]?.text).toBe( - "The evidence lines up with the cited sources.", - ); + expect(finalReasoning.parts[0]?.text).toBe("The evidence lines up with the cited sources."); } }); it("prefers final_artifact markdown over raw final event JSON text", () => { const result = applyWsFrameToMessages( [], - makeEvent( - "done", - '{ "final_markdown": "Hello there, it is great to meet you!" }', - { - final_artifact: { - kind: "markdown", - value: { - final_markdown: "Hello there, it is great to meet you!", - }, + makeEvent("done", '{ "final_markdown": "Hello there, it is great to meet you!" }', { + final_artifact: { + kind: "markdown", + value: { + final_markdown: "Hello there, it is great to meet you!", }, }, - ), + }), ); - const assistant = result.messages.find( - (message) => message.type === "assistant", - ); + const assistant = result.messages.find((message) => message.type === "assistant"); expect(assistant?.content).toBe("Hello there, it is great to meet you!"); }); @@ -1166,9 +1041,7 @@ describe("applyWsFrameToMessages", () => { }), ); - const assistant = result.messages.find( - (message) => message.type === "assistant", - ); + const assistant = result.messages.find((message) => message.type === "assistant"); expect(assistant?.content).toBe("Canonical completion text"); }); diff --git a/src/frontend/src/lib/workspace/backend-chat-event-adapter.ts b/src/frontend/src/lib/workspace/backend-chat-event-adapter.ts index 4e5f18162..4ae4ef1a6 100644 --- a/src/frontend/src/lib/workspace/backend-chat-event-adapter.ts +++ b/src/frontend/src/lib/workspace/backend-chat-event-adapter.ts @@ -1,7 +1,4 @@ -import type { - ChatMessage, - ChatRenderPart, -} from "@/lib/workspace/workspace-types"; +import type { ChatMessage, ChatRenderPart } from "@/lib/workspace/workspace-types"; import type { WsServerEvent, WsServerMessage } from "@/lib/rlm-api"; import { createLocalId } from "@/lib/id"; import { QueryClient } from "@tanstack/react-query"; @@ -66,10 +63,7 @@ function ensureStreamingAssistant(messages: ChatMessage[]): ChatMessage[] { ]; } -function appendAssistantToken( - messages: ChatMessage[], - token: string, -): ChatMessage[] { +function appendAssistantToken(messages: ChatMessage[], token: string): ChatMessage[] { if (!token) return messages; const withAssistant = ensureStreamingAssistant(messages); const idx = latestStreamingAssistantIndex(withAssistant); @@ -122,10 +116,7 @@ function finishReasoning(messages: ChatMessage[]): ChatMessage[] { return updated ? next : messages; } -function completeAssistant( - messages: ChatMessage[], - text: string, -): ChatMessage[] { +function completeAssistant(messages: ChatMessage[], text: string): ChatMessage[] { const idx = latestStreamingAssistantIndex(messages); if (idx >= 0) { @@ -161,13 +152,7 @@ function preferredFinalArtifactText(value: unknown): string | undefined { const record = asRecord(value); if (!record) return undefined; - for (const key of [ - "final_markdown", - "summary", - "text", - "content", - "message", - ]) { + for (const key of ["final_markdown", "summary", "text", "content", "message"]) { const candidate = asOptionalText(record[key]); if (candidate) return candidate; } @@ -180,48 +165,30 @@ function preferredFinalArtifactText(value: unknown): string | undefined { return undefined; } -function resolveFinalAssistantText( - text: string, - payload?: Record, -): string { - const preferred = preferredFinalArtifactText( - payload?.final_artifact ?? payload?.finalArtifact, - ); +function resolveFinalAssistantText(text: string, payload?: Record): string { + const preferred = preferredFinalArtifactText(payload?.final_artifact ?? payload?.finalArtifact); return preferred ?? text; } -function readGuardrailWarnings( - payload: Record | undefined, -): string[] { +function readGuardrailWarnings(payload: Record | undefined): string[] { const raw = payload?.guardrail_warnings; if (!Array.isArray(raw)) return []; - return raw - .map((item) => (typeof item === "string" ? item.trim() : "")) - .filter(Boolean); + return raw.map((item) => (typeof item === "string" ? item.trim() : "")).filter(Boolean); } function canonicalSummaryPayload( payload: Record | undefined, ): Record | undefined { - return asRecord( - payload?.run_summary ?? payload?.runSummary ?? payload?.summary, - ); + return asRecord(payload?.run_summary ?? payload?.runSummary ?? payload?.summary); } -function canonicalCompletionStatus( - payload: Record | undefined, -): string { +function canonicalCompletionStatus(payload: Record | undefined): string { const summary = canonicalSummaryPayload(payload); - return ( - asOptionalText(summary?.status ?? payload?.status)?.toLowerCase() ?? "" - ); + return asOptionalText(summary?.status ?? payload?.status)?.toLowerCase() ?? ""; } -function canonicalStepText( - step: Record, - fallback: string, -): string { +function canonicalStepText(step: Record, fallback: string): string { return ( asOptionalText(step.label) ?? asOptionalText(step.output) ?? @@ -230,22 +197,17 @@ function canonicalStepText( ); } -function routingStatusText( - text: string, - payload?: Record, -): string { +function routingStatusText(text: string, payload?: Record): string { const selectedSkills = Array.isArray(payload?.selected_skills) ? payload.selected_skills.map((item) => String(item)).filter(Boolean) : []; const routingDecision = asOptionalText(payload?.routing_decision); const sourceUrl = asOptionalText(payload?.source_url); - if (selectedSkills.length === 0 && !routingDecision && !sourceUrl) - return text; + if (selectedSkills.length === 0 && !routingDecision && !sourceUrl) return text; const parts = [text.trim()].filter(Boolean); if (routingDecision) parts.push(`route ${routingDecision}`); - if (selectedSkills.length > 0) - parts.push(`skills ${selectedSkills.join(", ")}`); + if (selectedSkills.length > 0) parts.push(`skills ${selectedSkills.join(", ")}`); if (sourceUrl) parts.push(`source ${sourceUrl}`); return parts.join(" | "); } @@ -287,14 +249,9 @@ function applyCanonicalExecutionStep( } if (stepType === "llm") { const output = asRecord(step.output); - const token = - typeof output?.text === "string" - ? output.text - : asOptionalText(step.output); + const token = typeof output?.text === "string" ? output.text : asOptionalText(step.output); const reasoning = - typeof step.label === "string" - ? step.label - : asOptionalText(output?.reasoning ?? step.input); + typeof step.label === "string" ? step.label : asOptionalText(output?.reasoning ?? step.input); return { messages: token ? appendAssistantToken(messages, token) @@ -345,44 +302,27 @@ function applyCanonicalExecutionCompleted( if (status === "failed" || status === "error") { let next = finishReasoning(messages); next = finalizeTraceParts(next); - next = appendSystem( - next, - `Backend error: ${text || "Unknown server error."}`, - ); + next = appendSystem(next, `Backend error: ${text || "Unknown server error."}`); return { messages: next, terminal: true, errored: true }; } - let next = completeAssistant( - messages, - resolveFinalAssistantText(text, payload), - ); + let next = completeAssistant(messages, resolveFinalAssistantText(text, payload)); next = finishReasoning(next); next = finalizeTraceParts(next); next = appendFinalTrajectoryThoughts(next, payload); next = appendFinalTrajectoryToolRows(next, payload); const finalReasoning = - typeof payload?.final_reasoning === "string" - ? payload.final_reasoning.trim() - : ""; + typeof payload?.final_reasoning === "string" ? payload.final_reasoning.trim() : ""; if (finalReasoning) { - next = appendReasoningEvent( - next, - finalReasoning, - "summary", - payload, - "final_reasoning", - ); + next = appendReasoningEvent(next, finalReasoning, "summary", payload, "final_reasoning"); } next = attachFinalReferences(next, payload); const warnings = readGuardrailWarnings(payload); if (warnings.length > 0) { - next = appendSystem( - next, - `Guardrail warnings:\n- ${warnings.join("\n- ")}`, - ); + next = appendSystem(next, `Guardrail warnings:\n- ${warnings.join("\n- ")}`); } return { messages: next, terminal: true, errored: false }; @@ -473,33 +413,21 @@ function appendFinalTrajectoryThoughts( return steps.reduce((acc, step) => { if (!step.thought) return acc; - return appendReasoningEvent( - acc, - step.thought, - "summary", - payload, - `thought_${step.index}`, - ); + return appendReasoningEvent(acc, step.thought, "summary", payload, `thought_${step.index}`); }, messages); } function currentTurnMessages(messages: ChatMessage[]): ChatMessage[] { - const lastUserIndex = messages.findLastIndex( - (message) => message.type === "user", - ); + const lastUserIndex = messages.findLastIndex((message) => message.type === "user"); return lastUserIndex >= 0 ? messages.slice(lastUserIndex + 1) : messages; } -function hasLiveToolOrSandboxTraceForCurrentTurn( - messages: ChatMessage[], -): boolean { +function hasLiveToolOrSandboxTraceForCurrentTurn(messages: ChatMessage[]): boolean { return currentTurnMessages(messages).some( (message) => message.type === "trace" && message.traceSource === "live" && - message.renderParts?.some( - (part) => part.kind === "tool" || part.kind === "sandbox", - ), + message.renderParts?.some((part) => part.kind === "tool" || part.kind === "sandbox"), ); } @@ -556,9 +484,7 @@ function finalizeTraceParts(messages: ChatMessage[]): ChatMessage[] { items: part.items.map((it) => ({ ...it, completed: true })), }; case "task": - return part.status === "in_progress" - ? { ...part, status: "completed" as const } - : part; + return part.status === "in_progress" ? { ...part, status: "completed" as const } : part; case "tool": case "sandbox": return part.state === "running" || part.state === "input-streaming" @@ -579,12 +505,7 @@ function resolveHitlByMessageId( ): ChatMessage[] { let changed = false; const next = messages.map((msg) => { - if ( - changed || - msg.id !== messageId || - msg.type !== "hitl" || - !msg.hitlData - ) { + if (changed || msg.id !== messageId || msg.type !== "hitl" || !msg.hitlData) { return msg; } changed = true; @@ -600,18 +521,10 @@ function resolveHitlByMessageId( return changed ? next : messages; } -function rollbackHitlByMessageId( - messages: ChatMessage[], - messageId: string, -): ChatMessage[] { +function rollbackHitlByMessageId(messages: ChatMessage[], messageId: string): ChatMessage[] { let changed = false; const next = messages.map((msg) => { - if ( - changed || - msg.id !== messageId || - msg.type !== "hitl" || - !msg.hitlData - ) { + if (changed || msg.id !== messageId || msg.type !== "hitl" || !msg.hitlData) { return msg; } changed = true; @@ -627,10 +540,7 @@ function rollbackHitlByMessageId( return changed ? next : messages; } -function applyEvent( - messages: ChatMessage[], - frame: WsServerEvent, -): ApplyFrameResult { +function applyEvent(messages: ChatMessage[], frame: WsServerEvent): ApplyFrameResult { const { kind, text, payload } = frame.data; switch (kind) { @@ -665,11 +575,8 @@ function applyEvent( const command = asOptionalText(payload?.command); const result = asRecord(payload?.result); const messageId = asOptionalText(result?.message_id ?? result?.messageId); - const resolution = - asOptionalText(result?.resolution) ?? - asOptionalText(result?.action_label); - const succeeded = - asOptionalText(result?.status)?.toLowerCase() !== "error"; + const resolution = asOptionalText(result?.resolution) ?? asOptionalText(result?.action_label); + const succeeded = asOptionalText(result?.status)?.toLowerCase() !== "error"; let next = messages; if (succeeded && command === "resolve_hitl" && messageId && resolution) { next = resolveHitlByMessageId(next, messageId, resolution); @@ -683,8 +590,7 @@ function applyEvent( { kind: "status_note", tone: succeeded ? "success" : "error", - text: - text || (succeeded ? "Action acknowledged" : "Action rejected"), + text: text || (succeeded ? "Action acknowledged" : "Action rejected"), }, text || (succeeded ? "Action acknowledged" : "Action rejected"), ), @@ -707,9 +613,7 @@ export function applyWsFrameToMessages( _queryClient?: QueryClient, ): ApplyFrameResult { if (frame.type === "error") { - const next = finalizeTraceParts( - appendSystem(messages, `Backend error: ${frame.message}`), - ); + const next = finalizeTraceParts(appendSystem(messages, `Backend error: ${frame.message}`)); return { messages: finishReasoning(next), terminal: true, errored: true }; } diff --git a/src/frontend/src/lib/workspace/backend-chat-event-trajectory.ts b/src/frontend/src/lib/workspace/backend-chat-event-trajectory.ts index 9968ac6aa..25537d5b8 100644 --- a/src/frontend/src/lib/workspace/backend-chat-event-trajectory.ts +++ b/src/frontend/src/lib/workspace/backend-chat-event-trajectory.ts @@ -35,11 +35,7 @@ function parseTrajectoryStepIndex( payload?: Record, stepData?: Record, ): number { - return ( - asOptionalNumber(payload?.step_index) ?? - asOptionalNumber(stepData?.index) ?? - 0 - ); + return asOptionalNumber(payload?.step_index) ?? asOptionalNumber(stepData?.index) ?? 0; } function normalizeTrajectoryStep( @@ -50,12 +46,9 @@ function normalizeTrajectoryStep( const action = asOptionalText(raw.action); const code = asOptionalText(raw.code); const toolName = - asOptionalText(raw.tool_name ?? raw.toolName) ?? - (code ? "repl_execute" : undefined); + asOptionalText(raw.tool_name ?? raw.toolName) ?? (code ? "repl_execute" : undefined); const thought = - asOptionalText(raw.thought) ?? - asOptionalText(raw.reasoning) ?? - asOptionalText(fallbackText); + asOptionalText(raw.thought) ?? asOptionalText(raw.reasoning) ?? asOptionalText(fallbackText); const toolInput = normalizeOptionalUnknown( raw.tool_args ?? raw.input ?? raw.tool_input ?? raw.toolInput ?? code, ); @@ -246,22 +239,16 @@ function summarizeTrajectoryValue(value: unknown): string | undefined { } } -export function trajectoryStepDetails( - step: NormalizedTrajectoryStep, -): string[] { +export function trajectoryStepDetails(step: NormalizedTrajectoryStep): string[] { const details: string[] = []; if (step.toolName) { details.push(`Tool · ${step.toolName}`); } if (step.toolInput !== undefined) { - details.push( - `Input · ${summarizeTrajectoryValue(step.toolInput) ?? "Available"}`, - ); + details.push(`Input · ${summarizeTrajectoryValue(step.toolInput) ?? "Available"}`); } if (step.toolOutput !== undefined) { - details.push( - `Observation · ${summarizeTrajectoryValue(step.toolOutput) ?? "Available"}`, - ); + details.push(`Observation · ${summarizeTrajectoryValue(step.toolOutput) ?? "Available"}`); } return details; } From 27417f4126b0dc01066fdf12937a8077ec2a1dfd Mon Sep 17 00:00:00 2001 From: Zachary BENSALEM Date: Sat, 6 Jun 2026 23:24:43 +0200 Subject: [PATCH 5/7] codex: address PR review feedback (#271) --- scripts/capture_phase0_baseline.sh | 5 ++--- scripts/mlflow_cli.py | 8 ++------ src/fleet_rlm/api/routers/ws/turn_runner.py | 2 +- .../api/runtime_services/run_lifecycle.py | 2 +- .../runtime_services/session_persistence.py | 10 ++++++---- .../observability/mlflow_runtime.py | 4 +++- src/fleet_rlm/runtime/tools/binding.py | 20 +++++++++++++++---- tests/unit/api/test_chat_persistence.py | 1 - .../test_daytona_sandbox_executor.py | 1 - .../unit/integrations/test_mlflow_context.py | 4 ++-- tests/unit/runtime/test_tools.py | 8 +++----- 11 files changed, 36 insertions(+), 29 deletions(-) diff --git a/scripts/capture_phase0_baseline.sh b/scripts/capture_phase0_baseline.sh index 3c37bd3eb..18200b93d 100755 --- a/scripts/capture_phase0_baseline.sh +++ b/scripts/capture_phase0_baseline.sh @@ -6,7 +6,8 @@ set -e echo "=== Phase 0: Capturing golden payloads and openapi.yaml baseline ===" -# Create golden payloads directory +# Clean up old golden payloads and recreate the directory +rm -rf tests/contracts/golden_payloads mkdir -p tests/contracts/golden_payloads # Capture openapi.yaml baseline @@ -19,8 +20,6 @@ cp src/frontend/src/lib/rlm-api/generated/openapi.ts tests/contracts/golden_payl # Run golden payload capture tests echo "Running golden payload capture tests..." -# Temporarily remove golden payloads directory to trigger capture -rm -rf tests/contracts/golden_payloads uv run pytest tests/contracts/test_golden_payloads.py::test_capture_chat_websocket_golden_payloads -v uv run pytest tests/contracts/test_golden_payloads.py::test_capture_passive_events_websocket_golden_payloads -v diff --git a/scripts/mlflow_cli.py b/scripts/mlflow_cli.py index faaf64fe9..fa5f8505d 100755 --- a/scripts/mlflow_cli.py +++ b/scripts/mlflow_cli.py @@ -150,7 +150,7 @@ def do_scorers_stop(args: argparse.Namespace) -> int: stop_scorer = getattr(scorer, "stop", None) if not callable(stop_scorer): raise RuntimeError("This MLflow scorer does not expose stop().") - stop_scorer(name=args.name, experiment_id=experiment_id) + stop_scorer() print(f"stopped_scorer={args.name}") print(f"experiment={config.experiment}") print(f"experiment_id={experiment_id or ''}") @@ -167,11 +167,7 @@ def do_scorers_start(args: argparse.Namespace) -> int: from mlflow.genai.scorers import ScorerSamplingConfig - start_scorer( - name=args.name, - experiment_id=experiment_id, - sampling_config=ScorerSamplingConfig(sample_rate=args.sample_rate, filter_string=args.filter_string), - ) + start_scorer(sampling_config=ScorerSamplingConfig(sample_rate=args.sample_rate, filter_string=args.filter_string)) print(f"started_scorer={args.name}") print(f"sample_rate={args.sample_rate}") print(f"experiment={config.experiment}") diff --git a/src/fleet_rlm/api/routers/ws/turn_runner.py b/src/fleet_rlm/api/routers/ws/turn_runner.py index 0cbe829e1..aa5ffb3e2 100644 --- a/src/fleet_rlm/api/routers/ws/turn_runner.py +++ b/src/fleet_rlm/api/routers/ws/turn_runner.py @@ -306,7 +306,7 @@ async def _stream_agent_events( try: await hosted_repl_bridge.stop() except Exception: - pass + logger.warning("Failed to stop hosted REPL bridge during stream cleanup.", exc_info=True) if not lifecycle.run_completed: lifecycle.raise_if_persistence_error() diff --git a/src/fleet_rlm/api/runtime_services/run_lifecycle.py b/src/fleet_rlm/api/runtime_services/run_lifecycle.py index 9831bb150..6c9ef5039 100644 --- a/src/fleet_rlm/api/runtime_services/run_lifecycle.py +++ b/src/fleet_rlm/api/runtime_services/run_lifecycle.py @@ -219,7 +219,7 @@ async def _stop_persist_worker(self) -> None: try: await self._persist_worker_task except asyncio.CancelledError: - pass + logger.debug("Persist worker task was cancelled during shutdown.") self._persist_worker_task = None self._persist_queue = None diff --git a/src/fleet_rlm/api/runtime_services/session_persistence.py b/src/fleet_rlm/api/runtime_services/session_persistence.py index f8eafbc21..08bd4bc4b 100644 --- a/src/fleet_rlm/api/runtime_services/session_persistence.py +++ b/src/fleet_rlm/api/runtime_services/session_persistence.py @@ -2,6 +2,7 @@ from __future__ import annotations +import inspect import logging import uuid from typing import Any @@ -158,8 +159,6 @@ async def _persist_manifest_to_local_store( if not callable(update_fn): return try: - import inspect - sig = inspect.signature(update_fn) # LocalStore.update_chat_session requires tenant_id + session_id UUIDs; the # async FleetRepository variant has the same shape. Both accept metadata_json. @@ -167,7 +166,9 @@ async def _persist_manifest_to_local_store( # locate it without a UUID round-trip. params = set(sig.parameters) if "external_session_id" in params: - await update_fn(external_session_id=sess_id, metadata_json={"_manifest_state": manifest}) + result = update_fn(external_session_id=sess_id, metadata_json={"_manifest_state": manifest}) + if inspect.iscoroutine(result): + await result else: # Async path: skip – we cannot derive the UUID here without identity_rows. pass @@ -190,7 +191,8 @@ async def _restore_manifest_from_local_store( if not callable(get_fn): return {} try: - row = await get_fn(external_session_id=sess_id) + result = get_fn(external_session_id=sess_id) + row = await result if inspect.iscoroutine(result) else result if row is None: return {} metadata = getattr(row, "metadata_json", None) diff --git a/src/fleet_rlm/integrations/observability/mlflow_runtime.py b/src/fleet_rlm/integrations/observability/mlflow_runtime.py index 0bd0c3536..e9600d160 100644 --- a/src/fleet_rlm/integrations/observability/mlflow_runtime.py +++ b/src/fleet_rlm/integrations/observability/mlflow_runtime.py @@ -513,12 +513,14 @@ def on_lm_end( # retry (adds ~8 s per turn). _THINK_TAG_RE = re.compile(r".*?", re.DOTALL | re.IGNORECASE) # Some models emit in one token batch and in another, so the -# paired regex above leaves orphaned closing tags after stripping complete pairs. +# paired regex above leaves orphaned tags after stripping complete pairs. +_ORPHAN_THINK_OPEN_RE = re.compile(r"", re.IGNORECASE) _ORPHAN_THINK_CLOSE_RE = re.compile(r"", re.IGNORECASE) def _strip_think_tags(text: str) -> str: text = _THINK_TAG_RE.sub("", text) + text = _ORPHAN_THINK_OPEN_RE.sub("", text) text = _ORPHAN_THINK_CLOSE_RE.sub("", text) return text.lstrip("\n") diff --git a/src/fleet_rlm/runtime/tools/binding.py b/src/fleet_rlm/runtime/tools/binding.py index 534ca9c0f..e2f809c1d 100644 --- a/src/fleet_rlm/runtime/tools/binding.py +++ b/src/fleet_rlm/runtime/tools/binding.py @@ -301,10 +301,14 @@ def browser_fetch_page( "Use a browser-capable sandbox (fleet-rlm-browser snapshot) for rendered page fetching.", ) else: + import logging + logger = logging.getLogger(__name__) with sync_playwright() as p: - browser = p.chromium.launch(headless=True, args=["--no-sandbox", "--disable-dev-shm-usage"]) - page = browser.new_page() + browser = None + page = None try: + browser = p.chromium.launch(headless=True, args=["--no-sandbox", "--disable-dev-shm-usage"]) + page = browser.new_page() page.goto(target_url, wait_until=wait_until, timeout=30000) text = page.inner_text("body") title = page.title() @@ -322,9 +326,17 @@ def browser_fetch_page( char_count=len(text), links=links[:100] if extract_links else [], ) + except Exception: + logger.exception("Browser execution failed") + SUBMIT( + status="error", + error="Browser execution failed. Please check the logs for details.", + ) finally: - page.close() - browser.close() + if page is not None: + page.close() + if browser is not None: + browser.close() """ diff --git a/tests/unit/api/test_chat_persistence.py b/tests/unit/api/test_chat_persistence.py index c9cd61d80..0dcdfa0f8 100644 --- a/tests/unit/api/test_chat_persistence.py +++ b/tests/unit/api/test_chat_persistence.py @@ -257,7 +257,6 @@ async def background() -> str: ) await stream_task - assert cancel_flag["cancelled"] is False assert calls == [] diff --git a/tests/unit/integrations/test_daytona_sandbox_executor.py b/tests/unit/integrations/test_daytona_sandbox_executor.py index 48a19e1f9..c3d18f55b 100644 --- a/tests/unit/integrations/test_daytona_sandbox_executor.py +++ b/tests/unit/integrations/test_daytona_sandbox_executor.py @@ -79,7 +79,6 @@ def fail_ensure_bridge(**kwargs: Any) -> None: callbacks=callbacks, ) - assert ensure_calls == 1 _BROKER_START_FAILURES.clear() diff --git a/tests/unit/integrations/test_mlflow_context.py b/tests/unit/integrations/test_mlflow_context.py index eb9bb7631..08de83866 100644 --- a/tests/unit/integrations/test_mlflow_context.py +++ b/tests/unit/integrations/test_mlflow_context.py @@ -14,7 +14,7 @@ def test_update_current_mlflow_trace_mirrors_fleet_metadata_to_tags(monkeypatch) captured: list[dict[str, Any]] = [] fake_mlflow = SimpleNamespace( - get_current_active_span=lambda: object(), + get_current_active_span=object, get_active_trace_id=lambda: "tr-test", update_current_trace=lambda **kwargs: captured.append(kwargs), ) @@ -59,7 +59,7 @@ def test_update_current_mlflow_trace_does_not_resend_active_trace_tags(monkeypat captured: list[dict[str, Any]] = [] fake_mlflow = SimpleNamespace( - get_current_active_span=lambda: object(), + get_current_active_span=object, get_active_trace_id=lambda: "tr-test", update_current_trace=lambda **kwargs: captured.append(kwargs), ) diff --git a/tests/unit/runtime/test_tools.py b/tests/unit/runtime/test_tools.py index 36ef16e7e..184e2c8db 100644 --- a/tests/unit/runtime/test_tools.py +++ b/tests/unit/runtime/test_tools.py @@ -452,8 +452,7 @@ def test_browser_fetch_page_in_discover_tools() -> None: def test_bound_browser_fetch_page_validates_public_url_before_sandbox(monkeypatch: pytest.MonkeyPatch) -> None: import fleet_rlm.runtime.tools.binding as binding_mod - import fleet_rlm.runtime.tools.document_tools as document_tools - from fleet_rlm.runtime.tools.binding import bind_runtime_tools + from fleet_rlm.runtime.tools import document_tools as document_tools from fleet_rlm.runtime.tools.browser_tools import browser_fetch_page calls: list[dict[str, Any]] = [] @@ -470,7 +469,7 @@ def fake_execute_sandbox_tool(interpreter: Any, code: str, variables: dict[str, monkeypatch.setattr(document_tools.socket, "getaddrinfo", fake_getaddrinfo) monkeypatch.setattr(binding_mod, "execute_sandbox_tool", fake_execute_sandbox_tool) - bound = bind_runtime_tools( + bound = binding_mod.bind_runtime_tools( [browser_fetch_page], runtime=types.SimpleNamespace(core_memory={}), interpreter=object(), @@ -502,7 +501,6 @@ def test_bound_browser_fetch_page_rejects_private_targets( url: str, ) -> None: import fleet_rlm.runtime.tools.binding as binding_mod - from fleet_rlm.runtime.tools.binding import bind_runtime_tools from fleet_rlm.runtime.tools.browser_tools import browser_fetch_page def fake_execute_sandbox_tool(*args: Any, **kwargs: Any) -> dict[str, Any]: @@ -510,7 +508,7 @@ def fake_execute_sandbox_tool(*args: Any, **kwargs: Any) -> dict[str, Any]: raise AssertionError("unsafe URL should be rejected before sandbox execution") monkeypatch.setattr(binding_mod, "execute_sandbox_tool", fake_execute_sandbox_tool) - bound = bind_runtime_tools( + bound = binding_mod.bind_runtime_tools( [browser_fetch_page], runtime=types.SimpleNamespace(core_memory={}), interpreter=object(), From 2cca8539109b91105dcdf8f854ceb3cc54fb3405 Mon Sep 17 00:00:00 2001 From: Zachary BENSALEM Date: Sat, 6 Jun 2026 23:28:17 +0200 Subject: [PATCH 6/7] codex: fix CI failure on PR #271 --- tests/unit/cli/test_mlflow_cli.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/tests/unit/cli/test_mlflow_cli.py b/tests/unit/cli/test_mlflow_cli.py index fc94eecf1..0ac121d5d 100644 --- a/tests/unit/cli/test_mlflow_cli.py +++ b/tests/unit/cli/test_mlflow_cli.py @@ -126,11 +126,9 @@ def test_scorers_stop_calls_registered_scorer_stop( capsys: pytest.CaptureFixture[str], ) -> None: apply_mlflow_env(clean_runtime_env) - stopped: list[dict[str, object]] = [] + stopped: list[str] = [] - scorer = SimpleNamespace( - stop=lambda *, name, experiment_id: stopped.append({"name": name, "experiment_id": experiment_id}) - ) + scorer = SimpleNamespace(stop=lambda: stopped.append("called")) fake_mlflow = _fake_mlflow_module(scorer=scorer) monkeypatch.setitem(sys.modules, "mlflow", fake_mlflow) @@ -139,7 +137,7 @@ def test_scorers_stop_calls_registered_scorer_stop( output = capsys.readouterr().out assert result == 0 - assert stopped == [{"name": "Trace Judge", "experiment_id": "exp-active"}] + assert stopped == ["called"] assert "stopped_scorer=Trace Judge" in output @@ -151,11 +149,9 @@ def test_scorers_start_calls_registered_scorer_start( apply_mlflow_env(clean_runtime_env) started: list[dict[str, object]] = [] - def start(*, name, experiment_id, sampling_config) -> None: + def start(*, sampling_config) -> None: started.append( { - "name": name, - "experiment_id": experiment_id, "sample_rate": sampling_config.sample_rate, "filter_string": sampling_config.filter_string, } @@ -181,8 +177,6 @@ def start(*, name, experiment_id, sampling_config) -> None: assert result == 0 assert started == [ { - "name": "Trace Judge", - "experiment_id": "exp-override", "sample_rate": 0.5, "filter_string": "status = 'OK'", } From 860e3d13cb5f8576720f992d904178e1fb7276d3 Mon Sep 17 00:00:00 2001 From: Zachary BENSALEM Date: Sat, 6 Jun 2026 23:34:31 +0200 Subject: [PATCH 7/7] codex: address PR review feedback (#271) --- src/fleet_rlm/api/routers/ws/turn_runner.py | 2 + tests/unit/api/test_runtime_diagnostics.py | 4 +- tests/unit/api/test_ws_turn_runner.py | 88 +++++++++++++++++++ .../unit/integrations/test_mlflow_context.py | 4 +- 4 files changed, 94 insertions(+), 4 deletions(-) create mode 100644 tests/unit/api/test_ws_turn_runner.py diff --git a/src/fleet_rlm/api/routers/ws/turn_runner.py b/src/fleet_rlm/api/routers/ws/turn_runner.py index aa5ffb3e2..e609d51ba 100644 --- a/src/fleet_rlm/api/routers/ws/turn_runner.py +++ b/src/fleet_rlm/api/routers/ws/turn_runner.py @@ -210,6 +210,8 @@ async def _emit_stream_event( run_id=lifecycle.run_id, ) is_terminal_event = _is_terminal_transport_event(event) + if websocket is not None: + await _try_send_json(websocket, event_dict) if isinstance(event, RuntimeEvent): step = step_builder.from_runtime_event(event) diff --git a/tests/unit/api/test_runtime_diagnostics.py b/tests/unit/api/test_runtime_diagnostics.py index d8d14fcae..f338a5a81 100644 --- a/tests/unit/api/test_runtime_diagnostics.py +++ b/tests/unit/api/test_runtime_diagnostics.py @@ -35,7 +35,7 @@ def test_runtime_status_surfaces_persisted_mlflow_scorers(monkeypatch) -> None: (), { "from_env": staticmethod( - lambda: type( + type( "FakeMlflowConfig", (), { @@ -43,7 +43,7 @@ def test_runtime_status_surfaces_persisted_mlflow_scorers(monkeypatch) -> None: "enable_auto_assessment": False, "tracking_uri": "http://127.0.0.1:5001", }, - )() + ) ) }, ), diff --git a/tests/unit/api/test_ws_turn_runner.py b/tests/unit/api/test_ws_turn_runner.py new file mode 100644 index 000000000..76515356d --- /dev/null +++ b/tests/unit/api/test_ws_turn_runner.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +from fleet_rlm.api.events import ExecutionStepBuilder +from fleet_rlm.api.routers.ws.turn_runner import _emit_stream_event +from fleet_rlm.api.runtime_services.run_lifecycle import ExecutionLifecycleManager +from fleet_rlm.runtime.schemas import StreamEvent + + +class _FakeEmitter: + def __init__(self) -> None: + self.events: list[Any] = [] + + async def emit(self, event: Any) -> None: + self.events.append(event) + + +class _FakeWebSocket: + def __init__(self) -> None: + self.messages: list[Any] = [] + + async def send_json(self, payload: Any) -> None: + self.messages.append(payload) + + +def _lifecycle() -> tuple[ExecutionLifecycleManager, ExecutionStepBuilder, _FakeEmitter]: + step_builder = ExecutionStepBuilder(run_id="run-1") + emitter = _FakeEmitter() + lifecycle = ExecutionLifecycleManager( + run_id="run-1", + workspace_id="workspace", + user_id="user", + session_id="session", + execution_emitter=emitter, + step_builder=step_builder, + ) + return lifecycle, step_builder, emitter + + +async def _persist_session_state(**kwargs: Any) -> None: + _ = kwargs + + +@pytest.mark.asyncio +async def test_emit_stream_event_sends_non_terminal_frame_to_websocket() -> None: + lifecycle, step_builder, emitter = _lifecycle() + websocket = _FakeWebSocket() + + await _emit_stream_event( + websocket=websocket, # type: ignore[arg-type] + lifecycle=lifecycle, + step_builder=step_builder, + event=StreamEvent(kind="status", text="Starting turn..."), + persist_session_state=_persist_session_state, + request_message="hello", + execution_emitter=emitter, # type: ignore[arg-type] + ) + + assert websocket.messages + assert websocket.messages[0]["kind"] == "execution_step" + assert websocket.messages[0]["text"] == "Starting turn..." + assert websocket.messages[0]["payload"]["source_type"] == "status" + assert lifecycle.run_completed is False + + +@pytest.mark.asyncio +async def test_emit_stream_event_sends_terminal_frame_before_completion() -> None: + lifecycle, step_builder, emitter = _lifecycle() + websocket = _FakeWebSocket() + + await _emit_stream_event( + websocket=websocket, # type: ignore[arg-type] + lifecycle=lifecycle, + step_builder=step_builder, + event=StreamEvent(kind="done", text="done", payload={"history_turns": 1}), + persist_session_state=_persist_session_state, + request_message="hello", + execution_emitter=emitter, # type: ignore[arg-type] + ) + + assert websocket.messages + assert websocket.messages[0]["kind"] == "execution_completed" + assert websocket.messages[0]["text"] == "done" + assert websocket.messages[0]["payload"]["source_type"] == "turn_completed" + assert lifecycle.run_completed is True diff --git a/tests/unit/integrations/test_mlflow_context.py b/tests/unit/integrations/test_mlflow_context.py index 08de83866..93f3d6af3 100644 --- a/tests/unit/integrations/test_mlflow_context.py +++ b/tests/unit/integrations/test_mlflow_context.py @@ -259,7 +259,7 @@ def set_status(self, status: str) -> None: self.record["status"] = status fake_mlflow = SimpleNamespace( - get_current_active_span=lambda: object(), + get_current_active_span=object, start_span=lambda name, span_type=None, attributes=None: FakeSpan(name, span_type, attributes), ) monkeypatch.setattr( @@ -320,7 +320,7 @@ def set_status(self, status: str) -> None: self.record["status"] = status fake_mlflow = SimpleNamespace( - get_current_active_span=lambda: object(), + get_current_active_span=object, start_span=lambda name, span_type=None, attributes=None: FakeSpan(name, span_type, attributes), ) monkeypatch.setattr(