Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/setup_mujoco_py312.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cu128
pip install astor easydict ipdb joblib loguru lxml matplotlib meshcat omegaconf \
opencv-python plotly pygame pynput rich scipy tensorboard tensordict \
termcolor tqdm trimesh "yourdfpy>=0.0.58" zmq shapely click \
"warp-lang>=1.10" pydantic "tyro>=1.0.0" "numpy<2"
"warp-lang>=1.10" pydantic "tyro>=1.0.10" "numpy<2"

# --- Install holosoma + extensions (--no-deps to skip numpy==1.23.5 pin) ---
pip install --no-deps -e $ROOT_DIR/src/holosoma
Expand Down
2 changes: 1 addition & 1 deletion src/holosoma/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ dependencies = [
"types-python-dateutil",
"types-requests",
"types-tqdm",
"tyro>=1.0.0",
"tyro>=1.0.10",
"wandb==0.22.0", # pin version until https://github.com/wandb/wandb/issues/10647 is fixed
"warp-lang==1.10.0",
"yourdfpy>=0.0.58",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from __future__ import annotations

import tyro
from pydantic.dataclasses import dataclass
from typing_extensions import Annotated

from .observation import ObservationConfig
from .robot import RobotConfig
Expand All @@ -26,8 +28,31 @@ class InferenceConfig:
task: TaskConfig
"""Task execution configuration."""

secondary: InferenceConfig | None = None
"""Secondary policy config for dual-mode (X-button switch).
When set, enables runtime switching between this (primary) policy and the secondary.
Override any field: --secondary.task.model-path, --secondary.task.rl-rate, etc.
Set to None to disable dual-mode."""

def _make_secondary_constructor():
"""Build subcommand constructor for the secondary config selector.

Lazily imports get_defaults() to avoid circular import (config_types <=> config_values).
Extracts .primary from each DualModePolicyConfig so the secondary offers the same presets
as the top-level selector. A config must be registered as a primary to be used as a secondary.
"""
from holosoma_inference.config.config_values.inference import get_defaults

return tyro.extras.subcommand_type_from_defaults({k: v.primary for k, v in get_defaults().items()})


@dataclass
class DualModePolicyConfig:
"""A primary inference config paired with an optional secondary for dual-mode."""

primary: Annotated[InferenceConfig, tyro.conf.arg(name="")]
"""Primary inference config."""

secondary: (
Annotated[
InferenceConfig,
tyro.conf.arg(constructor_factory=_make_secondary_constructor),
]
| None
) = None
"""Secondary inference config for dual-mode (X-button switch)."""
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
"""Default inference configurations for holosoma_inference."""

from __future__ import annotations

from dataclasses import replace
from importlib.metadata import entry_points

import tyro
from typing_extensions import Annotated

from holosoma_inference.config.config_types.inference import InferenceConfig
from holosoma_inference.config.config_types.inference import DualModePolicyConfig, InferenceConfig
from holosoma_inference.config.config_values import observation, robot, task

# Shared safety secondary for all G1 configs FastSAC locomotion.
# Each config references the same object; users can override any field
# with --secondary.task.model-path etc., or disable with --secondary none.
_g1_safety_secondary = InferenceConfig(
# Safety secondary for G1 configs - FastSAC locomotion.
# Users can select this explicitly via `secondary:g1-29dof-safety-loco`,
# or disable the secondary with `secondary:none`.
g1_safety_loco = InferenceConfig(
robot=robot.g1_29dof,
observation=observation.loco_g1_29dof,
task=task.safety_locomotion_g1,
Expand All @@ -22,7 +24,6 @@
robot=robot.g1_29dof,
observation=observation.loco_g1_29dof,
task=task.locomotion,
secondary=_g1_safety_secondary,
)

t1_29dof_loco = InferenceConfig(
Expand Down Expand Up @@ -62,14 +63,14 @@
# fmt: on
observation=observation.wbt,
task=task.wbt,
secondary=_g1_safety_secondary,
)

# Core defaults - no extension imports at module load time
DEFAULTS = {
"g1-29dof-loco": g1_29dof_loco,
"t1-29dof-loco": t1_29dof_loco,
"g1-29dof-wbt": g1_29dof_wbt,
DEFAULTS: dict[str, DualModePolicyConfig] = {
"g1-29dof-loco": DualModePolicyConfig(primary=g1_29dof_loco, secondary=g1_safety_loco),
"t1-29dof-loco": DualModePolicyConfig(primary=t1_29dof_loco, secondary=None),
"g1-29dof-wbt": DualModePolicyConfig(primary=g1_29dof_wbt, secondary=g1_safety_loco),
"g1-29dof-safety-loco": DualModePolicyConfig(primary=g1_safety_loco, secondary=None),
}

# Track whether extensions have been loaded
Expand All @@ -90,31 +91,31 @@ def _load_extensions() -> None:
DEFAULTS[ep.name] = ep.load()


def get_annotated_inference_config() -> type:
"""Build the annotated InferenceConfig type with all discovered configs.

This function loads extension configs lazily and returns a tyro-compatible
annotated type for CLI subcommand generation.
def _make_dual_mode_constructor():
"""Build subcommand constructor for the top-level config selector."""
defaults = get_defaults()

Returns:
Annotated type suitable for use with tyro.cli()
"""
_load_extensions()
return Annotated[
InferenceConfig,
tyro.conf.arg(
constructor=tyro.extras.subcommand_type_from_defaults(
{f"inference:{k}": v for k, v in DEFAULTS.items()}
# Validate: each default secondary must be registered as a primary in DEFAULTS,
# because we ask tyro to build the secondary subcommand choices from the same set.
primaries = {id(v.primary) for v in defaults.values()}
for name, cfg in defaults.items():
if cfg.secondary is not None and id(cfg.secondary) not in primaries:
raise ValueError(
f"Default secondary for '{name}' is not a registered primary in DEFAULTS."
)
),
]

return tyro.extras.subcommand_type_from_defaults(
{f"inference:{k}": v for k, v in defaults.items()}
)

def get_defaults() -> dict:
"""Get all inference config defaults, including extensions.

Returns:
Dictionary mapping config names to InferenceConfig instances.
"""
AnnotatedDualModePolicyConfig = Annotated[
DualModePolicyConfig,
tyro.conf.arg(constructor_factory=_make_dual_mode_constructor),
]


def get_defaults() -> dict[str, DualModePolicyConfig]:
"""Get all config defaults, including extensions."""
_load_extensions()
return DEFAULTS
12 changes: 6 additions & 6 deletions src/holosoma_inference/holosoma_inference/policies/dual_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from loguru import logger
from termcolor import colored

from holosoma_inference.config.config_types.inference import InferenceConfig
from holosoma_inference.config.config_types.inference import DualModePolicyConfig, InferenceConfig


def _select_policy_class(config: InferenceConfig):
Expand Down Expand Up @@ -48,22 +48,22 @@ class DualModePolicy:
The existing Select/1-9 multi-model switching still works within each policy.
"""

def __init__(self, primary_config: InferenceConfig, secondary_config: InferenceConfig):
primary_cls = _select_policy_class(primary_config)
secondary_cls = _select_policy_class(secondary_config)
def __init__(self, config: DualModePolicyConfig):
primary_cls = _select_policy_class(config.primary)
secondary_cls = _select_policy_class(config.secondary)

logger.info(
colored(f"Dual-mode: primary={primary_cls.__name__}, secondary={secondary_cls.__name__}", "magenta")
)

# Fully init primary (owns hardware)
self.primary = primary_cls(config=primary_config)
self.primary = primary_cls(config=config.primary)

# Init secondary with shared hardware
logger.info(colored("Initializing secondary policy (shared hardware)...", "magenta"))
secondary = object.__new__(secondary_cls)
secondary._shared_hardware_source = self.primary
secondary.__init__(config=secondary_config)
secondary.__init__(config=config.secondary)
self.secondary = secondary

self.active = self.primary
Expand Down
113 changes: 16 additions & 97 deletions src/holosoma_inference/holosoma_inference/run_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

Usage:
python run_policy.py inference:g1-29dof-loco --task.model-path path/to/model.onnx
python run_policy.py inference:g1-29dof-loco --task.model-path wandb://project/run/model.onnx
python run_policy.py inference:g1-29dof-loco --task.model-path https://wandb-url/files/model.onnx
python run_policy.py inference:g1-29dof-loco secondary:none
python run_policy.py inference:g1-29dof-loco secondary:g1-29dof-safety-loco
"""

from __future__ import annotations
Expand All @@ -18,8 +18,8 @@
import tyro
from loguru import logger

from holosoma_inference.config.config_types.inference import InferenceConfig
from holosoma_inference.config.config_values.inference import get_annotated_inference_config
from holosoma_inference.config.config_types.inference import DualModePolicyConfig
from holosoma_inference.config.config_values.inference import AnnotatedDualModePolicyConfig
from holosoma_inference.config.utils import TYRO_CONFIG
from holosoma_inference.policies.dual_mode import DualModePolicy, _select_policy_class
from holosoma_inference.utils.misc import restore_terminal_settings
Expand Down Expand Up @@ -98,28 +98,28 @@ def _print_control_guide(policy_class, use_joystick: bool, dual_mode: bool = Fal
logger.info("")


def run_policy(config: InferenceConfig):
def run_policy(config: DualModePolicyConfig):
"""Run policy with Tyro configuration."""
logger.info("🚀 Starting Policy with Tyro configuration...")
logger.info(f"🤖 Robot: {config.robot.robot_type}")
logger.info(f"📋 Observation groups: {list(config.observation.obs_dict.keys())}")
logger.info(f"⚙️ RL Rate: {config.task.rl_rate} Hz")
logger.info(f"📁 Model path: {config.task.model_path}")
logger.info(f"🤖 Robot: {config.primary.robot.robot_type}")
logger.info(f"📋 Observation groups: {list(config.primary.observation.obs_dict.keys())}")
logger.info(f"⚙️ RL Rate: {config.primary.task.rl_rate} Hz")
logger.info(f"📁 Model path: {config.primary.task.model_path}")

try:
# Determine policy class based on observation type
policy_class = _select_policy_class(config)
policy_class = _select_policy_class(config.primary)
dual_mode = config.secondary is not None

if dual_mode:
logger.info(f"Using {policy_class.__name__} (dual-mode enabled)")
policy = DualModePolicy(primary_config=config, secondary_config=config.secondary)
policy = DualModePolicy(config=config)
else:
logger.info(f"Using {policy_class.__name__}")
policy = policy_class(config=config)
policy = policy_class(config=config.primary)

logger.info("✅ Policy initialized successfully!")
_print_control_guide(policy_class, config.task.use_joystick, dual_mode=dual_mode)
_print_control_guide(policy_class, config.primary.task.use_joystick, dual_mode=dual_mode)
policy.run()
logger.info("✅ Policy execution completed!")

Expand All @@ -131,90 +131,9 @@ def run_policy(config: InferenceConfig):
restore_terminal_settings()


def _split_secondary_args(argv: list[str]) -> tuple[list[str], list[str]]:
"""Split --secondary.* args out of argv, renaming them for standalone parsing.

Returns (primary_argv, secondary_argv) where secondary args have the
``--secondary.`` prefix stripped, e.g. ``--secondary.task.model-path X``
becomes ``--task.model-path X``.
"""
primary = []
secondary = []
expect_secondary_value = False
for arg in argv:
if arg.startswith("--secondary."):
renamed = "--" + arg[len("--secondary.") :]
secondary.append(renamed)
# If not --key=value form, the next token might be the value
expect_secondary_value = "=" not in renamed
elif expect_secondary_value and not arg.startswith("--"):
secondary.append(arg)
expect_secondary_value = False
else:
primary.append(arg)
expect_secondary_value = False
return primary, secondary


def main(annotated_config=None):
"""Main entry point. Extensions can pass their own AnnotatedInferenceConfig."""
import argparse

from holosoma_inference.config.config_values.inference import DEFAULTS

# Pre-parse --secondary-preset and --secondary none before tyro.
# Tyro can't build a CLI parser for InferenceConfig | None when it
# contains dict[str, Any] fields, so we handle secondary selection ourselves.
pre = argparse.ArgumentParser(add_help=False, allow_abbrev=False)
pre.add_argument(
"--secondary-preset",
default=None,
metavar="NAME",
help=f"Select a preset for the secondary policy. Choices: {list(DEFAULTS.keys())}",
)
pre.add_argument("--secondary", default=None, help="Set to 'none' to disable dual-mode.")
known, remaining = pre.parse_known_args()

disable_secondary = known.secondary is not None and known.secondary.lower() == "none"
secondary_preset = known.secondary_preset

# Strip --secondary.* args from remaining so tyro doesn't see them
primary_argv, secondary_argv = _split_secondary_args(remaining)
sys.argv = [sys.argv[0]] + primary_argv

if annotated_config is None:
# Use factory function to lazily load extension configs
annotated_config = get_annotated_inference_config()
config = tyro.cli(annotated_config, config=TYRO_CONFIG)

from dataclasses import replace as _replace

if disable_secondary:
config = _replace(config, secondary=None)
elif secondary_preset:
preset = DEFAULTS.get(secondary_preset)
if preset is None:
logger.error(f"Unknown secondary preset: {secondary_preset}")
logger.info(f"Available presets: {list(DEFAULTS.keys())}")
sys.exit(1)
preset = _replace(preset, secondary=None)

# Parse secondary overrides against the preset defaults
if secondary_argv:
sys.argv = [sys.argv[0]] + secondary_argv
secondary = tyro.cli(InferenceConfig, default=preset, config=TYRO_CONFIG)
else:
secondary = preset
config = _replace(config, secondary=secondary)
elif secondary_argv:
# --secondary.* overrides on the config's default secondary
if config.secondary is not None:
sys.argv = [sys.argv[0]] + secondary_argv
secondary = tyro.cli(InferenceConfig, default=config.secondary, config=TYRO_CONFIG)
config = _replace(config, secondary=secondary)
else:
logger.warning("--secondary.* args ignored: no default secondary in this config")

def main():
"""Main entry point."""
config = tyro.cli(AnnotatedDualModePolicyConfig, config=TYRO_CONFIG)
run_policy(config)


Expand Down
7 changes: 3 additions & 4 deletions src/holosoma_inference/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
"sshkeyboard",
"termcolor",
"pyyaml",
"tyro>=0.10.0a4",
"tyro>=1.0.10",
"wandb",
"zmq",
"defusedxml",
Expand All @@ -69,9 +69,8 @@
"t1-29dof = holosoma_inference.config.config_values.robot:t1_29dof",
],
"holosoma.config.inference": [
"g1-29dof-loco = holosoma_inference.config.config_values.inference:g1_29dof_loco",
"t1-29dof-loco = holosoma_inference.config.config_values.inference:t1_29dof_loco",
"g1-29dof-wbt = holosoma_inference.config.config_values.inference:g1_29dof_wbt",
# Extensions register DualModePolicyConfig objects here.
# Core configs are in DEFAULTS directly; these are for external packages.
],
},
keywords="humanoid robotics inference policy onnx",
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/test_run_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def assert_run_policy_with_hsinference(config_name: str, model_path: str, timeou
f"python {REPO_ROOT}/src/holosoma_inference/holosoma_inference/run_policy.py "
f"inference:{config_name} "
f"--task.model-path={model_path} "
f"--secondary none",
f"secondary:none",
],
stdin=subprocess.DEVNULL,
stdout=subprocess.PIPE,
Expand Down