Sampling-based model predictive control on GPU with JAX and MuJoCo MJX.
Hydrax implements various sampling-based MPC algorithms on GPU. It is heavily inspired by MJPC, but focuses exclusively on sampling-based algorithms, runs on hardware accelerators via JAX and MJX, and includes support for online domain randomization.
Available methods:
Algorithm | Description | Import |
---|---|---|
Predictive sampling | Take the lowest-cost rollout at each iteration. | hydrax.algs.PredictiveSampling |
MPPI | Take an exponentially weighted average of the rollouts. | hydrax.algs.MPPI |
Cross Entropy Method | Fit a Gaussian distribution to the n best "elite" rollouts. |
hydrax.algs.CEM |
Evosax | Any of the 30+ evolution strategies implemented in evosax . Includes CMA-ES, differential evolution, and many more. |
hydrax.algs.Evosax |
- April 13, 2024. Large changes to the core
hydrax
functionality + some breaking changes.- Splines (and their knots) are now the default parameterization of the control signals and decision variables! Before, it was always assumed that every control step applied a zero-order hold. This is now a special case of the new spline parameterization.
- All "time-based" variables are now specified in the controller. Previously, variables like the planning horizon and number of sim steps per control step were specified in the task. Now, the main variables to specify are
plan_horizon
(the length of the planning horizon in seconds),num_knots
(the number of spline knots to plan with), anddt
(the planning time step (in the model XML)). This is a breaking change!
Set up a conda env with cuda support (first time only):
conda env create -f environment.yml
Enter the conda env:
conda activate hydrax
Install the package and dependencies:
# option 1: required deps only
pip install -e .
# option 2: all deps (including dev)
pip install -e .[dev]
(Optional) Set up pre-commit hooks if using development dependencies:
pre-commit autoupdate
pre-commit install
(Optional) Run unit tests if using development dependencies:
pytest
Launch an interactive pendulum swingup simulation with predictive sampling:
python examples/pendulum.py ps
Launch an interactive humanoid standup simulation (shown above) with MPPI and online domain randomization:
python examples/humanoid_standup.py
Other demos can be found in the examples
folder.
Hydrax considers optimal control problems of the form
where
To design a new task, you'll need to specify the cost (hydrax.task_base.Task
:
class MyNewTask(Task):
def __init__(self, ...):
# Create or load a mujoco model defining the dynamics (f)
mj_model = ...
super().__init__(mj_model, ...)
def running_cost(self, x: mjx.Data, u: jax.Array) -> float:
# Implement the running cost (l) here
return ...
def terminal_cost(self, x: jax.Array) -> float:
# Implement the terminal cost (phi) here
return ...
The dynamics (mujoco.MjModel
that is passed to the
constructor. Other constructor arguments specify the planning horizon
For the cost, simply implement the running_cost
(terminal_cost
(
See hydrax.tasks
for some example task implementations.
Hydrax considers sampling-based MPC algorithms that follow the following generic structure:
The meaning of the parameters
To implement a new planning algorithm, you'll need to inherit from
hydrax.alg_base.SamplingBasedController
and implement
the three methods shown below:
class MyControlAlgorithm(SamplingBasedController):
def init_params(self) -> Any:
# Initialize the policy parameters (theta).
...
return params
def sample_knots(self, params: Any) -> Tuple[jax.Array, Any]:
# Sample the spline knots U from the policy. Return the samples
# and the (updated) parameters.
...
return controls, params
def update_params(self, params: Any, rollouts: Trajectory) -> Any:
# Update the policy parameters (theta) based on the trajectory data
# (costs, controls, observations, etc) stored in the rollouts.
...
return new_params
These three methods define a unique sampling-based MPC algorithm. Hydrax takes
care of the rest, including parallelizing rollouts on GPU and collecting the
rollout data in a Trajectory
object.
Note: because of
the way JAX handles randomness,
we assume the PRNG key is stored as one of the parameters sample_knots
returns updated parameters along with the control samples
For some examples, take a look at hydrax.algs
.
One benefit of GPU-based simulation is the ability to roll out trajectories with different model parameters in parallel. Such domain randomization can improve robustness and help reduce the sim-to-real gap.
Hydrax provides tools to make online domain randomization easy. In particular,
you can add domain randomization to any task by overriding the
domain_randomize_model
and domain_randomize_data
methods of a given
Task
. For example:
class MyDomainRandomizedTask(Task):
...
def domain_randomize_model(self, rng: jax.Array) -> Dict[str, jax.Array]:
"""Randomize the friction coefficients."""
n_geoms = self.model.geom_friction.shape[0]
multiplier = jax.random.uniform(rng, (n_geoms,), minval=0.5, maxval=2.0)
new_frictions = self.model.geom_friction.at[:, 0].set(
self.model.geom_friction[:, 0] * multiplier
)
return {"geom_friction": new_frictions}
def domain_randomize_data(self, data: mjx.Data, rng: jax.Array) -> Dict[str, jax.Array]:
"""Randomly shift the measured configurations."""
shift = 0.005 * jax.random.normal(rng, (self.model.nq,))
return {"qpos": data.qpos + shift}
These methods return a dictionary of randomized parameters, given a particular
random seed (rng
). Hydrax takes care of the details of applying these
parameters to the model and data, and performing rollouts in parallel.
To use a domain randomized task, you'll need to tell the planner how many random
models to use with the num_randomizations
flag. For example,
task = MyDomainRandomizedTask(...)
ctrl = PredictiveSampling(
task,
num_samples=32,
noise_level=0.1,
num_randomizations=16,
)
sets up a predictive sampling controller that rolls out 32 control sequences across 16 domain randomized models.
The resulting Trajectory
rollouts will have
dimensions (num_randomizations, num_samples, num_time_steps, ...)
.
With domain randomization, we need to somehow aggregate costs across the
different domains. By default, we take the average cost over the randomizations,
similar to domain randomization in reinforcement learning. Other strategies are
available via the RiskStrategy
interface.
For example, to plan using the worst-case maximum cost across randomizations:
from hydrax.risk import WorstCase
...
task = MyDomainRandomizedTask(...)
ctrl = PredictiveSampling(
task,
num_samples=32,
noise_level=0.1,
num_randomizations=16,
risk_strategy=WorstCase(),
)
Available risk strategies:
Strategy | Description | Import |
---|---|---|
Average (default) | Take the expected cost across randomizations. | hydrax.risk.AverageCost |
Worst-case | Take the maximum cost across randomizations. | hydrax.risk.WorstCase |
Best-case | Take the minimum cost across randomizations. | hydrax.risk.BestCase |
Exponential | Take an exponentially weighted average with parameter |
hydrax.risk.ExponentialWeightedAverage |
VaR | Use the Value at Risk (VaR). | hydrax.risk.ValueAtRisk |
CVaR | Use the Conditional Value at Risk (CVaR). | hydrax.risk.ConditionalValueAtRisk |
@misc{kurtz2024hydrax,
title={Hydrax: Sampling-based model predictive control on GPU with JAX and MuJoCo MJX},
author={Kurtz, Vince},
year={2024},
note={https://github.com/vincekurtz/hydrax}
}