-
Notifications
You must be signed in to change notification settings - Fork 196
[eval callback]Refactor eval callbacks into grid-based multi-condition sweep architecture #101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,172 @@ | ||
| """Gait cycle analysis for foot-height-based phase detection. | ||
|
|
||
| Collects per-env foot height data over a configurable analysis window, | ||
| counts full gait cycles via zero-crossings, then provides vectorized | ||
| per-env gait phase detection. | ||
|
|
||
| Phases detected (based on left foot state): | ||
| swing_to_stance = left foot touchdown (air -> ground) | ||
| stance_to_swing = left foot liftoff (ground -> air) | ||
| mid_stance = left foot on ground, right foot at peak | ||
| mid_swing = left foot at peak (in air) | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import Any | ||
|
|
||
| import numpy as np | ||
| from loguru import logger | ||
|
|
||
| _PEAK_THRESHOLD = 0.65 | ||
|
|
||
|
|
||
| class GaitAnalyser: | ||
| """Analyses foot height oscillations to detect gait phases. | ||
|
|
||
| Lifecycle: | ||
| 1. Call ``record_foot_heights()`` each step during the analysis window. | ||
| 2. Call ``try_finalize()`` once enough steps have elapsed. | ||
| When it returns True, phase detection is ready. | ||
| 3. Call ``detect_phases()`` each step to get per-env phase labels. | ||
| 4. Call ``get_metadata()`` to retrieve analysis results for recording. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| sim: Any, | ||
| left_foot_isaac_id: int, | ||
| right_foot_isaac_id: int, | ||
| num_conditions: int, | ||
| *, | ||
| foot_contact_threshold: float = 0.5, | ||
| min_gait_cycles: int = 3, | ||
| ): | ||
| self._sim = sim | ||
| self._left_foot_isaac_id = left_foot_isaac_id | ||
| self._right_foot_isaac_id = right_foot_isaac_id | ||
| self._num_conditions = num_conditions | ||
| self._foot_contact_threshold = foot_contact_threshold | ||
| self._min_gait_cycles = min_gait_cycles | ||
|
|
||
| self._done = False | ||
| self._left_foot_z_history: list[np.ndarray] = [] | ||
| self._right_foot_z_history: list[np.ndarray] = [] | ||
| self._left_foot_z_range: tuple[np.ndarray, np.ndarray] | None = None | ||
| self._right_foot_z_range: tuple[np.ndarray, np.ndarray] | None = None | ||
| self._prev_left_on_ground: np.ndarray = np.zeros(num_conditions, dtype=bool) | ||
|
|
||
| self._metadata: dict[str, Any] = {} | ||
|
|
||
| @property | ||
| def done(self) -> bool: | ||
| return self._done | ||
|
|
||
| def record_foot_heights(self) -> None: | ||
| n = self._num_conditions | ||
| left_z = self._sim._robot.data.body_pos_w[:n, self._left_foot_isaac_id, 2] | ||
| right_z = self._sim._robot.data.body_pos_w[:n, self._right_foot_isaac_id, 2] | ||
| self._left_foot_z_history.append(left_z.detach().cpu().numpy().copy()) | ||
| self._right_foot_z_history.append(right_z.detach().cpu().numpy().copy()) | ||
|
|
||
| def try_finalize(self, *, force: bool = False) -> bool: | ||
| if len(self._left_foot_z_history) < 2: | ||
| return False | ||
|
|
||
| left_z = np.stack(self._left_foot_z_history, axis=0) | ||
| right_z = np.stack(self._right_foot_z_history, axis=0) | ||
| cycles = _count_gait_cycles(left_z) | ||
|
|
||
| if not force and cycles < self._min_gait_cycles: | ||
| return False | ||
|
|
||
| if cycles < self._min_gait_cycles: | ||
| logger.warning( | ||
| f"GaitAnalyser: forced finalization with only " | ||
| f"{cycles} cycles (need {self._min_gait_cycles}). " | ||
| f"Phase detection may be unreliable." | ||
| ) | ||
|
|
||
| self._left_foot_z_range = (np.min(left_z, axis=0), np.max(left_z, axis=0)) | ||
| self._right_foot_z_range = (np.min(right_z, axis=0), np.max(right_z, axis=0)) | ||
| self._done = True | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. perhaps set done at the end of the function, just in case of confusion w.r.t debugging |
||
|
|
||
| # Initialize ground-contact state from the last recorded frame | ||
| left_range = np.maximum(self._left_foot_z_range[1] - self._left_foot_z_range[0], 1e-4) | ||
| left_norm = (left_z[-1] - self._left_foot_z_range[0]) / left_range | ||
| self._prev_left_on_ground = left_norm < self._foot_contact_threshold | ||
|
|
||
| self._metadata = { | ||
| "gait_left_foot_z_min": self._left_foot_z_range[0].tolist(), | ||
| "gait_left_foot_z_max": self._left_foot_z_range[1].tolist(), | ||
| "gait_right_foot_z_min": self._right_foot_z_range[0].tolist(), | ||
| "gait_right_foot_z_max": self._right_foot_z_range[1].tolist(), | ||
| "gait_analysis_cycles": cycles, | ||
| "gait_analysis_steps": len(self._left_foot_z_history), | ||
| } | ||
|
|
||
| self._left_foot_z_history.clear() | ||
| self._right_foot_z_history.clear() | ||
|
|
||
| logger.info( | ||
| f"GaitAnalyser: analysis complete ({cycles} cycles). " | ||
| f"Left Z: [{self._left_foot_z_range[0].mean():.3f}, {self._left_foot_z_range[1].mean():.3f}], " | ||
| f"Right Z: [{self._right_foot_z_range[0].mean():.3f}, {self._right_foot_z_range[1].mean():.3f}]" | ||
| ) | ||
| return True | ||
|
|
||
| def detect_phases(self) -> np.ndarray: | ||
| """Vectorized gait phase detection for all envs. Single GPU sync.""" | ||
| assert self._left_foot_z_range is not None | ||
| assert self._right_foot_z_range is not None | ||
|
|
||
| n = self._num_conditions | ||
| left_z = self._sim._robot.data.body_pos_w[:n, self._left_foot_isaac_id, 2].cpu().numpy() | ||
| right_z = self._sim._robot.data.body_pos_w[:n, self._right_foot_isaac_id, 2].cpu().numpy() | ||
|
|
||
| threshold = self._foot_contact_threshold | ||
| left_min, left_max = self._left_foot_z_range | ||
| right_min, right_max = self._right_foot_z_range | ||
|
|
||
| # --- Normalize foot heights to [0, 1] using observed range --- | ||
| left_range = np.maximum(left_max - left_min, 1e-4) | ||
| right_range = np.maximum(right_max - right_min, 1e-4) | ||
|
|
||
| left_norm = (left_z - left_min) / left_range | ||
| right_norm = (right_z - right_min) / right_range | ||
|
|
||
| # --- Detect ground contact and transitions --- | ||
| left_on_ground = left_norm < threshold | ||
| prev_left_on_ground = self._prev_left_on_ground | ||
| self._prev_left_on_ground = left_on_ground.copy() | ||
|
|
||
| phases: np.ndarray = np.full(n, "", dtype=object) | ||
|
|
||
| # Transitions: left foot crossing the ground threshold | ||
| swing_to_stance = left_on_ground & ~prev_left_on_ground | ||
| stance_to_swing = ~left_on_ground & prev_left_on_ground | ||
| phases[swing_to_stance] = "swing_to_stance" | ||
| phases[stance_to_swing] = "stance_to_swing" | ||
|
|
||
| # Mid-phase: no transition, detect by peak height of the airborne foot | ||
| no_transition = ~swing_to_stance & ~stance_to_swing | ||
| phases[no_transition & ~left_on_ground & (left_norm >= _PEAK_THRESHOLD)] = "mid_swing" | ||
| phases[no_transition & left_on_ground & (right_norm >= _PEAK_THRESHOLD)] = "mid_stance" | ||
|
|
||
| return phases | ||
|
|
||
| def get_metadata(self) -> dict[str, Any]: | ||
| return self._metadata.copy() | ||
|
|
||
|
|
||
| def _count_gait_cycles(foot_z: np.ndarray) -> int: | ||
| """Count full gait cycles from zero-crossings of centered foot-Z signal.""" | ||
| # Center signal around mean, count sign changes (zero-crossings), | ||
| # two crossings = one full cycle. Return min across all envs. | ||
| mean_z = np.mean(foot_z, axis=0, keepdims=True) | ||
| centered = foot_z - mean_z | ||
| sign = np.sign(centered) | ||
| sign[sign == 0] = 1 | ||
| crossings = np.sum(np.abs(np.diff(sign, axis=0)) > 0, axis=0) | ||
| cycles = crossings // 2 | ||
| return int(np.min(cycles)) if cycles.size > 0 else 0 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,195 @@ | ||
| """Grid condition manager for multi-condition evaluation sweeps. | ||
|
|
||
| Owns the cross-product logic that combines sweep axes (velocity, payload, | ||
| push, etc.) into a flat list of per-env conditions. Any callback can | ||
| register a sweep axis; after all registrations, ``finalize()`` expands | ||
| the Cartesian product and validates against ``num_envs``. | ||
|
|
||
| The manager is created by EvalRecordingCallback (the single recorder) and | ||
| shared with other callbacks via ``_require_recording_cb()``. | ||
|
|
||
| Terminology: | ||
| axis — one factor of the grid, added via ``add_axis()``. | ||
| Each axis contributes one "slot" to the Cartesian product. | ||
| An axis can control a single variable (``name="push_force_n"``) | ||
| or several coupled variables | ||
| (``name=["lin_vel_x", "lin_vel_y", "ang_vel_yaw"]``). | ||
| group — a label used only for metadata output grouping in the NPZ. | ||
| Does not affect the grid logic. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import itertools | ||
| from typing import Any | ||
|
|
||
| from loguru import logger | ||
|
|
||
|
|
||
| class GridConditionManager: | ||
| """Manages sweep axes and the resulting per-env condition list. | ||
|
|
||
| Lifecycle: | ||
| 1. Created during ``on_pre_evaluate_policy`` by the recording callback. | ||
| 2. Other callbacks call ``add_axis()`` during their own | ||
| ``on_pre_evaluate_policy`` (callbacks execute in field-declaration | ||
| order from ``EvalCallbacksConfig``). | ||
| 3. Recording callback calls ``finalize()`` on the first env step | ||
| (deferred so all axes are registered). | ||
| """ | ||
|
|
||
| def __init__(self) -> None: | ||
| self._axes: list[dict[str, Any]] = [] | ||
| self.conditions: list[dict[str, Any]] = [] | ||
| self.num_conditions: int = 0 | ||
| self.warmup_steps: int = 0 | ||
| self._finalized: bool = False | ||
|
|
||
| def add_axis( | ||
| self, | ||
| name: str | list[str], | ||
| values: list, | ||
|
Comment on lines
+48
to
+51
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happens for duplicate axis names? "add" vs. "register". Good to document the behavior/expectations. You append below without checking but idk (yet) if you expect a single name. edit: i see below you de-dupe. Good to document then, especially for callers -- right now, there's no way to know a duplication happened and the match is by key only? not values |
||
| *, | ||
| labels: list[str] | None = None, | ||
| group: str = "", | ||
| ) -> None: | ||
| """Register one axis of the evaluation grid. | ||
|
|
||
| The grid is the Cartesian product of all registered axes. | ||
|
|
||
| Single-variable axis (one key per condition entry):: | ||
|
|
||
| cm.add_axis("push_force_n", [100.0, 150.0], group="push") | ||
|
|
||
| Multi-variable axis (coupled keys that vary together):: | ||
|
|
||
| cm.add_axis( | ||
| ["lin_vel_x", "lin_vel_y", "ang_vel_yaw"], | ||
| [(0.5, 0, 0), (1.0, 0, 0), (0, 0.3, 0)], | ||
| group="velocity", | ||
| ) | ||
|
|
||
| Args: | ||
| name: A single key string, or a list of keys for coupled variables. | ||
| values: The sweep values for this axis. For a single key, a flat | ||
| list (e.g. ``[0.5, 1.0]``). For coupled keys, a list of | ||
| tuples with one element per key. | ||
| labels: Human-readable labels (optional, defaults to str(value)). | ||
| group: Logical group for metadata output (e.g. "velocity", "push"). | ||
|
|
||
| Must be called before ``finalize()``. | ||
| """ | ||
| if self._finalized: | ||
| raise RuntimeError(f"Cannot add axis '{name}' after conditions are finalized.") | ||
|
|
||
| if isinstance(name, str): | ||
| keys = [name] | ||
| normalized_values = [(v,) for v in values] | ||
| else: | ||
| keys = list(name) | ||
| normalized_values = [tuple(v) for v in values] | ||
|
|
||
| self._axes.append( | ||
| { | ||
| "keys": keys, | ||
| "values": normalized_values, | ||
| "labels": labels, | ||
| "group": group, | ||
| } | ||
| ) | ||
|
|
||
| def finalize(self, num_envs: int) -> None: | ||
| """Expand the Cartesian product of all axes and validate env count. | ||
|
|
||
| After this call, ``conditions`` and ``num_conditions`` are set. | ||
| """ | ||
| if self._finalized: | ||
| return | ||
|
|
||
| # --- Expand Cartesian product of all axes --- | ||
| if not self._axes: | ||
| self.conditions = [{}] | ||
| else: | ||
| value_lists = [ax["values"] for ax in self._axes] | ||
| key_lists = [ax["keys"] for ax in self._axes] | ||
| self.conditions = [] | ||
| for combo in itertools.product(*value_lists): | ||
| cond: dict[str, Any] = {} | ||
| for keys, vals in zip(key_lists, combo): | ||
| cond.update(dict(zip(keys, vals))) | ||
| self.conditions.append(cond) | ||
|
|
||
| # --- Deduplicate (preserving order) --- | ||
| seen: set[tuple] = set() | ||
| unique: list[dict[str, Any]] = [] | ||
| for cond in self.conditions: | ||
| key = tuple(sorted(cond.items())) | ||
| if key not in seen: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suspect at the least you'll want to log duplicates, but see above re: comparing keys vs. values, and what callers can/can't do if there are duplicates with different values |
||
| seen.add(key) | ||
| unique.append(cond) | ||
| self.conditions = unique | ||
| self.num_conditions = len(unique) | ||
|
|
||
| if num_envs < self.num_conditions: | ||
| raise RuntimeError( | ||
| f"GridConditionManager: need num_envs >= {self.num_conditions} " | ||
| f"for {self.num_conditions} conditions, but got num_envs={num_envs}. " | ||
| f"Set --training.num_envs={self.num_conditions}" | ||
| ) | ||
| if num_envs > self.num_conditions: | ||
| logger.warning( | ||
| f"GridConditionManager: num_envs={num_envs} > {self.num_conditions} conditions. " | ||
| f"Only using first {self.num_conditions} envs." | ||
| ) | ||
|
|
||
| self._finalized = True | ||
|
|
||
| logger.info(f"GridConditionManager: finalized {self.num_conditions} conditions from {len(self._axes)} axes") | ||
| for i, cond in enumerate(self.conditions): | ||
| parts = [f"{k}={v:.2f}" if isinstance(v, float) else f"{k}={v}" for k, v in cond.items()] | ||
| logger.info(f" env {i}: {' '.join(parts)}") | ||
|
|
||
| def get_metadata(self) -> dict[str, Any]: | ||
| """Return condition metadata for NPZ recording. | ||
|
|
||
| ``grid_conditions`` uses hierarchical dicts grouped by the ``group`` | ||
| parameter passed to ``add_axis()``. Ungrouped keys stay at | ||
| the top level. | ||
|
|
||
| Example condition:: | ||
|
|
||
| {'velocity': {'lin_vel_x': 0.5, 'lin_vel_y': 0.0, 'ang_vel_yaw': 0.0}, | ||
| 'push': {'body_label': 'torso', 'direction': 'forward', 'force_n': 150.0}} | ||
| """ | ||
|
Comment on lines
+152
to
+163
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know our current position towards unit testing, but this kind of logic is a good reason to add them, especially for future regressions. |
||
| key_to_group: dict[str, str] = {} | ||
| for ax in self._axes: | ||
| grp = ax.get("group", "") | ||
| for k in ax["keys"]: | ||
| key_to_group[k] = grp | ||
|
|
||
| hierarchical: list[dict[str, Any]] = [] | ||
| for cond in self.conditions: | ||
| grouped: dict[str, Any] = {} | ||
| for key, val in cond.items(): | ||
| grp = key_to_group.get(key, "") | ||
| if grp: | ||
| grouped.setdefault(grp, {})[key] = val | ||
| else: | ||
| grouped[key] = val | ||
| hierarchical.append(grouped) | ||
|
|
||
| meta: dict[str, Any] = { | ||
| "num_conditions": self.num_conditions, | ||
| "grid_conditions": hierarchical, | ||
| } | ||
| for ax in self._axes: | ||
| keys = ax["keys"] | ||
| values = ax["values"] | ||
| if len(keys) == 1: | ||
| meta[f"sweep_{keys[0]}_values"] = [v[0] for v in values] | ||
| if ax["labels"] is not None: | ||
| meta[f"sweep_{keys[0]}_labels"] = ax["labels"] | ||
| else: | ||
| for k_idx, k in enumerate(keys): | ||
| meta[f"sweep_{k}_values"] = [v[k_idx] for v in values] | ||
| return meta | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will there ever be a second eval CB? defensively, it can help to check here and error if so.