diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index cf5181d79..58693546b 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -1907,8 +1907,8 @@ def step( dt: jtp.FloatLike, integrator: jaxsim.integrators.Integrator, integrator_state: dict[str, Any] | None = None, - joint_forces: jtp.VectorLike | None = None, link_forces: jtp.MatrixLike | None = None, + joint_force_references: jtp.VectorLike | None = None, **kwargs, ) -> tuple[js.data.JaxSimModelData, dict[str, Any]]: """ @@ -1920,10 +1920,10 @@ def step( dt: The time step to consider. integrator: The integrator to use. integrator_state: The state of the integrator. - joint_forces: The joint forces to consider. link_forces: The 6D forces to apply to the links expressed in the frame corresponding to the velocity representation of `data`. + joint_force_references: The joint force references to consider. kwargs: Additional kwargs to pass to the integrator. Returns: @@ -1962,7 +1962,7 @@ def step( dict( model=model, data=data, - joint_forces=joint_forces, + joint_force_references=joint_force_references, link_forces=link_forces, ) | integrator_kwargs diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index 33e95a22b..d2177b379 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -86,8 +86,8 @@ def system_velocity_dynamics( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, - joint_forces: jtp.Vector | None = None, link_forces: jtp.Vector | None = None, + joint_force_references: jtp.Vector | None = None, ) -> tuple[jtp.Vector, jtp.Vector, dict[str, Any]]: """ Compute the dynamics of the system velocity. @@ -95,10 +95,10 @@ def system_velocity_dynamics( Args: model: The model to consider. data: The data of the considered model. - joint_forces: The joint force references to apply. link_forces: The 6D forces to apply to the links expressed in the frame corresponding to the velocity representation of `data`. + joint_force_references: The joint force references to apply. Returns: A tuple containing the derivative of the base 6D velocity in inertial-fixed @@ -120,7 +120,7 @@ def system_velocity_dynamics( references = js.references.JaxSimModelReferences.build( model=model, link_forces=O_f_L, - joint_force_references=joint_forces, + joint_force_references=joint_force_references, data=data, velocity_representation=data.velocity_representation, ) @@ -192,7 +192,10 @@ def system_velocity_dynamics( f_L_total = references.link_forces(model=model, data=data) v̇_WB, s̈ = system_acceleration( - model=model, data=data, joint_forces=joint_forces, link_forces=f_L_total + model=model, + data=data, + joint_force_references=joint_force_references, + link_forces=f_L_total, ) return v̇_WB, s̈, aux_data @@ -202,8 +205,8 @@ def system_acceleration( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, - joint_forces: jtp.VectorLike | None = None, link_forces: jtp.MatrixLike | None = None, + joint_force_references: jtp.VectorLike | None = None, ) -> tuple[jtp.Vector, jtp.Vector]: """ Compute the system acceleration in the active representation. @@ -211,12 +214,13 @@ def system_acceleration( Args: model: The model to consider. data: The data of the considered model. - joint_forces: The joint forces to apply. link_forces: - The 6D forces to apply to the links expressed in the same representation of data. + The 6D forces to apply to the links expressed in the same + velocity representation of data. + joint_force_references: The joint force references to apply. Returns: - A tuple containing the base 6D acceleration in in the active representation + A tuple containing the base 6D acceleration in the active representation and the joint accelerations. """ @@ -232,9 +236,9 @@ def system_acceleration( ).astype(float) # Build joint torques if not provided. - τ = ( - jnp.atleast_1d(joint_forces.squeeze()) - if joint_forces is not None + τ_references = ( + jnp.atleast_1d(joint_force_references.squeeze()) + if joint_force_references is not None else jnp.zeros_like(data.joint_positions()) ).astype(float) @@ -243,15 +247,16 @@ def system_acceleration( # ==================== # TODO: enforce joint limits - τ_position_limit = jnp.zeros_like(τ).astype(float) + τ_position_limit = jnp.zeros_like(τ_references).astype(float) # ==================== # Joint friction model # ==================== - τ_friction = jnp.zeros_like(τ).astype(float) + τ_friction = jnp.zeros_like(τ_references).astype(float) if model.dofs() > 0: + # Static and viscous joint friction parameters kc = jnp.array( model.kin_dyn_parameters.joint_parameters.friction_static @@ -271,22 +276,27 @@ def system_acceleration( # ======================== # Compute the total joint forces. - τ_total = τ + τ_friction + τ_position_limit + τ_total = τ_references + τ_friction + τ_position_limit + # Store the link forces in a references object. references = js.references.JaxSimModelReferences.build( model=model, data=data, velocity_representation=data.velocity_representation, - joint_force_references=τ_total, link_forces=f_L, ) + # Compute forward dynamics. + # # - Joint accelerations: s̈ ∈ ℝⁿ # - Base acceleration: v̇_WB ∈ ℝ⁶ + # + # Note that ABA returns the base acceleration in the velocity representation + # stored in the `data` object. v̇_WB, s̈ = js.model.forward_dynamics_aba( model=model, data=data, - joint_forces=references.joint_force_references(model=model), + joint_forces=τ_total, link_forces=references.link_forces(model=model, data=data), ) @@ -337,8 +347,8 @@ def system_dynamics( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, *, - joint_forces: jtp.Vector | None = None, link_forces: jtp.Vector | None = None, + joint_force_references: jtp.Vector | None = None, baumgarte_quaternion_regularization: jtp.FloatLike = 1.0, ) -> tuple[ODEState, dict[str, Any]]: """ @@ -347,10 +357,10 @@ def system_dynamics( Args: model: The model to consider. data: The data of the considered model. - joint_forces: The joint forces to apply. link_forces: The 6D forces to apply to the links expressed in the frame corresponding to the velocity representation of `data`. + joint_force_references: The joint force references to apply. baumgarte_quaternion_regularization: The Baumgarte regularization coefficient used to adjust the norm of the quaternion (only used in integrators not operating on the SO(3) manifold). diff --git a/src/jaxsim/api/ode_data.py b/src/jaxsim/api/ode_data.py index cb96fd14c..622e0e196 100644 --- a/src/jaxsim/api/ode_data.py +++ b/src/jaxsim/api/ode_data.py @@ -38,16 +38,16 @@ class ODEInput(JaxsimDataclass): @staticmethod def build_from_jaxsim_model( model: js.model.JaxSimModel | None = None, - joint_forces: jtp.VectorLike | None = None, link_forces: jtp.MatrixLike | None = None, + joint_force_references: jtp.VectorLike | None = None, ) -> ODEInput: """ Build an `ODEInput` from a `JaxSimModel`. Args: model: The `JaxSimModel` associated with the ODE input. - joint_forces: The vector of joint forces. link_forces: The matrix of external forces applied to the links. + joint_force_references: The vector of joint force references. Returns: The `ODEInput` built from the `JaxSimModel`. @@ -60,8 +60,8 @@ def build_from_jaxsim_model( return ODEInput.build( physics_model_input=PhysicsModelInput.build_from_jaxsim_model( model=model, - joint_forces=joint_forces, link_forces=link_forces, + joint_force_references=joint_force_references, ), model=model, ) @@ -526,16 +526,16 @@ class PhysicsModelInput(JaxsimDataclass): @staticmethod def build_from_jaxsim_model( model: js.model.JaxSimModel | None = None, - joint_forces: jtp.VectorLike | None = None, link_forces: jtp.MatrixLike | None = None, + joint_force_references: jtp.VectorLike | None = None, ) -> PhysicsModelInput: """ Build a `PhysicsModelInput` from a `JaxSimModel`. Args: model: The `JaxSimModel` associated with the input. - joint_forces: The vector of joint forces. link_forces: The matrix of external forces applied to the links. + joint_force_references: The vector of joint force references. Returns: A `PhysicsModelInput` instance. @@ -546,7 +546,7 @@ def build_from_jaxsim_model( """ return PhysicsModelInput.build( - joint_forces=joint_forces, + joint_force_references=joint_force_references, link_forces=link_forces, number_of_dofs=model.dofs(), number_of_links=model.number_of_links(), @@ -554,8 +554,8 @@ def build_from_jaxsim_model( @staticmethod def build( - joint_forces: jtp.VectorLike | None = None, link_forces: jtp.MatrixLike | None = None, + joint_force_references: jtp.VectorLike | None = None, number_of_dofs: jtp.Int | None = None, number_of_links: jtp.Int | None = None, ) -> PhysicsModelInput: @@ -563,8 +563,8 @@ def build( Build a `PhysicsModelInput`. Args: - joint_forces: The vector of joint forces. link_forces: The matrix of external forces applied to the links. + joint_force_references: The vector of joint force references. number_of_dofs: The number of degrees of freedom of the model. number_of_links: The number of links of the model. @@ -572,19 +572,21 @@ def build( A `PhysicsModelInput` instance. """ - joint_forces = ( - joint_forces if joint_forces is not None else jnp.zeros(number_of_dofs) - ) + joint_force_references = jnp.atleast_1d( + jnp.array(joint_force_references, dtype=float).squeeze() + if joint_force_references is not None + else jnp.zeros(number_of_dofs) + ).astype(float) - link_forces = ( - link_forces + link_forces = jnp.atleast_2d( + jnp.array(link_forces, dtype=float).squeeze() if link_forces is not None else jnp.zeros(shape=(number_of_links, 6)) - ) + ).astype(float) return PhysicsModelInput( - tau=jnp.array(joint_forces, dtype=float), - f_ext=jnp.array(link_forces, dtype=float), + tau=joint_force_references, + f_ext=link_forces, ) @staticmethod diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index c34aa2240..4a216d6fe 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -55,8 +55,8 @@ def zero( @staticmethod def build( model: js.model.JaxSimModel, - joint_force_references: jtp.Vector | None = None, - link_forces: jtp.Matrix | None = None, + joint_force_references: jtp.VectorLike | None = None, + link_forces: jtp.MatrixLike | None = None, data: js.data.JaxSimModelData | None = None, velocity_representation: VelRepr | None = None, ) -> JaxSimModelReferences: @@ -78,14 +78,14 @@ def build( # Create or adjust joint force references. joint_force_references = jnp.atleast_1d( - joint_force_references.squeeze() + jnp.array(joint_force_references, dtype=float).squeeze() if joint_force_references is not None else jnp.zeros(model.dofs()) ).astype(float) # Create or adjust link forces. f_L = jnp.atleast_2d( - link_forces.squeeze() + jnp.array(link_forces, dtype=float).squeeze() if link_forces is not None else jnp.zeros((model.number_of_links(), 6)) ).astype(float) @@ -299,9 +299,9 @@ def set_joint_force_references( A new `JaxSimModelReferences` object with the given joint force references. """ - forces = jnp.array(forces) + forces = jnp.atleast_1d(jnp.array(forces, dtype=float).squeeze()) - def replace(forces: jtp.VectorLike) -> JaxSimModelReferences: + def replace(forces: jtp.Vector) -> JaxSimModelReferences: return self.replace( validate=True, input=self.input.replace(