Skip to content

Commit 21b9126

Browse files
author
eraykeskinmac
committed
feat: Add GR00T N1.6 (Isaac-GR00T N1D6) compatibility
Add support for Isaac-GR00T N1.6 observation and action formats while maintaining full backward compatibility with N1.5. - Add versioned data configs with auto-detection and fallback - Support N1.6 flat key format (video.X, state.X) with (B,T,H,W,C) video shape - Handle N1.6 action response with batch dimension stripping - Add N1.6 client protocol wrapping for Gr00tSimPolicyWrapper - Support colon syntax for explicit version selection (e.g. "libero:n1d6") - Add 29 new tests covering config, observation, action, and version handling - Zero regressions on existing 11 tests
1 parent 02f2389 commit 21b9126

4 files changed

Lines changed: 683 additions & 90 deletions

File tree

strands_robots_sim/policies/groot/__init__.py

Lines changed: 201 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#!/usr/bin/env python3
22
"""GR00T Policy — natural language robot control via GR00T inference servers.
33
4+
Supports both GR00T N1.5 and N1.6 observation/action formats.
5+
46
SPDX-License-Identifier: Apache-2.0
57
"""
68

@@ -12,33 +14,47 @@
1214

1315
from .. import Policy
1416
from .client import GR00TClient
15-
from .data_config import load_data_config
17+
from .data_config import LIBERO_STATE_TO_N1D6, load_data_config
1618

1719
logger = logging.getLogger(__name__)
1820

1921

2022
class Gr00tPolicy(Policy):
21-
"""GR00T policy: connects to a GR00T inference server via ZMQ."""
23+
"""GR00T policy: connects to a GR00T inference server via ZMQ.
24+
25+
Supports both N1.5 (prefixed keys: video.X, state, action.X)
26+
and N1.6 (direct keys: X, individual state components) formats.
27+
"""
2228

2329
def __init__(self, data_config: Union[str, dict], host: str = "localhost", port: int = 5555, **kwargs):
2430
"""Initialize GR00T policy.
2531
2632
Args:
27-
data_config: Config name (e.g. "libero") or dict with video/state/action/language keys
33+
data_config: Config name (e.g. "libero") or dict with video/state/action/language keys.
34+
Pass "libero:n1d6" or set groot_version="n1d6" in kwargs to force N1.6 format.
2835
host: Inference service host
2936
port: Inference service port
3037
"""
31-
self.config = load_data_config(data_config)
38+
groot_version = kwargs.pop("groot_version", "auto")
39+
40+
# Support "config_name:version" syntax (e.g. "libero:n1d6")
41+
if isinstance(data_config, str) and ":" in data_config:
42+
parts = data_config.split(":", 1)
43+
data_config = parts[0]
44+
groot_version = parts[1]
45+
46+
self.config = load_data_config(data_config, groot_version=groot_version)
3247
self.data_config_name = data_config if isinstance(data_config, str) else "custom"
33-
self.client = GR00TClient(host=host, port=port)
48+
self.groot_version = self.config.get("groot_version", "n1d5")
49+
self.client = GR00TClient(host=host, port=port, groot_version=self.groot_version)
3450

3551
self.camera_keys = self.config["video"]
3652
self.state_keys = self.config["state"]
3753
self.action_keys = self.config["action"]
3854
self.language_keys = self.config["language"]
3955
self.robot_state_keys = []
4056

41-
logger.info(f"🧠 GR00T Policy: {self.data_config_name} @ {host}:{port}")
57+
logger.info(f"🧠 GR00T Policy: {self.data_config_name} @ {host}:{port} (version: {self.groot_version})")
4258

