Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions examples/hf_sandbox_coding_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#!/usr/bin/env python3
"""Smoke-check a real OpenEnv server through the HF sandbox provider."""
Comment on lines +1 to +2

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing license header. Same as the provider module (header goes after the shebang):

Suggested change
#!/usr/bin/env python3
"""Smoke-check a real OpenEnv server through the HF sandbox provider."""
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Smoke-check a real OpenEnv server through the HF sandbox provider."""


from __future__ import annotations

import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).resolve().parents[1]))

from envs.coding_env import CodeAction, CodingEnv
from openenv.core.containers.runtime.hf_sandbox_provider import HFSandboxProvider


IMAGE = "hf.co/spaces/openenv/coding_env"


def run_code(base_url: str, code: str) -> str:
with CodingEnv(base_url=base_url).sync() as env:
env.reset()
result = env.step(CodeAction(code=code))
observation = result.observation
if observation.exit_code != 0:
raise RuntimeError(observation.stderr)
return observation.stdout.strip()


def main() -> None:
with HFSandboxProvider() as provider:
base_url = provider.start_container(IMAGE)
print(f"provider URL: {base_url}")
provider.wait_for_ready(base_url, timeout_s=300.0)

first_output = run_code(base_url, "answer = 40 + 2\nprint(answer)")
print(f"first connection output: {first_output!r}")
if first_output != "42":
raise RuntimeError("first HF sandbox connection returned unexpected output")

second_output = run_code(
base_url, "message = 'second connection ok'\nprint(message)"
)
print(f"second connection output: {second_output!r}")
if second_output != "second connection ok":
raise RuntimeError(
"second HF sandbox connection returned unexpected output"
)

print("HF sandbox coding_env check passed")


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ dependencies = [
"typer>=0.9.0",
"rich>=13.0.0",
"pyyaml>=6.0",
"huggingface_hub>=0.20.0",
"huggingface_hub>=1.20.1",
"openai>=2.7.2",
"tomli>=2.3.0",
"tomli-w>=1.2.0",
Expand All @@ -46,7 +46,7 @@ cli = [
"typer>=0.9.0",
"rich>=13.0.0",
"pyyaml>=6.0",
"huggingface_hub>=0.20.0",
"huggingface_hub>=1.20.1",
"openai>=2.7.2",
"tomli>=2.3.0",
"tomli-w>=1.2.0",
Expand Down
269 changes: 269 additions & 0 deletions src/openenv/core/containers/runtime/hf_sandbox_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
"""Hugging Face-backed provider for OpenEnv environment servers."""

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing license header. Every other module in the repo carries the BSD header. Part of the repo convention, should be added.

Suggested change
"""Hugging Face-backed provider for OpenEnv environment servers."""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Hugging Face-backed provider for OpenEnv environment servers."""


from __future__ import annotations

import asyncio
import socket
import threading
import time
from contextlib import suppress
from typing import Any

import httpx
import requests
import uvicorn
from fastapi import FastAPI, Request, Response, WebSocket
from huggingface_hub import HfApi
from huggingface_hub.utils import get_token
from starlette.websockets import WebSocketDisconnect
from websockets.asyncio.client import connect as ws_connect
from websockets.exceptions import ConnectionClosed

from .providers import ContainerProvider


_DEFAULT_PORT = 8000
_SERVER_COMMAND = "server"
_JOB_TIMEOUT = "24h"
_STARTUP_TIMEOUT_S = 120.0
_HOP_BY_HOP_HEADERS = {
"connection",
"content-encoding",
"content-length",
"host",
"keep-alive",
"proxy-authenticate",
"proxy-authorization",
"te",
"trailer",
"transfer-encoding",
"upgrade",
}


def _find_available_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.bind(("127.0.0.1", 0))
sock.listen(1)
return sock.getsockname()[1]


