Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/holosoma/holosoma/agents/callbacks/base_callback.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from __future__ import annotations

from typing import Any

from torch.nn import Module


Expand All @@ -8,6 +12,19 @@ def __init__(self, config, training_loop):
self.training_loop = training_loop
self.device = self.training_loop.device

def _get_env(self):
"""Get the unwrapped BaseTask environment."""
return self.training_loop._unwrap_env()

def _require_recording_cb(self) -> Any:
"""Find and return the EvalRecordingCallback, or raise."""
from holosoma.agents.callbacks.recording import EvalRecordingCallback

for cb in self.training_loop.eval_callbacks:
if isinstance(cb, EvalRecordingCallback):
return cb
raise RuntimeError(f"{type(self).__name__} requires EvalRecordingCallback. Set --recording.config.enabled=True")
Comment on lines +19 to +26
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will there ever be a second eval CB? defensively, it can help to check here and error if so.


def on_pre_evaluate_policy(self):
pass

Expand Down
172 changes: 172 additions & 0 deletions src/holosoma/holosoma/agents/callbacks/gait_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
"""Gait cycle analysis for foot-height-based phase detection.

Collects per-env foot height data over a configurable analysis window,
counts full gait cycles via zero-crossings, then provides vectorized
per-env gait phase detection.

Phases detected (based on left foot state):
swing_to_stance = left foot touchdown (air -> ground)
stance_to_swing = left foot liftoff (ground -> air)
mid_stance = left foot on ground, right foot at peak
mid_swing = left foot at peak (in air)
"""

from __future__ import annotations

from typing import Any

import numpy as np
from loguru import logger

_PEAK_THRESHOLD = 0.65


class GaitAnalyser:
"""Analyses foot height oscillations to detect gait phases.

Lifecycle:
1. Call ``record_foot_heights()`` each step during the analysis window.
2. Call ``try_finalize()`` once enough steps have elapsed.
When it returns True, phase detection is ready.
3. Call ``detect_phases()`` each step to get per-env phase labels.
4. Call ``get_metadata()`` to retrieve analysis results for recording.
"""

def __init__(
self,
sim: Any,
left_foot_isaac_id: int,
right_foot_isaac_id: int,
num_conditions: int,
*,
foot_contact_threshold: float = 0.5,
min_gait_cycles: int = 3,
):
self._sim = sim
self._left_foot_isaac_id = left_foot_isaac_id
self._right_foot_isaac_id = right_foot_isaac_id
self._num_conditions = num_conditions
self._foot_contact_threshold = foot_contact_threshold
self._min_gait_cycles = min_gait_cycles

self._done = False
self._left_foot_z_history: list[np.ndarray] = []
self._right_foot_z_history: list[np.ndarray] = []
self._left_foot_z_range: tuple[np.ndarray, np.ndarray] | None = None
self._right_foot_z_range: tuple[np.ndarray, np.ndarray] | None = None
self._prev_left_on_ground: np.ndarray = np.zeros(num_conditions, dtype=bool)

self._metadata: dict[str, Any] = {}

@property
def done(self) -> bool:
return self._done

def record_foot_heights(self) -> None:
n = self._num_conditions
left_z = self._sim._robot.data.body_pos_w[:n, self._left_foot_isaac_id, 2]
right_z = self._sim._robot.data.body_pos_w[:n, self._right_foot_isaac_id, 2]
self._left_foot_z_history.append(left_z.detach().cpu().numpy().copy())
self._right_foot_z_history.append(right_z.detach().cpu().numpy().copy())

def try_finalize(self, *, force: bool = False) -> bool:
if len(self._left_foot_z_history) < 2:
return False

left_z = np.stack(self._left_foot_z_history, axis=0)
right_z = np.stack(self._right_foot_z_history, axis=0)
cycles = _count_gait_cycles(left_z)

if not force and cycles < self._min_gait_cycles:
return False

if cycles < self._min_gait_cycles:
logger.warning(
f"GaitAnalyser: forced finalization with only "
f"{cycles} cycles (need {self._min_gait_cycles}). "
f"Phase detection may be unreliable."
)

self._left_foot_z_range = (np.min(left_z, axis=0), np.max(left_z, axis=0))
self._right_foot_z_range = (np.min(right_z, axis=0), np.max(right_z, axis=0))
self._done = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps set done at the end of the function, just in case of confusion w.r.t debugging


# Initialize ground-contact state from the last recorded frame
left_range = np.maximum(self._left_foot_z_range[1] - self._left_foot_z_range[0], 1e-4)
left_norm = (left_z[-1] - self._left_foot_z_range[0]) / left_range
self._prev_left_on_ground = left_norm < self._foot_contact_threshold

self._metadata = {
"gait_left_foot_z_min": self._left_foot_z_range[0].tolist(),
"gait_left_foot_z_max": self._left_foot_z_range[1].tolist(),
"gait_right_foot_z_min": self._right_foot_z_range[0].tolist(),
"gait_right_foot_z_max": self._right_foot_z_range[1].tolist(),
"gait_analysis_cycles": cycles,
"gait_analysis_steps": len(self._left_foot_z_history),
}