4359
@property
4460
def provider_name(self) -> str:
@@ -50,12 +66,66 @@ def set_robot_state_keys(self, robot_state_keys: List[str]) -> None:
5066
async def get_actions(self, observation_dict: Dict[str, Any], instruction: str, **kwargs) -> List[Dict[str, Any]]:
5167
"""Get actions from GR00T policy server.
5268
53-
Args:
54-
observation_dict: Robot observations (cameras + state)
55-
instruction: Natural language instruction
69+
Automatically formats observations for N1.5 or N1.6 based on config.
70+
"""
71+
if self.groot_version == "n1d6":
72+
obs = self._build_n1d6_observation(observation_dict, instruction)
73+
else:
74+
obs = self._build_n1d5_observation(observation_dict, instruction)
75+
76+
try:
77+
action_chunk = self.client.get_action(obs)
78+
except Exception as e:
79+
logger.error(f"GR00T inference failed: {e}")
80+
action_chunk = self._create_fallback_actions()
5681

57-
Returns:
58-
List of action dicts for robot execution
82+
return self._to_robot_actions(action_chunk)
83+
84+
def _build_n1d6_observation(self, observation_dict: Dict[str, Any], instruction: str) -> dict:
85+
"""Build observation dict for GR00T N1.6 format.
86+
87+
When the server uses Gr00tSimPolicyWrapper (--use-sim-policy-wrapper),
88+
it expects flat keys with prefixes: video.image, state.x, etc.
89+
The wrapper then converts these to the nested format internally.
90+
91+
Flat format (for SimPolicyWrapper):
92+
{
93+
"video.image": array(B, T, H, W, C),
94+
"video.wrist_image": array(B, T, H, W, C),
95+
"state.x": array(B, T, 1),
96+
...
97+
"annotation.human.action.task_description": ("instruction",),
98+
}
99+
"""
100+
obs = {}
101+
102+
# Camera observations — flat keys with "video." prefix, shape (B, T, H, W, C)
103+
for vkey in self.camera_keys:
104+
cam = self._find_camera(vkey, observation_dict)
105+
flat_key = f"video.{vkey}"
106+
if cam and cam in observation_dict:
107+
image = self._resize_image(observation_dict[cam], target_size=(256, 256))
108+
obs[flat_key] = image.reshape(1, 1, *image.shape).astype(np.uint8)
109+
else:
110+
obs[flat_key] = np.zeros((1, 1, 256, 256, 3), dtype=np.uint8)
111+
112+
# State observations — flat keys with "state." prefix
113+
if "libero" in self.data_config_name.lower():
114+
self._map_libero_state_n1d6(obs, observation_dict)
115+
else:
116+
for skey in self.state_keys:
117+
obs[f"state.{skey}"] = np.array([[[0.0]]], dtype=np.float32)
118+
119+
# Language instruction — as tuple for batch
120+
if self.language_keys:
121+
obs[self.language_keys[0]] = (instruction,)
122+
123+
return obs
124+
125+
def _build_n1d5_observation(self, observation_dict: Dict[str, Any], instruction: str) -> dict:
126+
"""Build observation dict for GR00T N1.5 format (legacy).
127+
128+
N1.5 uses prefixed keys: video.X, state, action.X
59129
"""
60130
obs = {}
61131