def _job_port_url(job: Any, port: int) -> str | None:
for url in job.status.expose_urls or []:
if f"--{port}." in url:
return str(url)
return None


def _to_ws_url(url: str) -> str:
if url.startswith("https://"):
return "wss://" + url[len("https://") :]
if url.startswith("http://"):
return "ws://" + url[len("http://") :]
return url
Comment on lines +58 to +63

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enforce secure transport on the upstream URL. The sibling cloud provider rejects non-HTTPS URLs (the token is a bearer secret, and EnvClient derives its WebSocket URL from this base). _to_ws_url silently downgrades to ws://, and the proxy injects the Bearer token on it. Worth asserting target_url starts with https:// in _wait_for_job_url so the invariant is explicit. (Left as a comment rather than a suggestion since the guard belongs in _wait_for_job_url, not here.)



class _LocalAuthProxy:

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Wauplin This class is a by product of hf auth and the url construction. I'm kinda ok with it here, but just wondered if there was a simpler way to deal with it.

def __init__(self, *, target_url: str, token: str):
self.target_url = target_url.rstrip("/")
self.token = token
self.port = _find_available_port()
self._server: uvicorn.Server | None = None
self._thread: threading.Thread | None = None

@property
def base_url(self) -> str:
return f"http://127.0.0.1:{self.port}"

def start(self) -> str:
app = FastAPI()