self._left_foot_z_history.clear()
self._right_foot_z_history.clear()

logger.info(
f"GaitAnalyser: analysis complete ({cycles} cycles). "
f"Left Z: [{self._left_foot_z_range[0].mean():.3f}, {self._left_foot_z_range[1].mean():.3f}], "
f"Right Z: [{self._right_foot_z_range[0].mean():.3f}, {self._right_foot_z_range[1].mean():.3f}]"
)
return True

def detect_phases(self) -> np.ndarray:
"""Vectorized gait phase detection for all envs. Single GPU sync."""
assert self._left_foot_z_range is not None
assert self._right_foot_z_range is not None

n = self._num_conditions
left_z = self._sim._robot.data.body_pos_w[:n, self._left_foot_isaac_id, 2].cpu().numpy()
right_z = self._sim._robot.data.body_pos_w[:n, self._right_foot_isaac_id, 2].cpu().numpy()

threshold = self._foot_contact_threshold
left_min, left_max = self._left_foot_z_range
right_min, right_max = self._right_foot_z_range

# --- Normalize foot heights to [0, 1] using observed range ---
left_range = np.maximum(left_max - left_min, 1e-4)
right_range = np.maximum(right_max - right_min, 1e-4)

left_norm = (left_z - left_min) / left_range
right_norm = (right_z - right_min) / right_range

# --- Detect ground contact and transitions ---
left_on_ground = left_norm < threshold
prev_left_on_ground = self._prev_left_on_ground
self._prev_left_on_ground = left_on_ground.copy()

phases: np.ndarray = np.full(n, "", dtype=object)

# Transitions: left foot crossing the ground threshold
swing_to_stance = left_on_ground & ~prev_left_on_ground
stance_to_swing = ~left_on_ground & prev_left_on_ground
phases[swing_to_stance] = "swing_to_stance"
phases[stance_to_swing] = "stance_to_swing"

# Mid-phase: no transition, detect by peak height of the airborne foot
no_transition = ~swing_to_stance & ~stance_to_swing
phases[no_transition & ~left_on_ground & (left_norm >= _PEAK_THRESHOLD)] = "mid_swing"
phases[no_transition & left_on_ground & (right_norm >= _PEAK_THRESHOLD)] = "mid_stance"

return phases

def get_metadata(self) -> dict[str, Any]:
return self._metadata.copy()


def _count_gait_cycles(foot_z: np.ndarray) -> int:
"""Count full gait cycles from zero-crossings of centered foot-Z signal."""
# Center signal around mean, count sign changes (zero-crossings),
# two crossings = one full cycle. Return min across all envs.
mean_z = np.mean(foot_z, axis=0, keepdims=True)
centered = foot_z - mean_z
sign = np.sign(centered)
sign[sign == 0] = 1
crossings = np.sum(np.abs(np.diff(sign, axis=0)) > 0, axis=0)
cycles = crossings // 2
return int(np.min(cycles)) if cycles.size > 0 else 0
195 changes: 195 additions & 0 deletions src/holosoma/holosoma/agents/callbacks/grid_conditions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
"""Grid condition manager for multi-condition evaluation sweeps.

Owns the cross-product logic that combines sweep axes (velocity, payload,
push, etc.) into a flat list of per-env conditions. Any callback can
register a sweep axis; after all registrations, ``finalize()`` expands
the Cartesian product and validates against ``num_envs``.

The manager is created by EvalRecordingCallback (the single recorder) and
shared with other callbacks via ``_require_recording_cb()``.

Terminology:
axis — one factor of the grid, added via ``add_axis()``.
Each axis contributes one "slot" to the Cartesian product.
An axis can control a single variable (``name="push_force_n"``)
or several coupled variables
(``name=["lin_vel_x", "lin_vel_y", "ang_vel_yaw"]``).
group — a label used only for metadata output grouping in the NPZ.
Does not affect the grid logic.
"""

from __future__ import annotations

import itertools
from typing import Any

from loguru import logger


class GridConditionManager:
"""Manages sweep axes and the resulting per-env condition list.

Lifecycle:
1. Created during ``on_pre_evaluate_policy`` by the recording callback.
2. Other callbacks call ``add_axis()`` during their own
``on_pre_evaluate_policy`` (callbacks execute in field-declaration
order from ``EvalCallbacksConfig``).
3. Recording callback calls ``finalize()`` on the first env step
(deferred so all axes are registered).
"""

def __init__(self) -> None:
self._axes: list[dict[str, Any]] = []
self.conditions: list[dict[str, Any]] = []
self.num_conditions: int = 0
self.warmup_steps: int = 0
self._finalized: bool = False

