Skip to content

Add HF sandbox provider#841

Open
burtenshaw wants to merge 5 commits into
huggingface:mainfrom
burtenshaw:ben/hf-sandbox-provider
Open

Add HF sandbox provider#841
burtenshaw wants to merge 5 commits into
huggingface:mainfrom
burtenshaw:ben/hf-sandbox-provider

Conversation

@burtenshaw

Copy link
Copy Markdown
Collaborator

This PR adds a minimal Hugging Face sandbox-backed provider for OpenEnv environment servers.

return url


class _LocalAuthProxy:

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Wauplin This class is a by product of hf auth and the url construction. I'm kinda ok with it here, but just wondered if there was a simpler way to deal with it.

@burtenshaw burtenshaw requested a review from adithya-s-k June 24, 2026 10:15
@burtenshaw burtenshaw marked this pull request as ready for review June 24, 2026 10:15
@burtenshaw

Copy link
Copy Markdown
Collaborator Author

HF sandbox benchmark

Run id: 20260624T115124Z
Namespace: burtenshaw
Image/flavor: python:3.12 / cpu-basic
Workload: each job allocates and touches the requested memory, then holds it while looping/sleeping for the requested duration.
Wall/throughput use the conservative group wall, falling back to HF backend total_secs when it is larger than the local timer.

Throughput by workload

workload concurrency success wall scheduling p50/p95 running p50/p95 total p50/p95 throughput avg peak RSS
short (10s, 128MB) 1 1/1 17.1s 4.0s/4.0s 11.0s/11.0s 15.0s/15.0s 3.50 jobs/min 142.1 MB
short (10s, 128MB) 2 2/2 17.0s 4.0s/4.0s 11.0s/11.0s 16.0s/16.0s 7.05 jobs/min 142.2 MB
short (10s, 128MB) 4 4/4 17.0s 4.0s/5.0s 11.0s/11.0s 15.5s/16.0s 14.12 jobs/min 142.1 MB
medium (45s, 512MB) 1 1/1 53.1s 5.0s/5.0s 46.0s/46.0s 51.0s/51.0s 1.13 jobs/min 527.7 MB
medium (45s, 512MB) 2 2/2 53.1s 4.5s/5.0s 46.0s/46.0s 50.5s/51.0s 2.26 jobs/min 527.5 MB
medium (45s, 512MB) 4 4/4 53.3s 4.0s/5.0s 46.0s/46.0s 51.0s/51.0s 4.51 jobs/min 527.7 MB
long (120s, 1024MB) 1 1/1 126.0s 4.0s/4.0s 121.0s/121.0s 126.0s/126.0s 0.48 jobs/min 1041.5 MB
long (120s, 1024MB) 2 2/2 128.8s 4.5s/5.0s 121.0s/121.0s 126.5s/127.0s 0.93 jobs/min 1041.6 MB
long (120s, 1024MB) 4 4/4 136.0s 8.0s/11.0s 121.0s/121.0s 129.5s/133.0s 1.77 jobs/min 1041.6 MB

Stage timing

workload concurrency first running p50 all done alloc mean
short (10s, 128MB) 1 5.3s 17.1s 0.1s
short (10s, 128MB) 2 6.3s 17.0s 0.1s
short (10s, 128MB) 4 7.4s 17.0s 0.1s
medium (45s, 512MB) 1 7.3s 53.1s 0.2s
medium (45s, 512MB) 2 6.2s 53.1s 0.2s
medium (45s, 512MB) 4 7.5s 53.3s 0.2s
long (120s, 1024MB) 1 7.3s 126.0s 0.5s
long (120s, 1024MB) 2 6.3s 128.8s 0.4s
long (120s, 1024MB) 4 9.9s 136.0s 0.4s

All benchmark jobs completed successfully.

@sergiopaniego sergiopaniego left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the first iteration!! tested locally and works nicely. looking forward to this integration :) Some AI-assisted review below with some ideas about when this is scaled.


I left inline comments (with suggested changes) on the specific spots. The framing I'd suggest: the correctness items and repo-convention items should land in this PR, while the scale-hardening items below are fine as documented limitations plus follow-up issues, as long as the design doesn't close the door on them.

Why the correctness items matter here specifically: this provider will be used to run many environments in parallel for RL training, where a single env that fails silently is worse than one that fails loudly. It either wastes wall-clock (the whole batch waits on a straggler) or feeds garbage observations and rewards into the policy.

