Skip to content

Commit

Permalink
Remove redundant system_velocity_dynamics function and update `syst…
Browse files Browse the repository at this point in the history
…em_dynamics` to use `system_acceleration` directly
  • Loading branch information
xela-95 committed Jan 30, 2025
1 parent 9a9c1e8 commit 1abdfe5
Showing 1 changed file with 1 addition and 53 deletions.
54 changes: 1 addition & 53 deletions src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Any

import jax
import jax.numpy as jnp

Expand All @@ -15,56 +13,6 @@
# ==================================


@jax.jit
@js.common.named_scope
def system_velocity_dynamics(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
link_forces: jtp.Vector | None = None,
joint_torques: jtp.Vector | None = None,
) -> tuple[jtp.Vector, jtp.Vector, dict[str, Any]]:
"""
Compute the dynamics of the system velocity.
Args:
model: The model to consider.
data: The data of the considered model.
link_forces:
The 6D forces to apply to the links expressed in inertial-fixed representation.
joint_torques: The joint torques acting on the joints.
Returns:
A tuple containing the derivative of the base 6D velocity in inertial-fixed
representation, the derivative of the joint velocities, and auxiliary data
returned by the system dynamics evaluation.
"""

# Build link forces if not provided.
# These forces are expressed in the frame corresponding to the velocity
# representation of data.
W_f_L = (
jnp.atleast_2d(link_forces.squeeze())
if link_forces is not None
else jnp.zeros((model.number_of_links(), 6))
).astype(float)

# ===========================
# Compute system acceleration
# ===========================

# Compute the system acceleration in inertial-fixed representation.
# This representation is useful for integration purpose.
W_v̇_WB, = system_acceleration(
model=model,
data=data,
joint_torques=joint_torques,
link_forces=W_f_L,
)

return W_v̇_WB,


def system_acceleration(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
Expand Down Expand Up @@ -197,7 +145,7 @@ def system_dynamics(
"""

with data.switch_velocity_representation(velocity_representation=VelRepr.Inertial):
W_v̇_WB, = system_velocity_dynamics(
W_v̇_WB, = system_acceleration(
model=model,
data=data,
joint_torques=joint_torques,
Expand Down

0 comments on commit 1abdfe5

Please sign in to comment.