def add_axis(
self,
name: str | list[str],
values: list,
Comment on lines +48 to +51
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens for duplicate axis names? "add" vs. "register". Good to document the behavior/expectations. You append below without checking but idk (yet) if you expect a single name.

edit: i see below you de-dupe. Good to document then, especially for callers -- right now, there's no way to know a duplication happened and the match is by key only? not values

*,
labels: list[str] | None = None,
group: str = "",
) -> None:
"""Register one axis of the evaluation grid.

The grid is the Cartesian product of all registered axes.

Single-variable axis (one key per condition entry)::

cm.add_axis("push_force_n", [100.0, 150.0], group="push")

Multi-variable axis (coupled keys that vary together)::

cm.add_axis(
["lin_vel_x", "lin_vel_y", "ang_vel_yaw"],
[(0.5, 0, 0), (1.0, 0, 0), (0, 0.3, 0)],
group="velocity",
)

Args:
name: A single key string, or a list of keys for coupled variables.
values: The sweep values for this axis. For a single key, a flat
list (e.g. ``[0.5, 1.0]``). For coupled keys, a list of
tuples with one element per key.
labels: Human-readable labels (optional, defaults to str(value)).
group: Logical group for metadata output (e.g. "velocity", "push").

Must be called before ``finalize()``.
"""
if self._finalized:
raise RuntimeError(f"Cannot add axis '{name}' after conditions are finalized.")

if isinstance(name, str):
keys = [name]
normalized_values = [(v,) for v in values]
else:
keys = list(name)
normalized_values = [tuple(v) for v in values]

self._axes.append(
{
"keys": keys,
"values": normalized_values,
"labels": labels,
"group": group,
}
)

def finalize(self, num_envs: int) -> None:
"""Expand the Cartesian product of all axes and validate env count.

After this call, ``conditions`` and ``num_conditions`` are set.
"""
if self._finalized:
return

# --- Expand Cartesian product of all axes ---
if not self._axes:
self.conditions = [{}]
else:
value_lists = [ax["values"] for ax in self._axes]
key_lists = [ax["keys"] for ax in self._axes]
self.conditions = []
for combo in itertools.product(*value_lists):
cond: dict[str, Any] = {}
for keys, vals in zip(key_lists, combo):
cond.update(dict(zip(keys, vals)))
self.conditions.append(cond)

# --- Deduplicate (preserving order) ---
seen: set[tuple] = set()
unique: list[dict[str, Any]] = []
for cond in self.conditions:
key = tuple(sorted(cond.items()))
if key not in seen:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect at the least you'll want to log duplicates, but see above re: comparing keys vs. values, and what callers can/can't do if there are duplicates with different values

seen.add(key)
unique.append(cond)
self.conditions = unique
self.num_conditions = len(unique)

if num_envs < self.num_conditions:
raise RuntimeError(
f"GridConditionManager: need num_envs >= {self.num_conditions} "
f"for {self.num_conditions} conditions, but got num_envs={num_envs}. "
f"Set --training.num_envs={self.num_conditions}"
)
if num_envs > self.num_conditions:
logger.warning(
f"GridConditionManager: num_envs={num_envs} > {self.num_conditions} conditions. "
f"Only using first {self.num_conditions} envs."
)

self._finalized = True

logger.info(f"GridConditionManager: finalized {self.num_conditions} conditions from {len(self._axes)} axes")
for i, cond in enumerate(self.conditions):
parts = [f"{k}={v:.2f}" if isinstance(v, float) else f"{k}={v}" for k, v in cond.items()]
logger.info(f" env {i}: {' '.join(parts)}")

def get_metadata(self) -> dict[str, Any]:
"""Return condition metadata for NPZ recording.

``grid_conditions`` uses hierarchical dicts grouped by the ``group``
parameter passed to ``add_axis()``. Ungrouped keys stay at
the top level.

Example condition::

{'velocity': {'lin_vel_x': 0.5, 'lin_vel_y': 0.0, 'ang_vel_yaw': 0.0},
'push': {'body_label': 'torso', 'direction': 'forward', 'force_n': 150.0}}
"""
Comment on lines +152 to +163
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know our current position towards unit testing, but this kind of logic is a good reason to add them, especially for future regressions.

key_to_group: dict[str, str] = {}
for ax in self._axes:
grp = ax.get("group", "")
for k in ax["keys"]:
key_to_group[k] = grp

hierarchical: list[dict[str, Any]] = []
for cond in self.conditions:
grouped: dict[str, Any] = {}
for key, val in cond.items():
grp = key_to_group.get(key, "")
if grp:
grouped.setdefault(grp, {})[key] = val
else:
grouped[key] = val
hierarchical.append(grouped)

meta: dict[str, Any] = {
"num_conditions": self.num_conditions,
"grid_conditions": hierarchical,
}
for ax in self._axes:
keys = ax["keys"]
values = ax["values"]
if len(keys) == 1:
meta[f"sweep_{keys[0]}_values"] = [v[0] for v in values]
if ax["labels"] is not None:
meta[f"sweep_{keys[0]}_labels"] = ax["labels"]
else:
for k_idx, k in enumerate(keys):
meta[f"sweep_{k}_values"] = [v[k_idx] for v in values]
return meta
Loading
Loading