@@ -82,34 +152,70 @@ async def get_actions(self, observation_dict: Dict[str, Any], instruction: str,
82152
robot_state_parts.extend(np.atleast_1d(value).flatten())
83153
else:
84154
robot_state_parts.append(float(value))
85-
robot_state = np.array(robot_state_parts, dtype=np.float64)
155+
robot_state = np.array(robot_state_parts, dtype=np.float32)
86156

87157
if "libero" in self.data_config_name.lower():
88-
self._map_libero_state(obs, observation_dict)
158+
self._map_libero_state_n1d5(obs, observation_dict)
89159
else:
90160
self._map_state(obs, robot_state)
91161

92162
# Language instruction
93163
if self.language_keys:
94164
obs[self.language_keys[0]] = instruction
95165

96-
# Batch dimension
166+
# Batch dimension for N1.5
97167
for k in obs:
98168
if isinstance(obs[k], np.ndarray) and k.startswith("video."):
99169
obs[k] = np.expand_dims(obs[k], axis=0)
100170
elif isinstance(obs[k], str):
101171
obs[k] = [obs[k]]
102172

103-
try:
104-
action_chunk = self.client.get_action(obs)
105-
except Exception as e:
106-
logger.error(f"GR00T inference failed: {e}")
107-
action_chunk = self._create_fallback_actions()
173+
return obs
108174

109-
return self._to_robot_actions(action_chunk)
175+
def _map_libero_state_n1d6(self, obs: dict, observation_dict: dict):
176+
"""Map Libero observation to N1.6 flat state keys (state.x, state.y, etc.).
177+
178+
State values have shape (B, T, dim) where B=1, T=1.
179+
Uses "state." prefix for SimPolicyWrapper compatibility.
180+
"""
181+
if "robot0_eef_pos" in observation_dict and "robot0_eef_quat" in observation_dict:
182+
xyz = observation_dict["robot0_eef_pos"]
183+
quat = observation_dict["robot0_eef_quat"]
184+
gripper = observation_dict.get("robot0_gripper_qpos", np.array([0.0, 0.0]))
185+
rpy = self._quat2axisangle(quat)
186+
obs["state.x"] = np.array([[[xyz[0]]]], dtype=np.float32)
187+
obs["state.y"] = np.array([[[xyz[1]]]], dtype=np.float32)
188+
obs["state.z"] = np.array([[[xyz[2]]]], dtype=np.float32)
189+
obs["state.roll"] = np.array([[[rpy[0]]]], dtype=np.float32)
190+
obs["state.pitch"] = np.array([[[rpy[1]]]], dtype=np.float32)
191+
obs["state.yaw"] = np.array([[[rpy[2]]]], dtype=np.float32)
192+
obs["state.gripper"] = np.asarray(gripper, dtype=np.float32).reshape(1, 1, -1)
193+
else:
194+
for key in ("x", "y", "z", "roll", "pitch", "yaw"):
195+
obs[f"state.{key}"] = np.array([[[0.0]]], dtype=np.float32)
196+
obs["state.gripper"] = np.array([[[0.0]]], dtype=np.float32)
197+
198+
def _map_libero_state_n1d5(self, obs: dict, observation_dict: dict):
199+
"""Map Libero end-effector pose to N1.5 state format (state.x, state.y, etc.)."""
200+
if "robot0_eef_pos" in observation_dict and "robot0_eef_quat" in observation_dict:
201+
xyz = observation_dict["robot0_eef_pos"]
202+
quat = observation_dict["robot0_eef_quat"]
203+
gripper = observation_dict.get("robot0_gripper_qpos", np.array([0.0, 0.0]))
204+
rpy = self._quat2axisangle(quat)
205+
obs["state.x"] = np.array([[xyz[0]]])
206+
obs["state.y"] = np.array([[xyz[1]]])
207+
obs["state.z"] = np.array([[xyz[2]]])
208+
obs["state.roll"] = np.array([[rpy[0]]])
209+
obs["state.pitch"] = np.array([[rpy[1]]])
210+
obs["state.yaw"] = np.array([[rpy[2]]])
211+
obs["state.gripper"] = np.expand_dims(gripper, axis=0)
212+
else:
213+
for key in ("x", "y", "z", "roll", "pitch", "yaw"):
214+
obs[f"state.{key}"] = np.array([[0.0]], dtype=np.float32)
215+
obs["state.gripper"] = np.array([[0.0]], dtype=np.float32)
110216

111217
def _find_camera(self, video_key: str, obs: dict) -> str:
112-
"""Map GR00T video key to available camera key."""
218+
"""Map GR00T video key to available camera key in observation."""
113219
if video_key in obs:
114220
return video_key
115221

@@ -189,27 +295,8 @@ def _resize_image(self, image: np.ndarray, target_size: tuple = (256, 256)) -> n
189295
except Exception:
190296
return image
191297

192-
def _map_libero_state(self, obs: dict, observation_dict: dict):
193-
"""Map Libero end-effector pose to GR00T state format."""
194-
if "robot0_eef_pos" in observation_dict and "robot0_eef_quat" in observation_dict:
195-
xyz = observation_dict["robot0_eef_pos"]
196-
quat = observation_dict["robot0_eef_quat"]
197-
gripper = observation_dict.get("robot0_gripper_qpos", np.array([0.0, 0.0]))
198-
rpy = self._quat2axisangle(quat)
199-
obs["state.x"] = np.array([[xyz[0]]])
200-
obs["state.y"] = np.array([[xyz[1]]])
201-
obs["state.z"] = np.array([[xyz[2]]])
202-
obs["state.roll"] = np.array([[rpy[0]]])
203-
obs["state.pitch"] = np.array([[rpy[1]]])
204-
obs["state.yaw"] = np.array([[rpy[2]]])
205-
obs["state.gripper"] = np.expand_dims(gripper, axis=0)
206-
else:
207-
for key in ("x", "y", "z", "roll", "pitch", "yaw"):
208-
obs[f"state.{key}"] = np.array([[0.0]], dtype=np.float64)
209-
obs["state.gripper"] = np.array([[0.0]], dtype=np.float64)
210-
211298
def _map_state(self, obs: dict, state: np.ndarray):
212-
"""Map robot state array to GR00T state keys."""
299+
"""Map robot state array to GR00T state keys (N1.5 format)."""
213300
name = self.data_config_name.lower()
214301
if "so100" in name and len(state) >= 6:
215302
obs["state.single_arm"] = state[:5].astype(np.float64)
@@ -229,17 +316,16 @@ def _map_state(self, obs: dict, state: np.ndarray):
229316
obs[self.state_keys[0]] = state.astype(np.float64)
230317

231318
def _to_robot_actions(self, chunk: dict) -> List[Dict[str, Any]]:
232-
"""Convert GR00T action chunk to list of robot action dicts."""
233-
act_key = None
234-
for k in self.action_keys:
235-
base = k.replace("action.", "") if k.startswith("action.") else k
236-
full = f"action.{base}"
237-
if full in chunk:
238-
act_key = full
239-
break
240-
if not act_key:
241-
act_keys = [k for k in chunk if k.startswith("action.")]
242-
act_key = act_keys[0] if act_keys else None
319+
"""Convert GR00T action chunk to list of robot action dicts.
320+
321+
Handles both N1.5 format (shape: (T, dim)) and
322+
N1.6 format (shape: (B, T, dim) where B=1).
323+
"""
324+
# Strip batch dimension from N1.6 response: (B, T, dim) -> (T, dim)
325+
chunk = self._strip_batch_dim(chunk)
326+
327+
# Find action key
328+
act_key = self._find_action_key(chunk)
243329
if not act_key:
244330
return []
245331

@@ -254,13 +340,15 @@ def _to_robot_actions(self, chunk: dict) -> List[Dict[str, Any]]:
254340
for i in range(horizon):
255341
parts = []
256342
for k in self.action_keys:
257-
mod = k.split(".")[-1]
258-
if f"action.{mod}" in chunk:
259-
parts.append(np.atleast_1d(chunk[f"action.{mod}"][i]))
343+
mod = k.split(".")[-1] if "." in k else k
344+
for candidate in (mod, f"action.{mod}"):
345+
if candidate in chunk:
346+
parts.append(np.atleast_1d(chunk[candidate][i]).flatten())
347+
break
260348
if not parts:
261349
for k, v in chunk.items():
262-
if k.startswith("action."):
263-
parts.append(np.atleast_1d(v[i]))
350+
if k.startswith("action.") or k in self.action_keys:
351+
parts.append(np.atleast_1d(v[i]).flatten())
264352

265353
concat = np.concatenate(parts) if parts else np.zeros(len(self.robot_state_keys) or 6)
266354
actions.append(
@@ -269,6 +357,34 @@ def _to_robot_actions(self, chunk: dict) -> List[Dict[str, Any]]:
269357

270358
return actions
271359

360+
@staticmethod
361+
def _strip_batch_dim(chunk: dict) -> dict:
362+
"""Strip batch dimension from N1.6 action response.
363+
364+
N1.6 returns shape (B, T, dim), we need (T, dim).
365+
N1.5 returns shape (T, dim), no change needed.
366+
"""
367+
result = {}
368+
for k, v in chunk.items():
369+
if isinstance(v, np.ndarray) and v.ndim == 3 and v.shape[0] == 1:
370+
result[k] = v[0] # (1, T, dim) -> (T, dim)
371+
else:
372+
result[k] = v
373+
return result
374+
375+
def _find_action_key(self, chunk: dict) -> str:
376+
"""Find the first available action key in chunk."""
377+
for k in self.action_keys:
378+
base = k.replace("action.", "") if k.startswith("action.") else k
379+
for candidate in (base, f"action.{base}"):
380+
if candidate in chunk:
381+
return candidate
382+
# Fallback: any action-like key
383+
for k in chunk:
384+
if k.startswith("action.") or k in ("x", "y", "z", "roll", "pitch", "yaw", "gripper"):
385+
return k
386+
return None
387+
272388
@staticmethod
273389
def _quat2axisangle(quat: np.ndarray) -> np.ndarray:
274390
"""Convert quaternion (x,y,z,w) to axis-angle (roll,pitch,yaw)."""
@@ -280,12 +396,17 @@ def _quat2axisangle(quat: np.ndarray) -> np.ndarray:
280396
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
281397

282398
def _to_libero_action(self, action_chunk: dict, idx: int = 0) -> np.ndarray:
283-
"""Convert GR00T action chunk to Libero 7-dim: [dx,dy,dz,droll,dpitch,dyaw,gripper]."""
399+
"""Convert GR00T action chunk to Libero 7-dim: [dx,dy,dz,droll,dpitch,dyaw,gripper].
400+
401+
After _strip_batch_dim, chunk values have shape (T, dim).
402+
"""
284403
components = []
285404
for key in ("x", "y", "z", "roll", "pitch", "yaw", "gripper"):
286-
full_key = f"action.{key}"
287-
if full_key in action_chunk:
288-
components.append(np.atleast_1d(action_chunk[full_key][idx])[0])
405+
for candidate in (key, f"action.{key}"):
406+
if candidate in action_chunk:
407+
val = action_chunk[candidate][idx]
408+
components.append(float(np.asarray(val).flatten()[0]))
409+
break
289410
else:
290411
components.append(0.0)
291412
action = np.array(components, dtype=np.float32)
@@ -305,21 +426,25 @@ def _create_fallback_actions(self) -> dict:
305426
"""Create zero-action fallback when inference fails."""
306427
chunk = {}
307428
horizon = 8
308-
for key in self.action_keys:
309-
mod = key.split(".")[-1]
310-
if "joint_pos" in mod.lower():
311-
dim = 7
312-
elif "eef_pos" in mod.lower():
313-
dim = 3
314-
elif "eef_quat" in mod.lower():
315-
dim = 4
316-
elif "gripper" in mod.lower():
317-
dim = 1
318-
else:
319-
dim = len(self.robot_state_keys) // 5 if self.robot_state_keys else 7
320-
chunk[f"action.{mod}"] = np.zeros((horizon, dim), dtype=np.float64)
321-
if not chunk:
322-
chunk["action.robot0_joint_pos"] = np.zeros((horizon, 7), dtype=np.float64)
429+
if self.groot_version == "n1d6":
430+
for key in ("x", "y", "z", "roll", "pitch", "yaw", "gripper"):
431+
chunk[key] = np.zeros((horizon, 1), dtype=np.float32)
432+
else:
433+
for key in self.action_keys:
434+
mod = key.split(".")[-1]
435+
if "joint_pos" in mod.lower():
436+
dim = 7
437+
elif "eef_pos" in mod.lower():
438+
dim = 3
439+
elif "eef_quat" in mod.lower():
440+
dim = 4
441+
elif "gripper" in mod.lower():
442+
dim = 1
443+
else:
444+
dim = len(self.robot_state_keys) // 5 if self.robot_state_keys else 7
445+
chunk[f"action.{mod}"] = np.zeros((horizon, dim), dtype=np.float32)
446+
if not chunk:
447+
chunk["action.robot0_joint_pos"] = np.zeros((horizon, 7), dtype=np.float32)
323448
return chunk
324449

325450

0 commit comments

Comments
 (0)