@app.api_route(
"/{path:path}",
methods=["GET", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"],
)
async def proxy_http(path: str, request: Request) -> Response:
query = request.url.query
target = f"{self.target_url}/{path}"
if query:
target = f"{target}?{query}"
headers = {
key: value
for key, value in request.headers.items()
if key.lower() not in _HOP_BY_HOP_HEADERS
}
headers["authorization"] = f"Bearer {self.token}"
async with httpx.AsyncClient(follow_redirects=True) as client:
upstream = await client.request(
request.method,
target,
content=await request.body(),
headers=headers,
timeout=60.0,
)
Comment on lines +96 to +103

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Distinguish "env died" from "env returned an error".

When the upstream job dies mid-rollout, the httpx transport error propagates and FastAPI returns a generic 500, indistinguishable from a legitimate application-level error from the env server itself. For a training loop these need different handling (restart the env vs. record a step error). Returning a 502 makes infra failures explicit and restartable:

Suggested change
async with httpx.AsyncClient(follow_redirects=True) as client:
upstream = await client.request(
request.method,
target,
content=await request.body(),
headers=headers,
timeout=60.0,
)
try:
async with httpx.AsyncClient(follow_redirects=True) as client:
upstream = await client.request(
request.method,
target,
content=await request.body(),
headers=headers,
timeout=60.0,
)
except httpx.HTTPError:
return Response(
content=b"upstream HF job unreachable",
status_code=502,
)

response_headers = {
key: value
for key, value in upstream.headers.items()
if key.lower() not in _HOP_BY_HOP_HEADERS
}
return Response(
content=upstream.content,
status_code=upstream.status_code,
headers=response_headers,
)

@app.websocket("/{path:path}")
async def proxy_websocket(path: str, websocket: WebSocket) -> None:
query = websocket.url.query
target = f"{_to_ws_url(self.target_url)}/{path}"
if query:
target = f"{target}?{query}"
await websocket.accept()
async with ws_connect(
target,
additional_headers={"Authorization": f"Bearer {self.token}"},
) as upstream:
to_upstream = asyncio.create_task(
self._client_to_upstream(websocket, upstream)
)
to_client = asyncio.create_task(
self._upstream_to_client(websocket, upstream)
)
done, pending = await asyncio.wait(
{to_upstream, to_client},
return_when=asyncio.FIRST_COMPLETED,
)
for task in pending:
task.cancel()
for task in done:
with suppress(ConnectionClosed, WebSocketDisconnect):
task.result()

config = uvicorn.Config(
app,
host="127.0.0.1",
port=self.port,
log_level="warning",
access_log=False,
)
self._server = uvicorn.Server(config)
self._thread = threading.Thread(target=self._server.run, daemon=True)
self._thread.start()
while not self._server.started:
if not self._thread.is_alive():
raise RuntimeError("HF sandbox auth proxy failed to start")
time.sleep(0.05)
return self.base_url

async def _client_to_upstream(self, websocket: WebSocket, upstream: Any) -> None:
async for message in websocket.iter_text():
await upstream.send(message)
Comment on lines +158 to +160

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted, not a blocker: the relay handles bytes and text upstream-to-client but only text client-to-upstream (iter_text()). Fine in practice since EnvClient only ever sends JSON text, but a short comment noting the assumption would help future readers.


async def _upstream_to_client(self, websocket: WebSocket, upstream: Any) -> None:
async for message in upstream:
if isinstance(message, bytes):
await websocket.send_bytes(message)
else:
await websocket.send_text(message)

def stop(self) -> None:
if self._server is None or self._thread is None:
return
self._server.should_exit = True
self._thread.join(timeout=5.0)
self._server = None
self._thread = None


class HFSandboxProvider(ContainerProvider):
"""Run an OpenEnv server on Hugging Face infrastructure."""

def __init__(self, *, flavor: str = "cpu-basic"):
self.flavor = flavor
self._api = HfApi()
self._token: str | None = None
self._job: Any = None
self._proxy: _LocalAuthProxy | None = None

def start_container(
self,
image: str,
port: int | None = None,
env_vars: dict[str, str] | None = None,
) -> str:
Comment on lines +188 to +193

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

start_container drops **kwargs.

The abstract ContainerProvider.start_container declares **kwargs: Any, and EnvClient calls provider.start_container(image, **kwargs) (see env_client.py:296). Without it, the generic path raises TypeError for any forwarded kwarg. Any is already imported:

Suggested change
def start_container(
self,
image: str,
port: int | None = None,
env_vars: dict[str, str] | None = None,
) -> str:
def start_container(
self,
image: str,
port: int | None = None,
env_vars: dict[str, str] | None = None,
**kwargs: Any,
) -> str:

if self._job is not None:
raise RuntimeError("HFSandboxProvider already has an active job")

if port not in (None, _DEFAULT_PORT):
raise ValueError(
f"HFSandboxProvider only supports port {_DEFAULT_PORT} (got {port})."
)

effective_token = get_token()
if not effective_token:
raise ValueError(
"HFSandboxProvider requires a Hugging Face token. Run `hf auth login`."
)

self._job = self._api.run_job(
image=image,
command=[_SERVER_COMMAND],
env=env_vars,
flavor=self.flavor,
timeout=_JOB_TIMEOUT,
expose=[_DEFAULT_PORT],
token=effective_token,
)
self._token = effective_token
target_url = self._wait_for_job_url()
self._proxy = _LocalAuthProxy(target_url=target_url, token=effective_token)
return self._proxy.start()

def _wait_for_job_url(self) -> str:
deadline = time.time() + _STARTUP_TIMEOUT_S
target_url = _job_port_url(self._job, _DEFAULT_PORT)
while target_url is None and time.time() < deadline:
time.sleep(0.5)
self._job = self._api.inspect_job(
job_id=self._job.id,
namespace=self._job.owner.name,
token=self._token,
)
target_url = _job_port_url(self._job, _DEFAULT_PORT)
Comment on lines +225 to +232

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fail fast when the job terminates before exposing its port.

This loop only watches expose_urls, so a job that goes to a terminal state (bad image, server command crashes) waits the full timeout and then raises a generic message with no signal about what actually failed. Detecting a terminal stage lets a broken env fail fast and identifiably:

Suggested change
while target_url is None and time.time() < deadline:
time.sleep(0.5)
self._job = self._api.inspect_job(
job_id=self._job.id,
namespace=self._job.owner.name,
token=self._token,
)
target_url = _job_port_url(self._job, _DEFAULT_PORT)
while target_url is None and time.time() < deadline:
time.sleep(0.5)
self._job = self._api.inspect_job(
job_id=self._job.id,
namespace=self._job.owner.name,
token=self._token,
)
stage = getattr(self._job.status, "stage", None)
if stage in ("ERROR", "COMPLETED", "DELETED"):
raise RuntimeError(
f"HF job {self._job.id} terminated early (stage={stage}) "
f"before exposing port {_DEFAULT_PORT}"
)
target_url = _job_port_url(self._job, _DEFAULT_PORT)

(Worth confirming the exact terminal stage names against the huggingface_hub version you target.)

if target_url is None:
raise RuntimeError(
f"HF job did not expose port {_DEFAULT_PORT} within "
f"{_STARTUP_TIMEOUT_S:.1f}s"
)
return target_url

def stop_container(self) -> None:
if self._proxy is not None:
self._proxy.stop()
self._proxy = None
if self._job is not None:
self._api.cancel_job(
job_id=self._job.id,
namespace=self._job.owner.name,
token=self._token,
)
self._job = None
self._token = None
Comment on lines +244 to +251

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make teardown idempotent and non-throwing.

self._job and self._token are reset after cancel_job. If cancel_job raises (network blip), the state is left dirty, and since close() runs on __exit__, that exception masks the original error from the with block. This matters at scale: training loops oversample and cancel unfinished rollouts aggressively, so teardown runs often and on the hot path. A teardown that can leak a (billable) job or throw will accumulate orphans over a long run.

Suggested change
if self._job is not None:
self._api.cancel_job(
job_id=self._job.id,
namespace=self._job.owner.name,
token=self._token,
)
self._job = None
self._token = None
if self._job is not None:
try:
self._api.cancel_job(
job_id=self._job.id,
namespace=self._job.owner.name,
token=self._token,
)
finally:
self._job = None
self._token = None


def wait_for_ready(self, base_url: str, timeout_s: float = 120.0) -> None:
deadline = time.time() + timeout_s
health_url = f"{base_url}/health"
while time.time() < deadline:
response = requests.get(health_url, timeout=5.0)
if response.status_code == 200:
return
time.sleep(1.0)
raise TimeoutError(
f"HF sandbox job at {base_url} did not become ready within {timeout_s}s"
)
Comment on lines +253 to +263

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait_for_ready aborts on the first transient error instead of honoring timeout_s.

requests.get(..., timeout=5.0) raises on a slow or refused response, and with no try/except the exception escapes the loop and kills the call long before the timeout_s budget is reached. I reproduced it by pointing wait_for_ready at a closed port: current code dies in 0.00s with ConnectionError, whereas with the try/except it retries and dies in ~8.0s with a clean TimeoutError.

Suggested change
def wait_for_ready(self, base_url: str, timeout_s: float = 120.0) -> None:
deadline = time.time() + timeout_s
health_url = f"{base_url}/health"
while time.time() < deadline:
response = requests.get(health_url, timeout=5.0)
if response.status_code == 200:
return
time.sleep(1.0)
raise TimeoutError(
f"HF sandbox job at {base_url} did not become ready within {timeout_s}s"
)
def wait_for_ready(self, base_url: str, timeout_s: float = 120.0) -> None:
deadline = time.time() + timeout_s
health_url = f"{base_url}/health"
while time.time() < deadline:
try:
response = requests.get(health_url, timeout=5.0)
if response.status_code == 200:
return
except requests.exceptions.RequestException:
pass
time.sleep(1.0)
raise TimeoutError(
f"HF sandbox job at {base_url} did not become ready within {timeout_s}s"
)

Good candidate for the first unit test (mock requests.get to raise, assert it still respects timeout_s).


def close(self) -> None:
self.stop_container()


__all__ = ["HFSandboxProvider"]
Loading