diff --git a/src/jaxsim/api/integrators.py b/src/jaxsim/api/integrators.py index ef50b5be6..e88ad0a89 100644 --- a/src/jaxsim/api/integrators.py +++ b/src/jaxsim/api/integrators.py @@ -14,16 +14,21 @@ def semi_implicit_euler_integration( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, - base_acceleration_inertial: jtp.Vector, - joint_accelerations: jtp.Vector, + link_forces: jtp.Vector, + joint_torques: jtp.Vector, ) -> JaxSimModelData: """Integrate the system state using the semi-implicit Euler method.""" # Step the dynamics forward. with data.switch_velocity_representation(jaxsim.VelRepr.Inertial): + W_v̇_WB, s̈ = js.ode.system_acceleration( + model=model, + data=data, + link_forces=link_forces, + joint_torques=joint_torques, + ) + dt = model.time_step - W_v̇_WB = base_acceleration_inertial - s̈ = joint_accelerations B_H_W = Transform.inverse(data._base_transform).at[:3, :3].set(jnp.eye(3)) BW_X_W = Adjoint.from_transform(B_H_W) @@ -81,8 +86,6 @@ def semi_implicit_euler_integration( def rk4_integration( model: js.model.JaxSimModel, data: JaxSimModelData, - base_acceleration_inertial: jtp.Vector, - joint_accelerations: jtp.Vector, link_forces: jtp.Vector, joint_torques: jtp.Vector, ) -> JaxSimModelData: diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index f9bb48fc7..2cba64e45 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -2063,41 +2063,6 @@ def step( model, data, joint_force_references=τ_references ) - # ====================== - # Compute contact forces - # ====================== - - W_f_L_terrain = jnp.zeros_like(W_f_L_external) - - if len(model.kin_dyn_parameters.contact_parameters.body) > 0: - - # Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact - # with the terrain. - W_f_L_terrain = js.contact_model.link_contact_forces( - model=model, - data=data, - link_forces=W_f_L_external, - joint_torques=τ_total, - ) - - # ============================== - # Compute the total link forces - # ============================== - - W_f_L_total = W_f_L_external + W_f_L_terrain - - # =============================== - # Compute the system acceleration - # =============================== - - with data.switch_velocity_representation(jaxsim.VelRepr.Inertial): - W_v̇_WB, s̈ = js.ode.system_acceleration( - model=model, - data=data, - link_forces=W_f_L_total, - joint_torques=τ_total, - ) - # ============================= # Advance the simulation state # ============================= @@ -2108,14 +2073,8 @@ def step( data_tf = integrator_fn( model=model, data=data, - base_acceleration_inertial=W_v̇_WB, - joint_accelerations=s̈, - # Pass link_forces and joint_torques if the integrator is rk4 - **( - {"link_forces": W_f_L_total, "joint_torques": τ_total} - if model.integrator == IntegratorType.RungeKutta4 - else {} - ), + link_forces=W_f_L_external, + joint_torques=τ_total, ) return data_tf diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index 2f61e1d43..32cee904e 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -35,6 +35,21 @@ def system_acceleration( and the joint accelerations. """ + # ====================== + # Compute contact forces + # ====================== + + if len(model.kin_dyn_parameters.contact_parameters.body) > 0: + + # Compute the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact + # with the terrain. + W_f_L_terrain = js.contact_model.link_contact_forces( + model=model, + data=data, + link_forces=link_forces, + joint_torques=joint_torques, + ) + # ==================== # Validate input data # ==================== @@ -54,6 +69,9 @@ def system_acceleration( link_forces=f_L, ) + with references.switch_velocity_representation(VelRepr.Inertial): + W_f_L_total = W_f_L_terrain + references.link_forces(model=model, data=data) + # Compute forward dynamics. # # - Joint accelerations: s̈ ∈ ℝⁿ @@ -61,12 +79,13 @@ def system_acceleration( # # Note that ABA returns the base acceleration in the velocity representation # stored in the `data` object. - v̇_WB, s̈ = js.model.forward_dynamics_aba( - model=model, - data=data, - joint_forces=joint_torques, - link_forces=references.link_forces(model=model, data=data), - ) + with data.switch_velocity_representation(VelRepr.Inertial): + v̇_WB, s̈ = js.model.forward_dynamics_aba( + model=model, + data=data, + joint_forces=joint_torques, + link_forces=W_f_L_total, + ) return v̇_WB, s̈ diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index 9d28dc00c..be88808ec 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -314,11 +314,11 @@ def compute_contact_forces( BW_ν = data.generalized_velocity BW_ν̇_free = jnp.hstack( - js.ode.system_acceleration( + js.model.forward_dynamics_aba( model=model, data=data, link_forces=references.link_forces(model=model, data=data), - joint_torques=references.joint_force_references(model=model), + joint_forces=references.joint_force_references(model=model), ) )