Skip to content

Commit

Permalink
Simplify integrators logic
Browse files Browse the repository at this point in the history
- Avoid to use `integrator_state` to pass around variables
- Make `wrap_system_dynamics_for_integration` accept only the dynamics
  as `model` and `data` are required arguments for the integration anyway
- Update `js.model.step` accodingly
  • Loading branch information
flferretti committed Oct 3, 2024
1 parent 82dca28 commit f0bed05
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 163 deletions.
20 changes: 7 additions & 13 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1934,7 +1934,7 @@ def step(
model: JaxSimModel,
data: js.data.JaxSimModelData,
*,
integrator_state: dict[str, Any] | None = None,
aux_dict: dict[str, Any] | None = None,
joint_forces: jtp.VectorLike | None = None,
link_forces: jtp.MatrixLike | None = None,
**kwargs,
Expand All @@ -1945,7 +1945,7 @@ def step(
Args:
model: The model to consider.
data: The data of the considered model.
integrator_state: The state of the integrator.
aux_dict: The auxiliary dictionary to store additional information.
joint_forces: The joint forces to consider.
link_forces:
The 6D forces to apply to the links expressed in the frame corresponding to
Expand All @@ -1954,32 +1954,25 @@ def step(
Returns:
A tuple containing the new data of the model
and the new state of the integrator.
and the new auxiliary dictionary.
"""
# Extract the integrator kwargs.
# The following logic allows using integrators having kwargs colliding with the
# kwargs of this step function.
kwargs = kwargs if kwargs is not None else {}
aux_dict = aux_dict or {}
integrator_kwargs = kwargs.pop("integrator_kwargs", {})
integrator_kwargs = kwargs | integrator_kwargs

integrator_state = (
integrator_state
if integrator_state is not None
else model.integrator.init(x0=data.state, t0=0.0, dt=model.dt)
)

# Extract the initial resources.
t0_ns = data.time_ns
state_t0 = data.state
integrator_state_x0 = integrator_state

# Step the dynamics forward.
state_tf, integrator_state_tf = model.integrator.step(
state_tf, aux_dict = model._integrator.step(
x0=state_t0,
t0=jnp.array(t0_ns / 1e9).astype(float),
dt=model.dt,
params=integrator_state_x0,
# Always inject the current (model, data) pair into the system dynamics
# considered by the integrator, and include the input variables represented
# by the pair (joint_forces, link_forces).
Expand All @@ -1995,6 +1988,7 @@ def step(
link_forces=link_forces,
)
| integrator_kwargs
| aux_dict
),
)

Expand Down Expand Up @@ -2084,5 +2078,5 @@ def step(

return (
data_tf,
integrator_state_tf,
aux_dict,
)
20 changes: 3 additions & 17 deletions src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,7 @@ def __call__(


def wrap_system_dynamics_for_integration(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
system_dynamics: SystemDynamicsFromModelAndData,
**kwargs,
) -> jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]:
"""
Wrap generic system dynamics operating on `JaxSimModel` and `JaxSimModelData`
Expand All @@ -43,21 +39,11 @@ def wrap_system_dynamics_for_integration(
The system dynamics closed over the model, the data, and the additional kwargs.
"""

# We allow to close `system_dynamics` over additional kwargs.
kwargs_closed = kwargs.copy()

# Create a local copy of model and data.
# The wrapped dynamics will hold a reference of this object.
model_closed = model.copy()
data_closed = data.copy().replace(
state=js.ode_data.ODEState.zero(model=model_closed, data=data)
)

def f(x: ODEState, t: Time, **kwargs_f) -> tuple[ODEState, dict[str, Any]]:

# Allow caller to override the closed data and model objects.
data_f = kwargs_f.pop("data", data_closed)
model_f = kwargs_f.pop("model", model_closed)
data_f = kwargs_f.pop("data")
model_f = kwargs_f.pop("model")

# Update the state and time stored inside data.
with data_f.editable(validate=True) as data_rw:
Expand All @@ -69,7 +55,7 @@ def f(x: ODEState, t: Time, **kwargs_f) -> tuple[ODEState, dict[str, Any]]:
return system_dynamics(
model=model_f,
data=data_rw,
**(kwargs_closed | kwargs_f),
**kwargs_f,
)

f: jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]
Expand Down
85 changes: 4 additions & 81 deletions src/jaxsim/integrators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,13 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]):
@classmethod
def build(
cls: type[Self],
*,
dynamics: SystemDynamics[State, StateDerivative],
**kwargs,
) -> Self:
"""
Build the integrator object.
Args:
dynamics: The system dynamics.
**kwargs: Additional keyword arguments to build the integrator.
Returns:
Expand All @@ -87,8 +85,6 @@ def step(
x0: State,
t0: Time,
dt: TimeStep,
*,
params: dict[str, Any],
**kwargs,
) -> tuple[State, dict[str, Any]]:
"""
Expand All @@ -98,96 +94,26 @@ def step(
x0: The initial state of the system.
t0: The initial time of the system.
dt: The time step of the integration.
params: The auxiliary dictionary of the integrator.
**kwargs: Additional keyword arguments.
Returns:
The final state of the system and the updated auxiliary dictionary.
"""

with self.editable(validate=False) as integrator:
integrator.params = params

with integrator.mutable_context(mutability=Mutability.MUTABLE):
with self.mutable_context(mutability=Mutability.MUTABLE) as integrator:
xf, aux_dict = integrator(x0, t0, dt, **kwargs)

return (
xf,
integrator.params
| {Integrator.AfterInitKey: jnp.array(False).astype(bool)}
| aux_dict,
# integrator.params
# | {Integrator.AfterInitKey: jnp.array(False).astype(bool)}
aux_dict,
)

@abc.abstractmethod
def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState:
pass

def init(
self,
x0: State,
t0: Time,
dt: TimeStep,
*,
include_dynamics_aux_dict: bool = False,
**kwargs,
) -> dict[str, Any]:
"""
Initialize the integrator.
Args:
x0: The initial state of the system.
t0: The initial time of the system.
dt: The time step of the integration.
Returns:
The auxiliary dictionary of the integrator.
Note:
This method should have the same signature as the inherited `__call__`
method, including additional kwargs.
Note:
If the integrator supports FSAL, the pair `(x0, t0)` must match the real
initial state and time of the system, otherwise the initial derivative of
the first step will be wrong.
"""

with self.editable(validate=False) as integrator:

# Initialize the integrator parameters.
# For initialization purpose, the integrators can check if the
# `Integrator.InitializingKey` is present in their parameters.
# The AfterInitKey is used in the first step after initialization.
integrator.params = {
Integrator.InitializingKey: jnp.array(True),
Integrator.AfterInitKey: jnp.array(False),
}

# Run a dummy call of the integrator.
# It is used only to get the params so that we know the structure
# of the corresponding pytree.
_ = integrator(x0, t0, dt, **kwargs)

# Remove the injected key.
_ = integrator.params.pop(Integrator.InitializingKey)

# Make sure that all leafs of the dictionary are JAX arrays.
# Also, since these are dummy parameters, set them all to zero.
params_after_init = jax.tree.map(lambda l: jnp.zeros_like(l), integrator.params)

# Mark the next step as first step after initialization.
params_after_init = params_after_init | {
Integrator.AfterInitKey: jnp.array(True)
}

# Store the zero parameters in the integrator.
# When the integrator is stepped, this is used to check if the passed
# parameters are valid.
with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION):
self.params = params_after_init

return params_after_init


@jax_dataclasses.pytree_dataclass
class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]):
Expand Down Expand Up @@ -227,15 +153,13 @@ def order(self) -> int:
def build(
cls: type[Self],
*,
dynamics: SystemDynamics[State, StateDerivative],
fsal_enabled_if_supported: jtp.BoolLike = True,
**kwargs,
) -> Self:
"""
Build the integrator object.
Args:
dynamics: The system dynamics.
fsal_enabled_if_supported:
Whether to enable the FSAL property, if supported.
**kwargs: Additional keyword arguments to build the integrator.
Expand Down Expand Up @@ -270,7 +194,6 @@ def build(

# Build the integrator object.
integrator = super().build(
dynamics=dynamics,
index_of_fsal=index_of_fsal,
fsal_enabled_if_supported=bool(fsal_enabled_if_supported),
**kwargs,
Expand Down
21 changes: 0 additions & 21 deletions tests/test_automatic_differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,24 +362,6 @@ def test_ad_integration(
# Test
# ====

import jaxsim.integrators

# Note that only fixes-step integrators support both FWD and RWD gradients.
# Select a second-order Heun scheme with quaternion integrated on SO(3).
# Note that it's always preferable using the SO(3) versions on AD applications so
# that the gradient of the integrated dynamics always considers unary quaternions.
integrator = jaxsim.integrators.fixed_step.Heun2SO3.build(
dynamics=js.ode.wrap_system_dynamics_for_integration(
model=model,
data=data,
system_dynamics=js.ode.system_dynamics,
),
)

# Initialize the integrator.
t0, dt = 0.0, 0.001
integrator_state = integrator.init(x0=data.state, t0=t0, dt=dt)

# Function exposing only the parameters to be differentiated.
def step(
W_p_B: jtp.Vector,
Expand Down Expand Up @@ -411,11 +393,8 @@ def step(
)

data_xf, _ = js.model.step(
dt=dt,
model=model,
data=data_x0,
integrator=integrator,
integrator_state=integrator_state,
joint_forces=τ,
link_forces=W_f_L,
)
Expand Down
43 changes: 12 additions & 31 deletions tests/test_simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,36 +60,26 @@ def test_box_with_external_forces(
additive=False,
)

# Create the integrator.
integrator = jaxsim.integrators.fixed_step.RungeKutta4SO3.build(
dynamics=js.ode.wrap_system_dynamics_for_integration(
model=model, data=data0, system_dynamics=js.ode.system_dynamics
)
)

# Initialize the integrator.
tf = 0.5
dt = 0.001
T = jnp.arange(start=0, stop=tf * 1e9, step=dt * 1e9, dtype=int)
integrator_state = integrator.init(x0=data0.state, t0=0.0, dt=dt)
T = jnp.arange(start=0, stop=tf * 1e9, step=model.dt * 1e9, dtype=int)
aux_dict = None

# Copy the initial data...
data = data0.copy()

# ... and step the simulation.
for t_ns in T:

data, integrator_state = js.model.step(
data, aux_dict = js.model.step(
model=model,
data=data,
dt=dt,
integrator=integrator,
integrator_state=integrator_state,
aux_dict=aux_dict,
link_forces=references.link_forces(model=model, data=data),
)

# Check that the box didn't move.
assert data.time() == t_ns / 1e9 + dt
assert data.time() == t_ns / 1e9 + model.dt
assert data.base_position() == pytest.approx(data0.base_position())
assert data.base_orientation() == pytest.approx(data0.base_orientation())

Expand Down Expand Up @@ -149,21 +139,14 @@ def test_box_with_zero_gravity(
additive=False,
)

# Create the integrator.
integrator = jaxsim.integrators.fixed_step.RungeKutta4SO3.build(
dynamics=js.ode.wrap_system_dynamics_for_integration(
model=model, data=data0, system_dynamics=js.ode.system_dynamics
)
)

# Initialize the integrator.
tf, dt = 1.0, 0.010
T = jnp.arange(start=0, stop=tf * 1e9, step=dt * 1e9, dtype=int)
integrator_state = integrator.init(x0=data0.state, t0=0.0, dt=dt)
tf = 0.01
T = jnp.arange(start=0, stop=tf * 1e9, step=model.dt * 1e9, dtype=int)

# Copy the initial data...
data = data0.copy()

aux_dict = None

# ... and step the simulation.
for t_ns in T:

Expand All @@ -174,17 +157,15 @@ def test_box_with_zero_gravity(
references.switch_velocity_representation(velocity_representation),
):

data, integrator_state = js.model.step(
data, aux_dict = js.model.step(
model=model,
data=data,
dt=dt,
integrator=integrator,
integrator_state=integrator_state,
link_forces=references.link_forces(model=model, data=data),
aux_dict=aux_dict,
)

# Check the final simulation time.
assert data.time() == T[-1] / 1e9 + dt
assert data.time() == T[-1] / 1e9 + model.dt

# Check that the box moved as expected.
assert data.base_position() == pytest.approx(
Expand Down

0 comments on commit f0bed05

Please sign in to comment.