Scale hardening (follow-up OK, but worth flagging)

  • Configurable startup timeouts. The 120s budgets are hardcoded. Heavy images and GPU flavors can take much longer to schedule, so these should be __init__ params.
  • Retry/backoff on run_job. At high concurrency the Jobs API will rate-limit or return transient init errors. A bounded retry would avoid losing envs to a single blip.
  • Reuse a shared httpx.AsyncClient in the proxy. A new client (and connection pool) is created per request. With many tool-call steps across many envs this is measurable hot-path overhead.
  • Per-env proxy footprint. Each provider instance starts its own uvicorn server, thread, and event loop. At high env concurrency that's a lot of local servers on the trainer host. Worth documenting as a known limit, or considering a single shared proxy that multiplexes upstreams.
  • Observability. The provider logs nothing. With many concurrent envs, structured logging of job id, stage, and failure cause is what makes a failing env debuggable.

Tests & docs

  • No tests yet. The pure helpers (_job_port_url, _to_ws_url, _find_available_port, hop-by-hop filtering) plus wait_for_ready are all testable with no network, matching the existing provider test suite under tests/test_core/.
  • Docstrings. Public methods need the HF doc-builder docstring format used by the other providers.

Comment on lines +253 to +263
def wait_for_ready(self, base_url: str, timeout_s: float = 120.0) -> None:
deadline = time.time() + timeout_s
health_url = f"{base_url}/health"
while time.time() < deadline:
response = requests.get(health_url, timeout=5.0)
if response.status_code == 200:
return
time.sleep(1.0)
raise TimeoutError(
f"HF sandbox job at {base_url} did not become ready within {timeout_s}s"
)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait_for_ready aborts on the first transient error instead of honoring timeout_s.

requests.get(..., timeout=5.0) raises on a slow or refused response, and with no try/except the exception escapes the loop and kills the call long before the timeout_s budget is reached. I reproduced it by pointing wait_for_ready at a closed port: current code dies in 0.00s with ConnectionError, whereas with the try/except it retries and dies in ~8.0s with a clean TimeoutError.

Suggested change
def wait_for_ready(self, base_url: str, timeout_s: float = 120.0) -> None:
deadline = time.time() + timeout_s
health_url = f"{base_url}/health"
while time.time() < deadline:
response = requests.get(health_url, timeout=5.0)
if response.status_code == 200:
return
time.sleep(1.0)
raise TimeoutError(
f"HF sandbox job at {base_url} did not become ready within {timeout_s}s"
)
def wait_for_ready(self, base_url: str, timeout_s: float = 120.0) -> None:
deadline = time.time() + timeout_s
health_url = f"{base_url}/health"
while time.time() < deadline:
try:
response = requests.get(health_url, timeout=5.0)
if response.status_code == 200:
return
except requests.exceptions.RequestException:
pass
time.sleep(1.0)
raise TimeoutError(
f"HF sandbox job at {base_url} did not become ready within {timeout_s}s"
)

Good candidate for the first unit test (mock requests.get to raise, assert it still respects timeout_s).

Comment on lines +96 to +103
async with httpx.AsyncClient(follow_redirects=True) as client:
upstream = await client.request(
request.method,
target,
content=await request.body(),
headers=headers,
timeout=60.0,
)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Distinguish "env died" from "env returned an error".

When the upstream job dies mid-rollout, the httpx transport error propagates and FastAPI returns a generic 500, indistinguishable from a legitimate application-level error from the env server itself. For a training loop these need different handling (restart the env vs. record a step error). Returning a 502 makes infra failures explicit and restartable:

Suggested change
async with httpx.AsyncClient(follow_redirects=True) as client:
upstream = await client.request(
request.method,
target,
content=await request.body(),
headers=headers,
timeout=60.0,
)
try:
async with httpx.AsyncClient(follow_redirects=True) as client:
upstream = await client.request(
request.method,
target,
content=await request.body(),
headers=headers,
timeout=60.0,
)
except httpx.HTTPError:
return Response(
content=b"upstream HF job unreachable",
status_code=502,
)

Comment on lines +225 to +232
while target_url is None and time.time() < deadline:
time.sleep(0.5)
self._job = self._api.inspect_job(
job_id=self._job.id,
namespace=self._job.owner.name,
token=self._token,
)
target_url = _job_port_url(self._job, _DEFAULT_PORT)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fail fast when the job terminates before exposing its port.

