Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 138 additions & 8 deletions strands_robots/tools/gr00t_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,90 @@
"""

import os
import re
import socket
import subprocess
import time
from typing import Any

from strands import tool

# ─────────────────────────────────────────────────────────────────────
# Input validation helpers
# ─────────────────────────────────────────────────────────────────────

# Characters that must never appear in values interpolated into commands.
_SHELL_META = re.compile(r"[;&|`$(){}\[\]!<>\\'\"\n\r\x00]")

# Strict patterns for enumerable parameters.
_DATA_CONFIG_RE = re.compile(r"^[a-z][a-z0-9_]{0,63}$")
_EMBODIMENT_TAG_RE = re.compile(r"^[a-z][a-z0-9_]{0,31}$")
_CONTAINER_NAME_RE = re.compile(r"^[a-zA-Z0-9][a-zA-Z0-9._-]{0,127}$")

# Allowlists for TensorRT dtype parameters.
_VALID_VIT_DTYPES = {"fp16", "fp8"}
_VALID_LLM_DTYPES = {"fp16", "nvfp4", "fp8"}
_VALID_DIT_DTYPES = {"fp16", "fp8"}


def _validate_path(value: str, label: str) -> None:
"""Reject paths containing shell metacharacters, null bytes, or traversal sequences."""
if "\x00" in value:
raise ValueError(f"{label} must not contain null bytes")
if ".." in value.split("/"):
raise ValueError(f"{label} must not contain '..' path traversal components")
if _SHELL_META.search(value):
raise ValueError(f"{label} contains disallowed characters: {value!r}")


def validate_inputs(
*,
data_config: str,
embodiment_tag: str,
port: int,
vit_dtype: str,
llm_dtype: str,
dit_dtype: str,
checkpoint_path: str | None = None,
trt_engine_path: str = "gr00t_engine",
container_name: str | None = None,
) -> None:
"""Validate all user-supplied parameters in one place.

Raises ValueError for any invalid input. This centralises validation so
that the main tool function stays focused on orchestration and each
check is independently testable via this single entry-point.
"""
# Enumerable string parameters
if not _DATA_CONFIG_RE.match(data_config):
raise ValueError(
f"data_config must be lowercase alphanumeric/underscore (got {data_config!r}). "
f"See the tool docstring for the full list of accepted configs."
)
if not _EMBODIMENT_TAG_RE.match(embodiment_tag):
raise ValueError(f"embodiment_tag must be lowercase alphanumeric/underscore (got {embodiment_tag!r})")

# Docker container name
if container_name is not None and not _CONTAINER_NAME_RE.match(container_name):
raise ValueError(f"container_name must match Docker naming rules (got {container_name!r})")

# Filesystem paths — reject shell metacharacters and traversal
if checkpoint_path is not None:
_validate_path(checkpoint_path, "checkpoint_path")
_validate_path(trt_engine_path, "trt_engine_path")

# TensorRT dtype allowlists
if vit_dtype not in _VALID_VIT_DTYPES:
raise ValueError(f"vit_dtype must be one of {_VALID_VIT_DTYPES}, got {vit_dtype!r}")
if llm_dtype not in _VALID_LLM_DTYPES:
raise ValueError(f"llm_dtype must be one of {_VALID_LLM_DTYPES}, got {llm_dtype!r}")
if dit_dtype not in _VALID_DIT_DTYPES:
raise ValueError(f"dit_dtype must be one of {_VALID_DIT_DTYPES}, got {dit_dtype!r}")

# Port range
if not (1 <= port <= 65535):
raise ValueError(f"port must be between 1 and 65535, got {port}")


@tool
def gr00t_inference(
Expand All @@ -24,7 +101,7 @@ def gr00t_inference(
data_config: str = "fourier_gr1_arms_only",
embodiment_tag: str = "gr1",
denoising_steps: int = 4,
host: str = "0.0.0.0",
host: str = "127.0.0.1",
container_name: str | None = None,
timeout: int = 60,
use_tensorrt: bool = False,
Expand Down Expand Up @@ -112,7 +189,7 @@ def gr00t_inference(
data_config: Embodiment data config name (see Data configs above).
embodiment_tag: Embodiment tag for the model (e.g., ``gr1``, ``so100``).
denoising_steps: Number of denoising steps for action generation (default: 4).
host: Host address to bind the service to (default: ``0.0.0.0``).
host: Host address to bind the service to (default: ``127.0.0.1``).
container_name: Specific Docker container name. Auto-detected if omitted.
timeout: Seconds to wait for service startup (default: 60).
use_tensorrt: Enable TensorRT acceleration (default: False).
Expand Down Expand Up @@ -180,6 +257,19 @@ def gr00t_inference(
if api_token is None:
api_token = os.environ.get("GROOT_API_TOKEN")

# ── Validate all inputs in one call ───────────────────────────────
validate_inputs(
data_config=data_config,
embodiment_tag=embodiment_tag,
port=port,
vit_dtype=vit_dtype,
llm_dtype=llm_dtype,
dit_dtype=dit_dtype,
checkpoint_path=checkpoint_path,
trt_engine_path=trt_engine_path,
container_name=container_name,
)

if action == "find_containers":
return _find_gr00t_containers()
elif action == "list":
Expand Down Expand Up @@ -314,6 +404,27 @@ def _check_service_status(port: int) -> dict[str, Any]:
}


def _is_gr00t_process(container_name: str, pid: str) -> bool:
"""Verify that a PID inside a container belongs to a GR00T inference process.

