diff --git a/scripts/setup_mujoco_py312.sh b/scripts/setup_mujoco_py312.sh index ae45831fa..b2ed30863 100755 --- a/scripts/setup_mujoco_py312.sh +++ b/scripts/setup_mujoco_py312.sh @@ -48,7 +48,7 @@ pip install torch torchvision --index-url https://download.pytorch.org/whl/cu128 pip install astor easydict ipdb joblib loguru lxml matplotlib meshcat omegaconf \ opencv-python plotly pygame pynput rich scipy tensorboard tensordict \ termcolor tqdm trimesh "yourdfpy>=0.0.58" zmq shapely click \ - "warp-lang>=1.10" pydantic "tyro>=1.0.0" "numpy<2" + "warp-lang>=1.10" pydantic "tyro>=1.0.10" "numpy<2" # --- Install holosoma + extensions (--no-deps to skip numpy==1.23.5 pin) --- pip install --no-deps -e $ROOT_DIR/src/holosoma diff --git a/src/holosoma/pyproject.toml b/src/holosoma/pyproject.toml index 31d77d52d..b6f4a5d8a 100644 --- a/src/holosoma/pyproject.toml +++ b/src/holosoma/pyproject.toml @@ -44,7 +44,7 @@ dependencies = [ "types-python-dateutil", "types-requests", "types-tqdm", - "tyro>=1.0.0", + "tyro>=1.0.10", "wandb==0.22.0", # pin version until https://github.com/wandb/wandb/issues/10647 is fixed "warp-lang==1.10.0", "yourdfpy>=0.0.58", diff --git a/src/holosoma_inference/holosoma_inference/config/config_types/inference.py b/src/holosoma_inference/holosoma_inference/config/config_types/inference.py index e99db9bfb..489a75167 100644 --- a/src/holosoma_inference/holosoma_inference/config/config_types/inference.py +++ b/src/holosoma_inference/holosoma_inference/config/config_types/inference.py @@ -2,7 +2,9 @@ from __future__ import annotations +import tyro from pydantic.dataclasses import dataclass +from typing_extensions import Annotated from .observation import ObservationConfig from .robot import RobotConfig @@ -26,8 +28,31 @@ class InferenceConfig: task: TaskConfig """Task execution configuration.""" - secondary: InferenceConfig | None = None - """Secondary policy config for dual-mode (X-button switch). - When set, enables runtime switching between this (primary) policy and the secondary. - Override any field: --secondary.task.model-path, --secondary.task.rl-rate, etc. - Set to None to disable dual-mode.""" + +def _make_secondary_constructor(): + """Build subcommand constructor for the secondary config selector. + + Lazily imports get_defaults() to avoid circular import (config_types <=> config_values). + Extracts .primary from each DualModePolicyConfig so the secondary offers the same presets + as the top-level selector. A config must be registered as a primary to be used as a secondary. + """ + from holosoma_inference.config.config_values.inference import get_defaults + + return tyro.extras.subcommand_type_from_defaults({k: v.primary for k, v in get_defaults().items()}) + + +@dataclass +class DualModePolicyConfig: + """A primary inference config paired with an optional secondary for dual-mode.""" + + primary: Annotated[InferenceConfig, tyro.conf.arg(name="")] + """Primary inference config.""" + + secondary: ( + Annotated[ + InferenceConfig, + tyro.conf.arg(constructor_factory=_make_secondary_constructor), + ] + | None + ) = None + """Secondary inference config for dual-mode (X-button switch).""" diff --git a/src/holosoma_inference/holosoma_inference/config/config_values/inference.py b/src/holosoma_inference/holosoma_inference/config/config_values/inference.py index 2a00e39d3..fa61ab6a0 100644 --- a/src/holosoma_inference/holosoma_inference/config/config_values/inference.py +++ b/src/holosoma_inference/holosoma_inference/config/config_values/inference.py @@ -1,18 +1,20 @@ """Default inference configurations for holosoma_inference.""" +from __future__ import annotations + from dataclasses import replace from importlib.metadata import entry_points import tyro from typing_extensions import Annotated -from holosoma_inference.config.config_types.inference import InferenceConfig +from holosoma_inference.config.config_types.inference import DualModePolicyConfig, InferenceConfig from holosoma_inference.config.config_values import observation, robot, task -# Shared safety secondary for all G1 configs — FastSAC locomotion. -# Each config references the same object; users can override any field -# with --secondary.task.model-path etc., or disable with --secondary none. -_g1_safety_secondary = InferenceConfig( +# Safety secondary for G1 configs - FastSAC locomotion. +# Users can select this explicitly via `secondary:g1-29dof-safety-loco`, +# or disable the secondary with `secondary:none`. +g1_safety_loco = InferenceConfig( robot=robot.g1_29dof, observation=observation.loco_g1_29dof, task=task.safety_locomotion_g1, @@ -22,7 +24,6 @@ robot=robot.g1_29dof, observation=observation.loco_g1_29dof, task=task.locomotion, - secondary=_g1_safety_secondary, ) t1_29dof_loco = InferenceConfig( @@ -62,14 +63,14 @@ # fmt: on observation=observation.wbt, task=task.wbt, - secondary=_g1_safety_secondary, ) # Core defaults - no extension imports at module load time -DEFAULTS = { - "g1-29dof-loco": g1_29dof_loco, - "t1-29dof-loco": t1_29dof_loco, - "g1-29dof-wbt": g1_29dof_wbt, +DEFAULTS: dict[str, DualModePolicyConfig] = { + "g1-29dof-loco": DualModePolicyConfig(primary=g1_29dof_loco, secondary=g1_safety_loco), + "t1-29dof-loco": DualModePolicyConfig(primary=t1_29dof_loco, secondary=None), + "g1-29dof-wbt": DualModePolicyConfig(primary=g1_29dof_wbt, secondary=g1_safety_loco), + "g1-29dof-safety-loco": DualModePolicyConfig(primary=g1_safety_loco, secondary=None), } # Track whether extensions have been loaded @@ -90,31 +91,31 @@ def _load_extensions() -> None: DEFAULTS[ep.name] = ep.load() -def get_annotated_inference_config() -> type: - """Build the annotated InferenceConfig type with all discovered configs. - - This function loads extension configs lazily and returns a tyro-compatible - annotated type for CLI subcommand generation. +def _make_dual_mode_constructor(): + """Build subcommand constructor for the top-level config selector.""" + defaults = get_defaults() - Returns: - Annotated type suitable for use with tyro.cli() - """ - _load_extensions() - return Annotated[ - InferenceConfig, - tyro.conf.arg( - constructor=tyro.extras.subcommand_type_from_defaults( - {f"inference:{k}": v for k, v in DEFAULTS.items()} + # Validate: each default secondary must be registered as a primary in DEFAULTS, + # because we ask tyro to build the secondary subcommand choices from the same set. + primaries = {id(v.primary) for v in defaults.values()} + for name, cfg in defaults.items(): + if cfg.secondary is not None and id(cfg.secondary) not in primaries: + raise ValueError( + f"Default secondary for '{name}' is not a registered primary in DEFAULTS." ) - ), - ] + return tyro.extras.subcommand_type_from_defaults( + {f"inference:{k}": v for k, v in defaults.items()} + ) -def get_defaults() -> dict: - """Get all inference config defaults, including extensions. - Returns: - Dictionary mapping config names to InferenceConfig instances. - """ +AnnotatedDualModePolicyConfig = Annotated[ + DualModePolicyConfig, + tyro.conf.arg(constructor_factory=_make_dual_mode_constructor), +] + + +def get_defaults() -> dict[str, DualModePolicyConfig]: + """Get all config defaults, including extensions.""" _load_extensions() return DEFAULTS diff --git a/src/holosoma_inference/holosoma_inference/policies/dual_mode.py b/src/holosoma_inference/holosoma_inference/policies/dual_mode.py index 16df88d18..7f1a1e113 100644 --- a/src/holosoma_inference/holosoma_inference/policies/dual_mode.py +++ b/src/holosoma_inference/holosoma_inference/policies/dual_mode.py @@ -7,7 +7,7 @@ from loguru import logger from termcolor import colored -from holosoma_inference.config.config_types.inference import InferenceConfig +from holosoma_inference.config.config_types.inference import DualModePolicyConfig, InferenceConfig def _select_policy_class(config: InferenceConfig): @@ -48,22 +48,22 @@ class DualModePolicy: The existing Select/1-9 multi-model switching still works within each policy. """ - def __init__(self, primary_config: InferenceConfig, secondary_config: InferenceConfig): - primary_cls = _select_policy_class(primary_config) - secondary_cls = _select_policy_class(secondary_config) + def __init__(self, config: DualModePolicyConfig): + primary_cls = _select_policy_class(config.primary) + secondary_cls = _select_policy_class(config.secondary) logger.info( colored(f"Dual-mode: primary={primary_cls.__name__}, secondary={secondary_cls.__name__}", "magenta") ) # Fully init primary (owns hardware) - self.primary = primary_cls(config=primary_config) + self.primary = primary_cls(config=config.primary) # Init secondary with shared hardware logger.info(colored("Initializing secondary policy (shared hardware)...", "magenta")) secondary = object.__new__(secondary_cls) secondary._shared_hardware_source = self.primary - secondary.__init__(config=secondary_config) + secondary.__init__(config=config.secondary) self.secondary = secondary self.active = self.primary diff --git a/src/holosoma_inference/holosoma_inference/run_policy.py b/src/holosoma_inference/holosoma_inference/run_policy.py index b35735df5..3b25952e8 100644 --- a/src/holosoma_inference/holosoma_inference/run_policy.py +++ b/src/holosoma_inference/holosoma_inference/run_policy.py @@ -6,8 +6,8 @@ Usage: python run_policy.py inference:g1-29dof-loco --task.model-path path/to/model.onnx - python run_policy.py inference:g1-29dof-loco --task.model-path wandb://project/run/model.onnx - python run_policy.py inference:g1-29dof-loco --task.model-path https://wandb-url/files/model.onnx + python run_policy.py inference:g1-29dof-loco secondary:none + python run_policy.py inference:g1-29dof-loco secondary:g1-29dof-safety-loco """ from __future__ import annotations @@ -18,8 +18,8 @@ import tyro from loguru import logger -from holosoma_inference.config.config_types.inference import InferenceConfig -from holosoma_inference.config.config_values.inference import get_annotated_inference_config +from holosoma_inference.config.config_types.inference import DualModePolicyConfig +from holosoma_inference.config.config_values.inference import AnnotatedDualModePolicyConfig from holosoma_inference.config.utils import TYRO_CONFIG from holosoma_inference.policies.dual_mode import DualModePolicy, _select_policy_class from holosoma_inference.utils.misc import restore_terminal_settings @@ -98,28 +98,28 @@ def _print_control_guide(policy_class, use_joystick: bool, dual_mode: bool = Fal logger.info("") -def run_policy(config: InferenceConfig): +def run_policy(config: DualModePolicyConfig): """Run policy with Tyro configuration.""" logger.info("🚀 Starting Policy with Tyro configuration...") - logger.info(f"🤖 Robot: {config.robot.robot_type}") - logger.info(f"📋 Observation groups: {list(config.observation.obs_dict.keys())}") - logger.info(f"⚙️ RL Rate: {config.task.rl_rate} Hz") - logger.info(f"📁 Model path: {config.task.model_path}") + logger.info(f"🤖 Robot: {config.primary.robot.robot_type}") + logger.info(f"📋 Observation groups: {list(config.primary.observation.obs_dict.keys())}") + logger.info(f"⚙️ RL Rate: {config.primary.task.rl_rate} Hz") + logger.info(f"📁 Model path: {config.primary.task.model_path}") try: # Determine policy class based on observation type - policy_class = _select_policy_class(config) + policy_class = _select_policy_class(config.primary) dual_mode = config.secondary is not None if dual_mode: logger.info(f"Using {policy_class.__name__} (dual-mode enabled)") - policy = DualModePolicy(primary_config=config, secondary_config=config.secondary) + policy = DualModePolicy(config=config) else: logger.info(f"Using {policy_class.__name__}") - policy = policy_class(config=config) + policy = policy_class(config=config.primary) logger.info("✅ Policy initialized successfully!") - _print_control_guide(policy_class, config.task.use_joystick, dual_mode=dual_mode) + _print_control_guide(policy_class, config.primary.task.use_joystick, dual_mode=dual_mode) policy.run() logger.info("✅ Policy execution completed!") @@ -131,90 +131,9 @@ def run_policy(config: InferenceConfig): restore_terminal_settings() -def _split_secondary_args(argv: list[str]) -> tuple[list[str], list[str]]: - """Split --secondary.* args out of argv, renaming them for standalone parsing. - - Returns (primary_argv, secondary_argv) where secondary args have the - ``--secondary.`` prefix stripped, e.g. ``--secondary.task.model-path X`` - becomes ``--task.model-path X``. - """ - primary = [] - secondary = [] - expect_secondary_value = False - for arg in argv: - if arg.startswith("--secondary."): - renamed = "--" + arg[len("--secondary.") :] - secondary.append(renamed) - # If not --key=value form, the next token might be the value - expect_secondary_value = "=" not in renamed - elif expect_secondary_value and not arg.startswith("--"): - secondary.append(arg) - expect_secondary_value = False - else: - primary.append(arg) - expect_secondary_value = False - return primary, secondary - - -def main(annotated_config=None): - """Main entry point. Extensions can pass their own AnnotatedInferenceConfig.""" - import argparse - - from holosoma_inference.config.config_values.inference import DEFAULTS - - # Pre-parse --secondary-preset and --secondary none before tyro. - # Tyro can't build a CLI parser for InferenceConfig | None when it - # contains dict[str, Any] fields, so we handle secondary selection ourselves. - pre = argparse.ArgumentParser(add_help=False, allow_abbrev=False) - pre.add_argument( - "--secondary-preset", - default=None, - metavar="NAME", - help=f"Select a preset for the secondary policy. Choices: {list(DEFAULTS.keys())}", - ) - pre.add_argument("--secondary", default=None, help="Set to 'none' to disable dual-mode.") - known, remaining = pre.parse_known_args() - - disable_secondary = known.secondary is not None and known.secondary.lower() == "none" - secondary_preset = known.secondary_preset - - # Strip --secondary.* args from remaining so tyro doesn't see them - primary_argv, secondary_argv = _split_secondary_args(remaining) - sys.argv = [sys.argv[0]] + primary_argv - - if annotated_config is None: - # Use factory function to lazily load extension configs - annotated_config = get_annotated_inference_config() - config = tyro.cli(annotated_config, config=TYRO_CONFIG) - - from dataclasses import replace as _replace - - if disable_secondary: - config = _replace(config, secondary=None) - elif secondary_preset: - preset = DEFAULTS.get(secondary_preset) - if preset is None: - logger.error(f"Unknown secondary preset: {secondary_preset}") - logger.info(f"Available presets: {list(DEFAULTS.keys())}") - sys.exit(1) - preset = _replace(preset, secondary=None) - - # Parse secondary overrides against the preset defaults - if secondary_argv: - sys.argv = [sys.argv[0]] + secondary_argv - secondary = tyro.cli(InferenceConfig, default=preset, config=TYRO_CONFIG) - else: - secondary = preset - config = _replace(config, secondary=secondary) - elif secondary_argv: - # --secondary.* overrides on the config's default secondary - if config.secondary is not None: - sys.argv = [sys.argv[0]] + secondary_argv - secondary = tyro.cli(InferenceConfig, default=config.secondary, config=TYRO_CONFIG) - config = _replace(config, secondary=secondary) - else: - logger.warning("--secondary.* args ignored: no default secondary in this config") - +def main(): + """Main entry point.""" + config = tyro.cli(AnnotatedDualModePolicyConfig, config=TYRO_CONFIG) run_policy(config) diff --git a/src/holosoma_inference/setup.py b/src/holosoma_inference/setup.py index 7cc1776d6..53bc35f9c 100644 --- a/src/holosoma_inference/setup.py +++ b/src/holosoma_inference/setup.py @@ -44,7 +44,7 @@ "sshkeyboard", "termcolor", "pyyaml", - "tyro>=0.10.0a4", + "tyro>=1.0.10", "wandb", "zmq", "defusedxml", @@ -69,9 +69,8 @@ "t1-29dof = holosoma_inference.config.config_values.robot:t1_29dof", ], "holosoma.config.inference": [ - "g1-29dof-loco = holosoma_inference.config.config_values.inference:g1_29dof_loco", - "t1-29dof-loco = holosoma_inference.config.config_values.inference:t1_29dof_loco", - "g1-29dof-wbt = holosoma_inference.config.config_values.inference:g1_29dof_wbt", + # Extensions register DualModePolicyConfig objects here. + # Core configs are in DEFAULTS directly; these are for external packages. ], }, keywords="humanoid robotics inference policy onnx", diff --git a/tests/e2e/test_run_policy.py b/tests/e2e/test_run_policy.py index 3aa5cc398..f5cfa6d96 100644 --- a/tests/e2e/test_run_policy.py +++ b/tests/e2e/test_run_policy.py @@ -72,7 +72,7 @@ def assert_run_policy_with_hsinference(config_name: str, model_path: str, timeou f"python {REPO_ROOT}/src/holosoma_inference/holosoma_inference/run_policy.py " f"inference:{config_name} " f"--task.model-path={model_path} " - f"--secondary none", + f"secondary:none", ], stdin=subprocess.DEVNULL, stdout=subprocess.PIPE,