From 62254385ab0f056dc859da9d59adc5d80852f417 Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Thu, 30 Jan 2025 12:54:06 +0100 Subject: [PATCH 1/7] Rename base linear and angular velocity parameters in `forward_kinematics_model` --- src/jaxsim/rbda/forward_kinematics.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/jaxsim/rbda/forward_kinematics.py b/src/jaxsim/rbda/forward_kinematics.py index 5f9c2348b..355fe347c 100644 --- a/src/jaxsim/rbda/forward_kinematics.py +++ b/src/jaxsim/rbda/forward_kinematics.py @@ -15,8 +15,8 @@ def forward_kinematics_model( base_position: jtp.VectorLike, base_quaternion: jtp.VectorLike, joint_positions: jtp.VectorLike, - base_linear_velocity: jtp.VectorLike, - base_angular_velocity: jtp.VectorLike, + base_linear_velocity_inertial: jtp.VectorLike, + base_angular_velocity_inertial: jtp.VectorLike, joint_velocities: jtp.VectorLike, ) -> jtp.Array: """ @@ -27,8 +27,8 @@ def forward_kinematics_model( base_position: The position of the base link. base_quaternion: The quaternion of the base link. joint_positions: The positions of the joints. - base_linear_velocity: The linear velocity of the base link. - base_angular_velocity: The angular velocity of the base link. + base_linear_velocity_inertial: The linear velocity of the base link in inertial-fixed representation. + base_angular_velocity_inertial: The angular velocity of the base link in inertial-fixed representation. joint_velocities: The velocities of the joints. Returns: @@ -40,8 +40,8 @@ def forward_kinematics_model( base_position=base_position, base_quaternion=base_quaternion, joint_positions=joint_positions, - base_linear_velocity=base_linear_velocity, - base_angular_velocity=base_angular_velocity, + base_linear_velocity=base_linear_velocity_inertial, + base_angular_velocity=base_angular_velocity_inertial, joint_velocities=joint_velocities, ) From 9bcde5b0869ca686f25930c789ad8ca95a831784 Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Thu, 30 Jan 2025 12:55:43 +0100 Subject: [PATCH 2/7] Update JaxSimModelData to use Mixed as default repr. in `build` function --- src/jaxsim/api/data.py | 40 +++++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index cf64af93a..fb538a392 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -41,7 +41,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation): base_transform: The base transform. joint_transforms: The joint transforms. link_transforms: The link transforms. - link_velocities: The link velocities. + link_velocities: The link velocities in inertial-fixed representation. """ # Joint state @@ -69,7 +69,7 @@ def build( base_linear_velocity: jtp.VectorLike | None = None, base_angular_velocity: jtp.VectorLike | None = None, joint_velocities: jtp.VectorLike | None = None, - velocity_representation: VelRepr = VelRepr.Inertial, + velocity_representation: VelRepr = VelRepr.Mixed, ) -> JaxSimModelData: """ Create a `JaxSimModelData` object with the given state. @@ -84,7 +84,7 @@ def build( base_angular_velocity: The base angular velocity in the selected representation. joint_velocities: The joint velocities. - velocity_representation: The velocity representation to use. + velocity_representation: The velocity representation to use. It defaults to mixed if not provided. Returns: A `JaxSimModelData` initialized with the given state. @@ -144,7 +144,7 @@ def build( translation=base_position, quaternion=base_quaternion ) - v_WB = JaxSimModelData.other_representation_to_inertial( + W_v_WB = JaxSimModelData.other_representation_to_inertial( array=jnp.hstack([base_linear_velocity, base_angular_velocity]), other_representation=velocity_representation, transform=W_H_B, @@ -155,28 +155,30 @@ def build( joint_positions=joint_positions, base_transform=W_H_B ) - link_transforms, link_velocities = jaxsim.rbda.forward_kinematics_model( - model=model, - base_position=base_position, - base_quaternion=base_quaternion, - joint_positions=joint_positions, - base_linear_velocity=v_WB[0:3], - base_angular_velocity=v_WB[3:6], - joint_velocities=joint_velocities, + link_transforms, link_velocities_inertial = ( + jaxsim.rbda.forward_kinematics_model( + model=model, + base_position=base_position, + base_quaternion=base_quaternion, + joint_positions=joint_positions, + base_linear_velocity_inertial=W_v_WB[0:3], + base_angular_velocity_inertial=W_v_WB[3:6], + joint_velocities=joint_velocities, + ) ) model_data = JaxSimModelData( base_quaternion=base_quaternion, base_position=base_position, joint_positions=joint_positions, - base_linear_velocity=v_WB[0:3], - base_angular_velocity=v_WB[3:6], + base_linear_velocity=W_v_WB[0:3], + base_angular_velocity=W_v_WB[3:6], joint_velocities=joint_velocities, velocity_representation=velocity_representation, base_transform=W_H_B, joint_transforms=joint_transforms, link_transforms=link_transforms, - link_velocities=link_velocities, + link_velocities=link_velocities_inertial, ) if not model_data.valid(model=model): @@ -189,14 +191,14 @@ def build( @staticmethod def zero( model: js.model.JaxSimModel, - velocity_representation: VelRepr = VelRepr.Inertial, + velocity_representation: VelRepr = VelRepr.Mixed, ) -> JaxSimModelData: """ Create a `JaxSimModelData` object with zero state. Args: model: The model for which to create the state. - velocity_representation: The velocity representation to use. + velocity_representation: The velocity representation to use. It defaults to mixed if not provided. Returns: A `JaxSimModelData` initialized with zero state. @@ -603,8 +605,8 @@ def update_cached(self, model: js.model.JaxSimModel) -> JaxSimModelData: base_quaternion=self.base_quaternion, joint_positions=self.joint_positions, joint_velocities=self.joint_velocities, - base_linear_velocity=self.base_linear_velocity, - base_angular_velocity=self.base_angular_velocity, + base_linear_velocity_inertial=self.base_linear_velocity, + base_angular_velocity_inertial=self.base_angular_velocity, ) return self.replace( From a6cdd58d99cf6dc8c7ab855401f9e07d9adc99f1 Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Thu, 30 Jan 2025 12:56:21 +0100 Subject: [PATCH 3/7] Restore `step` function to accept link forces in the same reprensentation of data --- src/jaxsim/api/model.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 3610e26f6..f6129da40 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -1989,7 +1989,7 @@ def step( model: JaxSimModel, data: js.data.JaxSimModelData, *, - link_forces_inertial: jtp.MatrixLike | None = None, + link_forces: jtp.MatrixLike | None = None, joint_force_references: jtp.VectorLike | None = None, ) -> js.data.JaxSimModelData: """ @@ -1999,8 +1999,8 @@ def step( model: The model to consider. data: The data of the considered model. dt: The time step to consider. If not specified, it is read from the model. - link_forces_inertial: - The 6D forces to apply to the links expressed in inertial-representation. + link_forces: + The 6D forces to apply to the links expressed in same representation of data. joint_force_references: The joint force references to consider. Returns: @@ -2016,11 +2016,22 @@ def step( # the enabled collidable points # Extract the inputs - W_f_L_external = jnp.atleast_2d( - jnp.array(link_forces_inertial, dtype=float).squeeze() - if link_forces_inertial is not None + O_f_L_external = jnp.atleast_2d( + jnp.array(link_forces, dtype=float).squeeze() + if link_forces is not None else jnp.zeros((model.number_of_links(), 6)) ) + + # Get the external forces in inertial-fixed representation. + W_f_L_external = jax.vmap( + lambda f_L, W_H_L: js.data.JaxSimModelData.other_representation_to_inertial( + f_L, + other_representation=data.velocity_representation, + transform=W_H_L, + is_force=True, + ) + )(O_f_L_external, data.link_transforms) + τ_references = jnp.atleast_1d( jnp.array(joint_force_references, dtype=float).squeeze() if joint_force_references is not None From 9a9c1e80f8f86615a0abf584dc2ff145444eaa16 Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Thu, 30 Jan 2025 12:59:44 +0100 Subject: [PATCH 4/7] Refactor tests to reflect changes in API --- tests/test_automatic_differentiation.py | 25 ++++++++++++++++--------- tests/test_simulations.py | 22 +++++++++++++--------- 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index d3153f2ed..5d9e57ec2 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -229,14 +229,21 @@ def test_ad_fk( # ==== # Get a closure exposing only the parameters to be differentiated. - fk = lambda W_p_B, W_Q_B, s, W_v_lin, W_v_ang, ṡ: jaxsim.rbda.forward_kinematics_model( - model=model, - base_position=W_p_B, - base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B), - joint_positions=s, - base_linear_velocity=W_v_lin, - base_angular_velocity=W_v_ang, - joint_velocities=ṡ, + fk = ( + lambda W_p_B, + W_Q_B, + s, + W_v_lin, + W_v_ang, + ṡ: jaxsim.rbda.forward_kinematics_model( + model=model, + base_position=W_p_B, + base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B), + joint_positions=s, + base_linear_velocity_inertial=W_v_lin, + base_angular_velocity_inertial=W_v_ang, + joint_velocities=ṡ, + ) ) # Check derivatives against finite differences. @@ -344,7 +351,7 @@ def step( model=model, data=data_x0, joint_force_references=τ, - link_forces_inertial=W_f_L, + link_forces=W_f_L, ) xf_W_p_B = data_xf.base_position diff --git a/tests/test_simulations.py b/tests/test_simulations.py index e56c2c1cf..62d490a98 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -74,7 +74,7 @@ def test_box_with_external_forces( data = js.model.step( model=model, data=data, - link_forces_inertial=references._link_forces, + link_forces=references.link_forces(model, data), ) # Check that the box didn't move. @@ -84,6 +84,7 @@ def test_box_with_external_forces( def test_box_with_zero_gravity( jaxsim_model_box: js.model.JaxSimModel, + velocity_representation: VelRepr, prng_key: jnp.ndarray, ): @@ -101,14 +102,14 @@ def test_box_with_zero_gravity( data0 = js.data.JaxSimModelData.build( model=model, base_position=jax.random.uniform(subkey, shape=(3,)), - velocity_representation=jaxsim.VelRepr.Inertial, + velocity_representation=velocity_representation, ) # Initialize a references object that simplifies handling external forces. references = js.references.JaxSimModelReferences.build( model=model, data=data0, - velocity_representation=jaxsim.VelRepr.Inertial, + velocity_representation=velocity_representation, ) # Apply a link forces to the base link. @@ -144,12 +145,15 @@ def test_box_with_zero_gravity( # ... and step the simulation. for _ in T: - - data = js.model.step( - model=model, - data=data, - link_forces_inertial=references.link_forces(model=model, data=data), - ) + with ( + data.switch_velocity_representation(velocity_representation), + references.switch_velocity_representation(velocity_representation), + ): + data = js.model.step( + model=model, + data=data, + link_forces=references.link_forces(model=model, data=data), + ) # Check that the box moved as expected. assert data.base_position == pytest.approx( From 1abdfe539e645f6783a2aa5fe2dcda99e24b2ee8 Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Thu, 30 Jan 2025 13:04:53 +0100 Subject: [PATCH 5/7] Remove redundant `system_velocity_dynamics` function and update `system_dynamics` to use `system_acceleration` directly --- src/jaxsim/api/ode.py | 54 +------------------------------------------ 1 file changed, 1 insertion(+), 53 deletions(-) diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index a7b96ec72..8870098ae 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -1,5 +1,3 @@ -from typing import Any - import jax import jax.numpy as jnp @@ -15,56 +13,6 @@ # ================================== -@jax.jit -@js.common.named_scope -def system_velocity_dynamics( - model: js.model.JaxSimModel, - data: js.data.JaxSimModelData, - *, - link_forces: jtp.Vector | None = None, - joint_torques: jtp.Vector | None = None, -) -> tuple[jtp.Vector, jtp.Vector, dict[str, Any]]: - """ - Compute the dynamics of the system velocity. - - Args: - model: The model to consider. - data: The data of the considered model. - link_forces: - The 6D forces to apply to the links expressed in inertial-fixed representation. - joint_torques: The joint torques acting on the joints. - - Returns: - A tuple containing the derivative of the base 6D velocity in inertial-fixed - representation, the derivative of the joint velocities, and auxiliary data - returned by the system dynamics evaluation. - """ - - # Build link forces if not provided. - # These forces are expressed in the frame corresponding to the velocity - # representation of data. - W_f_L = ( - jnp.atleast_2d(link_forces.squeeze()) - if link_forces is not None - else jnp.zeros((model.number_of_links(), 6)) - ).astype(float) - - # =========================== - # Compute system acceleration - # =========================== - - # Compute the system acceleration in inertial-fixed representation. - # This representation is useful for integration purpose. - W_v̇_WB, s̈ = system_acceleration( - model=model, - data=data, - joint_torques=joint_torques, - link_forces=W_f_L, - ) - - return W_v̇_WB, s̈ - - def system_acceleration( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, @@ -197,7 +145,7 @@ def system_dynamics( """ with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial): - W_v̇_WB, s̈ = system_velocity_dynamics( + W_v̇_WB, s̈ = system_acceleration( model=model, data=data, joint_torques=joint_torques, From f7c073ea64cfa57ffe7df75716a9a2fdf2215f74 Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Thu, 30 Jan 2025 13:05:15 +0100 Subject: [PATCH 6/7] Update `step` function to use `system_acceleration` instead of `system_velocity_dynamics` --- src/jaxsim/api/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index f6129da40..b831da27d 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -2074,7 +2074,7 @@ def step( # =============================== with data.switch_velocity_representation(jaxsim.VelRepr.Inertial): - W_v̇_WB, s̈ = js.ode.system_velocity_dynamics( + W_v̇_WB, s̈ = js.ode.system_acceleration( model=model, data=data, link_forces=W_f_L_total, From ecbf1502f75194a7691f6e447db121e6cd983369 Mon Sep 17 00:00:00 2001 From: Alessandro Croci Date: Thu, 30 Jan 2025 13:17:35 +0100 Subject: [PATCH 7/7] Format `test_automatic_differentiation.py` --- tests/test_automatic_differentiation.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index 5d9e57ec2..682e163b1 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -229,21 +229,14 @@ def test_ad_fk( # ==== # Get a closure exposing only the parameters to be differentiated. - fk = ( - lambda W_p_B, - W_Q_B, - s, - W_v_lin, - W_v_ang, - ṡ: jaxsim.rbda.forward_kinematics_model( - model=model, - base_position=W_p_B, - base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B), - joint_positions=s, - base_linear_velocity_inertial=W_v_lin, - base_angular_velocity_inertial=W_v_ang, - joint_velocities=ṡ, - ) + fk = lambda W_p_B, W_Q_B, s, W_v_lin, W_v_ang, ṡ: jaxsim.rbda.forward_kinematics_model( + model=model, + base_position=W_p_B, + base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B), + joint_positions=s, + base_linear_velocity_inertial=W_v_lin, + base_angular_velocity_inertial=W_v_ang, + joint_velocities=ṡ, ) # Check derivatives against finite differences.