This prevents accidentally killing unrelated processes that happen to
be listening on the same port.
"""
try:
result = subprocess.run(
["docker", "exec", container_name, "cat", f"/proc/{pid}/cmdline"],
capture_output=True,
text=True,
check=False,
)
if result.returncode == 0:
cmdline = result.stdout.replace("\x00", " ")
return "inference_service" in cmdline or "gr00t" in cmdline.lower()
except Exception:
pass
return False


def _stop_service(port: int) -> dict[str, Any]:
"""Stop GR00T inference service running on specific port."""
try:
Expand All @@ -334,13 +445,21 @@ def _stop_service(port: int) -> dict[str, Any]:
if result.returncode == 0 and result.stdout.strip():
pids = result.stdout.strip().split("\n")
for pid in pids:
if pid:
pid = pid.strip()
if pid and _is_gr00t_process(container_name, pid):
subprocess.run(["docker", "exec", container_name, "kill", "-TERM", pid], check=True)

time.sleep(2)

result = subprocess.run(
["docker", "exec", container_name, "pgrep", "-f", f"inference_service.py.*--port {port}"],
[
"docker",
"exec",
container_name,
"pgrep",
"-f",
f"inference_service.py.*--port {port}",
],
capture_output=True,
text=True,
check=False,
Expand All @@ -349,7 +468,8 @@ def _stop_service(port: int) -> dict[str, Any]:
if result.returncode == 0 and result.stdout.strip():
pids = result.stdout.strip().split("\n")
for pid in pids:
if pid:
pid = pid.strip()
if pid and _is_gr00t_process(container_name, pid):
subprocess.run(["docker", "exec", container_name, "kill", "-KILL", pid], check=True)

return {
Expand All @@ -362,22 +482,32 @@ def _stop_service(port: int) -> dict[str, Any]:
except subprocess.CalledProcessError:
continue

# Fallback: try host system
result = subprocess.run(["lsof", "-t", f"-i:{port}"], capture_output=True, text=True)
# Fallback: try host system — only kill processes that match inference_service
result = subprocess.run(
["pgrep", "-f", f"inference_service.py.*--port {port}"],
capture_output=True,
text=True,
)

if result.returncode == 0:
pids = result.stdout.strip().split("\n")
for pid in pids:
pid = pid.strip()
if pid:
subprocess.run(["kill", "-TERM", pid], check=True)

time.sleep(2)

result = subprocess.run(["lsof", "-t", f"-i:{port}"], capture_output=True, text=True)
result = subprocess.run(
["pgrep", "-f", f"inference_service.py.*--port {port}"],
capture_output=True,
text=True,
)

if result.returncode == 0:
pids = result.stdout.strip().split("\n")
for pid in pids:
pid = pid.strip()
if pid:
subprocess.run(["kill", "-KILL", pid], check=True)

Expand Down
Loading
Loading