11#!/usr/bin/env python3
22"""GR00T Policy — natural language robot control via GR00T inference servers.
33
4+ Supports both GR00T N1.5 and N1.6 observation/action formats.
5+
46SPDX-License-Identifier: Apache-2.0
57"""
68
1214
1315from .. import Policy
1416from .client import GR00TClient
15- from .data_config import load_data_config
17+ from .data_config import LIBERO_STATE_TO_N1D6 , load_data_config
1618
1719logger = logging .getLogger (__name__ )
1820
1921
2022class Gr00tPolicy (Policy ):
21- """GR00T policy: connects to a GR00T inference server via ZMQ."""
23+ """GR00T policy: connects to a GR00T inference server via ZMQ.
24+
25+ Supports both N1.5 (prefixed keys: video.X, state, action.X)
26+ and N1.6 (direct keys: X, individual state components) formats.
27+ """
2228
2329 def __init__ (self , data_config : Union [str , dict ], host : str = "localhost" , port : int = 5555 , ** kwargs ):
2430 """Initialize GR00T policy.
2531
2632 Args:
27- data_config: Config name (e.g. "libero") or dict with video/state/action/language keys
33+ data_config: Config name (e.g. "libero") or dict with video/state/action/language keys.
34+ Pass "libero:n1d6" or set groot_version="n1d6" in kwargs to force N1.6 format.
2835 host: Inference service host
2936 port: Inference service port
3037 """
31- self .config = load_data_config (data_config )
38+ groot_version = kwargs .pop ("groot_version" , "auto" )
39+
40+ # Support "config_name:version" syntax (e.g. "libero:n1d6")
41+ if isinstance (data_config , str ) and ":" in data_config :
42+ parts = data_config .split (":" , 1 )
43+ data_config = parts [0 ]
44+ groot_version = parts [1 ]
45+
46+ self .config = load_data_config (data_config , groot_version = groot_version )
3247 self .data_config_name = data_config if isinstance (data_config , str ) else "custom"
33- self .client = GR00TClient (host = host , port = port )
48+ self .groot_version = self .config .get ("groot_version" , "n1d5" )
49+ self .client = GR00TClient (host = host , port = port , groot_version = self .groot_version )
3450
3551 self .camera_keys = self .config ["video" ]
3652 self .state_keys = self .config ["state" ]
3753 self .action_keys = self .config ["action" ]
3854 self .language_keys = self .config ["language" ]
3955 self .robot_state_keys = []
4056
41- logger .info (f"🧠 GR00T Policy: { self .data_config_name } @ { host } :{ port } " )
57+ logger .info (f"🧠 GR00T Policy: { self .data_config_name } @ { host } :{ port } (version: { self . groot_version } ) " )
4258
4359 @property
4460 def provider_name (self ) -> str :
@@ -50,12 +66,66 @@ def set_robot_state_keys(self, robot_state_keys: List[str]) -> None:
5066 async def get_actions (self , observation_dict : Dict [str , Any ], instruction : str , ** kwargs ) -> List [Dict [str , Any ]]:
5167 """Get actions from GR00T policy server.
5268
53- Args:
54- observation_dict: Robot observations (cameras + state)
55- instruction: Natural language instruction
69+ Automatically formats observations for N1.5 or N1.6 based on config.
70+ """
71+ if self .groot_version == "n1d6" :
72+ obs = self ._build_n1d6_observation (observation_dict , instruction )
73+ else :
74+ obs = self ._build_n1d5_observation (observation_dict , instruction )
75+
76+ try :
77+ action_chunk = self .client .get_action (obs )
78+ except Exception as e :
79+ logger .error (f"GR00T inference failed: { e } " )
80+ action_chunk = self ._create_fallback_actions ()
5681
57- Returns:
58- List of action dicts for robot execution
82+ return self ._to_robot_actions (action_chunk )
83+
84+ def _build_n1d6_observation (self , observation_dict : Dict [str , Any ], instruction : str ) -> dict :
85+ """Build observation dict for GR00T N1.6 format.
86+
87+ When the server uses Gr00tSimPolicyWrapper (--use-sim-policy-wrapper),
88+ it expects flat keys with prefixes: video.image, state.x, etc.
89+ The wrapper then converts these to the nested format internally.
90+
91+ Flat format (for SimPolicyWrapper):
92+ {
93+ "video.image": array(B, T, H, W, C),
94+ "video.wrist_image": array(B, T, H, W, C),
95+ "state.x": array(B, T, 1),
96+ ...
97+ "annotation.human.action.task_description": ("instruction",),
98+ }
99+ """
100+ obs = {}
101+
102+ # Camera observations — flat keys with "video." prefix, shape (B, T, H, W, C)
103+ for vkey in self .camera_keys :
104+ cam = self ._find_camera (vkey , observation_dict )
105+ flat_key = f"video.{ vkey } "
106+ if cam and cam in observation_dict :
107+ image = self ._resize_image (observation_dict [cam ], target_size = (256 , 256 ))
108+ obs [flat_key ] = image .reshape (1 , 1 , * image .shape ).astype (np .uint8 )
109+ else :
110+ obs [flat_key ] = np .zeros ((1 , 1 , 256 , 256 , 3 ), dtype = np .uint8 )
111+
112+ # State observations — flat keys with "state." prefix
113+ if "libero" in self .data_config_name .lower ():
114+ self ._map_libero_state_n1d6 (obs , observation_dict )
115+ else :
116+ for skey in self .state_keys :
117+ obs [f"state.{ skey } " ] = np .array ([[[0.0 ]]], dtype = np .float32 )
118+
119+ # Language instruction — as tuple for batch
120+ if self .language_keys :
121+ obs [self .language_keys [0 ]] = (instruction ,)
122+
123+ return obs
124+
125+ def _build_n1d5_observation (self , observation_dict : Dict [str , Any ], instruction : str ) -> dict :
126+ """Build observation dict for GR00T N1.5 format (legacy).
127+
128+ N1.5 uses prefixed keys: video.X, state, action.X
59129 """
60130 obs = {}
61131
@@ -82,34 +152,70 @@ async def get_actions(self, observation_dict: Dict[str, Any], instruction: str,
82152 robot_state_parts .extend (np .atleast_1d (value ).flatten ())
83153 else :
84154 robot_state_parts .append (float (value ))
85- robot_state = np .array (robot_state_parts , dtype = np .float64 )
155+ robot_state = np .array (robot_state_parts , dtype = np .float32 )
86156
87157 if "libero" in self .data_config_name .lower ():
88- self ._map_libero_state (obs , observation_dict )
158+ self ._map_libero_state_n1d5 (obs , observation_dict )
89159 else :
90160 self ._map_state (obs , robot_state )
91161
92162 # Language instruction
93163 if self .language_keys :
94164 obs [self .language_keys [0 ]] = instruction
95165
96- # Batch dimension
166+ # Batch dimension for N1.5
97167 for k in obs :
98168 if isinstance (obs [k ], np .ndarray ) and k .startswith ("video." ):
99169 obs [k ] = np .expand_dims (obs [k ], axis = 0 )
100170 elif isinstance (obs [k ], str ):
101171 obs [k ] = [obs [k ]]
102172
103- try :
104- action_chunk = self .client .get_action (obs )
105- except Exception as e :
106- logger .error (f"GR00T inference failed: { e } " )
107- action_chunk = self ._create_fallback_actions ()
173+ return obs
108174
109- return self ._to_robot_actions (action_chunk )
175+ def _map_libero_state_n1d6 (self , obs : dict , observation_dict : dict ):
176+ """Map Libero observation to N1.6 flat state keys (state.x, state.y, etc.).
177+
178+ State values have shape (B, T, dim) where B=1, T=1.
179+ Uses "state." prefix for SimPolicyWrapper compatibility.
180+ """
181+ if "robot0_eef_pos" in observation_dict and "robot0_eef_quat" in observation_dict :
182+ xyz = observation_dict ["robot0_eef_pos" ]
183+ quat = observation_dict ["robot0_eef_quat" ]
184+ gripper = observation_dict .get ("robot0_gripper_qpos" , np .array ([0.0 , 0.0 ]))
185+ rpy = self ._quat2axisangle (quat )
186+ obs ["state.x" ] = np .array ([[[xyz [0 ]]]], dtype = np .float32 )
187+ obs ["state.y" ] = np .array ([[[xyz [1 ]]]], dtype = np .float32 )
188+ obs ["state.z" ] = np .array ([[[xyz [2 ]]]], dtype = np .float32 )
189+ obs ["state.roll" ] = np .array ([[[rpy [0 ]]]], dtype = np .float32 )
190+ obs ["state.pitch" ] = np .array ([[[rpy [1 ]]]], dtype = np .float32 )
191+ obs ["state.yaw" ] = np .array ([[[rpy [2 ]]]], dtype = np .float32 )
192+ obs ["state.gripper" ] = np .asarray (gripper , dtype = np .float32 ).reshape (1 , 1 , - 1 )
193+ else :
194+ for key in ("x" , "y" , "z" , "roll" , "pitch" , "yaw" ):
195+ obs [f"state.{ key } " ] = np .array ([[[0.0 ]]], dtype = np .float32 )
196+ obs ["state.gripper" ] = np .array ([[[0.0 ]]], dtype = np .float32 )
197+
198+ def _map_libero_state_n1d5 (self , obs : dict , observation_dict : dict ):
199+ """Map Libero end-effector pose to N1.5 state format (state.x, state.y, etc.)."""
200+ if "robot0_eef_pos" in observation_dict and "robot0_eef_quat" in observation_dict :
201+ xyz = observation_dict ["robot0_eef_pos" ]
202+ quat = observation_dict ["robot0_eef_quat" ]
203+ gripper = observation_dict .get ("robot0_gripper_qpos" , np .array ([0.0 , 0.0 ]))
204+ rpy = self ._quat2axisangle (quat )
205+ obs ["state.x" ] = np .array ([[xyz [0 ]]])
206+ obs ["state.y" ] = np .array ([[xyz [1 ]]])
207+ obs ["state.z" ] = np .array ([[xyz [2 ]]])
208+ obs ["state.roll" ] = np .array ([[rpy [0 ]]])
209+ obs ["state.pitch" ] = np .array ([[rpy [1 ]]])
210+ obs ["state.yaw" ] = np .array ([[rpy [2 ]]])
211+ obs ["state.gripper" ] = np .expand_dims (gripper , axis = 0 )
212+ else :
213+ for key in ("x" , "y" , "z" , "roll" , "pitch" , "yaw" ):
214+ obs [f"state.{ key } " ] = np .array ([[0.0 ]], dtype = np .float32 )
215+ obs ["state.gripper" ] = np .array ([[0.0 ]], dtype = np .float32 )
110216
111217 def _find_camera (self , video_key : str , obs : dict ) -> str :
112- """Map GR00T video key to available camera key."""
218+ """Map GR00T video key to available camera key in observation ."""
113219 if video_key in obs :
114220 return video_key
115221
@@ -189,27 +295,8 @@ def _resize_image(self, image: np.ndarray, target_size: tuple = (256, 256)) -> n
189295 except Exception :
190296 return image
191297
192- def _map_libero_state (self , obs : dict , observation_dict : dict ):
193- """Map Libero end-effector pose to GR00T state format."""
194- if "robot0_eef_pos" in observation_dict and "robot0_eef_quat" in observation_dict :
195- xyz = observation_dict ["robot0_eef_pos" ]
196- quat = observation_dict ["robot0_eef_quat" ]
197- gripper = observation_dict .get ("robot0_gripper_qpos" , np .array ([0.0 , 0.0 ]))
198- rpy = self ._quat2axisangle (quat )
199- obs ["state.x" ] = np .array ([[xyz [0 ]]])
200- obs ["state.y" ] = np .array ([[xyz [1 ]]])
201- obs ["state.z" ] = np .array ([[xyz [2 ]]])
202- obs ["state.roll" ] = np .array ([[rpy [0 ]]])
203- obs ["state.pitch" ] = np .array ([[rpy [1 ]]])
204- obs ["state.yaw" ] = np .array ([[rpy [2 ]]])
205- obs ["state.gripper" ] = np .expand_dims (gripper , axis = 0 )
206- else :
207- for key in ("x" , "y" , "z" , "roll" , "pitch" , "yaw" ):
208- obs [f"state.{ key } " ] = np .array ([[0.0 ]], dtype = np .float64 )
209- obs ["state.gripper" ] = np .array ([[0.0 ]], dtype = np .float64 )
210-
211298 def _map_state (self , obs : dict , state : np .ndarray ):
212- """Map robot state array to GR00T state keys."""
299+ """Map robot state array to GR00T state keys (N1.5 format) ."""
213300 name = self .data_config_name .lower ()
214301 if "so100" in name and len (state ) >= 6 :
215302 obs ["state.single_arm" ] = state [:5 ].astype (np .float64 )
@@ -229,17 +316,16 @@ def _map_state(self, obs: dict, state: np.ndarray):
229316 obs [self .state_keys [0 ]] = state .astype (np .float64 )
230317
231318 def _to_robot_actions (self , chunk : dict ) -> List [Dict [str , Any ]]:
232- """Convert GR00T action chunk to list of robot action dicts."""
233- act_key = None
234- for k in self .action_keys :
235- base = k .replace ("action." , "" ) if k .startswith ("action." ) else k
236- full = f"action.{ base } "
237- if full in chunk :
238- act_key = full
239- break
240- if not act_key :
241- act_keys = [k for k in chunk if k .startswith ("action." )]
242- act_key = act_keys [0 ] if act_keys else None
319+ """Convert GR00T action chunk to list of robot action dicts.
320+
321+ Handles both N1.5 format (shape: (T, dim)) and
322+ N1.6 format (shape: (B, T, dim) where B=1).
323+ """
324+ # Strip batch dimension from N1.6 response: (B, T, dim) -> (T, dim)
325+ chunk = self ._strip_batch_dim (chunk )
326+
327+ # Find action key
328+ act_key = self ._find_action_key (chunk )
243329 if not act_key :
244330 return []
245331
@@ -254,13 +340,15 @@ def _to_robot_actions(self, chunk: dict) -> List[Dict[str, Any]]:
254340 for i in range (horizon ):
255341 parts = []
256342 for k in self .action_keys :
257- mod = k .split ("." )[- 1 ]
258- if f"action.{ mod } " in chunk :
259- parts .append (np .atleast_1d (chunk [f"action.{ mod } " ][i ]))
343+ mod = k .split ("." )[- 1 ] if "." in k else k
344+ for candidate in (mod , f"action.{ mod } " ):
345+ if candidate in chunk :
346+ parts .append (np .atleast_1d (chunk [candidate ][i ]).flatten ())
347+ break
260348 if not parts :
261349 for k , v in chunk .items ():
262- if k .startswith ("action." ):
263- parts .append (np .atleast_1d (v [i ]))
350+ if k .startswith ("action." ) or k in self . action_keys :
351+ parts .append (np .atleast_1d (v [i ]). flatten () )
264352
265353 concat = np .concatenate (parts ) if parts else np .zeros (len (self .robot_state_keys ) or 6 )
266354 actions .append (
@@ -269,6 +357,34 @@ def _to_robot_actions(self, chunk: dict) -> List[Dict[str, Any]]:
269357
270358 return actions
271359
360+ @staticmethod
361+ def _strip_batch_dim (chunk : dict ) -> dict :
362+ """Strip batch dimension from N1.6 action response.
363+
364+ N1.6 returns shape (B, T, dim), we need (T, dim).
365+ N1.5 returns shape (T, dim), no change needed.
366+ """
367+ result = {}
368+ for k , v in chunk .items ():
369+ if isinstance (v , np .ndarray ) and v .ndim == 3 and v .shape [0 ] == 1 :
370+ result [k ] = v [0 ] # (1, T, dim) -> (T, dim)
371+ else :
372+ result [k ] = v
373+ return result
374+
375+ def _find_action_key (self , chunk : dict ) -> str :
376+ """Find the first available action key in chunk."""
377+ for k in self .action_keys :
378+ base = k .replace ("action." , "" ) if k .startswith ("action." ) else k
379+ for candidate in (base , f"action.{ base } " ):
380+ if candidate in chunk :
381+ return candidate
382+ # Fallback: any action-like key
383+ for k in chunk :
384+ if k .startswith ("action." ) or k in ("x" , "y" , "z" , "roll" , "pitch" , "yaw" , "gripper" ):
385+ return k
386+ return None
387+
272388 @staticmethod
273389 def _quat2axisangle (quat : np .ndarray ) -> np .ndarray :
274390 """Convert quaternion (x,y,z,w) to axis-angle (roll,pitch,yaw)."""
@@ -280,12 +396,17 @@ def _quat2axisangle(quat: np.ndarray) -> np.ndarray:
280396 return (quat [:3 ] * 2.0 * math .acos (quat [3 ])) / den
281397
282398 def _to_libero_action (self , action_chunk : dict , idx : int = 0 ) -> np .ndarray :
283- """Convert GR00T action chunk to Libero 7-dim: [dx,dy,dz,droll,dpitch,dyaw,gripper]."""
399+ """Convert GR00T action chunk to Libero 7-dim: [dx,dy,dz,droll,dpitch,dyaw,gripper].
400+
401+ After _strip_batch_dim, chunk values have shape (T, dim).
402+ """
284403 components = []
285404 for key in ("x" , "y" , "z" , "roll" , "pitch" , "yaw" , "gripper" ):
286- full_key = f"action.{ key } "
287- if full_key in action_chunk :
288- components .append (np .atleast_1d (action_chunk [full_key ][idx ])[0 ])
405+ for candidate in (key , f"action.{ key } " ):
406+ if candidate in action_chunk :
407+ val = action_chunk [candidate ][idx ]
408+ components .append (float (np .asarray (val ).flatten ()[0 ]))
409+ break
289410 else :
290411 components .append (0.0 )
291412 action = np .array (components , dtype = np .float32 )
@@ -305,21 +426,25 @@ def _create_fallback_actions(self) -> dict:
305426 """Create zero-action fallback when inference fails."""
306427 chunk = {}
307428 horizon = 8
308- for key in self .action_keys :
309- mod = key .split ("." )[- 1 ]
310- if "joint_pos" in mod .lower ():
311- dim = 7
312- elif "eef_pos" in mod .lower ():
313- dim = 3
314- elif "eef_quat" in mod .lower ():
315- dim = 4
316- elif "gripper" in mod .lower ():
317- dim = 1
318- else :
319- dim = len (self .robot_state_keys ) // 5 if self .robot_state_keys else 7
320- chunk [f"action.{ mod } " ] = np .zeros ((horizon , dim ), dtype = np .float64 )
321- if not chunk :
322- chunk ["action.robot0_joint_pos" ] = np .zeros ((horizon , 7 ), dtype = np .float64 )
429+ if self .groot_version == "n1d6" :
430+ for key in ("x" , "y" , "z" , "roll" , "pitch" , "yaw" , "gripper" ):
431+ chunk [key ] = np .zeros ((horizon , 1 ), dtype = np .float32 )
432+ else :
433+ for key in self .action_keys :
434+ mod = key .split ("." )[- 1 ]
435+ if "joint_pos" in mod .lower ():
436+ dim = 7
437+ elif "eef_pos" in mod .lower ():
438+ dim = 3
439+ elif "eef_quat" in mod .lower ():
440+ dim = 4
441+ elif "gripper" in mod .lower ():
442+ dim = 1
443+ else :
444+ dim = len (self .robot_state_keys ) // 5 if self .robot_state_keys else 7
445+ chunk [f"action.{ mod } " ] = np .zeros ((horizon , dim ), dtype = np .float32 )
446+ if not chunk :
447+ chunk ["action.robot0_joint_pos" ] = np .zeros ((horizon , 7 ), dtype = np .float32 )
323448 return chunk
324449
325450
0 commit comments