This loop only watches expose_urls, so a job that goes to a terminal state (bad image, server command crashes) waits the full timeout and then raises a generic message with no signal about what actually failed. Detecting a terminal stage lets a broken env fail fast and identifiably:

Suggested change
while target_url is None and time.time() < deadline:
time.sleep(0.5)
self._job = self._api.inspect_job(
job_id=self._job.id,
namespace=self._job.owner.name,
token=self._token,
)
target_url = _job_port_url(self._job, _DEFAULT_PORT)
while target_url is None and time.time() < deadline:
time.sleep(0.5)
self._job = self._api.inspect_job(
job_id=self._job.id,
namespace=self._job.owner.name,
token=self._token,
)
stage = getattr(self._job.status, "stage", None)
if stage in ("ERROR", "COMPLETED", "DELETED"):
raise RuntimeError(
f"HF job {self._job.id} terminated early (stage={stage}) "
f"before exposing port {_DEFAULT_PORT}"
)
target_url = _job_port_url(self._job, _DEFAULT_PORT)

(Worth confirming the exact terminal stage names against the huggingface_hub version you target.)

Comment on lines +244 to +251
if self._job is not None:
self._api.cancel_job(
job_id=self._job.id,
namespace=self._job.owner.name,
token=self._token,
)
self._job = None
self._token = None

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make teardown idempotent and non-throwing.

self._job and self._token are reset after cancel_job. If cancel_job raises (network blip), the state is left dirty, and since close() runs on __exit__, that exception masks the original error from the with block. This matters at scale: training loops oversample and cancel unfinished rollouts aggressively, so teardown runs often and on the hot path. A teardown that can leak a (billable) job or throw will accumulate orphans over a long run.

Suggested change
if self._job is not None:
self._api.cancel_job(
job_id=self._job.id,
namespace=self._job.owner.name,
token=self._token,
)
self._job = None
self._token = None
if self._job is not None:
try:
self._api.cancel_job(
job_id=self._job.id,
namespace=self._job.owner.name,
token=self._token,
)
finally:
self._job = None
self._token = None

Comment on lines +188 to +193
def start_container(
self,
image: str,
port: int | None = None,
env_vars: dict[str, str] | None = None,
) -> str:

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

start_container drops **kwargs.

The abstract ContainerProvider.start_container declares **kwargs: Any, and EnvClient calls provider.start_container(image, **kwargs) (see env_client.py:296). Without it, the generic path raises TypeError for any forwarded kwarg. Any is already imported:

Suggested change
def start_container(
self,
image: str,
port: int | None = None,
env_vars: dict[str, str] | None = None,
) -> str:
def start_container(
self,
image: str,
port: int | None = None,
env_vars: dict[str, str] | None = None,
**kwargs: Any,
) -> str:

@@ -0,0 +1,269 @@
"""Hugging Face-backed provider for OpenEnv environment servers."""

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing license header. Every other module in the repo carries the BSD header. Part of the repo convention, should be added.

Suggested change
"""Hugging Face-backed provider for OpenEnv environment servers."""
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Hugging Face-backed provider for OpenEnv environment servers."""

Comment on lines +1 to +2
#!/usr/bin/env python3
"""Smoke-check a real OpenEnv server through the HF sandbox provider."""

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing license header. Same as the provider module (header goes after the shebang):

Suggested change
#!/usr/bin/env python3
"""Smoke-check a real OpenEnv server through the HF sandbox provider."""
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Smoke-check a real OpenEnv server through the HF sandbox provider."""

Comment on lines +58 to +63
def _to_ws_url(url: str) -> str:
if url.startswith("https://"):
return "wss://" + url[len("https://") :]
if url.startswith("http://"):
return "ws://" + url[len("http://") :]
return url

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enforce secure transport on the upstream URL. The sibling cloud provider rejects non-HTTPS URLs (the token is a bearer secret, and EnvClient derives its WebSocket URL from this base). _to_ws_url silently downgrades to ws://, and the proxy injects the Bearer token on it. Worth asserting target_url starts with https:// in _wait_for_job_url so the invariant is explicit. (Left as a comment rather than a suggestion since the guard belongs in _wait_for_job_url, not here.)

Comment on lines +158 to +160
async def _client_to_upstream(self, websocket: WebSocket, upstream: Any) -> None:
async for message in websocket.iter_text():
await upstream.send(message)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted, not a blocker: the relay handles bytes and text upstream-to-client but only text client-to-upstream (iter_text()). Fine in practice since EnvClient only ever sends JSON text, but a short comment noting the assumption would help future readers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants