From a1c5770ceacef819c7e76da9b6afb6ebaa7d7da3 Mon Sep 17 00:00:00 2001 From: Alessandro Croci <57228872+xela-95@users.noreply.github.com> Date: Fri, 31 Jan 2025 16:28:06 +0100 Subject: [PATCH] [Sprint] Set mixed as default representation in `data.build` (#361) * Rename base linear and angular velocity parameters in `forward_kinematics_model` * Update JaxSimModelData to use Mixed as default repr. in `build` function * Restore `step` function to accept link forces in the same reprensentation of data * Refactor tests to reflect changes in API * Remove redundant `system_velocity_dynamics` function and update `system_dynamics` to use `system_acceleration` directly * Update `step` function to use `system_acceleration` instead of `system_velocity_dynamics` * Format `test_automatic_differentiation.py` --- src/jaxsim/api/data.py | 40 +++++++++--------- src/jaxsim/api/model.py | 25 ++++++++---- src/jaxsim/api/ode.py | 54 +------------------------ src/jaxsim/rbda/forward_kinematics.py | 12 +++--- tests/test_automatic_differentiation.py | 6 +-- tests/test_simulations.py | 22 +++++----- 6 files changed, 62 insertions(+), 97 deletions(-) diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index cf64af93a..fb538a392 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -41,7 +41,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation): base_transform: The base transform. joint_transforms: The joint transforms. link_transforms: The link transforms. - link_velocities: The link velocities. + link_velocities: The link velocities in inertial-fixed representation. """ # Joint state @@ -69,7 +69,7 @@ def build( base_linear_velocity: jtp.VectorLike | None = None, base_angular_velocity: jtp.VectorLike | None = None, joint_velocities: jtp.VectorLike | None = None, - velocity_representation: VelRepr = VelRepr.Inertial, + velocity_representation: VelRepr = VelRepr.Mixed, ) -> JaxSimModelData: """ Create a `JaxSimModelData` object with the given state. @@ -84,7 +84,7 @@ def build( base_angular_velocity: The base angular velocity in the selected representation. joint_velocities: The joint velocities. - velocity_representation: The velocity representation to use. + velocity_representation: The velocity representation to use. It defaults to mixed if not provided. Returns: A `JaxSimModelData` initialized with the given state. @@ -144,7 +144,7 @@ def build( translation=base_position, quaternion=base_quaternion ) - v_WB = JaxSimModelData.other_representation_to_inertial( + W_v_WB = JaxSimModelData.other_representation_to_inertial( array=jnp.hstack([base_linear_velocity, base_angular_velocity]), other_representation=velocity_representation, transform=W_H_B, @@ -155,28 +155,30 @@ def build( joint_positions=joint_positions, base_transform=W_H_B ) - link_transforms, link_velocities = jaxsim.rbda.forward_kinematics_model( - model=model, - base_position=base_position, - base_quaternion=base_quaternion, - joint_positions=joint_positions, - base_linear_velocity=v_WB[0:3], - base_angular_velocity=v_WB[3:6], - joint_velocities=joint_velocities, + link_transforms, link_velocities_inertial = ( + jaxsim.rbda.forward_kinematics_model( + model=model, + base_position=base_position, + base_quaternion=base_quaternion, + joint_positions=joint_positions, + base_linear_velocity_inertial=W_v_WB[0:3], + base_angular_velocity_inertial=W_v_WB[3:6], + joint_velocities=joint_velocities, + ) ) model_data = JaxSimModelData( base_quaternion=base_quaternion, base_position=base_position, joint_positions=joint_positions, - base_linear_velocity=v_WB[0:3], - base_angular_velocity=v_WB[3:6], + base_linear_velocity=W_v_WB[0:3], + base_angular_velocity=W_v_WB[3:6], joint_velocities=joint_velocities, velocity_representation=velocity_representation, base_transform=W_H_B, joint_transforms=joint_transforms, link_transforms=link_transforms, - link_velocities=link_velocities, + link_velocities=link_velocities_inertial, ) if not model_data.valid(model=model): @@ -189,14 +191,14 @@ def build( @staticmethod def zero( model: js.model.JaxSimModel, - velocity_representation: VelRepr = VelRepr.Inertial, + velocity_representation: VelRepr = VelRepr.Mixed, ) -> JaxSimModelData: """ Create a `JaxSimModelData` object with zero state. Args: model: The model for which to create the state. - velocity_representation: The velocity representation to use. + velocity_representation: The velocity representation to use. It defaults to mixed if not provided. Returns: A `JaxSimModelData` initialized with zero state. @@ -603,8 +605,8 @@ def update_cached(self, model: js.model.JaxSimModel) -> JaxSimModelData: base_quaternion=self.base_quaternion, joint_positions=self.joint_positions, joint_velocities=self.joint_velocities, - base_linear_velocity=self.base_linear_velocity, - base_angular_velocity=self.base_angular_velocity, + base_linear_velocity_inertial=self.base_linear_velocity, + base_angular_velocity_inertial=self.base_angular_velocity, ) return self.replace( diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 3610e26f6..b831da27d 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -1989,7 +1989,7 @@ def step( model: JaxSimModel, data: js.data.JaxSimModelData, *, - link_forces_inertial: jtp.MatrixLike | None = None, + link_forces: jtp.MatrixLike | None = None, joint_force_references: jtp.VectorLike | None = None, ) -> js.data.JaxSimModelData: """ @@ -1999,8 +1999,8 @@ def step( model: The model to consider. data: The data of the considered model. dt: The time step to consider. If not specified, it is read from the model. - link_forces_inertial: - The 6D forces to apply to the links expressed in inertial-representation. + link_forces: + The 6D forces to apply to the links expressed in same representation of data. joint_force_references: The joint force references to consider. Returns: @@ -2016,11 +2016,22 @@ def step( # the enabled collidable points # Extract the inputs - W_f_L_external = jnp.atleast_2d( - jnp.array(link_forces_inertial, dtype=float).squeeze() - if link_forces_inertial is not None + O_f_L_external = jnp.atleast_2d( + jnp.array(link_forces, dtype=float).squeeze() + if link_forces is not None else jnp.zeros((model.number_of_links(), 6)) ) + + # Get the external forces in inertial-fixed representation. + W_f_L_external = jax.vmap( + lambda f_L, W_H_L: js.data.JaxSimModelData.other_representation_to_inertial( + f_L, + other_representation=data.velocity_representation, + transform=W_H_L, + is_force=True, + ) + )(O_f_L_external, data.link_transforms) + τ_references = jnp.atleast_1d( jnp.array(joint_force_references, dtype=float).squeeze() if joint_force_references is not None @@ -2063,7 +2074,7 @@ def step( # =============================== with data.switch_velocity_representation(jaxsim.VelRepr.Inertial): - W_v̇_WB, s̈ = js.ode.system_velocity_dynamics( + W_v̇_WB, s̈ = js.ode.system_acceleration( model=model, data=data, link_forces=W_f_L_total, diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index a7b96ec72..8870098ae 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -1,5 +1,3 @@ -from typing import Any - import jax import jax.numpy as jnp @@ -15,56 +13,6 @@ # ================================== -@jax.jit -@js.common.named_scope -def system_velocity_dynamics( - model: js.model.JaxSimModel, - data: js.data.JaxSimModelData, - *, - link_forces: jtp.Vector | None = None, - joint_torques: jtp.Vector | None = None, -) -> tuple[jtp.Vector, jtp.Vector, dict[str, Any]]: - """ - Compute the dynamics of the system velocity. - - Args: - model: The model to consider. - data: The data of the considered model. - link_forces: - The 6D forces to apply to the links expressed in inertial-fixed representation. - joint_torques: The joint torques acting on the joints. - - Returns: - A tuple containing the derivative of the base 6D velocity in inertial-fixed - representation, the derivative of the joint velocities, and auxiliary data - returned by the system dynamics evaluation. - """ - - # Build link forces if not provided. - # These forces are expressed in the frame corresponding to the velocity - # representation of data. - W_f_L = ( - jnp.atleast_2d(link_forces.squeeze()) - if link_forces is not None - else jnp.zeros((model.number_of_links(), 6)) - ).astype(float) - - # =========================== - # Compute system acceleration - # =========================== - - # Compute the system acceleration in inertial-fixed representation. - # This representation is useful for integration purpose. - W_v̇_WB, s̈ = system_acceleration( - model=model, - data=data, - joint_torques=joint_torques, - link_forces=W_f_L, - ) - - return W_v̇_WB, s̈ - - def system_acceleration( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, @@ -197,7 +145,7 @@ def system_dynamics( """ with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial): - W_v̇_WB, s̈ = system_velocity_dynamics( + W_v̇_WB, s̈ = system_acceleration( model=model, data=data, joint_torques=joint_torques, diff --git a/src/jaxsim/rbda/forward_kinematics.py b/src/jaxsim/rbda/forward_kinematics.py index 5f9c2348b..355fe347c 100644 --- a/src/jaxsim/rbda/forward_kinematics.py +++ b/src/jaxsim/rbda/forward_kinematics.py @@ -15,8 +15,8 @@ def forward_kinematics_model( base_position: jtp.VectorLike, base_quaternion: jtp.VectorLike, joint_positions: jtp.VectorLike, - base_linear_velocity: jtp.VectorLike, - base_angular_velocity: jtp.VectorLike, + base_linear_velocity_inertial: jtp.VectorLike, + base_angular_velocity_inertial: jtp.VectorLike, joint_velocities: jtp.VectorLike, ) -> jtp.Array: """ @@ -27,8 +27,8 @@ def forward_kinematics_model( base_position: The position of the base link. base_quaternion: The quaternion of the base link. joint_positions: The positions of the joints. - base_linear_velocity: The linear velocity of the base link. - base_angular_velocity: The angular velocity of the base link. + base_linear_velocity_inertial: The linear velocity of the base link in inertial-fixed representation. + base_angular_velocity_inertial: The angular velocity of the base link in inertial-fixed representation. joint_velocities: The velocities of the joints. Returns: @@ -40,8 +40,8 @@ def forward_kinematics_model( base_position=base_position, base_quaternion=base_quaternion, joint_positions=joint_positions, - base_linear_velocity=base_linear_velocity, - base_angular_velocity=base_angular_velocity, + base_linear_velocity=base_linear_velocity_inertial, + base_angular_velocity=base_angular_velocity_inertial, joint_velocities=joint_velocities, ) diff --git a/tests/test_automatic_differentiation.py b/tests/test_automatic_differentiation.py index d3153f2ed..682e163b1 100644 --- a/tests/test_automatic_differentiation.py +++ b/tests/test_automatic_differentiation.py @@ -234,8 +234,8 @@ def test_ad_fk( base_position=W_p_B, base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B), joint_positions=s, - base_linear_velocity=W_v_lin, - base_angular_velocity=W_v_ang, + base_linear_velocity_inertial=W_v_lin, + base_angular_velocity_inertial=W_v_ang, joint_velocities=ṡ, ) @@ -344,7 +344,7 @@ def step( model=model, data=data_x0, joint_force_references=τ, - link_forces_inertial=W_f_L, + link_forces=W_f_L, ) xf_W_p_B = data_xf.base_position diff --git a/tests/test_simulations.py b/tests/test_simulations.py index e56c2c1cf..62d490a98 100644 --- a/tests/test_simulations.py +++ b/tests/test_simulations.py @@ -74,7 +74,7 @@ def test_box_with_external_forces( data = js.model.step( model=model, data=data, - link_forces_inertial=references._link_forces, + link_forces=references.link_forces(model, data), ) # Check that the box didn't move. @@ -84,6 +84,7 @@ def test_box_with_external_forces( def test_box_with_zero_gravity( jaxsim_model_box: js.model.JaxSimModel, + velocity_representation: VelRepr, prng_key: jnp.ndarray, ): @@ -101,14 +102,14 @@ def test_box_with_zero_gravity( data0 = js.data.JaxSimModelData.build( model=model, base_position=jax.random.uniform(subkey, shape=(3,)), - velocity_representation=jaxsim.VelRepr.Inertial, + velocity_representation=velocity_representation, ) # Initialize a references object that simplifies handling external forces. references = js.references.JaxSimModelReferences.build( model=model, data=data0, - velocity_representation=jaxsim.VelRepr.Inertial, + velocity_representation=velocity_representation, ) # Apply a link forces to the base link. @@ -144,12 +145,15 @@ def test_box_with_zero_gravity( # ... and step the simulation. for _ in T: - - data = js.model.step( - model=model, - data=data, - link_forces_inertial=references.link_forces(model=model, data=data), - ) + with ( + data.switch_velocity_representation(velocity_representation), + references.switch_velocity_representation(velocity_representation), + ): + data = js.model.step( + model=model, + data=data, + link_forces=references.link_forces(model=model, data=data), + ) # Check that the box moved as expected. assert data.base_position == pytest.approx(