diff --git a/examples/end_to_end/tbench2_pi_trl/README.md b/examples/end_to_end/tbench2_pi_trl/README.md new file mode 100644 index 000000000..07096a696 --- /dev/null +++ b/examples/end_to_end/tbench2_pi_trl/README.md @@ -0,0 +1,9 @@ +# Terminus + TRL Async GRPO + +Start a Terminus server and a vLLM server with weight transfer enabled, then run: + +```bash +TERMINUS_ENV_URL=http://localhost:8000 \ +TERMINUS_VLLM_SERVER_URL=http://localhost:8001 \ +uv run train_terminus_grpo.py +``` diff --git a/examples/end_to_end/tbench2_pi_trl/agents/terminus/.pi/skills/terminus-terminal-task/SKILL.md b/examples/end_to_end/tbench2_pi_trl/agents/terminus/.pi/skills/terminus-terminal-task/SKILL.md new file mode 100644 index 000000000..021809664 --- /dev/null +++ b/examples/end_to_end/tbench2_pi_trl/agents/terminus/.pi/skills/terminus-terminal-task/SKILL.md @@ -0,0 +1,26 @@ +--- +name: terminus-terminal-task +description: Use inside a Terminus environment session when solving one sandboxed terminal task with the terminal tool. +--- + +# Terminus Terminal Task + +Use this skill only inside a Terminus task session. + +## Workflow + +1. Read the task. +2. Use the `terminal` tool for each terminal action. +3. Pass `command` to inspect and modify the sandbox. +4. Check command output before choosing the next command. +5. When the task is complete, pass `final_answer` exactly once. + +## Guardrails + +- Do not change hidden checks or task configuration. +- Do not claim completion until the visible task requirements are satisfied. +- Stay focused on the current task and terminal outputs. +- Do not include both `command` and `final_answer` in the same tool call. +- For simple file writes, prefer commands like `printf %s 'text' > path`. +- If a command fails, inspect the error and continue with a smaller diagnostic + command. diff --git a/examples/end_to_end/tbench2_pi_trl/agents/terminus/AGENTS.md b/examples/end_to_end/tbench2_pi_trl/agents/terminus/AGENTS.md new file mode 100644 index 000000000..8defe822c --- /dev/null +++ b/examples/end_to_end/tbench2_pi_trl/agents/terminus/AGENTS.md @@ -0,0 +1,15 @@ +# Terminus Task Instructions + +You are solving one task inside a Terminus terminal environment. + +Use the available `terminal` tool to inspect and modify the sandbox. Prefer +short, direct shell commands. Read command output before deciding the next +action. While working, call `terminal` with a `command`. When done, call +`terminal` with a `final_answer`. Do not include both arguments in the same +tool call. + +When the requested task is complete, submit exactly one final answer. The final +answer should be concise and should not include implementation notes. + +Do not change hidden task checks or environment configuration. Stay focused on +the current task and the available terminal tool. diff --git a/examples/end_to_end/tbench2_pi_trl/pi_rollout_worker.py b/examples/end_to_end/tbench2_pi_trl/pi_rollout_worker.py new file mode 100644 index 000000000..9dc2434a0 --- /dev/null +++ b/examples/end_to_end/tbench2_pi_trl/pi_rollout_worker.py @@ -0,0 +1,945 @@ +"""Pi rollout worker for the Terminus async GRPO example.""" + +from __future__ import annotations + +import json +import logging +import queue +import re +import threading +import time +import uuid +from dataclasses import dataclass, field +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from typing import Any, Callable, Iterator, Sequence, cast + +import requests +from openenv.core.harness import HarnessRunLimits +from openenv.core.harness.pi_cli import PiCLIHarnessAdapter + +try: + from terminus_env.harness import build_terminal_tool_call +except Exception: # pragma: no cover - optional outside the Terminus example + build_terminal_tool_call = None + +try: + from trl.chat_template_utils import ( + add_response_schema, + get_training_chat_template, + is_chat_template_prefix_preserving, + parse_response, + ) +except Exception: # pragma: no cover - optional across TRL revisions + add_response_schema = None + get_training_chat_template = None + is_chat_template_prefix_preserving = None + parse_response = None + +try: + from vllm.distributed.weight_transfer.nccl_engine import ( + NCCLTrainerSendWeightsArgs, + NCCLWeightTransferEngine, + ) + from vllm.utils.network_utils import get_ip, get_open_port +except Exception: # pragma: no cover - optional outside training runtime + NCCLTrainerSendWeightsArgs = None + NCCLWeightTransferEngine = None + get_ip = None + get_open_port = None + +logger = logging.getLogger(__name__) + + +def _vllm_version() -> tuple[int, ...]: + try: + import vllm + + return tuple(int(part) for part in vllm.__version__.split(".")[:3]) + except Exception: + return (0, 0, 0) + + +_VLLM_NEEDS_WEIGHT_UPDATE_LIFECYCLE = _vllm_version() >= (0, 21, 0) + + +@dataclass +class RolloutSample: + input_ids: list[int] + completion_mask: list[int] + old_log_probs: list[float] + advantage: float + model_version: int + metrics: dict[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class WorkerConfig: + max_inflight: int = 2 + queue_maxsize: int = 64 + max_turns: int = 8 + max_completion_tokens: int = 512 + temperature: float = 1.0 + request_timeout_s: float = 600.0 + server_timeout_s: float = 600.0 + idle_sleep_s: float = 0.25 + vllm_weight_name_prefix: str = "" + + +class InterceptionServer: + """Minimal OpenAI-compatible gate for PI chat completion requests.""" + + def __init__(self, *, host: str = "127.0.0.1", port: int = 0, secret: str = "openenv"): + self.host = host + self.port = port + self.secret = secret + self._server: ThreadingHTTPServer | None = None + self._thread: threading.Thread | None = None + self._lock = threading.RLock() + self._rollouts: dict[str, queue.Queue[str]] = {} + self._intercepts: dict[str, dict[str, Any]] = {} + + @property + def base_url(self) -> str: + if self._server is None: + raise RuntimeError("interception server is not running") + return f"http://127.0.0.1:{self.port}" + + def start(self) -> None: + if self._server is not None: + return + server = ThreadingHTTPServer((self.host, self.port), self._handler()) + self._server = server + self.port = int(server.server_port) + self._thread = threading.Thread( + target=server.serve_forever, + daemon=True, + name="terminus-pi-interception", + ) + self._thread.start() + + def stop(self) -> None: + server = self._server + if server is None: + return + server.shutdown() + server.server_close() + if self._thread is not None: + self._thread.join(timeout=2.0) + self._server = None + self._thread = None + + def register_rollout(self, rollout_id: str) -> queue.Queue[str]: + request_queue: queue.Queue[str] = queue.Queue() + with self._lock: + self._rollouts[rollout_id] = request_queue + return request_queue + + def unregister_rollout(self, rollout_id: str) -> None: + with self._lock: + self._rollouts.pop(rollout_id, None) + intercepts = [ + key + for key, intercept in self._intercepts.items() + if intercept.get("rollout_id") == rollout_id + ] + for key in intercepts: + intercept = self._intercepts.pop(key) + intercept["response"] = _error_response("rollout cancelled") + intercept["event"].set() + + def get_intercept(self, request_id: str) -> dict[str, Any] | None: + with self._lock: + return self._intercepts.get(request_id) + + def deliver(self, intercept: dict[str, Any], response: dict[str, Any]) -> None: + intercept["response"] = response + intercept["event"].set() + + def _authorized(self, headers: Any) -> bool: + auth = headers.get("Authorization", "") + api_key = headers.get("x-api-key", "") + return auth == f"Bearer {self.secret}" or api_key == self.secret + + def _handler(self) -> type[BaseHTTPRequestHandler]: + outer = self + + class Handler(BaseHTTPRequestHandler): + def log_message(self, format: str, *args: Any) -> None: + return None + + def do_GET(self) -> None: + if self.path == "/health": + self._json({"status": "ok"}) + return + self._json({"error": "not found"}, status=404) + + def do_POST(self) -> None: + if not outer._authorized(self.headers): + self._json({"error": "unauthorized"}, status=401) + return + match = re.fullmatch( + r"/rollout/([^/]+)/v1/chat/completions", + self.path.split("?", 1)[0], + ) + if match is None: + self._json({"error": "not found"}, status=404) + return + rollout_id = match.group(1) + try: + length = int(self.headers.get("content-length", "0")) + body = json.loads(self.rfile.read(length).decode("utf-8")) + except Exception as exc: + self._json({"error": f"invalid JSON: {exc}"}, status=400) + return + + with outer._lock: + request_queue = outer._rollouts.get(rollout_id) + if request_queue is None: + self._json({"error": "rollout not found"}, status=404) + return + + request_id = f"req_{uuid.uuid4().hex[:8]}" + intercept = { + "request_id": request_id, + "rollout_id": rollout_id, + "messages": body.get("messages"), + "tools": body.get("tools"), + "body": body, + "event": threading.Event(), + "response": None, + } + with outer._lock: + outer._intercepts[request_id] = intercept + request_queue.put(request_id) + + if not intercept["event"].wait(timeout=900): + self._json({"error": "interception timeout"}, status=504) + return + + with outer._lock: + outer._intercepts.pop(request_id, None) + response = intercept["response"] or _error_response("empty response") + if body.get("stream"): + self._sse(response) + return + self._json(response) + + def _json(self, payload: dict[str, Any], *, status: int = 200) -> None: + body = json.dumps(payload).encode("utf-8") + self.send_response(status) + self.send_header("content-type", "application/json") + self.send_header("content-length", str(len(body))) + self.end_headers() + self.wfile.write(body) + + def _sse(self, payload: dict[str, Any]) -> None: + self.send_response(200) + self.send_header("content-type", "text/event-stream") + self.send_header("cache-control", "no-cache") + self.end_headers() + for choice in payload.get("choices") or []: + message = choice.get("message") or {} + tool_calls = [ + {"index": index, **tool_call} + for index, tool_call in enumerate(message.get("tool_calls") or []) + ] + self._sse_data( + { + "id": payload.get("id", ""), + "object": "chat.completion.chunk", + "created": payload.get("created", int(time.time())), + "model": payload.get("model", ""), + "choices": [ + { + "index": choice.get("index", 0), + "delta": { + "role": "assistant", + "content": message.get("content"), + "tool_calls": tool_calls or None, + }, + "finish_reason": None, + } + ], + } + ) + self._sse_data( + { + "id": payload.get("id", ""), + "object": "chat.completion.chunk", + "created": payload.get("created", int(time.time())), + "model": payload.get("model", ""), + "choices": [ + { + "index": choice.get("index", 0), + "delta": {}, + "finish_reason": choice.get("finish_reason") or "stop", + } + ], + } + ) + self.wfile.write(b"data: [DONE]\n\n") + + def _sse_data(self, payload: dict[str, Any]) -> None: + self.wfile.write(f"data: {json.dumps(payload)}\n\n".encode("utf-8")) + + return Handler + + +class TerminusPiRolloutWorker: + """TRL rollout worker that lets PI drive tools while trainer owns generation.""" + + def __init__( + self, + *, + session_factory: Any, + tasks: Sequence[Any], + tokenizer: Any, + vllm_base_url: str, + vllm_model: str, + vllm_api_key: str = "openenv", + config: WorkerConfig | None = None, + chat_template_kwargs: dict[str, Any] | None = None, + pi_command: str = "pi", + command_runner: Callable[..., Any] | None = None, + ): + self._session_factory = session_factory + self._tasks = list(tasks) + if not self._tasks: + raise ValueError("tasks must not be empty") + self._tokenizer = tokenizer + if add_response_schema is not None: + try: + self._tokenizer = add_response_schema(tokenizer) + except Exception: + logger.debug("could not add response schema to tokenizer", exc_info=True) + self._chat_template = None + if ( + get_training_chat_template is not None + and is_chat_template_prefix_preserving is not None + ): + try: + if not is_chat_template_prefix_preserving(self._tokenizer): + self._chat_template = get_training_chat_template(self._tokenizer) + except Exception: + logger.debug("could not inspect chat template", exc_info=True) + self._chat_template_kwargs = dict(chat_template_kwargs or {}) + self._vllm_base_url = vllm_base_url.rstrip("/") + self._vllm_model = vllm_model + self._vllm_api_key = vllm_api_key + self._config = config or WorkerConfig() + self._pi_command = pi_command + self._command_runner = command_runner + + self.rollout_buffer: queue.Queue[RolloutSample] = queue.Queue( + maxsize=self._config.queue_maxsize, + ) + self._interception = InterceptionServer(secret=vllm_api_key) + self._threads: list[threading.Thread] = [] + self._stop = threading.Event() + self._pause = threading.Event() + self._lock = threading.Lock() + self._weight_sync_lock = threading.Lock() + self._task_index = 0 + self._model_version = 0 + self._last_heartbeat_s = time.monotonic() + self._model_update_group: Any | None = None + + self._wait_for_server_ready() + self._init_weight_transfer() + + def start(self) -> None: + self._interception.start() + self._stop.clear() + self._last_heartbeat_s = time.monotonic() + with self._lock: + if self._threads: + return + for index in range(max(1, self._config.max_inflight)): + thread = threading.Thread( + target=self._loop, + args=(index,), + daemon=True, + name=f"terminus-pi-rollout-{index}", + ) + thread.start() + self._threads.append(thread) + + def stop(self) -> None: + self._stop.set() + for thread in self._threads: + thread.join(timeout=5.0) + with self._lock: + self._threads = [] + self._interception.stop() + self._destroy_model_update_group() + + def pause(self) -> None: + self._pause.set() + if self._model_update_group is not None: + self._post_json("/pause", params={"mode": "keep"}, timeout=60) + + def resume(self) -> None: + if self._model_update_group is not None: + self._post_json("/resume", timeout=60) + self._pause.clear() + + def send_weights(self, iterator: Iterator[tuple[str, Any]]) -> None: + items = list(iterator) + if not items: + return + if self._config.vllm_weight_name_prefix: + prefix = self._config.vllm_weight_name_prefix + items = [ + (name if name.startswith(prefix) else f"{prefix}{name}", tensor) + for name, tensor in items + ] + if self._model_update_group is None: + raise RuntimeError("vLLM weight-transfer group is not initialized") + + update_info = { + "names": [name for name, _ in items], + "dtype_names": [ + str(getattr(tensor, "dtype", "float32")).split(".")[-1] + for _, tensor in items + ], + "shapes": [list(getattr(tensor, "shape", [])) for _, tensor in items], + "packed": True, + "is_checkpoint_format": True, + } + + with self._weight_sync_lock: + if _VLLM_NEEDS_WEIGHT_UPDATE_LIFECYCLE: + self._post_json( + "/start_weight_update", + json_body={"is_checkpoint_format": True}, + timeout=60, + ) + + post_error: list[Exception] = [] + + def post_update() -> None: + try: + self._post_json( + "/update_weights", + json_body={"update_info": update_info}, + timeout=1800, + ) + except Exception as exc: # noqa: BLE001 + post_error.append(exc) + + update_thread = threading.Thread(target=post_update, daemon=True) + update_thread.start() + + assert NCCLTrainerSendWeightsArgs is not None + assert NCCLWeightTransferEngine is not None + NCCLWeightTransferEngine.trainer_send_weights( + iterator=iter(items), + trainer_args=NCCLTrainerSendWeightsArgs( + group=self._model_update_group, + packed=True, + ), + ) + + update_thread.join(timeout=1800) + if update_thread.is_alive(): + raise TimeoutError("timed out waiting for vLLM /update_weights") + if post_error: + raise RuntimeError("vLLM /update_weights failed") from post_error[0] + if _VLLM_NEEDS_WEIGHT_UPDATE_LIFECYCLE: + self._post_json("/finish_weight_update", timeout=120) + + def update_model_version(self, version: int) -> None: + with self._lock: + self._model_version = version + + def check_health(self, stale_after_s: float) -> None: + if not self._threads or not any(thread.is_alive() for thread in self._threads): + raise RuntimeError("Terminus PI rollout worker is not running") + age = time.monotonic() - self._last_heartbeat_s + if age > stale_after_s: + raise RuntimeError( + f"Terminus PI rollout worker heartbeat stale: {age:.0f}s" + ) + + def _wait_for_server_ready(self) -> None: + start = time.time() + while True: + try: + response = requests.get(f"{self._vllm_base_url}/health", timeout=5) + if response.status_code == 200: + return + except requests.RequestException: + pass + if time.time() - start >= self._config.server_timeout_s: + raise TimeoutError( + f"timed out waiting for vLLM server at {self._vllm_base_url}" + ) + time.sleep(2.0) + + def _post_json( + self, + path: str, + *, + timeout: float, + json_body: dict[str, Any] | None = None, + params: dict[str, Any] | None = None, + ) -> requests.Response: + response = requests.post( + f"{self._vllm_base_url}{path}", + headers={"Authorization": f"Bearer {self._vllm_api_key}"}, + json=json_body, + params=params, + timeout=timeout, + ) + if response.status_code != 200: + raise RuntimeError( + f"{path} returned {response.status_code}: {response.text[:400]}" + ) + return response + + def _init_weight_transfer(self) -> None: + if ( + NCCLTrainerSendWeightsArgs is None + or NCCLWeightTransferEngine is None + or get_ip is None + or get_open_port is None + ): + raise RuntimeError("vLLM NCCL weight-transfer modules are unavailable") + + response = requests.get( + f"{self._vllm_base_url}/get_world_size", + headers={"Authorization": f"Bearer {self._vllm_api_key}"}, + timeout=10, + ) + if response.status_code != 200: + raise RuntimeError( + "vLLM weight sync requires /get_world_size. Start vLLM with " + 'VLLM_SERVER_DEV_MODE=1 and --weight-transfer-config \'{"backend":"nccl"}\'.' + ) + + inference_world_size = int(response.json()["world_size"]) + init_info = { + "master_address": get_ip(), + "master_port": get_open_port(), + "rank_offset": 1, + "world_size": inference_world_size + 1, + } + post_error: list[Exception] = [] + + def post_init() -> None: + try: + self._post_json( + "/init_weight_transfer_engine", + json_body={"init_info": init_info}, + timeout=120, + ) + except Exception as exc: # noqa: BLE001 + post_error.append(exc) + + init_thread = threading.Thread(target=post_init, daemon=True) + init_thread.start() + self._model_update_group = NCCLWeightTransferEngine.trainer_init( + { + "master_address": init_info["master_address"], + "master_port": init_info["master_port"], + "world_size": init_info["world_size"], + } + ) + init_thread.join(timeout=120) + if init_thread.is_alive(): + raise TimeoutError("timed out waiting for vLLM weight-transfer init") + if post_error: + raise RuntimeError("vLLM weight-transfer init failed") from post_error[0] + + def _destroy_model_update_group(self) -> None: + group = self._model_update_group + if group is None: + return + try: + group.group.store = None + group.group.socket = None + except Exception: + logger.debug("could not destroy vLLM weight-transfer group", exc_info=True) + self._model_update_group = None + + def _loop(self, worker_index: int) -> None: + while not self._stop.is_set(): + while self._pause.is_set() and not self._stop.is_set(): + time.sleep(0.05) + if self._stop.is_set(): + return + task = self._next_task() + try: + sample = self._rollout( + task, + f"terminus-{worker_index}-{uuid.uuid4().hex[:8]}", + ) + self.rollout_buffer.put(sample, timeout=2.0) + self._last_heartbeat_s = time.monotonic() + except Exception: + logger.exception("terminus PI rollout failed") + time.sleep(self._config.idle_sleep_s) + + def _next_task(self) -> Any: + with self._lock: + task = self._tasks[self._task_index % len(self._tasks)] + self._task_index += 1 + return task + + def _rollout(self, task: Any, rollout_id: str) -> RolloutSample: + session = self._session_factory.create( + task=_session_task(task), + episode_id=rollout_id, + ) + request_queue = self._interception.register_rollout(rollout_id) + result_box: dict[str, Any] = {} + error_box: list[BaseException] = [] + + adapter = PiCLIHarnessAdapter( + pi_command=self._pi_command, + model=self._vllm_model, + model_base_url=f"{self._interception.base_url}/rollout/{rollout_id}/v1", + model_api_key=self._vllm_api_key, + timeout_s=self._config.request_timeout_s, + command_runner=self._command_runner, + ) + + def run_pi() -> None: + try: + result_box["rollout"] = adapter.run_black_box( + session=session, + limits=HarnessRunLimits(max_turns=self._config.max_turns), + ) + except BaseException as exc: # noqa: BLE001 + error_box.append(exc) + + pi_thread = threading.Thread(target=run_pi, daemon=True, name=f"pi-{rollout_id}") + pi_thread.start() + + all_ids: list[int] = [] + all_mask: list[int] = [] + all_logprobs: list[float] = [] + previous_prompt_and_turn: list[int] | None = None + turns = 0 + + try: + while turns < self._config.max_turns: + self._last_heartbeat_s = time.monotonic() + if error_box: + raise RuntimeError("pi subprocess failed") from error_box[0] + try: + request_id = request_queue.get(timeout=0.5) + except queue.Empty: + if not pi_thread.is_alive(): + break + continue + + intercept = self._interception.get_intercept(request_id) + if intercept is None: + continue + prompt_ids = self._render_prompt_ids(intercept) + if previous_prompt_and_turn is None: + all_ids.extend(prompt_ids) + all_mask.extend([0] * len(prompt_ids)) + all_logprobs.extend([0.0] * len(prompt_ids)) + else: + suffix = prompt_ids[len(previous_prompt_and_turn) :] + all_ids.extend(suffix) + all_mask.extend([0] * len(suffix)) + all_logprobs.extend([0.0] * len(suffix)) + + turn_ids, turn_logprobs, text, finish_reason = self._generate(prompt_ids) + all_ids.extend(turn_ids) + all_mask.extend([1] * len(turn_ids)) + all_logprobs.extend(turn_logprobs) + previous_prompt_and_turn = prompt_ids + turn_ids + turns += 1 + + assistant_message = _parse_assistant_message( + tokenizer=self._tokenizer, + completion_ids=turn_ids, + fallback_text=text, + ) + self._interception.deliver( + intercept, + _chat_response( + assistant_message, + model=self._vllm_model, + finish_reason=finish_reason, + ), + ) + + pi_thread.join(timeout=2.0) + if error_box: + raise RuntimeError("pi subprocess failed") from error_box[0] + rollout = result_box.get("rollout") + verify = session.verify( + transcript=[] if rollout is None else rollout.messages, + final_state=None + if rollout is None + else {"done": rollout.done, "metrics": dict(rollout.metrics)}, + ) + reward = float(verify.env_reward or 0.0) + if not all_ids: + pad_id = getattr(self._tokenizer, "pad_token_id", None) or 0 + all_ids = [pad_id] + all_mask = [1] + all_logprobs = [0.0] + metrics = {"reward": reward, "turns": float(turns)} + if rollout is not None: + metrics.update( + { + "pi/tool_calls": float(len(rollout.tool_trace)), + "pi/events": float(rollout.metrics.get("pi_events", 0.0)), + "pi/done": float(bool(rollout.done)), + } + ) + for name, value in (verify.metrics or {}).items(): + if isinstance(value, bool): + metrics[f"verify/{name}"] = float(value) + elif isinstance(value, (int, float)): + metrics[f"verify/{name}"] = float(value) + with self._lock: + model_version = self._model_version + return RolloutSample( + input_ids=all_ids, + completion_mask=all_mask, + old_log_probs=all_logprobs, + advantage=reward, + model_version=model_version, + metrics=metrics, + ) + finally: + self._interception.unregister_rollout(rollout_id) + pi_thread.join(timeout=1.0) + session.close() + + def _render_prompt_ids(self, intercept: dict[str, Any]) -> list[int]: + body = intercept.get("body") or {} + messages = body.get("messages") or intercept.get("messages") + if not isinstance(messages, list): + raise RuntimeError("intercepted request did not include messages") + messages = _normalize_chat_messages(messages) + kwargs: dict[str, Any] = { + "add_generation_prompt": True, + "return_dict": False, + **self._chat_template_kwargs, + } + if self._chat_template is not None: + kwargs["chat_template"] = self._chat_template + tools = body.get("tools") or intercept.get("tools") + if tools: + kwargs["tools"] = tools + try: + return cast(list[int], self._tokenizer.apply_chat_template(messages, **kwargs)) + except TypeError: + kwargs.pop("tools", None) + return cast(list[int], self._tokenizer.apply_chat_template(messages, **kwargs)) + + def _generate(self, prompt_ids: list[int]) -> tuple[list[int], list[float], str, str | None]: + response = requests.post( + f"{self._vllm_base_url}/v1/completions", + headers={ + "Authorization": f"Bearer {self._vllm_api_key}", + "Content-Type": "application/json", + }, + json={ + "model": self._vllm_model, + "prompt": prompt_ids, + "max_tokens": self._config.max_completion_tokens, + "temperature": self._config.temperature, + "n": 1, + "return_token_ids": True, + "logprobs": 0, + }, + timeout=self._config.request_timeout_s, + ) + if response.status_code != 200: + raise RuntimeError(f"vLLM {response.status_code}: {response.text[:400]}") + choice = response.json()["choices"][0] + token_ids = list(choice["token_ids"]) + logprobs = choice.get("logprobs", {}).get("token_logprobs", []) + if len(logprobs) != len(token_ids): + logprobs = [0.0] * len(token_ids) + return ( + token_ids, + [0.0 if value is None else float(value) for value in logprobs], + str(choice.get("text", "")), + choice.get("finish_reason"), + ) + + +def _parse_assistant_message( + *, + tokenizer: Any, + completion_ids: list[int], + fallback_text: str, +) -> dict[str, Any]: + parsed: dict[str, Any] = {} + try: + if parse_response is not None: + parsed = parse_response(tokenizer, completion_ids) + except Exception: + logger.debug("could not parse TRL response schema", exc_info=True) + if not isinstance(parsed, dict): + parsed = {} + content = str(parsed.get("content") or "") + tool_calls = _normalize_tool_calls(parsed.get("tool_calls")) + if not tool_calls: + tool_calls = _terminal_tool_call_from_text(content or fallback_text) + if tool_calls: + return {"role": "assistant", "content": "", "tool_calls": tool_calls} + return {"role": "assistant", "content": content or fallback_text} + + +def _normalize_chat_messages(messages: list[Any]) -> list[dict[str, Any]]: + normalized = [] + for message in messages: + if not isinstance(message, dict): + continue + item = dict(message) + content = item.get("content") + if content is None: + item["content"] = "" + elif isinstance(content, list): + item["content"] = "\n".join( + str(part.get("text", "")) + for part in content + if isinstance(part, dict) and part.get("text") is not None + ) + tool_calls = item.get("tool_calls") + if isinstance(tool_calls, list): + item["tool_calls"] = _normalize_message_tool_calls(tool_calls) + normalized.append(item) + return normalized + + +def _normalize_message_tool_calls(raw_tool_calls: list[Any]) -> list[dict[str, Any]]: + tool_calls = [] + for raw_call in raw_tool_calls: + if not isinstance(raw_call, dict): + continue + call = dict(raw_call) + function = call.get("function") + if not isinstance(function, dict): + tool_calls.append(call) + continue + function = dict(function) + arguments = function.get("arguments") + if isinstance(arguments, str): + try: + function["arguments"] = json.loads(arguments) + except json.JSONDecodeError: + function["arguments"] = {"command": arguments} + call["function"] = function + tool_calls.append(call) + return tool_calls + + +def _session_task(task: Any) -> Any: + if not isinstance(task, dict) or not isinstance(task.get("prompt"), list): + return task + instruction = "\n\n".join( + str(message.get("content", "")) + for message in task["prompt"] + if isinstance(message, dict) and message.get("content") + ) + if not instruction: + return task + return {**task, "instruction": instruction} + + +def _terminal_tool_call_from_text(text: str) -> list[dict[str, Any]]: + if build_terminal_tool_call is None or not text.strip(): + return [] + try: + tool_call = build_terminal_tool_call( + text, + call_id=f"call_{uuid.uuid4().hex[:8]}", + ) + except Exception: + logger.debug("could not parse Terminus terminal text", exc_info=True) + return [] + arguments = getattr(tool_call, "args", None) or {} + name = getattr(tool_call, "name", "") + if name != "terminal" or not arguments: + return [] + return [ + { + "id": str(getattr(tool_call, "id", "") or f"call_{uuid.uuid4().hex[:8]}"), + "type": "function", + "function": { + "name": name, + "arguments": json.dumps(arguments), + }, + } + ] + + +def _normalize_tool_calls(raw_tool_calls: Any) -> list[dict[str, Any]]: + if not isinstance(raw_tool_calls, list): + return [] + tool_calls = [] + for raw_call in raw_tool_calls: + if not isinstance(raw_call, dict): + continue + function = raw_call.get("function") + if not isinstance(function, dict) or not function.get("name"): + continue + arguments = function.get("arguments") or {} + tool_calls.append( + { + "id": str(raw_call.get("id") or f"call_{uuid.uuid4().hex[:8]}"), + "type": "function", + "function": { + "name": str(function["name"]), + "arguments": arguments + if isinstance(arguments, str) + else json.dumps(arguments), + }, + } + ) + return tool_calls + + +def _chat_response( + assistant_message: dict[str, Any], + *, + model: str, + finish_reason: str | None, +) -> dict[str, Any]: + choice_finish_reason = ( + "tool_calls" if assistant_message.get("tool_calls") else finish_reason or "stop" + ) + return { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion", + "created": int(time.time()), + "model": model, + "choices": [ + { + "index": 0, + "message": assistant_message, + "finish_reason": choice_finish_reason, + } + ], + } + + +def _error_response(message: str) -> dict[str, Any]: + return { + "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", + "object": "chat.completion", + "created": int(time.time()), + "model": "openenv-error", + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": message}, + "finish_reason": "stop", + } + ], + } + + +__all__ = ["TerminusPiRolloutWorker", "WorkerConfig"] diff --git a/examples/end_to_end/tbench2_pi_trl/train_terminus_grpo.py b/examples/end_to_end/tbench2_pi_trl/train_terminus_grpo.py new file mode 100644 index 000000000..31cb1b637 --- /dev/null +++ b/examples/end_to_end/tbench2_pi_trl/train_terminus_grpo.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python3 +# /// script +# requires-python = ">=3.10" +# dependencies = [ +# "aiohttp>=3.9.0", +# "accelerate>=1.10.0", +# "datasets>=3.0.0", +# "huggingface-hub>=0.35.0", +# "openenv-core", +# "openenv-terminus-env", +# "peft>=0.17.0", +# "trackio<0.25.0", +# "transformers @ git+https://github.com/huggingface/transformers.git@e1a37d29cd4822d74f4f3323289fb69e1eec61a0", +# "trl @ git+https://github.com/huggingface/trl.git@a7ba987d05b1e9dbbdbd2e9091264623746e3528", +# "vllm==0.19.1", +# ] +# [tool.uv.sources] +# openenv-core = { git = "https://github.com/burtenshaw/OpenEnv.git", branch = "codex/terminus-pi-trl-space" } +# openenv-terminus-env = { git = "https://github.com/burtenshaw/OpenEnv.git", branch = "codex/terminus-env-harness", subdirectory = "envs/terminus_env" } +# /// +"""Run Terminus async GRPO with PI rollouts owned by TRL.""" + +from __future__ import annotations + +import os +from pathlib import Path + +from datasets import load_dataset +from huggingface_hub import HfApi +from terminus_env.client import TerminusEnv +from terminus_env.harness import TerminusSessionFactory +from transformers import AutoTokenizer, TrainerCallback +from trl.experimental.async_grpo import AsyncGRPOConfig, AsyncGRPOTrainer + +from pi_rollout_worker import TerminusPiRolloutWorker, WorkerConfig + +TASK_DATASET_ID = "burtenshaw/terminus-pi-trl-hard-tasks" +MODEL = "Qwen/Qwen3.5-4B" +TRAINER_MODEL = os.environ.get("TERMINUS_TRAINER_MODEL", MODEL) +ENV_URL = os.environ.get("TERMINUS_ENV_URL", "http://localhost:8000") +OUTPUT_DIR = Path(os.environ.get("TERMINUS_OUTPUT_DIR", "/tmp/terminus-pi-trl-output")) +HUB_MODEL_ID = os.environ.get( + "TERMINUS_HUB_MODEL_ID", + "burtenshaw/terminus-pi-trl-async-grpo-qwen3-5-4b-hard", +) +TRACKIO_PROJECT = "terminus-pi-trl" +TRACKIO_SPACE_ID = os.environ.get( + "TRACKIO_SPACE_ID", + "burtenshaw/terminus-pi-trl-trackio", +) +REPORT_TO = "trackio" +RUN_NAME = os.environ.get("TERMINUS_RUN_NAME") or ( + os.environ.get("JOB_ID", "local") + "-terminus" +) +VLLM_SERVER_URL = os.environ.get("TERMINUS_VLLM_SERVER_URL", "http://localhost:8001") +VLLM_API_KEY = os.environ.get("VLLM_API_KEY", "openenv") + +os.environ["TRACKIO_PROJECT"] = TRACKIO_PROJECT + + +def main() -> None: + OUTPUT_DIR.mkdir(parents=True, exist_ok=True) + + task_dataset = load_dataset(TASK_DATASET_ID, split="train") + task = task_dataset[0] + max_steps = int(task["max_steps"]) + train_dataset = task_dataset.select_columns(["prompt"]) + session_factory = TerminusSessionFactory( + client_factory=lambda: TerminusEnv( + base_url=ENV_URL, + connect_timeout_s=30.0, + message_timeout_s=600.0, + ).sync(), + default_verify=list(task["verify"]), + ) + + tokenizer = AutoTokenizer.from_pretrained(MODEL) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + rank = int( + os.environ.get("RANK") + or os.environ.get("ACCELERATE_PROCESS_INDEX") + or os.environ.get("SLURM_PROCID") + or "0" + ) + worker = None + if rank == 0: + worker = TerminusPiRolloutWorker( + session_factory=session_factory, + tasks=list(task_dataset), + tokenizer=tokenizer, + vllm_base_url=VLLM_SERVER_URL, + vllm_model=MODEL, + vllm_api_key=VLLM_API_KEY, + chat_template_kwargs={"enable_thinking": False}, + config=WorkerConfig( + max_inflight=task["num_generations"], + max_turns=task["max_turns"], + max_completion_tokens=task["max_completion_length"], + vllm_weight_name_prefix="language_model.", + ), + ) + + def unused_reward(**kwargs: object) -> list[float]: + return [0.0] * len(kwargs.get("prompts", [])) + + trainer = AsyncGRPOTrainer( + model=TRAINER_MODEL, + args=AsyncGRPOConfig( + output_dir=str(OUTPUT_DIR), + max_steps=max_steps, + per_device_train_batch_size=task["batch_size"], + gradient_accumulation_steps=1, + num_generations=task["num_generations"], + max_completion_length=task["max_completion_length"], + max_inflight_tasks=task["num_generations"], + learning_rate=1e-6, + temperature=1.0, + weight_sync_steps=1, + logging_steps=1, + logging_strategy="steps", + log_completions=True, + report_to=REPORT_TO, + run_name=RUN_NAME, + project=TRACKIO_PROJECT, + trackio_space_id=TRACKIO_SPACE_ID, + save_strategy="no", + push_to_hub=True, + hub_model_id=HUB_MODEL_ID, + chat_template_kwargs={"enable_thinking": False}, + vllm_server_base_url=VLLM_SERVER_URL, + request_timeout=600, + vllm_server_timeout=600, + ), + processing_class=tokenizer, + train_dataset=train_dataset, + reward_funcs=unused_reward, + rollout_worker=worker, + ) + + class SaveAndUploadCallback(TrainerCallback): + def __init__(self) -> None: + self.saved = False + + def on_step_end(self, args, state, control, **kwargs): + if self.saved or state.global_step < max_steps: + return control + self.saved = True + state_dict = trainer.accelerator.get_state_dict(trainer.model) + if rank == 0: + trainer._save(str(OUTPUT_DIR), state_dict=state_dict) + del state_dict + api = HfApi() + api.create_repo(repo_id=HUB_MODEL_ID, repo_type="model", exist_ok=True) + api.upload_large_folder( + repo_id=HUB_MODEL_ID, + repo_type="model", + folder_path=OUTPUT_DIR, + num_workers=8, + ) + return control + + trainer.add_callback(SaveAndUploadCallback()) + trainer.train() + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 8da12c016..798c4b187 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,6 +85,7 @@ include-package-data = true [tool.setuptools.package-data] "openenv.cli" = ["templates/**/*"] +"openenv.core.harness" = ["*.mjs"] [tool.setuptools.exclude-package-data] "*" = ["*.pyc", "*.pyo", "__pycache__/*"] diff --git a/src/openenv/core/harness/__init__.py b/src/openenv/core/harness/__init__.py index 8c1412ca4..36a92120b 100644 --- a/src/openenv/core/harness/__init__.py +++ b/src/openenv/core/harness/__init__.py @@ -701,6 +701,8 @@ def rollout_func(prompts: list[Any], trainer: Any) -> dict[str, list[Any]]: return rollout_func +from .pi_cli import PiCLIHarnessAdapter + __all__ = [ "CLIHarnessAdapter", "HarnessAdapter", @@ -710,6 +712,7 @@ def rollout_func(prompts: list[Any], trainer: Any) -> dict[str, list[Any]]: "Message", "ModelStep", "ModelStepResult", + "PiCLIHarnessAdapter", "RESERVED_TOOL_NAMES", "ResourceSession", "ResourceSessionFactory", diff --git a/src/openenv/core/harness/pi_bridge.mjs b/src/openenv/core/harness/pi_bridge.mjs new file mode 100644 index 000000000..5d89a7bc9 --- /dev/null +++ b/src/openenv/core/harness/pi_bridge.mjs @@ -0,0 +1,78 @@ +import { Type } from "typebox"; + +const bridgeUrl = process.env.OPENENV_PI_BRIDGE_URL; +const modelBaseUrl = process.env.OPENENV_PI_MODEL_BASE_URL; +const modelId = process.env.OPENENV_PI_MODEL_ID; +const modelProvider = process.env.OPENENV_PI_MODEL_PROVIDER || "openenv-vllm"; +const modelApiKey = process.env.OPENENV_PI_MODEL_API_KEY || "openenv"; + +function registerModelProvider(pi) { + if (!modelBaseUrl || !modelId) { + return; + } + pi.registerProvider(modelProvider, { + baseUrl: modelBaseUrl, + apiKey: modelApiKey, + api: "openai-completions", + compat: { + supportsDeveloperRole: false, + supportsReasoningEffort: false, + thinkingFormat: "qwen-chat-template", + }, + models: [{ + id: modelId, + name: modelId, + reasoning: false, + input: ["text"], + contextWindow: 32768, + maxTokens: 4096, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + }], + }); +} + +async function callBridge(method, params = {}, id = method) { + if (!bridgeUrl) { + throw new Error("OPENENV_PI_BRIDGE_URL is not set"); + } + const response = await fetch(bridgeUrl, { + method: "POST", + headers: { "content-type": "application/json" }, + body: JSON.stringify({ + jsonrpc: "2.0", + id, + method, + params, + }), + }); + const payload = await response.json(); + if (!response.ok || payload.error) { + throw new Error(payload.error?.message || response.statusText); + } + return payload.result || {}; +} + +export default async function(pi) { + registerModelProvider(pi); + + const { tools = [] } = await callBridge("tools/list"); + for (const tool of tools) { + pi.registerTool({ + name: tool.name, + label: tool.name, + description: tool.description || tool.name, + parameters: Type.Unsafe(tool.inputSchema || { type: "object", properties: {} }), + async execute(toolCallId, params) { + const result = await callBridge( + "tools/call", + { name: tool.name, arguments: params || {} }, + toolCallId, + ); + return { + content: [{ type: "text", text: JSON.stringify(result.data ?? result) }], + details: result, + }; + }, + }); + } +} diff --git a/src/openenv/core/harness/pi_cli.py b/src/openenv/core/harness/pi_cli.py new file mode 100644 index 000000000..fe530d59f --- /dev/null +++ b/src/openenv/core/harness/pi_cli.py @@ -0,0 +1,254 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Pi CLI harness adapter.""" + +from __future__ import annotations + +import json +import os +import subprocess +import threading +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from importlib.resources import as_file, files +from typing import Any, Callable + +from . import ( + _serialize_for_message, + CLIHarnessAdapter, + HarnessRolloutResult, + HarnessRunLimits, + Message, + ResourceSession, + RolloutEvent, + SessionMCPBridge, + ToolResult, + ToolTraceEntry, +) + + +def _messages_to_prompt(messages: list[Message]) -> str: + if len(messages) == 1 and isinstance(messages[0].get("content"), str): + return str(messages[0]["content"]) + parts = [] + for message in messages: + role = str(message.get("role", "message")) + content = message.get("content", "") + parts.append(f"{role}:\n{_serialize_for_message(content)}") + return "\n\n".join(parts) + + +def _json_events(stdout: str) -> list[dict[str, Any]]: + events = [] + for line in stdout.splitlines(): + line = line.strip() + if not line: + continue + try: + event = json.loads(line) + except json.JSONDecodeError: + event = {"type": "raw", "text": line} + if isinstance(event, dict): + events.append(event) + return events + + +def _messages_from_events(events: list[dict[str, Any]], stdout: str) -> list[Message]: + for event in reversed(events): + messages = event.get("messages") + if event.get("type") == "agent_end" and isinstance(messages, list): + return [message for message in messages if isinstance(message, dict)] + + messages = [ + event["message"] + for event in events + if event.get("type") == "message_end" and isinstance(event.get("message"), dict) + ] + if messages: + return messages + return [{"role": "assistant", "content": stdout}] + + +def _response_body(handler: BaseHTTPRequestHandler, status: int, payload: Any) -> None: + body = json.dumps(payload, default=str).encode("utf-8") + handler.send_response(status) + handler.send_header("content-type", "application/json") + handler.send_header("content-length", str(len(body))) + handler.end_headers() + handler.wfile.write(body) + + +def _bridge_handler( + bridge: SessionMCPBridge, + tool_trace: list[ToolTraceEntry], +) -> type[BaseHTTPRequestHandler]: + class BridgeHandler(BaseHTTPRequestHandler): + def log_message(self, format: str, *args: Any) -> None: + return None + + def do_POST(self) -> None: + try: + content_length = int(self.headers.get("content-length", "0")) + request = json.loads(self.rfile.read(content_length).decode("utf-8")) + except Exception as exc: + _response_body( + self, + 400, + { + "jsonrpc": "2.0", + "id": None, + "error": {"code": -32700, "message": str(exc)}, + }, + ) + return + + response = bridge.handle_request(request) + if request.get("method") == "tools/call": + params = request.get("params", {}) or {} + result = response.get("result") + if isinstance(result, dict): + tool_trace.append( + ToolTraceEntry( + tool_name=str(params.get("name", "")), + arguments=dict(params.get("arguments", {}) or {}), + result=ToolResult( + data=result.get("data"), + done=bool(result.get("done")), + metadata=dict(result.get("metadata", {}) or {}), + error=result.get("error"), + ), + ) + ) + _response_body(self, 200, response) + + return BridgeHandler + + +class PiCLIHarnessAdapter(CLIHarnessAdapter): + """Black-box harness adapter that drives an OpenEnv session with the Pi CLI.""" + + def __init__( + self, + *, + pi_command: str | list[str] | tuple[str, ...] = "pi", + provider: str | None = None, + model: str | None = None, + model_base_url: str | None = None, + model_api_key: str = "openenv", + model_provider: str = "openenv-vllm", + cwd: str | None = None, + timeout_s: float = 600.0, + extra_args: list[str] | None = None, + command_runner: Callable[..., subprocess.CompletedProcess[str]] | None = None, + ): + self._pi_command = ( + [pi_command] if isinstance(pi_command, str) else list(pi_command) + ) + self._provider = provider + self._model = model + self._model_base_url = model_base_url + self._model_api_key = model_api_key + self._model_provider = model_provider + self._cwd = cwd + self._timeout_s = timeout_s + self._extra_args = list(extra_args or []) + self._command_runner = command_runner or subprocess.run + super().__init__(runner=self._run_pi) + + def _run_pi( + self, + bridge: SessionMCPBridge, + session: ResourceSession, + limits: HarnessRunLimits, + ) -> HarnessRolloutResult: + del limits + tools = session.list_tools() + tool_trace: list[ToolTraceEntry] = [] + bridge_resource = files("openenv.core.harness").joinpath("pi_bridge.mjs") + + with as_file(bridge_resource) as extension_path: + server = ThreadingHTTPServer( + ("127.0.0.1", 0), + _bridge_handler(bridge, tool_trace), + ) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + bridge_url = f"http://127.0.0.1:{server.server_port}" + + env = dict(os.environ) + env["OPENENV_PI_BRIDGE_URL"] = bridge_url + if self._model_base_url: + if not self._model: + raise ValueError("model is required when model_base_url is set") + env["OPENENV_PI_MODEL_BASE_URL"] = self._model_base_url + env["OPENENV_PI_MODEL_ID"] = self._model + env["OPENENV_PI_MODEL_PROVIDER"] = self._provider or self._model_provider + env["OPENENV_PI_MODEL_API_KEY"] = self._model_api_key + + command = [ + *self._pi_command, + "--mode", + "json", + "--print", + "--no-session", + "--no-builtin-tools", + "--no-extensions", + "--no-skills", + "--no-prompt-templates", + "--no-context-files", + "--extension", + str(extension_path), + "--tools", + ",".join(tool.name for tool in tools), + ] + provider = self._provider or ( + self._model_provider if self._model_base_url else None + ) + if provider: + command.extend(["--provider", provider]) + if self._model: + command.extend(["--model", self._model]) + command.extend(self._extra_args) + command.append(_messages_to_prompt(session.initial_messages())) + + try: + completed = self._command_runner( + command, + cwd=self._cwd, + env=env, + text=True, + capture_output=True, + timeout=self._timeout_s, + ) + finally: + server.shutdown() + server.server_close() + thread.join(timeout=1.0) + + stdout = completed.stdout or "" + stderr = completed.stderr or "" + if completed.returncode != 0: + raise RuntimeError( + "pi CLI failed with exit code " + f"{completed.returncode}: {(stderr or stdout).strip()}" + ) + + events = _json_events(stdout) + return HarnessRolloutResult( + messages=_messages_from_events(events, stdout), + tool_trace=tool_trace, + events=[RolloutEvent(type="pi_event", payload=event) for event in events], + done=bool(tool_trace and tool_trace[-1].result.done), + metrics={ + "harness": "pi_cli", + "pi_events": len(events), + "tool_calls": len(tool_trace), + "stderr": stderr, + }, + ) + + +__all__ = ["PiCLIHarnessAdapter"] diff --git a/tests/core/test_harness_runtime.py b/tests/core/test_harness_runtime.py index 388f70845..9d6149625 100644 --- a/tests/core/test_harness_runtime.py +++ b/tests/core/test_harness_runtime.py @@ -8,6 +8,9 @@ from __future__ import annotations +import json +import subprocess +import urllib.request from dataclasses import dataclass, field from typing import Any @@ -22,6 +25,7 @@ HarnessRunLimits, MCPHarnessAdapter, ModelStepResult, + PiCLIHarnessAdapter, RESERVED_TOOL_NAMES, RolloutEvent, SessionMCPBridge, @@ -700,3 +704,98 @@ def runner(bridge, current_session, limits): assert result.done is True assert result.metrics["mode"] == "black_box" assert result.events[0].payload["result"]["done"] is True + + +class TestPiCLIHarnessAdapter: + """Tests for the black-box Pi CLI harness adapter.""" + + def test_pi_cli_adapter_forwards_pi_tool_calls_to_session_bridge(self): + env = FakeSyncEnv() + session = StepEnvSessionAdapter( + client=env, + task="pi-task", + tool_specs=[ + Tool( + name="finish", + description="Finish the task", + input_schema={"type": "object", "properties": {}}, + ) + ], + action_builder=lambda name, arguments: name, + initial_messages_builder=lambda result, task: [ + {"role": "user", "content": f"Solve {task}"} + ], + ) + seen_commands: list[list[str]] = [] + + def fake_pi(command, cwd, env, text, capture_output, timeout): + del cwd, text, capture_output, timeout + seen_commands.append(list(command)) + bridge_url = env["OPENENV_PI_BRIDGE_URL"] + assert env["OPENENV_PI_MODEL_BASE_URL"] == "http://trainer.example/v1" + assert env["OPENENV_PI_MODEL_ID"] == "test-model" + assert env["OPENENV_PI_MODEL_PROVIDER"] == "openenv-vllm" + + def post(payload): + request = urllib.request.Request( + bridge_url, + data=json.dumps(payload).encode("utf-8"), + headers={"content-type": "application/json"}, + method="POST", + ) + with urllib.request.urlopen(request, timeout=5) as response: + return json.loads(response.read().decode("utf-8")) + + tools = post( + { + "jsonrpc": "2.0", + "id": "tools", + "method": "tools/list", + "params": {}, + } + ) + payload = post( + { + "jsonrpc": "2.0", + "id": "call-1", + "method": "tools/call", + "params": {"name": "finish", "arguments": {}}, + } + ) + + assert tools["result"]["tools"][0]["name"] == "finish" + assert payload["result"]["done"] is True + stdout = "\n".join( + [ + json.dumps({"type": "session", "id": "pi-session"}), + json.dumps( + { + "type": "agent_end", + "messages": [{"role": "assistant", "content": "done"}], + } + ), + ] + ) + return subprocess.CompletedProcess(command, 0, stdout=stdout, stderr="") + + adapter = PiCLIHarnessAdapter( + pi_command="pi", + model="test-model", + model_base_url="http://trainer.example/v1", + command_runner=fake_pi, + ) + + result = adapter.run_black_box(session=session, limits=HarnessRunLimits()) + + assert result.done is True + assert result.metrics["harness"] == "pi_cli" + assert result.tool_trace[0].tool_name == "finish" + assert result.tool_trace[0].result.metadata["reward"] == 1.0 + assert result.messages == [{"role": "assistant", "content": "done"}] + assert seen_commands[0][seen_commands[0].index("--provider") + 1] == "openenv-vllm" + assert seen_commands[0][seen_commands[0].index("--model") + 1] == "test-model" + assert seen_commands[0][:3] == ["pi", "--mode", "json"] + extension_path = seen_commands[0][seen_commands[0].index("--extension") + 1] + assert extension_path.endswith("pi_bridge.mjs") + assert "--no-context-files" in seen_commands[0] + assert seen_commands[0][-1] == "Solve pi-task"