From 53d985800a9e2029725daf0575e131bf3a5bb5fd Mon Sep 17 00:00:00 2001 From: Hui Kang Date: Tue, 23 Dec 2025 23:48:55 -0800 Subject: [PATCH 1/2] add socket code for policy --- docker/Dockerfile.gr00t_server | 40 +++++ docker/run_gr00t_server.sh | 131 ++++++++++++++ docker/setup/install_gr00t_deps.sh | 85 ++++++--- isaaclab_arena/evaluation/policy_runner.py | 7 + .../evaluation/policy_runner_cli.py | 3 +- isaaclab_arena/policy/policy_base.py | 31 +++- isaaclab_arena/remote_policy/__init__.py | 18 ++ .../remote_policy/message_serializer.py | 119 +++++++++++++ isaaclab_arena/remote_policy/policy_client.py | 93 ++++++++++ .../remote_policy/policy_registry.py | 79 +++++++++ isaaclab_arena/remote_policy/policy_server.py | 164 ++++++++++++++++++ .../remote_policy/remote_policy_config.py | 17 ++ .../remote_policy_server_runner.py | 62 +++++++ .../remote_policy/server_side_policy.py | 51 ++++++ isaaclab_arena_gr00t/gr00t_remote_policy.py | 95 ++++++++++ .../config/gr00t_closedloop_policy_config.py | 3 - .../policy/gr00t_closedloop_policy.py | 58 +++++-- 17 files changed, 1011 insertions(+), 45 deletions(-) create mode 100644 docker/Dockerfile.gr00t_server create mode 100644 docker/run_gr00t_server.sh create mode 100644 isaaclab_arena/remote_policy/__init__.py create mode 100644 isaaclab_arena/remote_policy/message_serializer.py create mode 100644 isaaclab_arena/remote_policy/policy_client.py create mode 100644 isaaclab_arena/remote_policy/policy_registry.py create mode 100644 isaaclab_arena/remote_policy/policy_server.py create mode 100644 isaaclab_arena/remote_policy/remote_policy_config.py create mode 100644 isaaclab_arena/remote_policy/remote_policy_server_runner.py create mode 100644 isaaclab_arena/remote_policy/server_side_policy.py create mode 100644 isaaclab_arena_gr00t/gr00t_remote_policy.py diff --git a/docker/Dockerfile.gr00t_server b/docker/Dockerfile.gr00t_server new file mode 100644 index 00000000..1f9cdcec --- /dev/null +++ b/docker/Dockerfile.gr00t_server @@ -0,0 +1,40 @@ +FROM nvcr.io/nvidia/pytorch:24.07-py3 + +ARG WORKDIR="/workspace" +ARG GROOT_DEPS_GROUP="base" +ENV WORKDIR=${WORKDIR} +ENV GROOT_DEPS_GROUP=${GROOT_DEPS_GROUP} +WORKDIR "${WORKDIR}" + +RUN apt-get update && apt-get install -y \ + git \ + git-lfs \ + cmake \ + && rm -rf /var/lib/apt/lists/* + +RUN pip install --upgrade pip + +COPY ./submodules/Isaac-GR00T ${WORKDIR}/submodules/Isaac-GR00T + +COPY docker/setup/install_gr00t_deps.sh /tmp/install_gr00t_deps.sh +RUN chmod +x /tmp/install_gr00t_deps.sh && \ + /tmp/install_gr00t_deps.sh --server && \ + rm -f /tmp/install_gr00t_deps.sh + +RUN pip install -e ${WORKDIR}/submodules/Isaac-GR00T + +RUN pip uninstall -y \ + opencv-python opencv-python-headless \ + opencv-contrib-python opencv-contrib-python-headless \ + || true && \ + pip install --no-cache-dir "opencv-python-headless==4.8.0.74" + +COPY isaaclab_arena/remote_policy ${WORKDIR}/isaaclab_arena/remote_policy + +COPY isaaclab_arena_gr00t ${WORKDIR}/isaaclab_arena_gr00t + +RUN pip install --no-cache-dir pyzmq msgpack + +ENV PYTHONPATH=${WORKDIR} + +ENTRYPOINT ["python", "-u", "-m", "isaaclab_arena.remote_policy.remote_policy_server_runner"] diff --git a/docker/run_gr00t_server.sh b/docker/run_gr00t_server.sh new file mode 100644 index 00000000..f0a1f20b --- /dev/null +++ b/docker/run_gr00t_server.sh @@ -0,0 +1,131 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ------------------------- +# User-configurable defaults +# ------------------------- + +# Model directory on the host. +# By default, use $HOME/models, but this can be overridden +# by the MODELS_DIR environment variable or the -d / --models_dir flag. +MODELS_DIR="${MODELS_DIR:-$HOME/models}" + +# Other parameters (can also be overridden via environment variables) +HOST="${HOST:-0.0.0.0}" +PORT="${PORT:-5555}" +API_TOKEN="${API_TOKEN:-API_TOKEN_123}" +TIMEOUT_MS="${TIMEOUT_MS:-5000}" +POLICY_TYPE="${POLICY_TYPE:-gr00t_closedloop}" +POLICY_CONFIG_YAML_PATH="${POLICY_CONFIG_YAML_PATH:-/workspace/isaaclab_arena_gr00t/gr1_manip_gr00t_closedloop_config.yaml}" + +# ------------------------- +# Help message +# ------------------------- +usage() { + cat <> /etc/bash.bashrc +if [[ "$USE_SERVER_ENV" -eq 0 ]]; then + echo "Ensuring pytorch torchrun script is in PATH..." + echo "export PATH=/isaac-sim/kit/python/bin:\$PATH" >> /etc/bash.bashrc +fi echo "GR00T dependencies installation completed successfully" diff --git a/isaaclab_arena/evaluation/policy_runner.py b/isaaclab_arena/evaluation/policy_runner.py index 4abba57c..9f7cf990 100644 --- a/isaaclab_arena/evaluation/policy_runner.py +++ b/isaaclab_arena/evaluation/policy_runner.py @@ -127,6 +127,13 @@ def main(): if metrics is not None: print(f"Metrics: {metrics}") + # NOTE(huikang, 2025-12-30)Explicitly clean up the remote policy client / server. + # Do NOT rely on a __del__ destructor in policy for this, since destructors are + # triggered implicitly and their execution time (or even whether they run) + # is not guaranteed, which makes resource cleanup unreliable. + if policy.is_remote: + policy.shutdown_remote(kill_server=args_cli.remote_kill_on_exit) + # Close the environment. env.close() diff --git a/isaaclab_arena/evaluation/policy_runner_cli.py b/isaaclab_arena/evaluation/policy_runner_cli.py index 7fb4001e..219f9b19 100644 --- a/isaaclab_arena/evaluation/policy_runner_cli.py +++ b/isaaclab_arena/evaluation/policy_runner_cli.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: Apache-2.0 import argparse - +from pathlib import Path def add_policy_runner_arguments(parser: argparse.ArgumentParser) -> None: """Add policy runner specific arguments to the parser.""" @@ -20,3 +20,4 @@ def add_policy_runner_arguments(parser: argparse.ArgumentParser) -> None: default=100, help="Number of steps to run the policy for", ) + diff --git a/isaaclab_arena/policy/policy_base.py b/isaaclab_arena/policy/policy_base.py index d927c8d7..068d1b05 100644 --- a/isaaclab_arena/policy/policy_base.py +++ b/isaaclab_arena/policy/policy_base.py @@ -3,6 +3,7 @@ # # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations import argparse import gymnasium as gym import torch @@ -10,6 +11,14 @@ from gymnasium.spaces.dict import Dict as GymSpacesDict from typing import Any +from enum import Enum + +from isaaclab_arena.remote_policy.remote_policy_config import RemotePolicyConfig +from isaaclab_arena.remote_policy.policy_client import PolicyClient + +class PolicyDeployment(Enum): + LOCAL = "local" + REMOTE = "remote" class PolicyBase(ABC): """ @@ -24,7 +33,11 @@ class PolicyBase(ABC): def __init__(self, config: Any): """ - Base class for policies. + Base class for policies with optional remote deployment. + + Args: + policy_deployment: "local" (default) or "remote". + remote_config: Required when policy_deployment == "remote". """ self.config = config @@ -97,3 +110,19 @@ def add_args_to_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentPars def from_args(args: argparse.Namespace) -> "PolicyBase": """Create a policy from the arguments.""" raise NotImplementedError("Function not implemented yet.") + def shutdown_remote(self, kill_server: bool = False) -> None: + """ + Clean up remote client, and optionally send 'kill' to stop the remote server. + + Args: + kill_server: If True, send a 'kill' RPC before closing the client. + """ + if not self.is_remote or self._policy_client is None: + return + if kill_server: + try: + self._policy_client.call_endpoint("kill", requires_input=False) + except Exception as exc: + print(f"[PolicyBase] Failed to send kill to remote server: {exc}") + self._policy_client.close() + self._policy_client = None diff --git a/isaaclab_arena/remote_policy/__init__.py b/isaaclab_arena/remote_policy/__init__.py new file mode 100644 index 00000000..befc257f --- /dev/null +++ b/isaaclab_arena/remote_policy/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2025, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from .remote_policy_config import RemotePolicyConfig +from .server_side_policy import ServerSidePolicy +from .message_serializer import MessageSerializer +from .policy_client import PolicyClient +from .policy_server import PolicyServer + +__all__ = [ + "RemotePolicyConfig", + "ServerSidePolicy", + "MessageSerializer", + "PolicyClient", + "PolicyServer", +] diff --git a/isaaclab_arena/remote_policy/message_serializer.py b/isaaclab_arena/remote_policy/message_serializer.py new file mode 100644 index 00000000..796161dd --- /dev/null +++ b/isaaclab_arena/remote_policy/message_serializer.py @@ -0,0 +1,119 @@ +# Copyright (c) 2025, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import io +from dataclasses import asdict, is_dataclass +from enum import Enum +from typing import Any, Dict + +import msgpack +import numpy as np + + +class MessageSerializer: + """Msgpack-based serializer for dict-based policy messages. + + Supports: + - standard Python types, + - dataclasses (via to_json_serializable), + - numpy.ndarray (tagged as __ndarray_class__), + - generic binary blobs (tagged as __blob_class__). + """ + + @staticmethod + def to_bytes(data: Any) -> bytes: + """Serialize a Python object to bytes using msgpack.""" + return msgpack.packb(data, default=MessageSerializer._encode_custom) + + @staticmethod + def from_bytes(data: bytes) -> Any: + """Deserialize bytes into Python objects, decoding custom tags.""" + return msgpack.unpackb(data, object_hook=MessageSerializer._decode_custom) + + # ------------------------------------------------------------------ # + # Custom encode / decode + # ------------------------------------------------------------------ # + + @staticmethod + def _decode_custom(obj: Any) -> Any: + """Decode tagged structures created in _encode_custom.""" + if not isinstance(obj, dict): + return obj + + # numpy array + if "__ndarray_class__" in obj: + return np.load(io.BytesIO(obj["as_npy"]), allow_pickle=False) + + # generic binary blob + if "__blob_class__" in obj: + return { + "mime": obj.get("mime"), + "data": obj.get("as_bytes"), + } + + # other tagged types can be added here + return obj + + @staticmethod + def _encode_custom(obj: Any) -> Any: + """Encode special Python objects into msgpack-friendly structures.""" + + # numpy array -> npy bytes + if isinstance(obj, np.ndarray): + output = io.BytesIO() + np.save(output, obj, allow_pickle=False) + return {"__ndarray_class__": True, "as_npy": output.getvalue()} + + # generic binary blob: bytes / bytearray + if isinstance(obj, (bytes, bytearray)): + return { + "__blob_class__": True, + "mime": None, + "as_bytes": bytes(obj), + } + + # optional: custom Image/Frame types with to_bytes() and mime attribute + if hasattr(obj, "to_bytes") and hasattr(obj, "mime"): + return { + "__blob_class__": True, + "mime": getattr(obj, "mime"), + "as_bytes": obj.to_bytes(), + } + + # fall back to JSON-serializable representation + return to_json_serializable(obj) + + +def to_json_serializable(obj: Any) -> Any: + """Recursively convert dataclasses and numpy arrays to JSON-serializable format. + + This is useful when encoding configuration objects or metadata. + """ + if is_dataclass(obj) and not isinstance(obj, type): + return to_json_serializable(asdict(obj)) + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.bool_): + return bool(obj) + elif isinstance(obj, dict): + return {key: to_json_serializable(value) for key, value in obj.items()} + elif isinstance(obj, (list, tuple)): + return [to_json_serializable(item) for item in obj] + elif isinstance(obj, set): + return [to_json_serializable(item) for item in obj] + elif isinstance(obj, (str, int, float, bool, type(None))): + return obj + elif isinstance(obj, Enum): + return obj.name + else: + # Fallback: convert to string + return str(obj) + diff --git a/isaaclab_arena/remote_policy/policy_client.py b/isaaclab_arena/remote_policy/policy_client.py new file mode 100644 index 00000000..970e8c3d --- /dev/null +++ b/isaaclab_arena/remote_policy/policy_client.py @@ -0,0 +1,93 @@ +# Copyright (c) 2025, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import zmq + +from .message_serializer import MessageSerializer +from .remote_policy_config import RemotePolicyConfig + +class PolicyClient: + """Synchronous client for talking to a PolicyServer over ZeroMQ.""" + + def __init__(self, config: RemotePolicyConfig) -> None: + self._config = config + self._context = zmq.Context() + self._socket = self._context.socket(zmq.REQ) + self._socket.setsockopt(zmq.RCVTIMEO, self._config.timeout_ms) + self._socket.connect(f"tcp://{self._config.host}:{self._config.port}") + + # ------------------------------------------------------------------ # + # Public API + # ------------------------------------------------------------------ # + + def ping(self) -> bool: + """Check if the server is reachable.""" + try: + self.call_endpoint("ping", requires_input=False) + return True + except Exception: + warnings.warn( + f"[PolicyClient] Failed to ping remote policy server at " + f"{self._config.host}:{self._config.port}: {exc}" + ) + return False + + def reset(self, env_ids=None, options: Optional[Dict[str, Any]] = None) -> Any: + """Reset remote policy state.""" + return self.call_endpoint( + endpoint="reset", + data={"env_ids": env_ids, "options": options}, + requires_input=True, + ) + + def kill(self) -> Any: + """Ask remote server to stop main loop.""" + return self.call_endpoint("kill", requires_input=False) + + def get_action( + self, + observation: Dict[str, Any], + ) -> Dict[str, Any]: + """Send policy_observations and get back policy action dict.""" + payload: Dict[str, Any] = {"observation": observation} + + resp = self.call_endpoint( + endpoint="get_action", + data=payload, + requires_input=True, + ) + return resp + + def call_endpoint( + self, + endpoint: str, + data: Optional[Dict[str, Any]] = None, + requires_input: bool = True, + ) -> Any: + """Generic RPC helper.""" + request: Dict[str, Any] = {"endpoint": endpoint} + if requires_input: + request["data"] = data or {} + if self._config.api_token: + request["api_token"] = self._config.api_token + + self._socket.send(MessageSerializer.to_bytes(request)) + message = self._socket.recv() + response = MessageSerializer.from_bytes(message) + + if isinstance(response, dict) and "error" in response: + raise RuntimeError(f"Server error: {response['error']}") + return response + + def close(self) -> None: + """Close the underlying ZeroMQ socket and context.""" + self._socket.close() + self._context.term() + diff --git a/isaaclab_arena/remote_policy/policy_registry.py b/isaaclab_arena/remote_policy/policy_registry.py new file mode 100644 index 00000000..636cd1e3 --- /dev/null +++ b/isaaclab_arena/remote_policy/policy_registry.py @@ -0,0 +1,79 @@ +# Copyright (c) 2025, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, List, Type + +from isaaclab_arena.remote_policy.server_side_policy import ServerSidePolicy + + +@dataclass(frozen=True) +class PolicyEntry: + policy_type: str + entry_point: str # "module_path:ClassName" + + +class PolicyRegistry: + def __init__(self) -> None: + self._entries: Dict[str, PolicyEntry] = {} + + def register(self, policy_type: str, entry_point: str) -> None: + if policy_type in self._entries: + raise ValueError(f"Policy type {policy_type!r} already registered") + if ":" not in entry_point: + raise ValueError( + f"Invalid entry_point {entry_point!r} for policy_type={policy_type!r} " + "(expected 'module_path:ClassName')" + ) + self._entries[policy_type] = PolicyEntry(policy_type, entry_point) + + def available_policy_types(self) -> List[str]: + return sorted(self._entries.keys()) + + def resolve_class(self, policy_type: str) -> Type[ServerSidePolicy]: + if policy_type not in self._entries: + raise ValueError( + f"Unknown policy_type={policy_type!r}. " + f"Available options: {self.available_policy_types()}" + ) + + entry = self._entries[policy_type] + module_path, class_name = entry.entry_point.split(":", 1) + + try: + module = __import__(module_path, fromlist=[class_name]) + except ImportError as exc: + raise ImportError( + f"Failed to import module '{module_path}' for policy_type={policy_type!r}. " + "This usually means the corresponding policy package is not installed " + "in the current server environment." + ) from exc + + try: + cls = getattr(module, class_name) + except AttributeError as exc: + raise ImportError( + f"Module '{module_path}' does not define class '{class_name}' " + f"for policy_type={policy_type!r}." + ) from exc + + if not issubclass(cls, ServerSidePolicy): + raise TypeError( + f"Resolved class '{class_name}' from '{module_path}' is not a ServerSidePolicy " + f"subclass (policy_type={policy_type!r})." + ) + return cls + + +policy_registry = PolicyRegistry() + +# Built-in registrations +policy_registry.register( + "gr00t_closedloop", + "isaaclab_arena_gr00t.gr00t_remote_policy:Gr00tRemoteServerSidePolicy", +) + diff --git a/isaaclab_arena/remote_policy/policy_server.py b/isaaclab_arena/remote_policy/policy_server.py new file mode 100644 index 00000000..c2d81553 --- /dev/null +++ b/isaaclab_arena/remote_policy/policy_server.py @@ -0,0 +1,164 @@ +# Copyright (c) 2025, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional, Type + +import zmq + +from .model_policy import ModelPolicy +from .message_serializer import MessageSerializer + + +@dataclass +class EndpointHandler: + handler: Callable[..., Any] + requires_input: bool = True + + +class PolicyServer: + def __init__( + self, + policy: ModelPolicy, + host: str = "*", + port: int = 5555, + api_token: Optional[str] = None, + timeout_ms: int = 15000, + ) -> None: + self._policy = policy + self._running = True + self._context = zmq.Context() + self._socket = self._context.socket(zmq.REP) + self._socket.setsockopt(zmq.RCVTIMEO, timeout_ms) + bind_addr = f"tcp://{host}:{port}" + print(f"[PolicyServer] binding on {bind_addr}") + self._socket.bind(bind_addr) + self._api_token = api_token + self._serializer = MessageSerializer + + self._endpoints: Dict[str, EndpointHandler] = {} + self._register_default_endpoints() + + def _register_default_endpoints(self) -> None: + self.register_endpoint("ping", self._handle_ping, requires_input=False) + self.register_endpoint("kill", self._handle_kill, requires_input=False) + self.register_endpoint("get_action", self._handle_get_action, requires_input=True) + self.register_endpoint("reset", self._handle_reset, requires_input=True) + print(f"[PolicyServer] registered endpoints: {list(self._endpoints.keys())}") + + def register_endpoint( + self, + name: str, + handler: Callable[..., Any], + requires_input: bool = True, + ) -> None: + self._endpoints[name] = EndpointHandler(handler=handler, requires_input=requires_input) + + def _handle_ping(self) -> Dict[str, Any]: + print("[PolicyServer] handle ping") + return {"status": "ok"} + + def _handle_kill(self) -> Dict[str, Any]: + print("[PolicyServer] handle kill -> stopping") + self._running = False + return {"status": "stopping"} + + def _handle_get_action( + self, + observation: Dict[str, Any], + options: Optional[Dict[str, Any]] = None, + **_: Any, + ) -> Dict[str, Any]: + print("[PolicyServer] handle get_action") + print(f" observation keys: {list(observation.keys())}") + if options is not None: + print(f" options keys: {list(options.keys())}") + action, info = self._policy.get_action( + observation=observation, + options=options, + ) + return {"action": action, "info": info} + + def _handle_reset(self, env_ids=None, options=None, **_: Any) -> Dict[str, Any]: + print(f"[PolicyServer] handle reset: env_ids={env_ids}, options={options}") + if hasattr(self._policy, "reset"): + self._policy.reset(env_ids=env_ids, reset_options=options) + return {"status": "reset"} + + def _validate_token(self, request: Dict[str, Any]) -> bool: + if self._api_token is None: + return True + ok = request.get("api_token") == self._api_token + if not ok: + print("[PolicyServer] invalid api_token in request") + return ok + + def run(self) -> None: + addr = self._socket.getsockopt_string(zmq.LAST_ENDPOINT) + print(f"[PolicyServer] listening on {addr}, api_token={self._api_token!r}") + while self._running: + try: + raw = self._socket.recv() + print(f"[PolicyServer] received {len(raw)} bytes") + request = self._serializer.from_bytes(raw) + + if not isinstance(request, dict): + raise TypeError(f"Expected dict request, got {type(request)!r}") + + print(f"[PolicyServer] request keys: {list(request.keys())}") + + if not self._validate_token(request): + self._socket.send( + self._serializer.to_bytes({"error": "Unauthorized: invalid api_token"}) + ) + continue + + endpoint = request.get("endpoint", "get_action") + handler = self._endpoints.get(endpoint) + if handler is None: + raise ValueError(f"Unknown endpoint: {endpoint}") + print(f"[PolicyServer] dispatch endpoint='{endpoint}'") + + data = request.get("data", {}) or {} + if not isinstance(data, dict): + raise TypeError(f"Expected dict data, got {type(data)!r}") + + if handler.requires_input: + result = handler.handler(**data) + else: + result = handler.handler() + + resp_bytes = self._serializer.to_bytes(result) + print(f"[PolicyServer] sending response ({len(resp_bytes)} bytes)") + self._socket.send(resp_bytes) + except zmq.Again: + # timeout, loop again + continue + except Exception as exc: + import traceback + + print(f"[PolicyServer] Error: {exc}") + print(traceback.format_exc()) + self._socket.send(self._serializer.to_bytes({"error": str(exc)})) + + @staticmethod + def start( + policy: ModelPolicy, + host: str = "*", + port: int = 5555, + api_token: Optional[str] = None, + timeout_ms: int = 15000, + ) -> None: + server = PolicyServer( + policy=policy, + host=host, + port=port, + api_token=api_token, + timeout_ms=timeout_ms, + ) + server.run() + diff --git a/isaaclab_arena/remote_policy/remote_policy_config.py b/isaaclab_arena/remote_policy/remote_policy_config.py new file mode 100644 index 00000000..41911bec --- /dev/null +++ b/isaaclab_arena/remote_policy/remote_policy_config.py @@ -0,0 +1,17 @@ +# Copyright (c) 2025, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +@dataclass +class RemotePolicyConfig: + """Configuration for using a remote PolicyServer.""" + host: str + port: int + api_token: Optional[str] = None + timeout_ms: int = 15000 diff --git a/isaaclab_arena/remote_policy/remote_policy_server_runner.py b/isaaclab_arena/remote_policy/remote_policy_server_runner.py new file mode 100644 index 00000000..ddbfbb16 --- /dev/null +++ b/isaaclab_arena/remote_policy/remote_policy_server_runner.py @@ -0,0 +1,62 @@ +# Copyright (c) 2025, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +from pathlib import Path +from typing import Type + +from isaaclab_arena.remote_policy.server_side_policy import ServerSidePolicy +from isaaclab_arena.remote_policy.policy_server import PolicyServer +from isaaclab_arena.remote_policy.policy_registry import policy_registry + + +def resolve_policy_class(policy_type: str) -> Type[ServerSidePolicy]: + return policy_registry.resolve_class(policy_type) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser("IsaacLab Arena Remote Policy Server") + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int, default=5555) + parser.add_argument("--api_token", type=str, default=None) + parser.add_argument("--timeout_ms", type=int, default=5000) + + parser.add_argument( + "--policy_type", + type=str, + required=True, + choices=policy_registry.available_policy_types(), + help="Which remote policy to run (e.g. 'gr00t_closedloop').", + ) + parser.add_argument( + "--policy_config_yaml_path", + type=str, + required=True, + help="Path to policy-specific config YAML.", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + + policy_cls = resolve_policy_class(args.policy_type) + policy = policy_cls(policy_config_yaml_path=Path(args.policy_config_yaml_path)) + + server = PolicyServer( + policy=policy, + host=args.host, + port=args.port, + api_token=args.api_token, + timeout_ms=args.timeout_ms, + ) + server.run() + + +if __name__ == "__main__": + main() + diff --git a/isaaclab_arena/remote_policy/server_side_policy.py b/isaaclab_arena/remote_policy/server_side_policy.py new file mode 100644 index 00000000..a1d0e966 --- /dev/null +++ b/isaaclab_arena/remote_policy/server_side_policy.py @@ -0,0 +1,51 @@ +# Copyright (c) 2025, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict + + +class ServerSidePolicy(ABC): + """Server-side policy interface. + + This interface is intentionally independent of IsaacLab-Arena. + The server only sees JSON-serializable observations and returns + JSON-serializable actions. + """ + + @abstractmethod + def get_action( + self, observation: dict[str, Any], options: dict[str, Any] | None = None + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Compute and return the next action based on current observation with validation. + + This is the main public interface. It validates the observation, calls + the internal _get_action(), and validates the resulting action. + + Args: + observation: Dictionary containing the current state/observation + options: Optional configuration dict for action computation + + Returns: + Tuple of (action, info): + - action: Dictionary containing the validated action + - info: Dictionary containing additional metadata + + Raises: + AssertionError/ValueError: If observation or action validation fails + """ + @abstractmethod + def reset(self, options: dict[str, Any] | None = None) -> dict[str, Any]: + """Reset the policy to its initial state. + + Args: + options: Dictionary containing the options for the reset + + Returns: + Dictionary containing the info after resetting the policy + """ + pass diff --git a/isaaclab_arena_gr00t/gr00t_remote_policy.py b/isaaclab_arena_gr00t/gr00t_remote_policy.py new file mode 100644 index 00000000..f871163c --- /dev/null +++ b/isaaclab_arena_gr00t/gr00t_remote_policy.py @@ -0,0 +1,95 @@ +# Copyright (c) 2025, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from pathlib import Path +from typing import Any, Dict, Tuple + +from gr00t.experiment.data_config import DATA_CONFIG_MAP, load_data_config +from gr00t.model.policy import Gr00tPolicy + +from isaaclab_arena.remote_policy.server_side_policy import ServerSidePolicy +from isaaclab_arena_gr00t.policy_config import Gr00tClosedloopPolicyConfig +from isaaclab_arena_gr00t.data_utils.io_utils import create_config_from_yaml + + +class Gr00tRemoteServerSidePolicy(ServerSidePolicy): + """Server-side wrapper around Gr00tPolicy.""" + + def __init__(self, policy_config_yaml_path: Path) -> None: + print(f"[Gr00tRemoteServerSidePolicy] loading config from: {policy_config_yaml_path}") + self._cfg = create_config_from_yaml(policy_config_yaml_path, Gr00tClosedloopPolicyConfig) + print( + "[Gr00tRemoteServerSidePolicy] config:\n" + f" model_path = {self._cfg.model_path}\n" + f" embodiment_tag = {self._cfg.embodiment_tag}\n" + f" task_mode_name = {self._cfg.task_mode_name}\n" + f" data_config = {self._cfg.data_config}\n" + f" action_horizon = {self._cfg.action_horizon}\n" + f" action_chunk_len = {self._cfg.action_chunk_length}\n" + f" pov_cam_name_sim = {self._cfg.pov_cam_name_sim}\n" + f" policy_device = {self._cfg.policy_device}" + ) + self._policy = self._load_gr00t_policy() + print("[Gr00tRemoteServerSidePolicy] Gr00tPolicy loaded successfully") + + def _load_gr00t_policy(self) -> Gr00tPolicy: + print(f"[Gr00tRemoteServerSidePolicy] loading data_config={self._cfg.data_config}") + if self._cfg.data_config in DATA_CONFIG_MAP: + data_config = DATA_CONFIG_MAP[self._cfg.data_config] + elif self._cfg.data_config == "unitree_g1_sim_wbc": + data_config = load_data_config("isaaclab_arena_gr00t.data_config:UnitreeG1SimWBCDataConfig") + else: + raise ValueError(f"Invalid data config: {self._cfg.data_config}") + + modality_config = data_config.modality_config() + modality_transform = data_config.transform() + + model_path = Path(self._cfg.model_path) + if not model_path.exists(): + raise FileNotFoundError(f"Model path does not exist: {model_path}") + print(f"[Gr00tRemoteServerSidePolicy] loading checkpoint from: {model_path}") + + policy = Gr00tPolicy( + model_path=str(model_path), + modality_config=modality_config, + modality_transform=modality_transform, + embodiment_tag=self._cfg.embodiment_tag, + denoising_steps=self._cfg.denoising_steps, + device=self._cfg.policy_device, + ) + return policy + + # ------------------------------------------------------------------ # + # ServerSidePolicy interface + # ------------------------------------------------------------------ # + + def get_action( + self, + observation: Dict[str, Any], + options: Dict[str, Any] | None = None, + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + print("[Gr00tRemoteServerSidePolicy] get_action called") + print(f" observation keys: {list(observation.keys())}") + if options is not None: + print(f" options keys: {list(options.keys())}") + + result = self._policy.get_action(observation) + # Gr00tPolicy.get_action usually returns a dict; wrap it with empty info. + if isinstance(result, tuple) and len(result) == 2: + action, info = result + else: + action, info = result, {} + + print("[Gr00tRemoteServerSidePolicy] get_action done") + return action, info + + def reset(self, options: Dict[str, Any] | None = None) -> Dict[str, Any]: + print(f"[Gr00tRemoteServerSidePolicy] reset called: options={options}") + if hasattr(self._policy, "reset"): + self._policy.reset(options=options) + return {} + diff --git a/isaaclab_arena_gr00t/policy/config/gr00t_closedloop_policy_config.py b/isaaclab_arena_gr00t/policy/config/gr00t_closedloop_policy_config.py index 41f4da5b..0451cba6 100644 --- a/isaaclab_arena_gr00t/policy/config/gr00t_closedloop_policy_config.py +++ b/isaaclab_arena_gr00t/policy/config/gr00t_closedloop_policy_config.py @@ -5,10 +5,8 @@ from dataclasses import dataclass, field from pathlib import Path - from isaaclab_arena_gr00t.policy.config.task_mode import TaskMode - @dataclass class Gr00tClosedloopPolicyConfig: @@ -98,7 +96,6 @@ def __post_init__(self): assert Path( self.state_joints_config_path ).exists(), f"state_joints_config_path does not exist: {self.state_joints_config_path}" - assert Path(self.model_path).exists(), f"model_path does not exist: {self.model_path}" # embodiment_tag assert self.embodiment_tag in [ "gr1", diff --git a/isaaclab_arena_gr00t/policy/gr00t_closedloop_policy.py b/isaaclab_arena_gr00t/policy/gr00t_closedloop_policy.py index 6e631295..92741a04 100644 --- a/isaaclab_arena_gr00t/policy/gr00t_closedloop_policy.py +++ b/isaaclab_arena_gr00t/policy/gr00t_closedloop_policy.py @@ -3,17 +3,15 @@ # # SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations import argparse import gymnasium as gym import torch from dataclasses import dataclass, field from pathlib import Path -from typing import Any +from typing import Any, Dict -from gr00t.experiment.data_config import DATA_CONFIG_MAP, load_data_config -from gr00t.model.policy import Gr00tPolicy - -from isaaclab_arena.policy.policy_base import PolicyBase +from isaaclab_arena.policy.policy_base import PolicyBase, PolicyDeployment from isaaclab_arena_g1.g1_whole_body_controller.wbc_policy.policy.policy_constants import ( NUM_BASE_HEIGHT_CMD, NUM_NAVIGATE_CMD, @@ -106,6 +104,18 @@ def __init__(self, config: Gr00tClosedloopPolicyArgs): self.device = config.policy_device self.task_mode = TaskMode(self.policy_config.task_mode_name) + self.policy = None + + if self.is_remote: + if not self.remote_client.ping(): + cfg = self.remote_config + raise RuntimeError( + f"Failed to connect to remote policy server at " + f"{cfg.host}:{cfg.port}." + ) + else: + self.policy = self.load_local_policy() + self.policy_joints_config = self.load_policy_joints_config(self.policy_config.policy_joints_config_path) self.robot_action_joints_config = self.load_sim_action_joints_config( self.policy_config.action_joints_config_path @@ -114,8 +124,6 @@ def __init__(self, config: Gr00tClosedloopPolicyArgs): self.action_dim = len(self.robot_action_joints_config) if self.task_mode == TaskMode.G1_LOCOMANIPULATION: - # GR00T outputs are used for WBC inputs dim. So adding WBC commands to the action dim. - # WBC commands: navigate_command, base_height_command, torso_orientation_rpy_command self.action_dim += NUM_NAVIGATE_CMD + NUM_BASE_HEIGHT_CMD + NUM_TORSO_ORIENTATION_RPY_CMD self.current_action_chunk = torch.zeros( @@ -180,15 +188,22 @@ def load_sim_action_joints_config(self, action_config_path: Path) -> dict[str, A """Load the simulation action joint config from the data config.""" return load_robot_joints_config_from_yaml(action_config_path) - def load_policy(self) -> Gr00tPolicy: - """Load the dataset, whose iterator will be used as the policy.""" - assert Path( - self.policy_config.model_path - ).exists(), f"Dataset path {self.policy_config.dataset_path} does not exist" + def load_local_policy(self): + try: + from gr00t.experiment.data_config import DATA_CONFIG_MAP, load_data_config + from gr00t.model.policy import Gr00tPolicy + except ImportError as exc: + raise ImportError( + "GR00T policy dependencies are not installed. " + "Install gr00t packages or use policy_deployment=PolicyDeployment.REMOTE." + ) from exc + + assert Path(self.policy_config.model_path).exists(), ( + f"Model path {self.policy_config.model_path} does not exist" + ) - # Use the same data preprocessor specified in the data config map if self.policy_config.data_config in DATA_CONFIG_MAP: - self.data_config = DATA_CONFIG_MAP[self.policy_config.data_config] + data_config = DATA_CONFIG_MAP[self.policy_config.data_config] elif self.policy_config.data_config == "unitree_g1_sim_wbc": self.data_config = load_data_config( "isaaclab_arena_gr00t.embodiments.g1.g1_sim_wbc_data_config:UnitreeG1SimWBCDataConfig" @@ -196,8 +211,8 @@ def load_policy(self) -> Gr00tPolicy: else: raise ValueError(f"Invalid data config: {self.policy_config.data_config}") - modality_config = self.data_config.modality_config() - modality_transform = self.data_config.transform() + modality_config = data_config.modality_config() + modality_transform = data_config.transform() return Gr00tPolicy( model_path=self.policy_config.model_path, modality_config=modality_config, @@ -303,7 +318,16 @@ def get_action_chunk(self, observation: dict[str, Any], camera_name: str = "robo Shape: (num_envs, action_chunk_length, self.action_dim) """ policy_observations = self.get_observations(observation, camera_name) - robot_action_policy = self.policy.get_action(policy_observations) + + if not self.is_remote: + if self._policy is None: + raise RuntimeError("Local GR00T policy is not initialized.") + robot_action_policy = self.policy.get_action(policy_observations) + else: + robot_action_policy = self.remote_client.get_action( + observation=policy_observations, + ) + robot_action_sim = remap_policy_joints_to_sim_joints( robot_action_policy, self.policy_joints_config, self.robot_action_joints_config, self.device ) From 6fe0f9511ec7524fa7405d3afd12bbb903104530 Mon Sep 17 00:00:00 2001 From: Hui Kang Date: Sat, 10 Jan 2026 21:56:45 -0800 Subject: [PATCH 2/2] Move Gr00t policy to server side --- docker/Dockerfile.gr00t_server | 2 +- docker/run_gr00t_server.sh | 252 +++++++++++------ docs/index.rst | 2 + .../concept_remote_policies_design.rst | 241 ++++++++++++++++ .../locomanipulation/step_4_evaluation.rst | 48 ++++ .../static_manipulation/step_5_evaluation.rst | 48 ++++ docs/pages/quickstart/docker_containers.rst | 49 ++++ .../evaluation/policy_runner_cli.py | 3 +- isaaclab_arena/policy/__init__.py | 1 + .../policy/action_chunking_client.py | 182 ++++++++++++ isaaclab_arena/policy/client_side_policy.py | 202 ++++++++++++++ isaaclab_arena/policy/policy_base.py | 36 +-- isaaclab_arena/remote_policy/__init__.py | 10 +- .../remote_policy/action_protocol.py | 83 ++++++ .../remote_policy/message_serializer.py | 27 +- isaaclab_arena/remote_policy/policy_client.py | 72 +++-- .../remote_policy/policy_registry.py | 79 ------ isaaclab_arena/remote_policy/policy_server.py | 96 +++++-- .../remote_policy/remote_policy_config.py | 7 +- .../remote_policy_server_runner.py | 94 +++++-- .../remote_policy/server_side_policy.py | 208 ++++++++++++-- isaaclab_arena_gr00t/gr00t_remote_policy.py | 95 ------- .../config/gr00t_closedloop_policy_config.py | 3 + .../policy/gr00t_closedloop_policy.py | 263 +++++------------- isaaclab_arena_gr00t/policy/gr00t_core.py | 195 +++++++++++++ .../policy/gr00t_remote_policy.py | 202 ++++++++++++++ 26 files changed, 1903 insertions(+), 597 deletions(-) mode change 100644 => 100755 docker/run_gr00t_server.sh create mode 100644 docs/pages/concepts/concept_remote_policies_design.rst create mode 100644 isaaclab_arena/policy/action_chunking_client.py create mode 100644 isaaclab_arena/policy/client_side_policy.py create mode 100644 isaaclab_arena/remote_policy/action_protocol.py delete mode 100644 isaaclab_arena/remote_policy/policy_registry.py delete mode 100644 isaaclab_arena_gr00t/gr00t_remote_policy.py create mode 100644 isaaclab_arena_gr00t/policy/gr00t_core.py create mode 100644 isaaclab_arena_gr00t/policy/gr00t_remote_policy.py diff --git a/docker/Dockerfile.gr00t_server b/docker/Dockerfile.gr00t_server index 1f9cdcec..3b3adf22 100644 --- a/docker/Dockerfile.gr00t_server +++ b/docker/Dockerfile.gr00t_server @@ -30,8 +30,8 @@ RUN pip uninstall -y \ pip install --no-cache-dir "opencv-python-headless==4.8.0.74" COPY isaaclab_arena/remote_policy ${WORKDIR}/isaaclab_arena/remote_policy - COPY isaaclab_arena_gr00t ${WORKDIR}/isaaclab_arena_gr00t +COPY isaaclab_arena_g1 ${WORKDIR}/isaaclab_arena_g1 RUN pip install --no-cache-dir pyzmq msgpack diff --git a/docker/run_gr00t_server.sh b/docker/run_gr00t_server.sh old mode 100644 new mode 100755 index f0a1f20b..f40dd2c7 --- a/docker/run_gr00t_server.sh +++ b/docker/run_gr00t_server.sh @@ -5,15 +5,23 @@ set -euo pipefail # User-configurable defaults # ------------------------- -# Model directory on the host. -# By default, use $HOME/models, but this can be overridden -# by the MODELS_DIR environment variable or the -d / --models_dir flag. +# Default mount directories on the host machine +DATASETS_DIR="${DATASETS_DIR:-$HOME/datasets}" MODELS_DIR="${MODELS_DIR:-$HOME/models}" +EVAL_DIR="${EVAL_DIR:-$HOME/eval}" -# Other parameters (can also be overridden via environment variables) +# Docker image name and tag for the GR00T policy server +DOCKER_IMAGE_NAME="${DOCKER_IMAGE_NAME:-gr00t_policy_server}" +DOCKER_VERSION_TAG="${DOCKER_VERSION_TAG:-latest}" + +# Rebuild controls +FORCE_REBUILD="${FORCE_REBUILD:-false}" +NO_CACHE="" + +# Server parameters (can also be overridden via environment variables) HOST="${HOST:-0.0.0.0}" PORT="${PORT:-5555}" -API_TOKEN="${API_TOKEN:-API_TOKEN_123}" +API_TOKEN="${API_TOKEN:-}" TIMEOUT_MS="${TIMEOUT_MS:-5000}" POLICY_TYPE="${POLICY_TYPE:-gr00t_closedloop}" POLICY_CONFIG_YAML_PATH="${POLICY_CONFIG_YAML_PATH:-/workspace/isaaclab_arena_gr00t/gr1_manip_gr00t_closedloop_config.yaml}" @@ -22,110 +30,174 @@ POLICY_CONFIG_YAML_PATH="${POLICY_CONFIG_YAML_PATH:-/workspace/isaaclab_arena_gr # Help message # ------------------------- usage() { + script_name=$(basename "$0") cat < Path to datasets on the host. Default: "$DATASETS_DIR". + -m Path to models on the host. Default: "$MODELS_DIR". + -e Path to evaluation data on the host. Default: "$EVAL_DIR". + -n Docker image name. Default: "$DOCKER_IMAGE_NAME". + -r Force rebuilding of the Docker image. + -R Force rebuilding of the Docker image, without cache. + +Server-specific options (passed through to the policy server entrypoint): + --host HOST + --port PORT + --api_token TOKEN + --timeout_ms MS + --policy_type TYPE + --policy_config_yaml_path PATH Examples: - # Use default \$HOME/models - bash ./docker/run_gr00t_server.sh + # Minimal: use defaults, just build & run server + bash $script_name - # Use a custom models directory and port - bash ./docker/run_gr00t_server.sh -d /data/models --port 6000 --api_token MY_TOKEN + # Custom models directory and port + bash $script_name -m /data/models --port 6000 --api_token MY_TOKEN - # Use an environment variable to set the models directory - MODELS_DIR=/data/models bash ./docker/run_gr00t_server.sh + # Custom image name, force rebuild, and datasets/eval mounts + bash $script_name -n gr00t_server -r \\ + -d /data/datasets -m /data/models -e /data/eval \\ + --policy_type isaaclab_arena_gr00t.policy.gr00t_remote_policy.Gr00tRemoteServerSidePolicy \\ + --policy_config_yaml_path isaaclab_arena_gr00t/policy/config/gr1_manip_gr00t_closedloop_config.yaml EOF } # ------------------------- -# CLI parsing +# Parse docker/path options (short flags, like run_docker.sh) # ------------------------- +DOCKER_ARGS_DONE=false +SERVER_ARGS=() + while [[ $# -gt 0 ]]; do - case "$1" in - -d|--models_dir) - MODELS_DIR="$2" - shift 2 - ;; - --host) - HOST="$2" - shift 2 - ;; - --port) - PORT="$2" - shift 2 - ;; - --api_token) - API_TOKEN="$2" - shift 2 - ;; - --timeout_ms) - TIMEOUT_MS="$2" - shift 2 - ;; - --policy_type) - POLICY_TYPE="$2" - shift 2 - ;; - --policy_config_yaml_path) - POLICY_CONFIG_YAML_PATH="$2" - shift 2 - ;; - -h|--help) - usage - exit 0 - ;; - *) - echo "Unknown option: $1" - usage - exit 1 - ;; - esac + if [ "$DOCKER_ARGS_DONE" = false ]; then + case "$1" in + -v) + set -x + shift 1 + ;; + -d) + DATASETS_DIR="$2" + shift 2 + ;; + -m) + MODELS_DIR="$2" + shift 2 + ;; + -e) + EVAL_DIR="$2" + shift 2 + ;; + -n) + DOCKER_IMAGE_NAME="$2" + shift 2 + ;; + -r) + FORCE_REBUILD="true" + shift 1 + ;; + -R) + FORCE_REBUILD="true" + NO_CACHE="--no-cache" + shift 1 + ;; + -h|--help) + usage + exit 0 + ;; + --host|--port|--api_token|--timeout_ms|--policy_type|--policy_config_yaml_path) + # From here on, treat everything as server args and stop parsing docker flags + DOCKER_ARGS_DONE=true + SERVER_ARGS+=("$1") + shift 1 + ;; + --*) + # Unknown long option at docker level -> treat as server arg + DOCKER_ARGS_DONE=true + SERVER_ARGS+=("$1") + shift 1 + ;; + *) + # Anything else -> treat as server arg + DOCKER_ARGS_DONE=true + SERVER_ARGS+=("$1") + shift 1 + ;; + esac + else + SERVER_ARGS+=("$1") + shift 1 + fi done -echo "Using MODELS_DIR=${MODELS_DIR}" -echo "Server config:" -echo " HOST = ${HOST}" -echo " PORT = ${PORT}" -echo " API_TOKEN = ${API_TOKEN}" -echo " TIMEOUT_MS = ${TIMEOUT_MS}" -echo " POLICY_TYPE = ${POLICY_TYPE}" -echo " POLICY_CONFIG_YAML_PATH = ${POLICY_CONFIG_YAML_PATH}" +# If no server args were passed, use defaults +if [ ${#SERVER_ARGS[@]} -eq 0 ]; then + SERVER_ARGS=( + --host "${HOST}" + --port "${PORT}" + --api_token "${API_TOKEN}" + --timeout_ms "${TIMEOUT_MS}" + --policy_type "${POLICY_TYPE}" + --policy_config_yaml_path "${POLICY_CONFIG_YAML_PATH}" + ) +fi + +echo "Host paths:" +echo " DATASETS_DIR = ${DATASETS_DIR}" +echo " MODELS_DIR = ${MODELS_DIR}" +echo " EVAL_DIR = ${EVAL_DIR}" +echo "Docker image:" +echo " ${DOCKER_IMAGE_NAME}:${DOCKER_VERSION_TAG}" +echo "Rebuild:" +echo " FORCE_REBUILD = ${FORCE_REBUILD}, NO_CACHE = '${NO_CACHE}'" +echo "Server args:" +printf ' %q ' "${SERVER_ARGS[@]}"; echo # ------------------------- # 1) Build the Docker image # ------------------------- -docker build \ - -f docker/Dockerfile.gr00t_server \ - -t gr00t_policy_server:latest \ - . + +IMAGE_TAG_FULL="${DOCKER_IMAGE_NAME}:${DOCKER_VERSION_TAG}" + +if docker images -q "${IMAGE_TAG_FULL}" > /dev/null 2>&1 && [ "${FORCE_REBUILD}" != "true" ]; then + echo "Docker image ${IMAGE_TAG_FULL} already exists. Skipping rebuild." + echo "Use -r or -R to force rebuilding the image." +else + echo "Building Docker image ${IMAGE_TAG_FULL}..." + docker build --pull \ + ${NO_CACHE} \ + -f docker/Dockerfile.gr00t_server \ + -t "${IMAGE_TAG_FULL}" \ + . +fi # ------------------------- # 2) Run the container # ------------------------- -docker run --rm \ - --gpus all \ - --net host \ - --name gr00t_policy_server_container \ - -v "${MODELS_DIR}":/models \ - gr00t_policy_server:latest \ - --host "${HOST}" \ - --port "${PORT}" \ - --api_token "${API_TOKEN}" \ - --timeout_ms "${TIMEOUT_MS}" \ - --policy_type "${POLICY_TYPE}" \ - --policy_config_yaml_path "${POLICY_CONFIG_YAML_PATH}" +DOCKER_RUN_ARGS=( + --rm + --gpus all + --net host + --name gr00t_policy_server_container + -v "${MODELS_DIR}":/models +) + +# Only mount datasets / eval if the directories exist on host +if [ -d "${DATASETS_DIR}" ]; then + DOCKER_RUN_ARGS+=(-v "${DATASETS_DIR}":/datasets) +fi + +if [ -d "${EVAL_DIR}" ]; then + DOCKER_RUN_ARGS+=(-v "${EVAL_DIR}":/eval) +fi + +docker run "${DOCKER_RUN_ARGS[@]}" \ + "${IMAGE_TAG_FULL}" \ + "${SERVER_ARGS[@]}" diff --git a/docs/index.rst b/docs/index.rst index 56dfff13..200ad640 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -217,6 +217,8 @@ TABLE OF CONTENTS pages/concepts/concept_assets_design pages/concepts/concept_affordances_design pages/concepts/concept_policy_design + pages/concepts/concept_remote_policies_design + .. toctree:: :maxdepth: 1 diff --git a/docs/pages/concepts/concept_remote_policies_design.rst b/docs/pages/concepts/concept_remote_policies_design.rst new file mode 100644 index 00000000..31e76e4f --- /dev/null +++ b/docs/pages/concepts/concept_remote_policies_design.rst @@ -0,0 +1,241 @@ +Remote Policies Design +====================== + +This section describes the generic remote policy interface in Isaac Lab Arena, +how it is structured around server-side and client-side policies, and how to +plug in your own remote policies. + +Overview +-------- + +Isaac Lab Arena supports running policies in a separate process or machine +and communicating with them via a lightweight RPC protocol. + +The remote-policy design is centred around two main classes: + +- ``ServerSidePolicy``: implemented next to the model in a remote + environment. It defines how to initialise the policy, how to compute + actions for a given observation, and how to handle resets or task + descriptions. +- ``ClientSidePolicy``: implemented inside Isaac Lab Arena. It exposes + the usual policy interface to environments while handling all + RPC-related details (packing observations, sending requests, receiving + and post-processing actions). + +To make sure both sides agree on how observations and actions are +encoded, the server and client share a lightweight ``ActionProtocol``. +The protocol itself does not implement policy logic; it is simply a +contract that describes: + +- which observation entries are exchanged and how they are structured; +- how actions produced by the server should be interpreted on the client + side (for example, one action per step, or sequences of actions), + without prescribing a specific model or task. + +In practice, you implement a ``ServerSidePolicy`` in the remote +environment and a matching ``ClientSidePolicy`` inside Isaac Lab Arena. +As long as they agree on an ``ActionProtocol``, the environments and +evaluation scripts can remain unchanged. + +Server-side policy +------------------ + +Server-side code runs next to the model in its own Python environment +or container. The remote policy utilities are designed to be +self-contained: you can copy the ``isaaclab_arena/remote_policy`` +folder into your server repository and import from it without depending +on Isaac Sim. + +Using the generic server runner +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +In most cases you do not need to implement a custom RPC loop. Instead, +you can start a server using the generic runner +``isaaclab_arena.remote_policy.remote_policy_server_runner`` in the +server environment. + +The runner dynamically loads a ``ServerSidePolicy`` subclass and passes +command-line configuration to it. For example, to launch a GR00T-based +remote policy, you can run: + +.. code-block:: bash + + python -m isaaclab_arena.remote_policy.remote_policy_server_runner \ + --host 127.0.0.1 \ + --port 5555 \ + --policy_type isaaclab_arena_gr00t.policy.gr00t_remote_policy.Gr00tRemoteServerSidePolicy \ + --policy_config_yaml_path /workspace/isaaclab_arena_gr00t/policy/config/gr1_manip_gr00t_closedloop_config.yaml + +In this example: + +- ``--policy_type`` is a dotted Python path to the GR00T + ``ServerSidePolicy`` implementation that will be imported at runtime. +- ``--policy_config_yaml_path`` points to a model-specific configuration + file. Other subclasses may accept different configuration arguments or + may not use a YAML file at all. + +For convenience, the Arena repository also provides a wrapper script +``docker/run_gr00t_server.sh`` and a dedicated Dockerfile +``docker/Dockerfile.gr00t_server`` that build and run a GR00T remote +policy server container using the same runner. + +Custom server-side policies +^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To add a new remote policy, implement your own subclass of +``ServerSidePolicy`` in your server repository and configure the +runner to load it. + +A typical implementation does the following: + +1. **Define the ActionProtocol** + + Implement ``_build_protocol(self)`` to return an appropriate + protocol instance that describes the interface between server and + client. For example, when using chunked actions: + + .. code-block:: python + + def _build_protocol(self) -> ChunkingActionProtocol: + return ChunkingActionProtocol( + action_dim=self._action_dim, + observation_keys=self._required_observation_keys, + action_chunk_length=self._action_chunk_length, + ) + + If your policy uses a different structure (for example, single-step + actions or additional metadata), you can define your own protocol + subclass instead of ``ChunkingActionProtocol``. The only requirement + is that the client-side policy uses the same protocol class. + +2. **Implement the action computation** + + Implement ``get_action(self, observation, options=None)`` to: + + - parse the incoming observation according to the protocol; + - run the model forward pass; + - return a dictionary that contains at least an ``"action"`` entry + matching the protocol (for example, a batch of chunked actions), + plus any optional info. + +3. **Handle resets and task descriptions** + + - Implement ``reset(self, env_ids=None, options=None)`` to clear + any server-side state when environments reset. + - Implement ``set_task_description(self, task_description)`` if + the policy needs a natural-language or structured description of + the current task; return a small status or updated config dict. + +The GR00T implementation +``isaaclab_arena_gr00t.policy.gr00t_remote_policy.Gr00tRemoteServerSidePolicy`` +follows this pattern: it declares required observation keys, uses +numpy-based preprocessing utilities, and outputs fixed-length action +chunks that are described by a ``ChunkingActionProtocol``. + +Client-side policy +------------------ + +Client-side policies live under ``isaaclab_arena.policy`` and inherit +from ``isaaclab_arena.policy.policy_base.PolicyBase``. They run inside +Isaac Lab Arena and present a standard policy interface to environments, +while internally talking to a remote server. + +A client-side policy is responsible for: + +- Managing a ``RemotePolicyConfig`` and the underlying RPC client used + to connect to the remote server. +- Performing an initial handshake to negotiate an ``ActionProtocol`` + with the server. +- Packing observations into a protocol-compatible format and sending + them over RPC. +- Receiving actions from the server and applying any client-side + post-processing or validation that is specific to the environment. + +Implementing a new client-side policy +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +To add a new client-side policy, you typically: + +1. Subclass ``ClientSidePolicy`` and choose an appropriate protocol + class (for example ``ChunkingActionProtocol`` or your own + ``ActionProtocol`` subclass). + +2. Implement the core ``get_action(...)`` method, which: + + - uses helper methods such as + ``pack_observation_for_server(observation)`` to build the request; + - calls the remote server to obtain actions; + - reshapes or transforms the returned actions into the format + expected by the environment (for example, per-step actions, or + batched actions across multiple envs). + +3. Optionally override ``reset(...)`` if you maintain client-side + state beyond what the base class handles, and call + ``shutdown_remote(...)`` when you want to proactively clean up the + remote connection. + +The base ``ClientSidePolicy`` also provides: + +- shared CLI helpers (``add_remote_args_to_parser()``, + ``build_remote_config_from_args()``) so that policies can be created + directly from command-line arguments; and +- a small set of convenience properties, such as ``protocol``, + ``action_dim`` and ``observation_keys``, which come from the + negotiated ``ActionProtocol``. + +Example: Action chunking on the client +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +``ActionChunkingClientSidePolicy`` is a concrete client-side policy that +implements one specific pattern of post-processing: consuming fixed-size +chunks of actions produced by the server. + +- It uses ``ChunkingActionProtocol`` to agree on: + + - how many action dimensions the policy outputs; and + - how many actions are grouped into each chunk. + +- Internally it keeps track, for each environment, of: + + - the current action chunk received from the server; and + - which index within the chunk should be used for the next step. + +On each call to ``get_action(...)`` the policy: + +1. Determines which environments need a new chunk. +2. Requests a chunk of actions from the remote server for those envs. +3. Validates shapes against the negotiated protocol. +4. Returns exactly one action per environment to the caller, while + caching the remaining actions in the chunk for future steps. + +This pattern is useful when the remote model predicts multiple future +actions at once, while the environment still steps one action at a +time. + +ActionProtocol +-------------- + +The ``ActionProtocol`` family defines the contract that the server and +client use to check that they agree on how to exchange data, without +encoding any policy-specific logic. + +All protocols share basic information such as: + +- how many action dimensions are produced; and +- which observation keys should be provided by the client. + +Specialised subclasses (such as ``ChunkingActionProtocol``) can add +extra fields that are only relevant for a particular pattern, for +example the length of an action chunk. Other use cases can define their +own protocol subclasses as needed, as long as both the server-side and +client-side policy use the same class. + +PolicyServer +------------ + +``PolicyServer`` is a small ZeroMQ-based loop that exposes a single +``ServerSidePolicy`` instance over a dict-based RPC API. It is +intentionally minimal: most users only need to implement a +``ServerSidePolicy`` and then start a server via the generic runner +or a domain-specific wrapper such as ``docker/run_gr00t_server.sh``, +without subclassing ``PolicyServer`` itself. diff --git a/docs/pages/example_workflows/locomanipulation/step_4_evaluation.rst b/docs/pages/example_workflows/locomanipulation/step_4_evaluation.rst index 7ae535f3..79a1cfa7 100644 --- a/docs/pages/example_workflows/locomanipulation/step_4_evaluation.rst +++ b/docs/pages/example_workflows/locomanipulation/step_4_evaluation.rst @@ -147,3 +147,51 @@ and the number of episodes is more than the single environment evaluation becaus The policy was trained on datasets generated using CPU-based physics, therefore the evaluation uses ``--device cpu`` to ensure physics reproducibility. If you have GPU-generated datasets, you can switch to using GPU-based physics for evaluation by providing the ``--device cuda`` flag. + +Step 3: Remote policy evaluation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The same GR00T N1.5 policy can also be evaluated as a *remote policy* +running in a separate process, using the generic remote policy interface +and the remote policy design described in +:doc:`../../concepts/concept_remote_policies_design`. + +Start the GR00T policy server in a separate terminal using the helper +script: + +.. code-block:: bash + + bash docker/run_gr00t_server.sh \ + --host 127.0.0.1 \ + --port 5555 \ + --policy_type isaaclab_arena_gr00t.policy.gr00t_remote_policy.Gr00tRemoteServerSidePolicy \ + --policy_config_yaml_path isaaclab_arena_gr00t/policy/config/g1_locomanip_gr00t_closedloop_config.yaml \ + --policy_device cuda + +If the models directory was already configured earlier (for example via +the ``MODELS_DIR`` environment variable), the ``-m`` flag can be +omitted. Otherwise, use ``-m`` to point to the directory that contains +the GR00T model checkpoint. + +Then, instead of running the closed-loop policy directly inside the +Arena process, connect from the evaluation script using a client-side +remote policy: + +.. code-block:: bash + + python isaaclab_arena/evaluation/policy_runner.py \ + --policy_type isaaclab_arena.policy.action_chunking_client.ActionChunkingClientSidePolicy \ + --remote_host 127.0.0.1 \ + --remote_port 5555 \ + --num_steps 1200 \ + --num_envs 5 \ + --enable_cameras \ + --device cpu \ + --remote_kill_on_exit \ + galileo_g1_locomanip_pick_and_place \ + --object brown_box \ + --embodiment g1_wbc_joint + +In this configuration, the environment and evaluation logic run inside +Isaac Lab Arena, while GR00T inference runs in the separate server +process, connected through the remote policy interface. diff --git a/docs/pages/example_workflows/static_manipulation/step_5_evaluation.rst b/docs/pages/example_workflows/static_manipulation/step_5_evaluation.rst index 04165dad..6609a0ea 100644 --- a/docs/pages/example_workflows/static_manipulation/step_5_evaluation.rst +++ b/docs/pages/example_workflows/static_manipulation/step_5_evaluation.rst @@ -148,3 +148,51 @@ than the single environment evaluation because of the parallel evaluation. which are realized by using the PINK IK controller. GR00T N1.5 policy is trained on upper body joint positions, so we use ``gr1_joint`` for closed-loop policy inference. + + +Step 3: Remote policy evaluation +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The same task can also be evaluated using a remote policy running in a +separate process, using the generic remote policy interface described +in :doc:`../../concepts/concept_remote_policies_design`. + +Start the remote policy server (for example, a GR00T-based policy) in a +separate terminal using the provided helper script: + +.. code-block:: bash + + bash docker/run_gr00t_server.sh \ + --host 127.0.0.1 \ + --port 5555 \ + --policy_type isaaclab_arena_gr00t.policy.gr00t_remote_policy.Gr00tRemoteServerSidePolicy \ + --policy_config_yaml_path isaaclab_arena_gr00t/policy/config/gr1_manip_gr00t_closedloop_config.yaml + +If the models directory was already configured in previous steps (for +example via the ``MODELS_DIR`` environment variable), the ``-m`` flag +can be omitted. Otherwise, use ``-m`` to point to the directory that +contains the GR00T model files on the host. + +This script builds and runs a dedicated remote policy server based on +the generic ``remote_policy_server_runner`` and exposes the policy +over ZeroMQ and msgpack. + +Then connect from the evaluation script using a client-side remote +policy: + +.. code-block:: bash + + python isaaclab_arena/evaluation/policy_runner.py \ + --policy_type isaaclab_arena.policy.action_chunking_client.ActionChunkingClientSidePolicy \ + --remote_host 127.0.0.1 \ + --remote_port 5555 \ + --num_steps 2000 \ + --num_envs 10 \ + --enable_cameras \ + --remote_kill_on_exit \ + gr1_open_microwave \ + --embodiment gr1_joint + +In this configuration, the environment runs inside Isaac Sim while +all policy inference happens in the separate GR00T server process, +connected through the remote policy interface. diff --git a/docs/pages/quickstart/docker_containers.rst b/docs/pages/quickstart/docker_containers.rst index e965a4f8..76ebc828 100644 --- a/docs/pages/quickstart/docker_containers.rst +++ b/docs/pages/quickstart/docker_containers.rst @@ -50,3 +50,52 @@ These directories are configurable through argument to the run docker script. For a full list of arguments see the ``run_docker.sh`` script at ``isaac_arena/docker/run_docker.sh``. + +Remote policies and GR00T +------------------------- + +Remote policy workflows (for example the GR1 Open Microwave Door Task) +are split into two containers: + +- The **Base** Isaac Lab Arena container, started via + ``docker/run_docker.sh``. This container does not need to install + GR00T when you run policies in remote mode. +- A separate **GR00T policy server** container, started via + ``docker/run_gr00t_server.sh``, which builds an image from + ``docker/Dockerfile.gr00t_server`` and runs the remote policy server + entrypoint. + +A typical workflow is: + +1. Start the Base container for simulation and evaluation: + + .. code-block:: bash + + bash docker/run_docker.sh + +2. In a second terminal, start the GR00T policy server container: + + .. code-block:: bash + + bash docker/run_gr00t_server.sh \ + --host 127.0.0.1 \ + --port 5555 \ + --policy_type isaaclab_arena_gr00t.policy.gr00t_remote_policy.Gr00tRemoteServerSidePolicy \ + --policy_config_yaml_path isaaclab_arena_gr00t/policy/config/gr1_manip_gr00t_closedloop_config.yaml + +3. Inside the Base container, run the evaluation script with a + client-side remote policy (see the static manipulation example + workflow for full command lines). + +This setup cleanly separates the Isaac Lab Arena simulation environment +from the GR00T policy server environment. The Base container focuses on +running environments and evaluation logic, while the GR00T server +container is responsible only for hosting the GR00T model and its +dependencies. + +If you want to host other policy models as remote servers, you can +follow the same pattern: create a dedicated server Dockerfile and +launcher script (similar to ``docker/Dockerfile.gr00t_server`` and +``docker/run_gr00t_server.sh``), and point it to a custom +``ServerSidePolicy`` implementation as described in +:doc:`../concepts/concept_remote_policies_design`. diff --git a/isaaclab_arena/evaluation/policy_runner_cli.py b/isaaclab_arena/evaluation/policy_runner_cli.py index 219f9b19..7fb4001e 100644 --- a/isaaclab_arena/evaluation/policy_runner_cli.py +++ b/isaaclab_arena/evaluation/policy_runner_cli.py @@ -4,7 +4,7 @@ # SPDX-License-Identifier: Apache-2.0 import argparse -from pathlib import Path + def add_policy_runner_arguments(parser: argparse.ArgumentParser) -> None: """Add policy runner specific arguments to the parser.""" @@ -20,4 +20,3 @@ def add_policy_runner_arguments(parser: argparse.ArgumentParser) -> None: default=100, help="Number of steps to run the policy for", ) - diff --git a/isaaclab_arena/policy/__init__.py b/isaaclab_arena/policy/__init__.py index c8fd8e52..2d60af4f 100644 --- a/isaaclab_arena/policy/__init__.py +++ b/isaaclab_arena/policy/__init__.py @@ -3,5 +3,6 @@ # # SPDX-License-Identifier: Apache-2.0 +from .action_chunking_client import * from .replay_action_policy import * from .zero_action_policy import * diff --git a/isaaclab_arena/policy/action_chunking_client.py b/isaaclab_arena/policy/action_chunking_client.py new file mode 100644 index 00000000..6c47f9c5 --- /dev/null +++ b/isaaclab_arena/policy/action_chunking_client.py @@ -0,0 +1,182 @@ +# Copyright (c) 2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +import gymnasium as gym +import torch +from typing import Any + +from isaaclab_arena.policy.client_side_policy import ClientSidePolicy +from isaaclab_arena.remote_policy.action_protocol import ChunkingActionProtocol +from isaaclab_arena.remote_policy.remote_policy_config import RemotePolicyConfig + + +class ActionChunkingClientSidePolicy(ClientSidePolicy): + """Client-side policy that consumes fixed-length action chunks sequentially.""" + + def __init__( + self, + config: Any, + num_envs: int, + device: str, + remote_config: RemotePolicyConfig, + ) -> None: + super().__init__(config=config, remote_config=remote_config, protocol_cls=ChunkingActionProtocol) + + self._num_envs = num_envs + self._device = device + + self._current_action_chunk = torch.zeros( + self._num_envs, + self.protocol.action_chunk_length, + self.protocol.action_dim, + dtype=torch.float32, + device=self._device, + ) + self._current_action_index = torch.full( + (self._num_envs,), + fill_value=-1, + dtype=torch.int32, + device=self._device, + ) + self._env_requires_new_chunk = torch.ones( + self._num_envs, + dtype=torch.bool, + device=self._device, + ) + + self.task_description: str | None = None + + # ---------------------- CLI ---------------------------------------- + + @staticmethod + def add_args_to_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Add CLI arguments for ActionChunkingClientSidePolicy.""" + # Shared remote policy args. + parser = ClientSidePolicy.add_remote_args_to_parser(parser) + + # Policy-specific args. + group = parser.add_argument_group( + "Action Chunking Client Policy", + "Arguments for client-side action chunking policy.", + ) + group.add_argument( + "--policy_device", + type=str, + default="cuda", + help="Device to use for the policy-related operations.", + ) + return parser + + @staticmethod + def from_args(args: argparse.Namespace) -> ActionChunkingClientSidePolicy: + """Create an ActionChunkingClientSidePolicy from CLI arguments.""" + remote_config = ClientSidePolicy.build_remote_config_from_args(args) + return ActionChunkingClientSidePolicy( + config=None, + num_envs=args.num_envs, + device=args.policy_device, + remote_config=remote_config, + ) + + # ---------------------- Task description ---------------------------- + + def set_task_description(self, task_description: str | None) -> str: + """Set the task description on both client-side and remote policy.""" + self.task_description = task_description + if task_description is not None: + self.remote_client.call_endpoint( + "set_task_description", + data={"task_description": task_description}, + requires_input=True, + ) + return self.task_description or "" + + # ---------------------- Chunking logic ------------------------------ + + def _request_new_chunk( + self, + observation: dict[str, Any], + ) -> torch.Tensor: + """Request a new action chunk from the remote policy and validate it.""" + protocol = self.protocol + packed_obs = self.pack_observation_for_server(observation) + + resp = self.remote_client.get_action(packed_obs) + if not isinstance(resp, dict): + raise TypeError(f"Expected dict from get_action, got {type(resp)!r}") + if "action" not in resp: + raise KeyError("Remote response does not contain key 'action' for ActionChunkingClientSidePolicy.") + + raw_chunk = resp["action"] + if not isinstance(raw_chunk, torch.Tensor): + raw_chunk = torch.tensor(raw_chunk, dtype=torch.float32, device=self._device) + else: + raw_chunk = raw_chunk.to(self._device, dtype=torch.float32) + + if raw_chunk.shape[0] != self._num_envs: + raise ValueError(f"Expected batch size {self._num_envs}, got {raw_chunk.shape[0]}") + if raw_chunk.shape[1] != protocol.action_chunk_length: + raise ValueError( + f"Expected at least {protocol.action_chunk_length} actions per chunk, got {raw_chunk.shape[1]}" + ) + if raw_chunk.shape[2] != protocol.action_dim: + raise ValueError(f"Expected action_dim {protocol.action_dim}, got {raw_chunk.shape[2]}") + + return raw_chunk + + def get_action( + self, + env: gym.Env, + observation: gym.spaces.Dict, + ) -> torch.Tensor: + """Return one action per env step, consuming action chunks sequentially.""" + protocol = self.protocol + + if bool(self._env_requires_new_chunk.any()): + new_chunk = self._request_new_chunk(observation) + mask = self._env_requires_new_chunk + + self._current_action_chunk[mask] = new_chunk[mask] + self._current_action_index[mask] = 0 + self._env_requires_new_chunk[mask] = False + + idx = self._current_action_index # [N] + batch_idx = torch.arange(self._num_envs, device=self._device) + + action = self._current_action_chunk[batch_idx, idx] + + if action.shape != (self._num_envs, protocol.action_dim): + raise RuntimeError( + f"Unexpected action shape {action.shape}, expected {(self._num_envs, protocol.action_dim)}" + ) + + self._current_action_index += 1 + self._env_requires_new_chunk = self._current_action_index >= protocol.action_chunk_length + + return action + + def reset(self, env_ids: torch.Tensor | None = None) -> None: + """Reset client-side chunking state and remote policy state.""" + if env_ids is None: + env_ids = torch.arange( + self._num_envs, + device=self._device, + dtype=torch.long, + ) + + self._current_action_chunk[env_ids] = 0.0 + self._current_action_index[env_ids] = -1 + self._env_requires_new_chunk[env_ids] = True + + # Reset remote state via ClientSidePolicy. + super().reset(env_ids=env_ids) diff --git a/isaaclab_arena/policy/client_side_policy.py b/isaaclab_arena/policy/client_side_policy.py new file mode 100644 index 00000000..afd0aa8d --- /dev/null +++ b/isaaclab_arena/policy/client_side_policy.py @@ -0,0 +1,202 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +import torch +from typing import Any + +from isaaclab_arena.policy.policy_base import PolicyBase +from isaaclab_arena.remote_policy.action_protocol import ActionMode, ActionProtocol +from isaaclab_arena.remote_policy.policy_client import PolicyClient +from isaaclab_arena.remote_policy.remote_policy_config import RemotePolicyConfig + + +class ClientSidePolicy(PolicyBase): + """Base class for policies that query a remote policy server. + + Responsibilities: + - Manage RemotePolicyConfig and PolicyClient. + - Handshake with the server via get_init_info(). + - Provide observation packing based on observation_keys. + - Provide shared CLI helpers for remote-related arguments. + + Subclasses: + - Must implement get_action(). + """ + + def __init__(self, config: Any, remote_config: RemotePolicyConfig, protocol_cls: type[ActionProtocol]) -> None: + super().__init__(config=config) + + if protocol_cls is None: + raise ValueError("protocol_cls is required.") + + if protocol_cls.MODE is None: + raise ValueError(f"{protocol_cls.__name__}.MODE must be defined as an ActionMode.") + + self.protocol_cls = protocol_cls + requested_action_mode: ActionMode = protocol_cls.MODE + + self._remote_config = remote_config + self._client = PolicyClient(config=self._remote_config) + + # 1) Ping server to ensure connectivity. + if not self._client.ping(): + raise RuntimeError( + f"Failed to connect to remote policy server at {self._remote_config.host}:{self._remote_config.port}." + ) + + # 2) Handshake: send requested_action_mode, parse response. + init_resp = self._client.get_init_info(requested_action_mode=requested_action_mode.value) + + if not isinstance(init_resp, dict): + raise TypeError(f"Expected dict from get_init_info, got {type(init_resp)!r}") + + status = init_resp.get("status", "error") + if status != "success": + message = init_resp.get("message", "no message") + raise RuntimeError(f"Remote policy get_init_info failed with status='{status}': {message}") + + cfg_dict = init_resp.get("config") + if not isinstance(cfg_dict, dict): + raise TypeError( + f"Remote policy get_init_info must return a 'config' dict inside the response, got {type(cfg_dict)!r}" + ) + + self._protocol: ActionProtocol = self.protocol_cls.from_dict(cfg_dict) + + # ---------------------- properties ---------------------------------- + @property + def protocol(self) -> ActionProtocol: + return self._protocol + + @property + def action_mode(self) -> ActionMode: + return self._protocol.mode + + @property + def action_dim(self) -> int: + return self._protocol.action_dim + + @property + def observation_keys(self) -> list[str]: + return list(self._protocol.observation_keys) + + @property + def remote_config(self) -> RemotePolicyConfig: + return self._remote_config + + @property + def remote_client(self) -> PolicyClient: + return self._client + + @property + def is_remote(self) -> bool: + return True + + # ---------------------- observation packing ------------------------- + @staticmethod + def _get_nested_observation(observation: dict[str, Any], key_path: str) -> Any: + """Get a nested value from a dict using 'a.b.c' path.""" + cur: Any = observation + for k in key_path.split("."): + cur = cur[k] + return cur + + def pack_observation_for_server( + self, + observation: dict[str, Any], + ) -> dict[str, Any]: + """Pack selected observation entries into a flat CPU dict for the server. + + Uses `self.observation_keys` from ClientSidePolicyConfig and: + - Extracts values using nested key paths. + - Moves torch.Tensor values to CPU numpy arrays. + """ + packed: dict[str, Any] = {} + for key_path in self.observation_keys: + value = self._get_nested_observation(observation, key_path) + if isinstance(value, torch.Tensor): + value = value.detach().cpu().numpy() + packed[key_path] = value + return packed + + def reset(self, env_ids: torch.Tensor | None = None) -> None: + """Optionally reset remote policy state. + + Client-side state should be reset in subclasses. + """ + if env_ids is None: + return + env_ids_list = env_ids.detach().cpu().tolist() + self._client.reset(env_ids=env_ids_list, options=None) + + def shutdown_remote(self, kill_server: bool = False) -> None: + """Clean up the remote client and optionally stop the remote server.""" + if kill_server: + try: + self._client.call_endpoint("kill", requires_input=False) + except Exception as exc: + print(f"[ClientSidePolicy] Failed to send kill to remote server: {exc}") + self._client.close() + + # ---------------------- shared CLI helpers -------------------------- + + @staticmethod + def add_remote_args_to_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Add shared remote-policy arguments to the parser. + + This should be called from subclass.add_args_to_parser(). + """ + group = parser.add_argument_group( + "Remote Policy", + "Arguments for connecting to a remote policy server.", + ) + group.add_argument( + "--remote_host", + type=str, + default=None, + required=True, + help="Remote policy server host.", + ) + group.add_argument( + "--remote_port", + type=int, + default=5555, + help="Remote policy server port.", + ) + group.add_argument( + "--remote_api_token", + type=str, + default=None, + help="API token for the remote policy server.", + ) + group.add_argument( + "--remote_timeout_ms", + type=int, + default=15000, + help="Timeout (ms) for remote policy requests.", + ) + group.add_argument( + "--remote_kill_on_exit", + action="store_true", + help="If set, send a 'kill' request to the remote policy server when the run finishes.", + ) + return parser + + @staticmethod + def build_remote_config_from_args(args: argparse.Namespace) -> RemotePolicyConfig: + """Construct RemotePolicyConfig from CLI arguments. + + Assumes add_remote_args_to_parser() has been called on the parser. + """ + + return RemotePolicyConfig( + host=args.remote_host, + port=args.remote_port, + api_token=args.remote_api_token, + timeout_ms=args.remote_timeout_ms, + ) diff --git a/isaaclab_arena/policy/policy_base.py b/isaaclab_arena/policy/policy_base.py index 068d1b05..bf594ea5 100644 --- a/isaaclab_arena/policy/policy_base.py +++ b/isaaclab_arena/policy/policy_base.py @@ -3,7 +3,6 @@ # # SPDX-License-Identifier: Apache-2.0 -from __future__ import annotations import argparse import gymnasium as gym import torch @@ -11,14 +10,6 @@ from gymnasium.spaces.dict import Dict as GymSpacesDict from typing import Any -from enum import Enum - -from isaaclab_arena.remote_policy.remote_policy_config import RemotePolicyConfig -from isaaclab_arena.remote_policy.policy_client import PolicyClient - -class PolicyDeployment(Enum): - LOCAL = "local" - REMOTE = "remote" class PolicyBase(ABC): """ @@ -33,11 +24,7 @@ class PolicyBase(ABC): def __init__(self, config: Any): """ - Base class for policies with optional remote deployment. - - Args: - policy_deployment: "local" (default) or "remote". - remote_config: Required when policy_deployment == "remote". + Base class for policies. """ self.config = config @@ -99,6 +86,11 @@ def length(self) -> int | None: """Get the length of the policy (for dataset-driven policies).""" pass + @property + def is_remote(self) -> bool: + """Check if policy is run remotely.""" + return False + @staticmethod @abstractmethod def add_args_to_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: @@ -110,19 +102,3 @@ def add_args_to_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentPars def from_args(args: argparse.Namespace) -> "PolicyBase": """Create a policy from the arguments.""" raise NotImplementedError("Function not implemented yet.") - def shutdown_remote(self, kill_server: bool = False) -> None: - """ - Clean up remote client, and optionally send 'kill' to stop the remote server. - - Args: - kill_server: If True, send a 'kill' RPC before closing the client. - """ - if not self.is_remote or self._policy_client is None: - return - if kill_server: - try: - self._policy_client.call_endpoint("kill", requires_input=False) - except Exception as exc: - print(f"[PolicyBase] Failed to send kill to remote server: {exc}") - self._policy_client.close() - self._policy_client = None diff --git a/isaaclab_arena/remote_policy/__init__.py b/isaaclab_arena/remote_policy/__init__.py index befc257f..6b2258a3 100644 --- a/isaaclab_arena/remote_policy/__init__.py +++ b/isaaclab_arena/remote_policy/__init__.py @@ -1,13 +1,14 @@ -# Copyright (c) 2025, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). # All rights reserved. # # SPDX-License-Identifier: Apache-2.0 -from .remote_policy_config import RemotePolicyConfig -from .server_side_policy import ServerSidePolicy +from .action_protocol import ActionMode, ActionProtocol, ChunkingActionProtocol from .message_serializer import MessageSerializer from .policy_client import PolicyClient from .policy_server import PolicyServer +from .remote_policy_config import RemotePolicyConfig +from .server_side_policy import ServerSidePolicy __all__ = [ "RemotePolicyConfig", @@ -15,4 +16,7 @@ "MessageSerializer", "PolicyClient", "PolicyServer", + "ActionMode", + "ActionProtocol", + "ChunkingActionProtocol", ] diff --git a/isaaclab_arena/remote_policy/action_protocol.py b/isaaclab_arena/remote_policy/action_protocol.py new file mode 100644 index 00000000..d76688b6 --- /dev/null +++ b/isaaclab_arena/remote_policy/action_protocol.py @@ -0,0 +1,83 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, ClassVar + + +class ActionMode(str, Enum): + """Action output mode of a policy. + + Currently only CHUNK is used. + Other modes can be added later if needed. + """ + + CHUNK = "chunk" + + +@dataclass +class ActionProtocol(ABC): + """Base handshake/config for a policy's action output. + + - Encapsulates the ActionMode. + - Holds common fields (action_dim, observation_keys). + - Subclasses add mode-specific fields (e.g. chunk_length). + """ + + # Subclasses must override this. + MODE: ClassVar[ActionMode | None] = None + + # Common fields for all modes. + action_dim: int = 0 + observation_keys: list[str] = field(default_factory=list) + + def __post_init__(self) -> None: + """Validate that subclasses configured MODE properly.""" + mode = type(self).MODE + if mode is None: + raise NotImplementedError(f"{type(self).__name__} must define MODE as an ActionMode.") + + @classmethod + @abstractmethod + def from_dict(cls, data: dict[str, Any]) -> ActionProtocol: + """Build protocol config from server-side config dict.""" + + @abstractmethod + def to_dict(self) -> dict[str, Any]: + """Serialize protocol config to a dict for RPC.""" + + @property + def mode(self) -> ActionMode: + return self.MODE + + +@dataclass +class ChunkingActionProtocol(ActionProtocol): + """ActionProtocol for CHUNK mode.""" + + MODE: ClassVar[ActionMode] = ActionMode.CHUNK + + # Mode-specific field. + action_chunk_length: int = 0 + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> ChunkingActionProtocol: + return cls( + action_dim=int(data["action_dim"]), + observation_keys=list(data["observation_keys"]), + action_chunk_length=int(data["action_chunk_length"]), + ) + + def to_dict(self) -> dict[str, Any]: + return { + "action_mode": self.mode.value, + "action_dim": self.action_dim, + "observation_keys": self.observation_keys, + "action_chunk_length": self.action_chunk_length, + } diff --git a/isaaclab_arena/remote_policy/message_serializer.py b/isaaclab_arena/remote_policy/message_serializer.py index 796161dd..94167cd2 100644 --- a/isaaclab_arena/remote_policy/message_serializer.py +++ b/isaaclab_arena/remote_policy/message_serializer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). # All rights reserved. # # SPDX-License-Identifier: Apache-2.0 @@ -6,12 +6,12 @@ from __future__ import annotations import io +import numpy as np from dataclasses import asdict, is_dataclass from enum import Enum -from typing import Any, Dict +from typing import Any import msgpack -import numpy as np class MessageSerializer: @@ -40,7 +40,21 @@ def from_bytes(data: bytes) -> Any: @staticmethod def _decode_custom(obj: Any) -> Any: - """Decode tagged structures created in _encode_custom.""" + """Decode tagged structures created in _encode_custom. + + This function is registered as the `object_hook` for msgpack.unpackb, + so it is called once for every decoded map/dict. + + - If the dict contains a special tag (e.g. '__ndarray_class__' or + '__blob_class__'), it is converted back into the corresponding + high-level type (numpy array, blob, etc.). + - If the dict has no special tag, it is returned unchanged. In that + case the object stays as whatever type msgpack's default decoder + produced (dict, list, int, str, ...). + + Untagged values and non-dict types are therefore handled entirely + by msgpack's built-in decoder. + """ if not isinstance(obj, dict): return obj @@ -105,9 +119,7 @@ def to_json_serializable(obj: Any) -> Any: return bool(obj) elif isinstance(obj, dict): return {key: to_json_serializable(value) for key, value in obj.items()} - elif isinstance(obj, (list, tuple)): - return [to_json_serializable(item) for item in obj] - elif isinstance(obj, set): + elif isinstance(obj, (list, tuple, set)): return [to_json_serializable(item) for item in obj] elif isinstance(obj, (str, int, float, bool, type(None))): return obj @@ -116,4 +128,3 @@ def to_json_serializable(obj: Any) -> Any: else: # Fallback: convert to string return str(obj) - diff --git a/isaaclab_arena/remote_policy/policy_client.py b/isaaclab_arena/remote_policy/policy_client.py index 970e8c3d..04e25b2e 100644 --- a/isaaclab_arena/remote_policy/policy_client.py +++ b/isaaclab_arena/remote_policy/policy_client.py @@ -1,17 +1,18 @@ -# Copyright (c) 2025, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). # All rights reserved. # # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations -from dataclasses import dataclass -from typing import Any, Dict, Optional +import warnings +from typing import Any import zmq -from .message_serializer import MessageSerializer -from .remote_policy_config import RemotePolicyConfig +from isaaclab_arena.remote_policy.message_serializer import MessageSerializer +from isaaclab_arena.remote_policy.remote_policy_config import RemotePolicyConfig + class PolicyClient: """Synchronous client for talking to a PolicyServer over ZeroMQ.""" @@ -32,20 +33,24 @@ def ping(self) -> bool: try: self.call_endpoint("ping", requires_input=False) return True - except Exception: + except Exception as exc: warnings.warn( - f"[PolicyClient] Failed to ping remote policy server at " - f"{self._config.host}:{self._config.port}: {exc}" + f"[PolicyClient] Failed to ping remote policy server at {self._config.host}:{self._config.port}: {exc}" ) return False - def reset(self, env_ids=None, options: Optional[Dict[str, Any]] = None) -> Any: + def reset(self, env_ids=None, options: dict[str, Any] | None = None) -> Any: """Reset remote policy state.""" - return self.call_endpoint( + resp = self.call_endpoint( endpoint="reset", data={"env_ids": env_ids, "options": options}, requires_input=True, ) + if isinstance(resp, dict): + status = resp.get("status") + if status not in ("reset_success", "ok", "reset_ok", None): + raise RuntimeError(f"Remote reset failed with status={status}, resp={resp}") + return resp def kill(self) -> Any: """Ask remote server to stop main loop.""" @@ -53,11 +58,11 @@ def kill(self) -> Any: def get_action( self, - observation: Dict[str, Any], - ) -> Dict[str, Any]: + observation: dict[str, Any], + ) -> dict[str, Any]: """Send policy_observations and get back policy action dict.""" - payload: Dict[str, Any] = {"observation": observation} - + payload: dict[str, Any] = {"observation": observation} + resp = self.call_endpoint( endpoint="get_action", data=payload, @@ -65,14 +70,48 @@ def get_action( ) return resp + def get_init_info(self, requested_action_mode: str) -> dict[str, Any]: + """Call get_init_info on the server with a requested_action_mode. + + Args: + requested_action_mode: ActionMode value (e.g. "chunk"). + + Returns: + A dict returned by the server, expected to contain: + - "status" + - "message" (optional) + - "config" (on success) + """ + payload = {"requested_action_mode": requested_action_mode} + resp = self.call_endpoint( + "get_init_info", + data=payload, + requires_input=True, + ) + if not isinstance(resp, dict): + raise TypeError(f"Expected dict from get_init_info, got {type(resp)!r}") + return resp + + def set_task_description(self, task_description: str | None) -> dict[str, Any]: + """Send task description to the remote policy.""" + payload: dict[str, Any] = {"task_description": task_description} + resp = self.call_endpoint( + endpoint="set_task_description", + data=payload, + requires_input=True, + ) + if not isinstance(resp, dict): + raise TypeError(f"Expected dict from set_task_description, got {type(resp)!r}") + return resp + def call_endpoint( self, endpoint: str, - data: Optional[Dict[str, Any]] = None, + data: dict[str, Any] | None = None, requires_input: bool = True, ) -> Any: """Generic RPC helper.""" - request: Dict[str, Any] = {"endpoint": endpoint} + request: dict[str, Any] = {"endpoint": endpoint} if requires_input: request["data"] = data or {} if self._config.api_token: @@ -90,4 +129,3 @@ def close(self) -> None: """Close the underlying ZeroMQ socket and context.""" self._socket.close() self._context.term() - diff --git a/isaaclab_arena/remote_policy/policy_registry.py b/isaaclab_arena/remote_policy/policy_registry.py deleted file mode 100644 index 636cd1e3..00000000 --- a/isaaclab_arena/remote_policy/policy_registry.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright (c) 2025, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Dict, List, Type - -from isaaclab_arena.remote_policy.server_side_policy import ServerSidePolicy - - -@dataclass(frozen=True) -class PolicyEntry: - policy_type: str - entry_point: str # "module_path:ClassName" - - -class PolicyRegistry: - def __init__(self) -> None: - self._entries: Dict[str, PolicyEntry] = {} - - def register(self, policy_type: str, entry_point: str) -> None: - if policy_type in self._entries: - raise ValueError(f"Policy type {policy_type!r} already registered") - if ":" not in entry_point: - raise ValueError( - f"Invalid entry_point {entry_point!r} for policy_type={policy_type!r} " - "(expected 'module_path:ClassName')" - ) - self._entries[policy_type] = PolicyEntry(policy_type, entry_point) - - def available_policy_types(self) -> List[str]: - return sorted(self._entries.keys()) - - def resolve_class(self, policy_type: str) -> Type[ServerSidePolicy]: - if policy_type not in self._entries: - raise ValueError( - f"Unknown policy_type={policy_type!r}. " - f"Available options: {self.available_policy_types()}" - ) - - entry = self._entries[policy_type] - module_path, class_name = entry.entry_point.split(":", 1) - - try: - module = __import__(module_path, fromlist=[class_name]) - except ImportError as exc: - raise ImportError( - f"Failed to import module '{module_path}' for policy_type={policy_type!r}. " - "This usually means the corresponding policy package is not installed " - "in the current server environment." - ) from exc - - try: - cls = getattr(module, class_name) - except AttributeError as exc: - raise ImportError( - f"Module '{module_path}' does not define class '{class_name}' " - f"for policy_type={policy_type!r}." - ) from exc - - if not issubclass(cls, ServerSidePolicy): - raise TypeError( - f"Resolved class '{class_name}' from '{module_path}' is not a ServerSidePolicy " - f"subclass (policy_type={policy_type!r})." - ) - return cls - - -policy_registry = PolicyRegistry() - -# Built-in registrations -policy_registry.register( - "gr00t_closedloop", - "isaaclab_arena_gr00t.gr00t_remote_policy:Gr00tRemoteServerSidePolicy", -) - diff --git a/isaaclab_arena/remote_policy/policy_server.py b/isaaclab_arena/remote_policy/policy_server.py index c2d81553..12f7c030 100644 --- a/isaaclab_arena/remote_policy/policy_server.py +++ b/isaaclab_arena/remote_policy/policy_server.py @@ -1,17 +1,18 @@ -# Copyright (c) 2025, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). # All rights reserved. # # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Callable, Dict, Optional, Type +from typing import Any import zmq -from .model_policy import ModelPolicy -from .message_serializer import MessageSerializer +from isaaclab_arena.remote_policy.message_serializer import MessageSerializer +from isaaclab_arena.remote_policy.server_side_policy import ServerSidePolicy @dataclass @@ -23,10 +24,10 @@ class EndpointHandler: class PolicyServer: def __init__( self, - policy: ModelPolicy, + policy: ServerSidePolicy, host: str = "*", port: int = 5555, - api_token: Optional[str] = None, + api_token: str | None = None, timeout_ms: int = 15000, ) -> None: self._policy = policy @@ -38,9 +39,8 @@ def __init__( print(f"[PolicyServer] binding on {bind_addr}") self._socket.bind(bind_addr) self._api_token = api_token - self._serializer = MessageSerializer - self._endpoints: Dict[str, EndpointHandler] = {} + self._endpoints: dict[str, EndpointHandler] = {} self._register_default_endpoints() def _register_default_endpoints(self) -> None: @@ -48,6 +48,8 @@ def _register_default_endpoints(self) -> None: self.register_endpoint("kill", self._handle_kill, requires_input=False) self.register_endpoint("get_action", self._handle_get_action, requires_input=True) self.register_endpoint("reset", self._handle_reset, requires_input=True) + self.register_endpoint("get_init_info", self._handle_get_init_info, requires_input=True) + self.register_endpoint("set_task_description", self._handle_set_task_description, requires_input=True) print(f"[PolicyServer] registered endpoints: {list(self._endpoints.keys())}") def register_endpoint( @@ -58,38 +60,73 @@ def register_endpoint( ) -> None: self._endpoints[name] = EndpointHandler(handler=handler, requires_input=requires_input) - def _handle_ping(self) -> Dict[str, Any]: + def _handle_get_init_info( + self, + requested_action_mode: str, + ) -> dict[str, Any]: + print(f"[PolicyServer] handle get_init_info: requested_action_mode={requested_action_mode!r}") + resp = self._policy.get_init_info(requested_action_mode=requested_action_mode) + if not isinstance(resp, dict): + raise TypeError(f"Policy.get_init_info() must return dict, got {type(resp)!r}") + return resp + + def _handle_set_task_description( + self, + task_description: str | None = None, + **_: Any, + ) -> dict[str, Any]: + print(f"[PolicyServer] handle set_task_description: {task_description!r}") + resp = self._policy.set_task_description(task_description) + if not isinstance(resp, dict): + raise TypeError(f"Policy.set_task_description() must return dict, got {type(resp)!r}") + return resp + + def _handle_ping(self) -> dict[str, Any]: print("[PolicyServer] handle ping") return {"status": "ok"} - def _handle_kill(self) -> Dict[str, Any]: + def _handle_kill(self) -> dict[str, Any]: print("[PolicyServer] handle kill -> stopping") self._running = False return {"status": "stopping"} def _handle_get_action( self, - observation: Dict[str, Any], - options: Optional[Dict[str, Any]] = None, + observation: dict[str, Any], + options: dict[str, Any] | None = None, **_: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: print("[PolicyServer] handle get_action") - print(f" observation keys: {list(observation.keys())}") if options is not None: print(f" options keys: {list(options.keys())}") action, info = self._policy.get_action( observation=observation, options=options, ) - return {"action": action, "info": info} - def _handle_reset(self, env_ids=None, options=None, **_: Any) -> Dict[str, Any]: + if not isinstance(action, dict): + raise TypeError(f"Policy.get_action() must return (dict, dict), got action type={type(action)!r}") + if not isinstance(info, dict): + raise TypeError(f"Policy.get_action() must return (dict, dict), got info type={type(info)!r}") + + merged: dict[str, Any] = {} + merged.update(action) + if any(k in merged for k in info.keys()): + raise ValueError(f"Policy info keys conflict with action keys: {set(merged.keys()) & set(info.keys())}") + merged.update(info) + + return merged + + def _handle_reset(self, env_ids=None, options=None, **_: Any) -> dict[str, Any]: print(f"[PolicyServer] handle reset: env_ids={env_ids}, options={options}") + status: dict[str, Any] = {"status": "reset_success"} if hasattr(self._policy, "reset"): - self._policy.reset(env_ids=env_ids, reset_options=options) - return {"status": "reset"} + resp = self._policy.reset(env_ids=env_ids, reset_options=options) + if isinstance(resp, dict): + status.update(resp) + return status - def _validate_token(self, request: Dict[str, Any]) -> bool: + def _validate_token(self, request: dict[str, Any]) -> bool: if self._api_token is None: return True ok = request.get("api_token") == self._api_token @@ -104,7 +141,7 @@ def run(self) -> None: try: raw = self._socket.recv() print(f"[PolicyServer] received {len(raw)} bytes") - request = self._serializer.from_bytes(raw) + request = MessageSerializer.from_bytes(raw) if not isinstance(request, dict): raise TypeError(f"Expected dict request, got {type(request)!r}") @@ -112,12 +149,16 @@ def run(self) -> None: print(f"[PolicyServer] request keys: {list(request.keys())}") if not self._validate_token(request): - self._socket.send( - self._serializer.to_bytes({"error": "Unauthorized: invalid api_token"}) - ) + self._socket.send(MessageSerializer.to_bytes({"error": "Unauthorized: invalid api_token"})) continue endpoint = request.get("endpoint", "get_action") + if "endpoint" not in request: + self._socket.send(MessageSerializer.to_bytes({"error": "Missing 'endpoint' in request"})) + continue + + endpoint = request["endpoint"] + handler = self._endpoints.get(endpoint) if handler is None: raise ValueError(f"Unknown endpoint: {endpoint}") @@ -132,7 +173,7 @@ def run(self) -> None: else: result = handler.handler() - resp_bytes = self._serializer.to_bytes(result) + resp_bytes = MessageSerializer.to_bytes(result) print(f"[PolicyServer] sending response ({len(resp_bytes)} bytes)") self._socket.send(resp_bytes) except zmq.Again: @@ -143,14 +184,14 @@ def run(self) -> None: print(f"[PolicyServer] Error: {exc}") print(traceback.format_exc()) - self._socket.send(self._serializer.to_bytes({"error": str(exc)})) + self._socket.send(MessageSerializer.to_bytes({"error": str(exc)})) @staticmethod def start( - policy: ModelPolicy, + policy: ServerSidePolicy, host: str = "*", port: int = 5555, - api_token: Optional[str] = None, + api_token: str | None = None, timeout_ms: int = 15000, ) -> None: server = PolicyServer( @@ -161,4 +202,3 @@ def start( timeout_ms=timeout_ms, ) server.run() - diff --git a/isaaclab_arena/remote_policy/remote_policy_config.py b/isaaclab_arena/remote_policy/remote_policy_config.py index 41911bec..a256f14c 100644 --- a/isaaclab_arena/remote_policy/remote_policy_config.py +++ b/isaaclab_arena/remote_policy/remote_policy_config.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). # All rights reserved. # # SPDX-License-Identifier: Apache-2.0 @@ -6,12 +6,13 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Optional + @dataclass class RemotePolicyConfig: """Configuration for using a remote PolicyServer.""" + host: str port: int - api_token: Optional[str] = None + api_token: str | None = None timeout_ms: int = 15000 diff --git a/isaaclab_arena/remote_policy/remote_policy_server_runner.py b/isaaclab_arena/remote_policy/remote_policy_server_runner.py index ddbfbb16..c96cd00b 100644 --- a/isaaclab_arena/remote_policy/remote_policy_server_runner.py +++ b/isaaclab_arena/remote_policy/remote_policy_server_runner.py @@ -1,52 +1,107 @@ -# Copyright (c) 2025, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). # All rights reserved. # # SPDX-License-Identifier: Apache-2.0 + from __future__ import annotations import argparse -from pathlib import Path -from typing import Type +from importlib import import_module -from isaaclab_arena.remote_policy.server_side_policy import ServerSidePolicy from isaaclab_arena.remote_policy.policy_server import PolicyServer -from isaaclab_arena.remote_policy.policy_registry import policy_registry +from isaaclab_arena.remote_policy.server_side_policy import ServerSidePolicy -def resolve_policy_class(policy_type: str) -> Type[ServerSidePolicy]: - return policy_registry.resolve_class(policy_type) +def get_policy_cls(policy_type: str) -> type[ServerSidePolicy]: + """Dynamically import and return a ServerSidePolicy subclass. + The policy_type argument must be a fully qualified Python path of the form: + "package.subpackage.module.ClassName" + """ + print(f"[remote_policy_server_runner] Importing server-side policy from: {policy_type}") + if "." not in policy_type: + raise ValueError( + "policy_type must be a dotted Python import path of the form " + "'module.submodule.ClassName', " + f"got: {policy_type!r}" + ) + module_path, class_name = policy_type.rsplit(".", 1) + module = import_module(module_path) + policy_cls = getattr(module, class_name) + return policy_cls -def parse_args() -> argparse.Namespace: + +def build_base_parser() -> argparse.ArgumentParser: + """Build the base CLI parser for the remote policy server. + + This parser only contains arguments that are common to all server-side policies. + Policy-specific arguments are added later by the selected ServerSidePolicy subclass. + """ parser = argparse.ArgumentParser("IsaacLab Arena Remote Policy Server") + + # Generic server options. parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=5555) parser.add_argument("--api_token", type=str, default=None) parser.add_argument("--timeout_ms", type=int, default=5000) + # Which ServerSidePolicy implementation to run. parser.add_argument( "--policy_type", type=str, required=True, - choices=policy_registry.available_policy_types(), - help="Which remote policy to run (e.g. 'gr00t_closedloop').", - ) - parser.add_argument( - "--policy_config_yaml_path", - type=str, - required=True, - help="Path to policy-specific config YAML.", + help=( + "Dotted Python path of the server-side policy to run, e.g. " + "'isaaclab_arena_gr00t.policy.gr00t_remote_policy.Gr00tRemoteServerSidePolicy'." + ), ) - return parser.parse_args() + return parser + + +def parse_args() -> argparse.Namespace: + """Parse CLI arguments in two stages. + + 1) Parse only the base arguments to discover which policy class to use. + 2) Let that class extend the parser with its own arguments, then parse again. + """ + # Stage 1: parse base args to get policy_type. + base_parser = build_base_parser() + base_args, _ = base_parser.parse_known_args() + + policy_cls = get_policy_cls(base_args.policy_type) + print(f"[remote_policy_server_runner] Requested server-side policy: {base_args.policy_type} -> {policy_cls}") + + # Stage 2: build a fresh parser, extend it with policy-specific arguments, then parse fully. + full_parser = build_base_parser() + if not hasattr(policy_cls, "add_args_to_parser"): + raise TypeError( + f"Server-side policy class {policy_cls} must define a static 'add_args_to_parser(parser)' method." + ) + full_parser = policy_cls.add_args_to_parser(full_parser) # type: ignore[assignment] + + args = full_parser.parse_args() + return args def main() -> None: + """Entry point for running a remote policy server. + + The script: + 1) Parses CLI arguments in two stages. + 2) Instantiates the requested ServerSidePolicy via its from_args() helper. + 3) Wraps it in a PolicyServer and starts the RPC loop. + """ args = parse_args() - policy_cls = resolve_policy_class(args.policy_type) - policy = policy_cls(policy_config_yaml_path=Path(args.policy_config_yaml_path)) + policy_cls = get_policy_cls(args.policy_type) + if not hasattr(policy_cls, "from_args"): + raise TypeError(f"Server-side policy class {policy_cls} must define a static 'from_args(args)' method.") + # Construct the server-side policy from CLI arguments. + policy = policy_cls.from_args(args) # type: ignore[call-arg] + + # Start the RPC server. server = PolicyServer( policy=policy, host=args.host, @@ -59,4 +114,3 @@ def main() -> None: if __name__ == "__main__": main() - diff --git a/isaaclab_arena/remote_policy/server_side_policy.py b/isaaclab_arena/remote_policy/server_side_policy.py index a1d0e966..8b96f3bf 100644 --- a/isaaclab_arena/remote_policy/server_side_policy.py +++ b/isaaclab_arena/remote_policy/server_side_policy.py @@ -1,51 +1,207 @@ -# Copyright (c) 2025, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# Copyright (c) 2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +# Copyright (c) 2025-2026, +# The Isaac Lab Arena Project Developers +# (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). # All rights reserved. # # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations +import argparse from abc import ABC, abstractmethod -from typing import Any, Dict +from typing import Any + +from isaaclab_arena.remote_policy.action_protocol import ActionMode, ActionProtocol class ServerSidePolicy(ABC): - """Server-side policy interface. + """Base class for server-side remote policies. + + This class defines: + * The protocol- and handshake-related API that the PolicyServer relies on. + * A minimal configuration hook via ``config_class`` and ``from_dict``. + * A CLI construction pattern via ``add_args_to_parser`` and ``from_args``, + mirroring the design of :class:`isaaclab_arena.policy.policy_base.PolicyBase` + on the client side. - This interface is intentionally independent of IsaacLab-Arena. - The server only sees JSON-serializable observations and returns - JSON-serializable actions. + Concrete server-side policies (e.g. GR00T-based ones) should: + * Implement ``_build_protocol()`` and the core RPC methods. + * Optionally define a dataclass as ``config_class``. + * Implement ``add_args_to_parser(parser)`` and ``from_args(args)`` + so they can be instantiated directly from command-line arguments. """ - @abstractmethod - def get_action( - self, observation: dict[str, Any], options: dict[str, Any] | None = None - ) -> tuple[dict[str, Any], dict[str, Any]]: - """Compute and return the next action based on current observation with validation. + # Optional: subclasses can define this to enable from_dict() + config_class: type | None = None - This is the main public interface. It validates the observation, calls - the internal _get_action(), and validates the resulting action. + def __init__(self, config: Any | None = None) -> None: + """Base constructor for server-side policies. Args: - observation: Dictionary containing the current state/observation - options: Optional configuration dict for action computation + config: Optional configuration object (for example, a dataclass + instance). Subclasses are free to interpret this as needed. + """ + self.config = config + self._protocol: ActionProtocol | None = None + self._task_description: str | None = None - Returns: - Tuple of (action, info): - - action: Dictionary containing the validated action - - info: Dictionary containing additional metadata + # ------------------------------------------------------------------ + # Config helpers (mirroring PolicyBase.from_dict) + # ------------------------------------------------------------------ + + @classmethod + def from_dict(cls, config_dict: dict[str, Any]) -> ServerSidePolicy: + """Create a policy instance from a configuration dictionary. + + Path: dict -> ConfigDataclass -> Policy instance - Raises: - AssertionError/ValueError: If observation or action validation fails + This mirrors :meth:`PolicyBase.from_dict` on the client side. """ + if cls.config_class is None: + raise NotImplementedError(f"{cls.__name__} must define 'config_class' to use from_dict().") + + config = cls.config_class(**config_dict) # type: ignore[misc] + return cls(config) # type: ignore[call-arg] + + # ------------------------------------------------------------------ + # Protocol / handshake API + # ------------------------------------------------------------------ + @abstractmethod - def reset(self, options: dict[str, Any] | None = None) -> dict[str, Any]: - """Reset the policy to its initial state. + def _build_protocol(self) -> ActionProtocol: + """Subclasses must build and return an ActionProtocol instance.""" + raise NotImplementedError + + @property + def protocol(self) -> ActionProtocol: + """Return the ActionProtocol associated with this policy. + + The protocol is lazily constructed on first access via ``_build_protocol()``. + """ + if self._protocol is None: + self._protocol = self._build_protocol() + if self._protocol.mode is None: + raise ValueError(f"{self.__class__.__name__} has an ActionProtocol with mode=None, which is not allowed.") + return self._protocol + + def get_init_info(self, requested_action_mode: str) -> dict[str, Any]: + """Handle the initial handshake with the client. + + Checks that the requested action mode is valid and supported by + this policy's ActionProtocol, and returns either an error status + or the protocol configuration as a plain dictionary. + """ + proto = self.protocol + + try: + requested_mode_enum = ActionMode(requested_action_mode) + except ValueError: + return { + "status": "invalid_action_mode", + "message": f"Requested action_mode={requested_action_mode!r} is invalid.", + } + + if requested_mode_enum is not proto.mode: + return { + "status": "unsupported_action_mode", + "message": ( + f"Requested action_mode={requested_mode_enum.value!r} " + "is not supported by this policy. " + f"Supported: {proto.mode.value!r}." + ), + } + + return { + "status": "success", + "config": proto.to_dict(), + } + + # ------------------------------------------------------------------ + # Core RPC methods (to be used by PolicyServer) + # ------------------------------------------------------------------ + + @abstractmethod + def get_action( + self, + observation: dict[str, Any], + ) -> dict[str, Any]: + """Compute one or more actions given an observation payload. Args: - options: Dictionary containing the options for the reset + observation: Flat observation dictionary received from the client. Returns: - Dictionary containing the info after resetting the policy + A dictionary that must contain at least an ``"action"`` entry + whose structure is compatible with the negotiated ActionProtocol. + """ + raise NotImplementedError + + def reset(self) -> None: + """Reset the policy state. + + Subclasses may override this if they maintain per-environment or + global state that needs to be cleared between episodes. + """ + ... + + def set_task_description( + self, + task_description: str | None, + ) -> dict[str, Any]: + """Set the task description and return a small status/config payload. + + The default implementation stores the description locally and + echoes it back. Subclasses can override this to perform additional + updates or validation. + """ + self._task_description = task_description + return {"task_description": self._task_description or ""} + + # ------------------------------------------------------------------ + # Shared helpers + # ------------------------------------------------------------------ + + def unpack_observation(self, flat_obs: dict[str, Any]) -> dict[str, Any]: + """Convert a flat dotted-key observation dict into a nested dict. + + For example, a key ``"camera_obs.pov.rgb"`` becomes + ``nested["camera_obs"]["pov"]["rgb"]``. + """ + nested: dict[str, Any] = {} + for key_path, value in flat_obs.items(): + cur = nested + parts = key_path.split(".") + for k in parts[:-1]: + cur = cur.setdefault(k, {}) + cur[parts[-1]] = value + return nested + + # ------------------------------------------------------------------ + # CLI helpers (to mirror PolicyBase.add_args_to_parser / from_args) + # ------------------------------------------------------------------ + + @staticmethod + @abstractmethod + def add_args_to_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Add policy-specific CLI arguments to the parser. + + Server-side policies are expected to implement this so that + :mod:`remote_policy_server_runner` can delegate CLI argument + definitions to the selected policy class. + """ + raise NotImplementedError("ServerSidePolicy subclasses must implement add_args_to_parser().") + + @staticmethod + @abstractmethod + def from_args(args: argparse.Namespace) -> ServerSidePolicy: + """Construct a server-side policy instance from CLI arguments. + + This mirrors the ``from_args(args)`` pattern used by client-side + policies deriving from :class:`PolicyBase`. """ - pass + raise NotImplementedError("ServerSidePolicy subclasses must implement from_args(args).") diff --git a/isaaclab_arena_gr00t/gr00t_remote_policy.py b/isaaclab_arena_gr00t/gr00t_remote_policy.py deleted file mode 100644 index f871163c..00000000 --- a/isaaclab_arena_gr00t/gr00t_remote_policy.py +++ /dev/null @@ -1,95 +0,0 @@ -# Copyright (c) 2025, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). -# All rights reserved. -# -# SPDX-License-Identifier: Apache-2.0 - -from __future__ import annotations - -from pathlib import Path -from typing import Any, Dict, Tuple - -from gr00t.experiment.data_config import DATA_CONFIG_MAP, load_data_config -from gr00t.model.policy import Gr00tPolicy - -from isaaclab_arena.remote_policy.server_side_policy import ServerSidePolicy -from isaaclab_arena_gr00t.policy_config import Gr00tClosedloopPolicyConfig -from isaaclab_arena_gr00t.data_utils.io_utils import create_config_from_yaml - - -class Gr00tRemoteServerSidePolicy(ServerSidePolicy): - """Server-side wrapper around Gr00tPolicy.""" - - def __init__(self, policy_config_yaml_path: Path) -> None: - print(f"[Gr00tRemoteServerSidePolicy] loading config from: {policy_config_yaml_path}") - self._cfg = create_config_from_yaml(policy_config_yaml_path, Gr00tClosedloopPolicyConfig) - print( - "[Gr00tRemoteServerSidePolicy] config:\n" - f" model_path = {self._cfg.model_path}\n" - f" embodiment_tag = {self._cfg.embodiment_tag}\n" - f" task_mode_name = {self._cfg.task_mode_name}\n" - f" data_config = {self._cfg.data_config}\n" - f" action_horizon = {self._cfg.action_horizon}\n" - f" action_chunk_len = {self._cfg.action_chunk_length}\n" - f" pov_cam_name_sim = {self._cfg.pov_cam_name_sim}\n" - f" policy_device = {self._cfg.policy_device}" - ) - self._policy = self._load_gr00t_policy() - print("[Gr00tRemoteServerSidePolicy] Gr00tPolicy loaded successfully") - - def _load_gr00t_policy(self) -> Gr00tPolicy: - print(f"[Gr00tRemoteServerSidePolicy] loading data_config={self._cfg.data_config}") - if self._cfg.data_config in DATA_CONFIG_MAP: - data_config = DATA_CONFIG_MAP[self._cfg.data_config] - elif self._cfg.data_config == "unitree_g1_sim_wbc": - data_config = load_data_config("isaaclab_arena_gr00t.data_config:UnitreeG1SimWBCDataConfig") - else: - raise ValueError(f"Invalid data config: {self._cfg.data_config}") - - modality_config = data_config.modality_config() - modality_transform = data_config.transform() - - model_path = Path(self._cfg.model_path) - if not model_path.exists(): - raise FileNotFoundError(f"Model path does not exist: {model_path}") - print(f"[Gr00tRemoteServerSidePolicy] loading checkpoint from: {model_path}") - - policy = Gr00tPolicy( - model_path=str(model_path), - modality_config=modality_config, - modality_transform=modality_transform, - embodiment_tag=self._cfg.embodiment_tag, - denoising_steps=self._cfg.denoising_steps, - device=self._cfg.policy_device, - ) - return policy - - # ------------------------------------------------------------------ # - # ServerSidePolicy interface - # ------------------------------------------------------------------ # - - def get_action( - self, - observation: Dict[str, Any], - options: Dict[str, Any] | None = None, - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - print("[Gr00tRemoteServerSidePolicy] get_action called") - print(f" observation keys: {list(observation.keys())}") - if options is not None: - print(f" options keys: {list(options.keys())}") - - result = self._policy.get_action(observation) - # Gr00tPolicy.get_action usually returns a dict; wrap it with empty info. - if isinstance(result, tuple) and len(result) == 2: - action, info = result - else: - action, info = result, {} - - print("[Gr00tRemoteServerSidePolicy] get_action done") - return action, info - - def reset(self, options: Dict[str, Any] | None = None) -> Dict[str, Any]: - print(f"[Gr00tRemoteServerSidePolicy] reset called: options={options}") - if hasattr(self._policy, "reset"): - self._policy.reset(options=options) - return {} - diff --git a/isaaclab_arena_gr00t/policy/config/gr00t_closedloop_policy_config.py b/isaaclab_arena_gr00t/policy/config/gr00t_closedloop_policy_config.py index 0451cba6..41f4da5b 100644 --- a/isaaclab_arena_gr00t/policy/config/gr00t_closedloop_policy_config.py +++ b/isaaclab_arena_gr00t/policy/config/gr00t_closedloop_policy_config.py @@ -5,8 +5,10 @@ from dataclasses import dataclass, field from pathlib import Path + from isaaclab_arena_gr00t.policy.config.task_mode import TaskMode + @dataclass class Gr00tClosedloopPolicyConfig: @@ -96,6 +98,7 @@ def __post_init__(self): assert Path( self.state_joints_config_path ).exists(), f"state_joints_config_path does not exist: {self.state_joints_config_path}" + assert Path(self.model_path).exists(), f"model_path does not exist: {self.model_path}" # embodiment_tag assert self.embodiment_tag in [ "gr1", diff --git a/isaaclab_arena_gr00t/policy/gr00t_closedloop_policy.py b/isaaclab_arena_gr00t/policy/gr00t_closedloop_policy.py index 92741a04..bef792ea 100644 --- a/isaaclab_arena_gr00t/policy/gr00t_closedloop_policy.py +++ b/isaaclab_arena_gr00t/policy/gr00t_closedloop_policy.py @@ -3,56 +3,34 @@ # # SPDX-License-Identifier: Apache-2.0 -from __future__ import annotations import argparse import gymnasium as gym +import numpy as np import torch from dataclasses import dataclass, field -from pathlib import Path -from typing import Any, Dict - -from isaaclab_arena.policy.policy_base import PolicyBase, PolicyDeployment -from isaaclab_arena_g1.g1_whole_body_controller.wbc_policy.policy.policy_constants import ( - NUM_BASE_HEIGHT_CMD, - NUM_NAVIGATE_CMD, - NUM_TORSO_ORIENTATION_RPY_CMD, -) +from typing import Any + +from gr00t.model.policy import Gr00tPolicy + +from isaaclab_arena.policy.policy_base import PolicyBase from isaaclab_arena_gr00t.policy.config.gr00t_closedloop_policy_config import Gr00tClosedloopPolicyConfig, TaskMode -from isaaclab_arena_gr00t.utils.image_conversion import resize_frames_with_padding -from isaaclab_arena_gr00t.utils.io_utils import create_config_from_yaml, load_robot_joints_config_from_yaml -from isaaclab_arena_gr00t.utils.joints_conversion import ( - remap_policy_joints_to_sim_joints, - remap_sim_joints_to_policy_joints, +from isaaclab_arena_gr00t.policy.gr00t_core import ( + Gr00tBasePolicyArgs, + build_gr00t_action_tensor, + build_gr00t_policy_inputs_np, + compute_action_dim, + load_gr00t_joint_configs, + load_gr00t_policy_from_config, ) -from isaaclab_arena_gr00t.utils.robot_joints import JointsAbsPosition +from isaaclab_arena_gr00t.utils.io_utils import create_config_from_yaml @dataclass -class Gr00tClosedloopPolicyArgs: +class Gr00tClosedloopPolicyArgs(Gr00tBasePolicyArgs): """ Configuration dataclass for Gr00tClosedloopPolicy. - - This dataclass serves as the single source of truth for policy configuration, - supporting both dict-based (from JSON) and CLI-based configuration paths. - - Field metadata is used to auto-generate argparse arguments, ensuring consistency - between the dataclass definition and CLI argument parsing. """ - policy_config_yaml_path: str = field( - metadata={ - "help": "Path to the Gr00t closedloop policy config YAML file", - "required": True, - } - ) - - policy_device: str = field( - default="cuda", - metadata={ - "help": "Device to use for the policy-related operations", - }, - ) - num_envs: int = field( default=1, metadata={ @@ -60,20 +38,8 @@ class Gr00tClosedloopPolicyArgs: }, ) - # from_dict() is not needed - can use Gr00tClosedloopPolicyArgs(**dict) directly - # or use Gr00tClosedloopPolicy.from_dict() which is inherited from PolicyBase - @classmethod def from_cli_args(cls, args: argparse.Namespace) -> "Gr00tClosedloopPolicyArgs": - """ - Create configuration from parsed CLI arguments. - - Args: - args: Parsed command line arguments - - Returns: - Gr00tClosedloopPolicyArgs instance - """ return cls( policy_config_yaml_path=args.policy_config_yaml_path, policy_device=args.policy_device, @@ -96,48 +62,35 @@ def __init__(self, config: Gr00tClosedloopPolicyArgs): """ super().__init__(config) self.policy_config = create_config_from_yaml(config.policy_config_yaml_path, Gr00tClosedloopPolicyConfig) - self.policy = self.load_policy() - # determine rollout how many action prediction per observation - self.action_chunk_length = self.policy_config.action_chunk_length self.num_envs = config.num_envs self.device = config.policy_device self.task_mode = TaskMode(self.policy_config.task_mode_name) - self.policy = None - - if self.is_remote: - if not self.remote_client.ping(): - cfg = self.remote_config - raise RuntimeError( - f"Failed to connect to remote policy server at " - f"{cfg.host}:{cfg.port}." - ) - else: - self.policy = self.load_local_policy() - - self.policy_joints_config = self.load_policy_joints_config(self.policy_config.policy_joints_config_path) - self.robot_action_joints_config = self.load_sim_action_joints_config( - self.policy_config.action_joints_config_path - ) - self.robot_state_joints_config = self.load_sim_state_joints_config(self.policy_config.state_joints_config_path) + # Joint configurations + ( + self.policy_joints_config, + self.robot_action_joints_config, + self.robot_state_joints_config, + ) = load_gr00t_joint_configs(self.policy_config) - self.action_dim = len(self.robot_action_joints_config) - if self.task_mode == TaskMode.G1_LOCOMANIPULATION: - self.action_dim += NUM_NAVIGATE_CMD + NUM_BASE_HEIGHT_CMD + NUM_TORSO_ORIENTATION_RPY_CMD + self.action_dim = compute_action_dim(self.task_mode, self.robot_action_joints_config) + self.action_chunk_length = self.policy_config.action_chunk_length + self.policy: Gr00tPolicy = load_gr00t_policy_from_config(self.policy_config) + + # Chunking state (local-only logic) self.current_action_chunk = torch.zeros( - (config.num_envs, self.policy_config.action_horizon, self.action_dim), - dtype=torch.float, - device=config.policy_device, + (self.num_envs, self.policy_config.action_horizon, self.action_dim), + dtype=torch.float32, + device=self.device, ) - # Use a bool list toindicate that the action chunk is not yet computed for each env - # True means the action chunk is not yet computed, False means the action chunk is valid - self.env_requires_new_action_chunk = torch.ones(config.num_envs, dtype=torch.bool, device=config.policy_device) - - self.current_action_index = torch.zeros(config.num_envs, dtype=torch.int32, device=config.policy_device) + # True means the action chunk is not yet computed, False means the action chunk is valid. + self.env_requires_new_action_chunk = torch.ones(self.num_envs, dtype=torch.bool, device=self.device) + # Per-env index into the current action chunk + self.current_action_index = torch.zeros(self.num_envs, dtype=torch.int32, device=self.device) - # task description of task being evaluated. It will be set by the task being evaluated. + # Task description of the task being evaluated. It will be set externally. self.task_description: str | None = None @staticmethod @@ -176,52 +129,6 @@ def add_args_to_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentPars ) return parser - def load_policy_joints_config(self, policy_config_path: Path) -> dict[str, Any]: - """Load the GR00T policy joint config from the data config.""" - return load_robot_joints_config_from_yaml(policy_config_path) - - def load_sim_state_joints_config(self, state_config_path: Path) -> dict[str, Any]: - """Load the simulation state joint config from the data config.""" - return load_robot_joints_config_from_yaml(state_config_path) - - def load_sim_action_joints_config(self, action_config_path: Path) -> dict[str, Any]: - """Load the simulation action joint config from the data config.""" - return load_robot_joints_config_from_yaml(action_config_path) - - def load_local_policy(self): - try: - from gr00t.experiment.data_config import DATA_CONFIG_MAP, load_data_config - from gr00t.model.policy import Gr00tPolicy - except ImportError as exc: - raise ImportError( - "GR00T policy dependencies are not installed. " - "Install gr00t packages or use policy_deployment=PolicyDeployment.REMOTE." - ) from exc - - assert Path(self.policy_config.model_path).exists(), ( - f"Model path {self.policy_config.model_path} does not exist" - ) - - if self.policy_config.data_config in DATA_CONFIG_MAP: - data_config = DATA_CONFIG_MAP[self.policy_config.data_config] - elif self.policy_config.data_config == "unitree_g1_sim_wbc": - self.data_config = load_data_config( - "isaaclab_arena_gr00t.embodiments.g1.g1_sim_wbc_data_config:UnitreeG1SimWBCDataConfig" - ) - else: - raise ValueError(f"Invalid data config: {self.policy_config.data_config}") - - modality_config = data_config.modality_config() - modality_transform = data_config.transform() - return Gr00tPolicy( - model_path=self.policy_config.model_path, - modality_config=modality_config, - modality_transform=modality_transform, - embodiment_tag=self.policy_config.embodiment_tag, - denoising_steps=self.policy_config.denoising_steps, - device=self.policy_config.policy_device, - ) - def set_task_description(self, task_description: str | None) -> str: """Set the language instruction of the task being evaluated.""" if task_description is None: @@ -230,41 +137,34 @@ def set_task_description(self, task_description: str | None) -> str: return self.task_description def get_observations(self, observation: dict[str, Any], camera_name: str = "robot_head_cam_rgb") -> dict[str, Any]: - rgb = observation["camera_obs"][camera_name] - # gr00t uses numpy arrays - rgb = rgb.cpu().numpy() - # Apply preprocessing to rgb if size is not the same as the target size - if rgb.shape[1:3] != self.policy_config.target_image_size[:2]: - rgb = resize_frames_with_padding( - rgb, target_image_size=self.policy_config.target_image_size, bgr_conversion=False, pad_img=True - ) - # GR00T uses np arrays, needs to copy torch tensor from gpu to cpu before conversion - joint_pos_sim = observation["policy"]["robot_joint_pos"].cpu() - joint_pos_state_sim = JointsAbsPosition(joint_pos_sim, self.robot_state_joints_config) - # Retrieve joint positions as proprioceptive states and remap to policy joint orders - joint_pos_state_policy = remap_sim_joints_to_policy_joints(joint_pos_state_sim, self.policy_joints_config) - - # Pack inputs to dictionary and run the inference + """Adapter: torch env observation -> numpy GR00T policy inputs. + + The core GR00T logic in gr00t_core.py works on numpy. This method: + - extracts torch tensors from the environment observation, + - moves them to CPU and converts to numpy, + - uses the shared numpy-based preprocessing, + - returns a numpy dict suitable for Gr00tPolicy. + """ assert self.task_description is not None, "Task description is not set" - policy_observations = { - # TODO(xinejiayao, 2025-12-10): when multi-task with parallel envs feature is enabled, we need to pass in a list of task descriptions. - "annotation.human.task_description": [self.task_description] * self.num_envs, - "video.ego_view": rgb.reshape( - self.num_envs, - 1, - self.policy_config.target_image_size[0], - self.policy_config.target_image_size[1], - self.policy_config.target_image_size[2], - ), - "state.left_arm": joint_pos_state_policy["left_arm"].reshape(self.num_envs, 1, -1), - "state.right_arm": joint_pos_state_policy["right_arm"].reshape(self.num_envs, 1, -1), - "state.left_hand": joint_pos_state_policy["left_hand"].reshape(self.num_envs, 1, -1), - "state.right_hand": joint_pos_state_policy["right_hand"].reshape(self.num_envs, 1, -1), - } - # NOTE(xinjieyao, 2025-10-07): waist is not used in GR1 tabletop manipulation - if self.task_mode == TaskMode.G1_LOCOMANIPULATION: - policy_observations["state.waist"] = joint_pos_state_policy["waist"].reshape(self.num_envs, 1, -1) - return policy_observations + + # Extract torch tensors from observation + rgb_t: torch.Tensor = observation["camera_obs"][camera_name] + joint_pos_sim_t: torch.Tensor = observation["policy"]["robot_joint_pos"] + + # Convert to numpy for core logic + rgb_np: np.ndarray = rgb_t.detach().cpu().numpy() + joint_pos_sim_np: np.ndarray = joint_pos_sim_t.detach().cpu().numpy() + + # Use shared numpy-based preprocessing + policy_obs_np = build_gr00t_policy_inputs_np( + rgb_np=rgb_np, + joint_pos_sim_np=joint_pos_sim_np, + task_description=self.task_description, + policy_config=self.policy_config, + robot_state_joints_config=self.robot_state_joints_config, + policy_joints_config=self.policy_joints_config, + ) + return policy_obs_np def get_action(self, env: gym.Env, observation: dict[str, Any]) -> torch.Tensor: """Get the the immediate next action from the current action chunk. @@ -318,42 +218,15 @@ def get_action_chunk(self, observation: dict[str, Any], camera_name: str = "robo Shape: (num_envs, action_chunk_length, self.action_dim) """ policy_observations = self.get_observations(observation, camera_name) - - if not self.is_remote: - if self._policy is None: - raise RuntimeError("Local GR00T policy is not initialized.") - robot_action_policy = self.policy.get_action(policy_observations) - else: - robot_action_policy = self.remote_client.get_action( - observation=policy_observations, - ) - - robot_action_sim = remap_policy_joints_to_sim_joints( - robot_action_policy, self.policy_joints_config, self.robot_action_joints_config, self.device + robot_action_policy = self.policy.get_action(policy_observations) + action_tensor = build_gr00t_action_tensor( + robot_action_policy=robot_action_policy, + task_mode=self.task_mode, + policy_joints_config=self.policy_joints_config, + robot_action_joints_config=self.robot_action_joints_config, + device=self.device, ) - if self.task_mode == TaskMode.G1_LOCOMANIPULATION: - # NOTE(xinjieyao, 2025-09-29): GR00T output dim=32, does not fit the entire action space, - # including torso_orientation_rpy_command. Manually set it to 0. - torso_orientation_rpy_command = torch.zeros( - robot_action_policy["action.navigate_command"].shape, dtype=torch.float, device=self.device - ) - action_tensor = torch.cat( - [ - robot_action_sim.get_joints_pos(), - torch.tensor(robot_action_policy["action.navigate_command"], dtype=torch.float, device=self.device), - torch.tensor( - robot_action_policy["action.base_height_command"], dtype=torch.float, device=self.device - ), - torso_orientation_rpy_command, - ], - axis=2, - ) - elif self.task_mode == TaskMode.GR1_TABLETOP_MANIPULATION: - action_tensor = robot_action_sim.get_joints_pos() - else: - raise ValueError(f"Unsupported task mode: {self.task_mode}") - assert action_tensor.shape[0] == self.num_envs and action_tensor.shape[1] >= self.action_chunk_length return action_tensor diff --git a/isaaclab_arena_gr00t/policy/gr00t_core.py b/isaaclab_arena_gr00t/policy/gr00t_core.py new file mode 100644 index 00000000..4417db4a --- /dev/null +++ b/isaaclab_arena_gr00t/policy/gr00t_core.py @@ -0,0 +1,195 @@ +# Copyright (c) 2025-2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import numpy as np +import torch +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from gr00t.experiment.data_config import DATA_CONFIG_MAP, load_data_config +from gr00t.model.policy import Gr00tPolicy + +from isaaclab_arena_g1.g1_whole_body_controller.wbc_policy.policy.policy_constants import ( + NUM_BASE_HEIGHT_CMD, + NUM_NAVIGATE_CMD, + NUM_TORSO_ORIENTATION_RPY_CMD, +) +from isaaclab_arena_gr00t.policy.config.gr00t_closedloop_policy_config import Gr00tClosedloopPolicyConfig, TaskMode +from isaaclab_arena_gr00t.utils.image_conversion import resize_frames_with_padding +from isaaclab_arena_gr00t.utils.io_utils import load_robot_joints_config_from_yaml +from isaaclab_arena_gr00t.utils.joints_conversion import ( + remap_policy_joints_to_sim_joints, + remap_sim_joints_to_policy_joints, +) +from isaaclab_arena_gr00t.utils.robot_joints import JointsAbsPosition + + +@dataclass +class Gr00tBasePolicyArgs: + """Base configuration for GR00T policies (shared by local and remote).""" + + policy_config_yaml_path: str = field( + metadata={ + "help": "Path to the Gr00t closedloop policy config YAML file", + "required": True, + } + ) + + policy_device: str = field( + default="cuda", + metadata={ + "help": "Device to use for the policy-related operations.", + }, + ) + + +# --------------------------------------------------------------------------- # +# Config / model helpers (backend-agnostic) +# --------------------------------------------------------------------------- # + + +def load_gr00t_policy_from_config(policy_config: Gr00tClosedloopPolicyConfig) -> Gr00tPolicy: + """Load a Gr00tPolicy from the closed-loop config.""" + if policy_config.data_config in DATA_CONFIG_MAP: + data_config = DATA_CONFIG_MAP[policy_config.data_config] + elif policy_config.data_config == "unitree_g1_sim_wbc": + data_config = load_data_config( + "isaaclab_arena_gr00t.embodiments.g1.g1_sim_wbc_data_config:UnitreeG1SimWBCDataConfig" + ) + else: + raise ValueError(f"Invalid data config: {policy_config.data_config}") + + modality_config = data_config.modality_config() + modality_transform = data_config.transform() + + model_path = Path(policy_config.model_path) + if not model_path.exists(): + raise FileNotFoundError(f"Model path does not exist: {model_path}") + + return Gr00tPolicy( + model_path=str(model_path), + modality_config=modality_config, + modality_transform=modality_transform, + embodiment_tag=policy_config.embodiment_tag, + denoising_steps=policy_config.denoising_steps, + device=policy_config.policy_device, + ) + + +def load_gr00t_joint_configs( + policy_config: Gr00tClosedloopPolicyConfig, +) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]: + """Load policy / action / state joint configs.""" + policy_joints_config = load_robot_joints_config_from_yaml(policy_config.policy_joints_config_path) + robot_action_joints_config = load_robot_joints_config_from_yaml(policy_config.action_joints_config_path) + robot_state_joints_config = load_robot_joints_config_from_yaml(policy_config.state_joints_config_path) + return policy_joints_config, robot_action_joints_config, robot_state_joints_config + + +def compute_action_dim(task_mode: TaskMode, robot_action_joints_config: dict[str, Any]) -> int: + """Compute action dimension given task_mode and action joints configuration.""" + action_dim = len(robot_action_joints_config) + if task_mode == TaskMode.G1_LOCOMANIPULATION: + action_dim += NUM_NAVIGATE_CMD + NUM_BASE_HEIGHT_CMD + NUM_TORSO_ORIENTATION_RPY_CMD + return action_dim + + +# --------------------------------------------------------------------------- # +# Core SSOT logic (numpy-based) +# --------------------------------------------------------------------------- # + + +def build_gr00t_policy_inputs_np( + rgb_np: np.ndarray, # (N, H, W, C) + joint_pos_sim_np: np.ndarray, # (N, num_joints) + task_description: str, + policy_config: Gr00tClosedloopPolicyConfig, + robot_state_joints_config: dict[str, Any], + policy_joints_config: dict[str, Any], +) -> dict[str, Any]: + """Convert numpy observations to numpy GR00T policy inputs.""" + num_envs = rgb_np.shape[0] + + # Resize RGB frames if needed + if rgb_np.shape[1:3] != tuple(policy_config.target_image_size[:2]): + rgb_np = resize_frames_with_padding( + rgb_np, + target_image_size=policy_config.target_image_size, + bgr_conversion=False, + pad_img=True, + ) + + # Use existing JointsAbsPosition / remap helpers by temporarily going through torch + joint_pos_state_sim = JointsAbsPosition(joint_pos_sim_np, robot_state_joints_config) + joint_pos_state_policy = remap_sim_joints_to_policy_joints(joint_pos_state_sim, policy_joints_config) + + left_arm = joint_pos_state_policy["left_arm"].reshape(num_envs, 1, -1) + right_arm = joint_pos_state_policy["right_arm"].reshape(num_envs, 1, -1) + left_hand = joint_pos_state_policy["left_hand"].reshape(num_envs, 1, -1) + right_hand = joint_pos_state_policy["right_hand"].reshape(num_envs, 1, -1) + + policy_inputs: dict[str, Any] = { + # TODO(xinejiayao, 2025-12-10): when multi-task with parallel envs feature is enabled, we need to pass in a list of task descriptions. + "annotation.human.task_description": [task_description] * num_envs, + "video.ego_view": rgb_np.reshape( + num_envs, + 1, + policy_config.target_image_size[0], + policy_config.target_image_size[1], + policy_config.target_image_size[2], + ), + "state.left_arm": left_arm, + "state.right_arm": right_arm, + "state.left_hand": left_hand, + "state.right_hand": right_hand, + } + # NOTE(xinjieyao, 2025-10-07): waist is not used in GR1 tabletop manipulation + if TaskMode(policy_config.task_mode_name) == TaskMode.G1_LOCOMANIPULATION: + waist = joint_pos_state_policy["waist"].reshape(num_envs, 1, -1) + policy_inputs["state.waist"] = waist + + return policy_inputs + + +def build_gr00t_action_tensor( + robot_action_policy: dict[str, Any], + task_mode: TaskMode, + policy_joints_config: dict[str, Any], + robot_action_joints_config: dict[str, Any], + device: str | torch.device, +) -> np.ndarray: + """Convert numpy GR00T outputs to numpy action tensor (N, horizon, action_dim).""" + + robot_action_sim = remap_policy_joints_to_sim_joints( + robot_action_policy, + policy_joints_config, + robot_action_joints_config, + device, + ) + + if task_mode == TaskMode.G1_LOCOMANIPULATION: + # NOTE(xinjieyao, 2025-09-29): GR00T output dim=32, does not fit the entire action space, + # including torso_orientation_rpy_command. Manually set it to 0. + torso_orientation_rpy_command = torch.zeros( + robot_action_policy["action.navigate_command"].shape, dtype=torch.float, device=device + ) + action_tensor = torch.cat( + [ + robot_action_sim.get_joints_pos(), + torch.tensor(robot_action_policy["action.navigate_command"], dtype=torch.float, device=device), + torch.tensor(robot_action_policy["action.base_height_command"], dtype=torch.float, device=device), + torso_orientation_rpy_command, + ], + axis=2, + ) + elif task_mode == TaskMode.GR1_TABLETOP_MANIPULATION: + action_tensor = robot_action_sim.get_joints_pos() + else: + raise ValueError(f"Unsupported task mode: {task_mode}") + + return action_tensor diff --git a/isaaclab_arena_gr00t/policy/gr00t_remote_policy.py b/isaaclab_arena_gr00t/policy/gr00t_remote_policy.py new file mode 100644 index 00000000..1c99724d --- /dev/null +++ b/isaaclab_arena_gr00t/policy/gr00t_remote_policy.py @@ -0,0 +1,202 @@ +# Copyright (c) 2026, The Isaac Lab Arena Project Developers (https://github.com/isaac-sim/IsaacLab-Arena/blob/main/CONTRIBUTORS.md). +# All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +import numpy as np +from dataclasses import dataclass +from typing import Any + +from gr00t.model.policy import Gr00tPolicy + +from isaaclab_arena.remote_policy.action_protocol import ChunkingActionProtocol +from isaaclab_arena.remote_policy.server_side_policy import ServerSidePolicy +from isaaclab_arena_gr00t.policy.config.gr00t_closedloop_policy_config import Gr00tClosedloopPolicyConfig, TaskMode +from isaaclab_arena_gr00t.policy.gr00t_core import ( + Gr00tBasePolicyArgs, + build_gr00t_action_tensor, + build_gr00t_policy_inputs_np, + compute_action_dim, + load_gr00t_joint_configs, + load_gr00t_policy_from_config, +) +from isaaclab_arena_gr00t.utils.io_utils import create_config_from_yaml + + +@dataclass +class Gr00tRemotePolicyArgs(Gr00tBasePolicyArgs): + """Configuration for Gr00tRemoteServerSidePolicy. + + Reuses policy_config_yaml_path and policy_device from the base. + """ + + @classmethod + def from_cli_args(cls, args: argparse.Namespace) -> Gr00tRemotePolicyArgs: + return cls( + policy_config_yaml_path=args.policy_config_yaml_path, + policy_device=args.policy_device, + ) + + +class Gr00tRemoteServerSidePolicy(ServerSidePolicy): + """Server-side wrapper around Gr00tPolicy.""" + + config_class = Gr00tRemotePolicyArgs + + def __init__(self, config: Gr00tRemotePolicyArgs) -> None: + super().__init__(config) + + print(f"[Gr00tRemoteServerSidePolicy] loading config from: {config.policy_config_yaml_path}") + self.policy_config = create_config_from_yaml(config.policy_config_yaml_path, Gr00tClosedloopPolicyConfig) + print( + "[Gr00tRemoteServerSidePolicy] config:\n" + f" model_path = {self.policy_config.model_path}\n" + f" embodiment_tag = {self.policy_config.embodiment_tag}\n" + f" task_mode_name = {self.policy_config.task_mode_name}\n" + f" data_config = {self.policy_config.data_config}\n" + f" action_horizon = {self.policy_config.action_horizon}\n" + f" action_chunk_len = {self.policy_config.action_chunk_length}\n" + f" pov_cam_name_sim = {self.policy_config.pov_cam_name_sim}\n" + f" policy_device = {self.policy_config.policy_device}\n" + ) + + self.device = config.policy_device + self.task_mode = TaskMode(self.policy_config.task_mode_name) + + # Joint configurations + ( + self.policy_joints_config, + self.robot_action_joints_config, + self.robot_state_joints_config, + ) = load_gr00t_joint_configs(self.policy_config) + + self.action_dim = compute_action_dim(self.task_mode, self.robot_action_joints_config) + self.action_chunk_length = self.policy_config.action_chunk_length + + self.required_observation_keys: list[str] = [ + f"camera_obs.{self.policy_config.pov_cam_name_sim}", + "policy.robot_joint_pos", + ] + + # Underlying GR00T policy + self.policy: Gr00tPolicy = load_gr00t_policy_from_config(self.policy_config) + print("[Gr00tRemoteServerSidePolicy] Gr00tPolicy loaded successfully") + + # Task description will be set via set_task_description RPC + self._task_description: str | None = None + + # ---------------------- CLI helpers (server-side) ------------------- + + @staticmethod + def add_args_to_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Add server-side GR00T remote policy arguments.""" + group = parser.add_argument_group( + "Gr00t Remote Server Policy", + "Arguments for GR00T-based server-side remote policy.", + ) + group.add_argument( + "--policy_config_yaml_path", + type=str, + required=True, + help="Path to the GR00T closedloop policy config YAML file.", + ) + group.add_argument( + "--policy_device", + type=str, + default="cuda", + help="Device to use for server-side GR00T inference (default: cuda).", + ) + return parser + + @staticmethod + def from_args(args: argparse.Namespace) -> Gr00tRemoteServerSidePolicy: + """Create a Gr00tRemoteServerSidePolicy from CLI arguments.""" + config = Gr00tRemotePolicyArgs.from_cli_args(args) + return Gr00tRemoteServerSidePolicy(config) + + # ------------ protocol ------------ + + def _build_protocol(self) -> ChunkingActionProtocol: + proto = ChunkingActionProtocol( + action_dim=self.action_dim, + observation_keys=self.required_observation_keys, + action_chunk_length=self.action_chunk_length, + ) + print(f"[Gr00tRemoteServerSidePolicy] protocol mode = {proto.mode.value}") + return proto + + # ------------------------------------------------------------------ # + # Helper methods + # ------------------------------------------------------------------ # + + def _build_policy_observations( + self, + observation: dict[str, Any], + camera_name: str, + ) -> dict[str, Any]: + """Convert packed numpy observation into numpy GR00T policy inputs. + + The client sends a flat dict of numpy arrays. + ServerSidePolicy.unpack_observation reconstructs the nested structure: + - observation["camera_obs"][camera_name] : (N, H, W, C) numpy + - observation["policy"]["robot_joint_pos"]: (N, num_joints) numpy + """ + nested_obs = self.unpack_observation(observation) + rgb_np: np.ndarray = nested_obs["camera_obs"][camera_name] + joint_pos_sim_np: np.ndarray = nested_obs["policy"]["robot_joint_pos"] + + assert self._task_description is not None, "Task description is not set" + + policy_obs_np = build_gr00t_policy_inputs_np( + rgb_np=rgb_np, + joint_pos_sim_np=joint_pos_sim_np, + task_description=self._task_description, + policy_config=self.policy_config, + robot_state_joints_config=self.robot_state_joints_config, + policy_joints_config=self.policy_joints_config, + ) + return policy_obs_np + + # ------------------------------------------------------------------ # + # ServerSidePolicy interface + # ------------------------------------------------------------------ # + + def set_task_description(self, task_description: str | None) -> dict[str, Any]: + if task_description is None: + task_description = self.policy_config.language_instruction + self._task_description = task_description + return {"status": "ok"} + + def get_action( + self, observation: dict[str, Any], options: dict[str, Any] | None = None + ) -> tuple[dict[str, Any], dict[str, Any]]: + camera_name = self.policy_config.pov_cam_name_sim + + # 1) Shared numpy-based preprocessing + policy_observations = self._build_policy_observations(observation, camera_name) + + # 2) GR00T forward pass + robot_action_policy = self.policy.get_action(policy_observations) + + # 3) postprocessing + action_tensor = build_gr00t_action_tensor( + robot_action_policy=robot_action_policy, + task_mode=self.task_mode, + policy_joints_config=self.policy_joints_config, + robot_action_joints_config=self.robot_action_joints_config, + device=self.device, + ) + + assert action_tensor.shape[1] >= self.action_chunk_length + + action_chunk = action_tensor[:, : self.action_chunk_length, :].cpu().numpy() + action: dict[str, Any] = {"action": action_chunk} + info: dict[str, Any] = {} + return action, info + + def reset(self, env_ids: list[int] | None = None, reset_options: dict[str, Any] | None = None) -> dict[str, Any]: + # GR00T policy is stateless for this closed-loop usage; nothing to reset + return {"status": "reset_success"}