Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Sprint] Set mixed as default representation in data.build #361

Merged
merged 7 commits into from
Jan 31, 2025
40 changes: 21 additions & 19 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
base_transform: The base transform.
joint_transforms: The joint transforms.
link_transforms: The link transforms.
link_velocities: The link velocities.
link_velocities: The link velocities in inertial-fixed representation.
"""

# Joint state
Expand Down Expand Up @@ -69,7 +69,7 @@ def build(
base_linear_velocity: jtp.VectorLike | None = None,
base_angular_velocity: jtp.VectorLike | None = None,
joint_velocities: jtp.VectorLike | None = None,
velocity_representation: VelRepr = VelRepr.Inertial,
velocity_representation: VelRepr = VelRepr.Mixed,
) -> JaxSimModelData:
"""
Create a `JaxSimModelData` object with the given state.
Expand All @@ -84,7 +84,7 @@ def build(
base_angular_velocity:
The base angular velocity in the selected representation.
joint_velocities: The joint velocities.
velocity_representation: The velocity representation to use.
velocity_representation: The velocity representation to use. It defaults to mixed if not provided.

Returns:
A `JaxSimModelData` initialized with the given state.
Expand Down Expand Up @@ -144,7 +144,7 @@ def build(
translation=base_position, quaternion=base_quaternion
)

v_WB = JaxSimModelData.other_representation_to_inertial(
W_v_WB = JaxSimModelData.other_representation_to_inertial(
array=jnp.hstack([base_linear_velocity, base_angular_velocity]),
other_representation=velocity_representation,
transform=W_H_B,
Expand All @@ -155,28 +155,30 @@ def build(
joint_positions=joint_positions, base_transform=W_H_B
)

link_transforms, link_velocities = jaxsim.rbda.forward_kinematics_model(
model=model,
base_position=base_position,
base_quaternion=base_quaternion,
joint_positions=joint_positions,
base_linear_velocity=v_WB[0:3],
base_angular_velocity=v_WB[3:6],
joint_velocities=joint_velocities,
link_transforms, link_velocities_inertial = (
jaxsim.rbda.forward_kinematics_model(
model=model,
base_position=base_position,
base_quaternion=base_quaternion,
joint_positions=joint_positions,
base_linear_velocity_inertial=W_v_WB[0:3],
base_angular_velocity_inertial=W_v_WB[3:6],
joint_velocities=joint_velocities,
)
)

model_data = JaxSimModelData(
base_quaternion=base_quaternion,
base_position=base_position,
joint_positions=joint_positions,
base_linear_velocity=v_WB[0:3],
base_angular_velocity=v_WB[3:6],
base_linear_velocity=W_v_WB[0:3],
base_angular_velocity=W_v_WB[3:6],
joint_velocities=joint_velocities,
velocity_representation=velocity_representation,
base_transform=W_H_B,
joint_transforms=joint_transforms,
link_transforms=link_transforms,
link_velocities=link_velocities,
link_velocities=link_velocities_inertial,
)

if not model_data.valid(model=model):
Expand All @@ -189,14 +191,14 @@ def build(
@staticmethod
def zero(
model: js.model.JaxSimModel,
velocity_representation: VelRepr = VelRepr.Inertial,
velocity_representation: VelRepr = VelRepr.Mixed,
) -> JaxSimModelData:
"""
Create a `JaxSimModelData` object with zero state.

Args:
model: The model for which to create the state.
velocity_representation: The velocity representation to use.
velocity_representation: The velocity representation to use. It defaults to mixed if not provided.

Returns:
A `JaxSimModelData` initialized with zero state.
Expand Down Expand Up @@ -603,8 +605,8 @@ def update_cached(self, model: js.model.JaxSimModel) -> JaxSimModelData:
base_quaternion=self.base_quaternion,
joint_positions=self.joint_positions,
joint_velocities=self.joint_velocities,
base_linear_velocity=self.base_linear_velocity,
base_angular_velocity=self.base_angular_velocity,
base_linear_velocity_inertial=self.base_linear_velocity,
base_angular_velocity_inertial=self.base_angular_velocity,
)

return self.replace(
Expand Down
25 changes: 18 additions & 7 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1989,7 +1989,7 @@ def step(
model: JaxSimModel,
data: js.data.JaxSimModelData,
*,
link_forces_inertial: jtp.MatrixLike | None = None,
link_forces: jtp.MatrixLike | None = None,
joint_force_references: jtp.VectorLike | None = None,
) -> js.data.JaxSimModelData:
"""
Expand All @@ -1999,8 +1999,8 @@ def step(
model: The model to consider.
data: The data of the considered model.
dt: The time step to consider. If not specified, it is read from the model.
link_forces_inertial:
The 6D forces to apply to the links expressed in inertial-representation.
link_forces:
The 6D forces to apply to the links expressed in same representation of data.
joint_force_references: The joint force references to consider.

Returns:
Expand All @@ -2016,11 +2016,22 @@ def step(
# the enabled collidable points

# Extract the inputs
W_f_L_external = jnp.atleast_2d(
jnp.array(link_forces_inertial, dtype=float).squeeze()
if link_forces_inertial is not None
O_f_L_external = jnp.atleast_2d(
jnp.array(link_forces, dtype=float).squeeze()
if link_forces is not None
else jnp.zeros((model.number_of_links(), 6))
)

# Get the external forces in inertial-fixed representation.
W_f_L_external = jax.vmap(
lambda f_L, W_H_L: js.data.JaxSimModelData.other_representation_to_inertial(
f_L,
other_representation=data.velocity_representation,
transform=W_H_L,
is_force=True,
)
)(O_f_L_external, data.link_transforms)

τ_references = jnp.atleast_1d(
jnp.array(joint_force_references, dtype=float).squeeze()
if joint_force_references is not None
Expand Down Expand Up @@ -2063,7 +2074,7 @@ def step(
# ===============================

with data.switch_velocity_representation(jaxsim.VelRepr.Inertial):
W_v̇_WB, s̈ = js.ode.system_velocity_dynamics(
W_v̇_WB, s̈ = js.ode.system_acceleration(
model=model,
data=data,
link_forces=W_f_L_total,
Expand Down
54 changes: 1 addition & 53 deletions src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Any

import jax
import jax.numpy as jnp

Expand All @@ -15,56 +13,6 @@
# ==================================


@jax.jit
@js.common.named_scope
def system_velocity_dynamics(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
link_forces: jtp.Vector | None = None,
joint_torques: jtp.Vector | None = None,
) -> tuple[jtp.Vector, jtp.Vector, dict[str, Any]]:
"""
Compute the dynamics of the system velocity.

Args:
model: The model to consider.
data: The data of the considered model.
link_forces:
The 6D forces to apply to the links expressed in inertial-fixed representation.
joint_torques: The joint torques acting on the joints.

Returns:
A tuple containing the derivative of the base 6D velocity in inertial-fixed
representation, the derivative of the joint velocities, and auxiliary data
returned by the system dynamics evaluation.
"""

# Build link forces if not provided.
# These forces are expressed in the frame corresponding to the velocity
# representation of data.
W_f_L = (
jnp.atleast_2d(link_forces.squeeze())
if link_forces is not None
else jnp.zeros((model.number_of_links(), 6))
).astype(float)

# ===========================
# Compute system acceleration
# ===========================

# Compute the system acceleration in inertial-fixed representation.
# This representation is useful for integration purpose.
W_v̇_WB, s̈ = system_acceleration(
model=model,
data=data,
joint_torques=joint_torques,
link_forces=W_f_L,
)

return W_v̇_WB, s̈


def system_acceleration(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
Expand Down Expand Up @@ -197,7 +145,7 @@ def system_dynamics(
"""

with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):
W_v̇_WB, s̈ = system_velocity_dynamics(
W_v̇_WB, s̈ = system_acceleration(
model=model,
data=data,
joint_torques=joint_torques,
Expand Down
12 changes: 6 additions & 6 deletions src/jaxsim/rbda/forward_kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def forward_kinematics_model(
base_position: jtp.VectorLike,
base_quaternion: jtp.VectorLike,
joint_positions: jtp.VectorLike,
base_linear_velocity: jtp.VectorLike,
base_angular_velocity: jtp.VectorLike,
base_linear_velocity_inertial: jtp.VectorLike,
base_angular_velocity_inertial: jtp.VectorLike,
joint_velocities: jtp.VectorLike,
) -> jtp.Array:
"""
Expand All @@ -27,8 +27,8 @@ def forward_kinematics_model(
base_position: The position of the base link.
base_quaternion: The quaternion of the base link.
joint_positions: The positions of the joints.
base_linear_velocity: The linear velocity of the base link.
base_angular_velocity: The angular velocity of the base link.
base_linear_velocity_inertial: The linear velocity of the base link in inertial-fixed representation.
base_angular_velocity_inertial: The angular velocity of the base link in inertial-fixed representation.
joint_velocities: The velocities of the joints.

Returns:
Expand All @@ -40,8 +40,8 @@ def forward_kinematics_model(
base_position=base_position,
base_quaternion=base_quaternion,
joint_positions=joint_positions,
base_linear_velocity=base_linear_velocity,
base_angular_velocity=base_angular_velocity,
base_linear_velocity=base_linear_velocity_inertial,
base_angular_velocity=base_angular_velocity_inertial,
joint_velocities=joint_velocities,
)

Expand Down
6 changes: 3 additions & 3 deletions tests/test_automatic_differentiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ def test_ad_fk(
base_position=W_p_B,
base_quaternion=W_Q_B / jnp.linalg.norm(W_Q_B),
joint_positions=s,
base_linear_velocity=W_v_lin,
base_angular_velocity=W_v_ang,
base_linear_velocity_inertial=W_v_lin,
base_angular_velocity_inertial=W_v_ang,
joint_velocities=ṡ,
)

Expand Down Expand Up @@ -344,7 +344,7 @@ def step(
model=model,
data=data_x0,
joint_force_references=τ,
link_forces_inertial=W_f_L,
link_forces=W_f_L,
)

xf_W_p_B = data_xf.base_position
Expand Down
22 changes: 13 additions & 9 deletions tests/test_simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_box_with_external_forces(
data = js.model.step(
model=model,
data=data,
link_forces_inertial=references._link_forces,
link_forces=references.link_forces(model, data),
)

# Check that the box didn't move.
Expand All @@ -84,6 +84,7 @@ def test_box_with_external_forces(

def test_box_with_zero_gravity(
jaxsim_model_box: js.model.JaxSimModel,
velocity_representation: VelRepr,
prng_key: jnp.ndarray,
):

Expand All @@ -101,14 +102,14 @@ def test_box_with_zero_gravity(
data0 = js.data.JaxSimModelData.build(
model=model,
base_position=jax.random.uniform(subkey, shape=(3,)),
velocity_representation=jaxsim.VelRepr.Inertial,
velocity_representation=velocity_representation,
)

# Initialize a references object that simplifies handling external forces.
references = js.references.JaxSimModelReferences.build(
model=model,
data=data0,
velocity_representation=jaxsim.VelRepr.Inertial,
velocity_representation=velocity_representation,
)

# Apply a link forces to the base link.
Expand Down Expand Up @@ -144,12 +145,15 @@ def test_box_with_zero_gravity(

# ... and step the simulation.
for _ in T:

data = js.model.step(
model=model,
data=data,
link_forces_inertial=references.link_forces(model=model, data=data),
)
with (
data.switch_velocity_representation(velocity_representation),
references.switch_velocity_representation(velocity_representation),
):
data = js.model.step(
model=model,
data=data,
link_forces=references.link_forces(model=model, data=data),
)

# Check that the box moved as expected.
assert data.base_position == pytest.approx(
Expand Down
Loading