From f0bed05298a657298bb44f8d10dff19377b3efae Mon Sep 17 00:00:00 2001 From: Filippo Luca Ferretti Date: Thu, 3 Oct 2024 19:10:39 +0200 Subject: [PATCH] Simplify integrators logic - Avoid to use `integrator_state` to pass around variables - Make `wrap_system_dynamics_for_integration` accept only the dynamics as `model` and `data` are required arguments for the integration anyway - Update `js.model.step` accodingly --- src/jaxsim/api/model.py | 20 ++---- src/jaxsim/api/ode.py | 20 +----- src/jaxsim/integrators/common.py | 85 ++----------------------- tests/test_automatic_differentiation.py | 21 ------ tests/test_simulations.py | 43 ++++--------- 5 files changed, 26 insertions(+), 163 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 9b1ee443c..ce2eb30c2 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -1934,7 +1934,7 @@ def step( model: JaxSimModel, data: js.data.JaxSimModelData, *, - integrator_state: dict[str, Any] | None = None, + aux_dict: dict[str, Any] | None = None, joint_forces: jtp.VectorLike | None = None, link_forces: jtp.MatrixLike | None = None, **kwargs, @@ -1945,7 +1945,7 @@ def step( Args: model: The model to consider. data: The data of the considered model. - integrator_state: The state of the integrator. + aux_dict: The auxiliary dictionary to store additional information. joint_forces: The joint forces to consider. link_forces: The 6D forces to apply to the links expressed in the frame corresponding to @@ -1954,32 +1954,25 @@ def step( Returns: A tuple containing the new data of the model - and the new state of the integrator. + and the new auxiliary dictionary. """ # Extract the integrator kwargs. # The following logic allows using integrators having kwargs colliding with the # kwargs of this step function. kwargs = kwargs if kwargs is not None else {} + aux_dict = aux_dict or {} integrator_kwargs = kwargs.pop("integrator_kwargs", {}) integrator_kwargs = kwargs | integrator_kwargs - integrator_state = ( - integrator_state - if integrator_state is not None - else model.integrator.init(x0=data.state, t0=0.0, dt=model.dt) - ) - # Extract the initial resources. t0_ns = data.time_ns state_t0 = data.state - integrator_state_x0 = integrator_state # Step the dynamics forward. - state_tf, integrator_state_tf = model.integrator.step( + state_tf, aux_dict = model._integrator.step( x0=state_t0, t0=jnp.array(t0_ns / 1e9).astype(float), dt=model.dt, - params=integrator_state_x0, # Always inject the current (model, data) pair into the system dynamics # considered by the integrator, and include the input variables represented # by the pair (joint_forces, link_forces). @@ -1995,6 +1988,7 @@ def step( link_forces=link_forces, ) | integrator_kwargs + | aux_dict ), ) @@ -2084,5 +2078,5 @@ def step( return ( data_tf, - integrator_state_tf, + aux_dict, ) diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index 33e95a22b..0dffcad8f 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -23,11 +23,7 @@ def __call__( def wrap_system_dynamics_for_integration( - model: js.model.JaxSimModel, - data: js.data.JaxSimModelData, - *, system_dynamics: SystemDynamicsFromModelAndData, - **kwargs, ) -> jaxsim.integrators.common.SystemDynamics[ODEState, ODEState]: """ Wrap generic system dynamics operating on `JaxSimModel` and `JaxSimModelData` @@ -43,21 +39,11 @@ def wrap_system_dynamics_for_integration( The system dynamics closed over the model, the data, and the additional kwargs. """ - # We allow to close `system_dynamics` over additional kwargs. - kwargs_closed = kwargs.copy() - - # Create a local copy of model and data. - # The wrapped dynamics will hold a reference of this object. - model_closed = model.copy() - data_closed = data.copy().replace( - state=js.ode_data.ODEState.zero(model=model_closed, data=data) - ) - def f(x: ODEState, t: Time, **kwargs_f) -> tuple[ODEState, dict[str, Any]]: # Allow caller to override the closed data and model objects. - data_f = kwargs_f.pop("data", data_closed) - model_f = kwargs_f.pop("model", model_closed) + data_f = kwargs_f.pop("data") + model_f = kwargs_f.pop("model") # Update the state and time stored inside data. with data_f.editable(validate=True) as data_rw: @@ -69,7 +55,7 @@ def f(x: ODEState, t: Time, **kwargs_f) -> tuple[ODEState, dict[str, Any]]: return system_dynamics( model=model_f, data=data_rw, - **(kwargs_closed | kwargs_f), + **kwargs_f, ) f: jaxsim.integrators.common.SystemDynamics[ODEState, ODEState] diff --git a/src/jaxsim/integrators/common.py b/src/jaxsim/integrators/common.py index e6d05c266..4275bb10f 100644 --- a/src/jaxsim/integrators/common.py +++ b/src/jaxsim/integrators/common.py @@ -65,7 +65,6 @@ class Integrator(JaxsimDataclass, abc.ABC, Generic[State, StateDerivative]): @classmethod def build( cls: type[Self], - *, dynamics: SystemDynamics[State, StateDerivative], **kwargs, ) -> Self: @@ -73,7 +72,6 @@ def build( Build the integrator object. Args: - dynamics: The system dynamics. **kwargs: Additional keyword arguments to build the integrator. Returns: @@ -87,8 +85,6 @@ def step( x0: State, t0: Time, dt: TimeStep, - *, - params: dict[str, Any], **kwargs, ) -> tuple[State, dict[str, Any]]: """ @@ -98,96 +94,26 @@ def step( x0: The initial state of the system. t0: The initial time of the system. dt: The time step of the integration. - params: The auxiliary dictionary of the integrator. **kwargs: Additional keyword arguments. Returns: The final state of the system and the updated auxiliary dictionary. """ - with self.editable(validate=False) as integrator: - integrator.params = params - - with integrator.mutable_context(mutability=Mutability.MUTABLE): + with self.mutable_context(mutability=Mutability.MUTABLE) as integrator: xf, aux_dict = integrator(x0, t0, dt, **kwargs) return ( xf, - integrator.params - | {Integrator.AfterInitKey: jnp.array(False).astype(bool)} - | aux_dict, + # integrator.params + # | {Integrator.AfterInitKey: jnp.array(False).astype(bool)} + aux_dict, ) @abc.abstractmethod def __call__(self, x0: State, t0: Time, dt: TimeStep, **kwargs) -> NextState: pass - def init( - self, - x0: State, - t0: Time, - dt: TimeStep, - *, - include_dynamics_aux_dict: bool = False, - **kwargs, - ) -> dict[str, Any]: - """ - Initialize the integrator. - - Args: - x0: The initial state of the system. - t0: The initial time of the system. - dt: The time step of the integration. - - Returns: - The auxiliary dictionary of the integrator. - - Note: - This method should have the same signature as the inherited `__call__` - method, including additional kwargs. - - Note: - If the integrator supports FSAL, the pair `(x0, t0)` must match the real - initial state and time of the system, otherwise the initial derivative of - the first step will be wrong. - """ - - with self.editable(validate=False) as integrator: - - # Initialize the integrator parameters. - # For initialization purpose, the integrators can check if the - # `Integrator.InitializingKey` is present in their parameters. - # The AfterInitKey is used in the first step after initialization. - integrator.params = { - Integrator.InitializingKey: jnp.array(True), - Integrator.AfterInitKey: jnp.array(False), - } - - # Run a dummy call of the integrator. - # It is used only to get the params so that we know the structure - # of the corresponding pytree. - _ = integrator(x0, t0, dt, **kwargs) - - # Remove the injected key. - _ = integrator.params.pop(Integrator.InitializingKey) - - # Make sure that all leafs of the dictionary are JAX arrays. - # Also, since these are dummy parameters, set them all to zero. - params_after_init = jax.tree.map(lambda l: jnp.zeros_like(l), integrator.params) - - # Mark the next step as first step after initialization. - params_after_init = params_after_init | { - Integrator.AfterInitKey: jnp.array(True) - } - - # Store the zero parameters in the integrator. - # When the integrator is stepped, this is used to check if the passed - # parameters are valid. - with self.mutable_context(mutability=Mutability.MUTABLE_NO_VALIDATION): - self.params = params_after_init - - return params_after_init - @jax_dataclasses.pytree_dataclass class ExplicitRungeKutta(Integrator[PyTreeType, PyTreeType], Generic[PyTreeType]): @@ -227,7 +153,6 @@ def order(self) -> int: def build( cls: type[Self], *, - dynamics: SystemDynamics[State, StateDerivative], fsal_enabled_if_supported: jtp.BoolLike = True, **kwargs, ) -> Self: @@ -235,7 +160,6 @@ def build( Build the integrator object. Args: - dynamics: The system dynamics. fsal_enabled_if_supported: Whether to enable the FSAL property, if supported. **kwargs: Additional keyword arguments to build the integrator. @@ -270,7 +194,6 @@ def build( # Build the integrator object. integrator = super().build( - dynamics=dynamics, index_of_fsal=index_of_fsal, fsal_enabled_if_supported=bool(fsal_enabled_if_supported), **kwargs, diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index c173003b1..f3716394b 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -362,24 +362,6 @@ def test_ad_integration( # Test # ==== - import jaxsim.integrators - - # Note that only fixes-step integrators support both FWD and RWD gradients. - # Select a second-order Heun scheme with quaternion integrated on SO(3). - # Note that it's always preferable using the SO(3) versions on AD applications so - # that the gradient of the integrated dynamics always considers unary quaternions. - integrator = jaxsim.integrators.fixed_step.Heun2SO3.build( - dynamics=js.ode.wrap_system_dynamics_for_integration( - model=model, - data=data, - system_dynamics=js.ode.system_dynamics, - ), - ) - - # Initialize the integrator. - t0, dt = 0.0, 0.001 - integrator_state = integrator.init(x0=data.state, t0=t0, dt=dt) - # Function exposing only the parameters to be differentiated. def step( W_p_B: jtp.Vector, @@ -411,11 +393,8 @@ def step( ) data_xf, _ = js.model.step( - dt=dt, model=model, data=data_x0, - integrator=integrator, - integrator_state=integrator_state, joint_forces=τ, link_forces=W_f_L, ) diff --git a/tests/test_simulations.py b/tests/test_simulations.py index 41807b8e2..9829e0d0d 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -60,18 +60,10 @@ def test_box_with_external_forces( additive=False, ) - # Create the integrator. - integrator = jaxsim.integrators.fixed_step.RungeKutta4SO3.build( - dynamics=js.ode.wrap_system_dynamics_for_integration( - model=model, data=data0, system_dynamics=js.ode.system_dynamics - ) - ) - # Initialize the integrator. tf = 0.5 - dt = 0.001 - T = jnp.arange(start=0, stop=tf * 1e9, step=dt * 1e9, dtype=int) - integrator_state = integrator.init(x0=data0.state, t0=0.0, dt=dt) + T = jnp.arange(start=0, stop=tf * 1e9, step=model.dt * 1e9, dtype=int) + aux_dict = None # Copy the initial data... data = data0.copy() @@ -79,17 +71,15 @@ def test_box_with_external_forces( # ... and step the simulation. for t_ns in T: - data, integrator_state = js.model.step( + data, aux_dict = js.model.step( model=model, data=data, - dt=dt, - integrator=integrator, - integrator_state=integrator_state, + aux_dict=aux_dict, link_forces=references.link_forces(model=model, data=data), ) # Check that the box didn't move. - assert data.time() == t_ns / 1e9 + dt + assert data.time() == t_ns / 1e9 + model.dt assert data.base_position() == pytest.approx(data0.base_position()) assert data.base_orientation() == pytest.approx(data0.base_orientation()) @@ -149,21 +139,14 @@ def test_box_with_zero_gravity( additive=False, ) - # Create the integrator. - integrator = jaxsim.integrators.fixed_step.RungeKutta4SO3.build( - dynamics=js.ode.wrap_system_dynamics_for_integration( - model=model, data=data0, system_dynamics=js.ode.system_dynamics - ) - ) - - # Initialize the integrator. - tf, dt = 1.0, 0.010 - T = jnp.arange(start=0, stop=tf * 1e9, step=dt * 1e9, dtype=int) - integrator_state = integrator.init(x0=data0.state, t0=0.0, dt=dt) + tf = 0.01 + T = jnp.arange(start=0, stop=tf * 1e9, step=model.dt * 1e9, dtype=int) # Copy the initial data... data = data0.copy() + aux_dict = None + # ... and step the simulation. for t_ns in T: @@ -174,17 +157,15 @@ def test_box_with_zero_gravity( references.switch_velocity_representation(velocity_representation), ): - data, integrator_state = js.model.step( + data, aux_dict = js.model.step( model=model, data=data, - dt=dt, - integrator=integrator, - integrator_state=integrator_state, link_forces=references.link_forces(model=model, data=data), + aux_dict=aux_dict, ) # Check the final simulation time. - assert data.time() == T[-1] / 1e9 + dt + assert data.time() == T[-1] / 1e9 + model.dt # Check that the box moved as expected. assert data.base_position() == pytest.approx(