diff --git a/src/jaxsim/api/contact.py b/src/jaxsim/api/contact.py index 95655be0f..f230fd021 100644 --- a/src/jaxsim/api/contact.py +++ b/src/jaxsim/api/contact.py @@ -98,6 +98,7 @@ def collidable_point_forces( data: js.data.JaxSimModelData, link_forces: jtp.MatrixLike | None = None, joint_force_references: jtp.VectorLike | None = None, + **kwargs, ) -> jtp.Matrix: """ Compute the 6D forces applied to each collidable point. @@ -110,6 +111,7 @@ def collidable_point_forces( representation of data. joint_force_references: The joint force references to apply to the joints. + kwargs: Additional keyword arguments to pass to the active contact model. Returns: The 6D forces applied to each collidable point expressed in the frame @@ -121,6 +123,7 @@ def collidable_point_forces( data=data, link_forces=link_forces, joint_force_references=joint_force_references, + **kwargs, ) return f_Ci @@ -132,7 +135,8 @@ def collidable_point_dynamics( data: js.data.JaxSimModelData, link_forces: jtp.MatrixLike | None = None, joint_force_references: jtp.VectorLike | None = None, -) -> tuple[jtp.Matrix, dict[str, jtp.Array]]: + **kwargs, +) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: r""" Compute the 6D force applied to each collidable point. @@ -144,6 +148,7 @@ def collidable_point_dynamics( representation of data. joint_force_references: The joint force references to apply to the joints. + kwargs: Additional keyword arguments to pass to the active contact model. Returns: The 6D force applied to each collidable point and additional data based @@ -158,86 +163,46 @@ def collidable_point_dynamics( Instead, the 6D forces are returned in the active representation. """ - # Build the soft contact model. + # Build the common kw arguments to pass to the computation of the contact forces. + common_kwargs = dict( + link_forces=link_forces, + joint_force_references=joint_force_references, + ) + + # Build the additional kwargs to pass to the computation of the contact forces. match model.contact_model: case contacts.SoftContacts(): - assert isinstance(model.contact_model, contacts.SoftContacts) - # Compute the 6D force expressed in the inertial frame and applied to each - # collidable point, and the corresponding material deformation rate. - # Note that the material deformation rate is always returned in the mixed frame - # C[W] = (W_p_C, [W]). This is convenient for integration purpose. - W_f_Ci, (CW_ṁ,) = model.contact_model.compute_contact_forces( - model=model, data=data - ) - - # Create the dictionary of auxiliary data. - # This contact model considers the material deformation as additional state - # of the ODE system. We need to pass its dynamics to the integrator. - aux_data = dict(m_dot=CW_ṁ) + kwargs_contact_model = {} case contacts.RigidContacts(): - assert isinstance(model.contact_model, contacts.RigidContacts) - # Compute the 6D force expressed in the inertial frame and applied to each - # collidable point. - W_f_Ci, _ = model.contact_model.compute_contact_forces( - model=model, - data=data, - link_forces=link_forces, - joint_force_references=joint_force_references, - ) - - aux_data = dict() + kwargs_contact_model = common_kwargs | kwargs case contacts.RelaxedRigidContacts(): - assert isinstance(model.contact_model, contacts.RelaxedRigidContacts) - # Compute the 6D force expressed in the inertial frame and applied to each - # collidable point. - W_f_Ci, _ = model.contact_model.compute_contact_forces( - model=model, - data=data, - link_forces=link_forces, - joint_force_references=joint_force_references, - ) - - aux_data = dict() + kwargs_contact_model = common_kwargs | kwargs case contacts.ViscoElasticContacts(): - assert isinstance(model.contact_model, contacts.ViscoElasticContacts) - # It is not yet clear how to pass the time step to this stage. - # A possibility is to restrict the integrator to only forward Euler - # and store the Δt inside the model. - module = jaxsim.rbda.contacts.visco_elastic.step.__module__ - name = jaxsim.rbda.contacts.visco_elastic.step.__name__ - msg = "You need to use the custom '{}.{}' function with this contact model." - jaxsim.exceptions.raise_runtime_error_if( - condition=True, msg=msg.format(module, name) - ) - - # Compute the 6D force expressed in the inertial frame and applied to each - # collidable point. - W_f_Ci, (W_f̿_Ci, m_tf) = model.contact_model.compute_contact_forces( - model=model, - data=data, - dt=None, # TODO - link_forces=link_forces, - joint_force_references=joint_force_references, - ) - - aux_data = dict(W_f_avg2_C=W_f̿_Ci, m_tf=m_tf) + kwargs_contact_model = common_kwargs | dict(dt=model.time_step) | kwargs case _: - raise ValueError(f"Invalid contact model {model.contact_model}") + raise ValueError(f"Invalid contact model: {model.contact_model}") + + # Compute the contact forces with the active contact model. + W_f_C, aux_data = model.contact_model.compute_contact_forces( + model=model, + data=data, + **kwargs_contact_model, + ) # Compute the transforms of the implicit frames `C[L] = (W_p_C, [L])` # associated to each collidable point. # In inertial-fixed representation, the computation of these transforms # is not necessary and the conversion below becomes a no-op. - W_H_Ci = ( + W_H_C = ( js.contact.transforms(model=model, data=data) if data.velocity_representation is not VelRepr.Inertial else jnp.zeros( @@ -253,7 +218,7 @@ def collidable_point_dynamics( transform=W_H_C, is_force=True, ) - )(W_f_Ci, W_H_Ci) + )(W_f_C, W_H_C) return f_Ci, aux_data @@ -392,11 +357,13 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float: max_penetration=max_δ, number_of_active_collidable_points_steady_state=nc, damping_ratio=damping_ratio, - **dict( - p=model.contact_model.parameters.p, - q=model.contact_model.parameters.q, - ) - | kwargs, + **( + dict( + p=model.contact_model.parameters.p, + q=model.contact_model.parameters.q, + ) + | kwargs + ), ) case contacts.ViscoElasticContacts(): @@ -410,11 +377,13 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float: max_penetration=max_δ, number_of_active_collidable_points_steady_state=nc, damping_ratio=damping_ratio, - **dict( - p=model.contact_model.parameters.p, - q=model.contact_model.parameters.q, - ) - | kwargs, + **( + dict( + p=model.contact_model.parameters.p, + q=model.contact_model.parameters.q, + ) + | kwargs + ), ) ) @@ -427,11 +396,13 @@ def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float: parameters = contacts.RigidContactsParams.build( mu=static_friction_coefficient, - **dict( - K=K, - D=2 * jnp.sqrt(K), - ) - | kwargs, + **( + dict( + K=K, + D=2 * jnp.sqrt(K), + ) + | kwargs + ), ) case contacts.RelaxedRigidContacts(): diff --git a/src/jaxsim/api/model.py b/src/jaxsim/api/model.py index 0c9e63f62..044a2b4d3 100644 --- a/src/jaxsim/api/model.py +++ b/src/jaxsim/api/model.py @@ -1770,8 +1770,10 @@ def body_to_other_representation( def link_contact_forces( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, + *, link_forces: jtp.MatrixLike | None = None, joint_force_references: jtp.VectorLike | None = None, + **kwargs, ) -> jtp.Matrix: """ Compute the 6D contact forces of all links of the model. @@ -1784,6 +1786,7 @@ def link_contact_forces( representation of data. joint_force_references: The joint force references to apply to the joints. + kwargs: Additional keyword arguments to pass to the active contact model.. Returns: A `(nL, 6)` array containing the stacked 6D contact forces of the links, @@ -1820,47 +1823,16 @@ def link_contact_forces( joint_force_references=joint_force_references, ) - # Compute the 6D forces applied to each collidable point expressed in the - # inertial frame. - with ( - data.switch_velocity_representation(VelRepr.Inertial), - input_references.switch_velocity_representation(VelRepr.Inertial), - ): - W_f_C = js.contact.collidable_point_forces( - model=model, - data=data, - link_forces=input_references.link_forces(), - joint_force_references=input_references.joint_force_references(), - ) - - # Construct the vector defining the parent link index of each collidable point. - # We use this vector to sum the 6D forces of all collidable points rigidly - # attached to the same link. - parent_link_index_of_collidable_points = jnp.array( - model.kin_dyn_parameters.contact_parameters.body, dtype=int - ) - - # Create the mask that associate each collidable point to their parent link. - # We use this mask to sum the collidable points to the right link. - mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange( - model.number_of_links() - ) - - # Sum the forces of all collidable points rigidly attached to a body. - # Since the contact forces W_f_C are expressed in the world frame, - # we don't need any coordinate transformation. - W_f_L = mask.T @ W_f_C - - # Create a references object to store the link forces. - references = js.references.JaxSimModelReferences.build( - model=model, link_forces=W_f_L, velocity_representation=VelRepr.Inertial + # Compute the 6D forces applied to the links equivalent to the forces applied + # to the frames associated to the collidable points. + f_L, _ = model.contact_model.compute_link_contact_forces( + model=model, + data=data, + link_forces=input_references.link_forces(model=model, data=data), + joint_force_references=input_references.joint_force_references(), + **kwargs, ) - # Use the references object to convert the link forces to the velocity - # representation of data. - with references.switch_velocity_representation(data.velocity_representation): - f_L = references.link_forces(model=model, data=data) - return f_L @@ -1967,6 +1939,11 @@ def step( Returns: A tuple containing the new data of the model and the new state of the integrator. + + Note: + In order to reduce the occurrences of frame conversions performed internally, + it is recommended to use inertial-fixed velocity representation. This can be + particularly useful for automatically differentiated logic. """ # Extract the integrator kwargs. @@ -1976,15 +1953,61 @@ def step( integrator_kwargs = kwargs.pop("integrator_kwargs", {}) integrator_kwargs = kwargs | integrator_kwargs - integrator_state = integrator_state if integrator_state is not None else dict() + # Initialize the integrator state. + integrator_state_t0 = integrator_state if integrator_state is not None else dict() # Initialize the time-related variables. state_t0 = data.state t0 = jnp.array(t0, dtype=float) dt = jnp.array(dt if dt is not None else model.time_step).astype(float) - # Rename the integrator state. - integrator_state_t0 = integrator_state + # The visco-elastic contacts operate at best with their own integrator. + # They can be used with Euler-like integrators, paying the price of ignoring + # some of the benefits of continuous-time integration on the system position. + # Furthermore, the requirement to know the Δt used by the integrator is not + # compatible with high-order integrators, that use advanced RK stages to evaluate + # the dynamics at intermediate times. + module = jaxsim.rbda.contacts.visco_elastic.step.__module__ + name = jaxsim.rbda.contacts.visco_elastic.step.__name__ + msg = "You need to use the custom '{}.{}' function with this contact model." + jaxsim.exceptions.raise_runtime_error_if( + condition=( + isinstance(model.contact_model, jaxsim.rbda.contacts.ViscoElasticContacts) + & ( + ~jnp.allclose(dt, model.time_step) + | ~isinstance(integrator, jaxsim.integrators.fixed_step.ForwardEuler) + ) + ), + msg=msg.format(module, name), + ) + + # ================= + # Phase 1: pre-step + # ================= + + # TODO: some contact models here may want to perform a dynamic filtering of + # the enabled collidable points. + + # Build the references object. + # We assume that the link forces are expressed in the frame corresponding to the + # velocity representation of the data. + references = js.references.JaxSimModelReferences.build( + model=model, + data=data, + velocity_representation=data.velocity_representation, + link_forces=link_forces, + joint_force_references=joint_force_references, + ) + + # ============= + # Phase 2: step + # ============= + + # Prepare the references to pass. + with references.switch_velocity_representation(data.velocity_representation): + + f_L = references.link_forces(model=model, data=data) + τ_references = references.joint_force_references(model=model) # Step the dynamics forward. state_tf, integrator_state_tf = integrator.step( @@ -1994,7 +2017,7 @@ def step( params=integrator_state_t0, # Always inject the current (model, data) pair into the system dynamics # considered by the integrator, and include the input variables represented - # by the pair (joint_force_references, link_forces). + # by the pair (f_L, τ_references). # Note that the wrapper of the system dynamics will override (state_x0, t0) # inside the passed data even if it is not strictly needed. This logic is # necessary to re-use the jit-compiled step function of compatible pytrees @@ -2003,8 +2026,8 @@ def step( dict( model=model, data=data, - joint_force_references=joint_force_references, - link_forces=link_forces, + link_forces=f_L, + joint_force_references=τ_references, ) | integrator_kwargs ), @@ -2013,6 +2036,10 @@ def step( # Store the new state of the model. data_tf = data.replace(state=state_tf) + # ================== + # Phase 3: post-step + # ================== + # Post process the simulation state, if needed. match model.contact_model: @@ -2040,17 +2067,18 @@ def step( msg="Baumgarte stabilization is not supported with ForwardEuler integrators", ) + W_p_C = js.contact.collidable_point_positions(model, data_tf) + + # Compute the penetration depth of the collidable points. + δ, *_ = jax.vmap( + jaxsim.rbda.contacts.common.compute_penetration_data, + in_axes=(0, 0, None), + )(W_p_C, jnp.zeros_like(W_p_C), model.terrain) + with data_tf.switch_velocity_representation(VelRepr.Mixed): J_WC = js.contact.jacobian(model, data_tf) M = js.model.free_floating_mass_matrix(model, data_tf) - W_p_C = js.contact.collidable_point_positions(model, data_tf) - - # Compute the penetration depth of the collidable points. - δ, *_ = jax.vmap( - jaxsim.rbda.contacts.common.compute_penetration_data, - in_axes=(0, 0, None), - )(W_p_C, jnp.zeros_like(W_p_C), model.terrain) # Compute the impact velocity. # It may be discontinuous in case new contacts are made. @@ -2063,13 +2091,13 @@ def step( ) ) - # Reset the generalized velocity. - data_tf = data_tf.reset_base_velocity(BW_nu_post_impact[0:6]) - data_tf = data_tf.reset_joint_velocities(BW_nu_post_impact[6:]) + # Reset the generalized velocity. + data_tf = data_tf.reset_base_velocity(BW_nu_post_impact[0:6]) + data_tf = data_tf.reset_joint_velocities(BW_nu_post_impact[6:]) - # Restore the input velocity representation. - data_tf = data_tf.replace( - velocity_representation=data.velocity_representation, validate=False - ) + # Restore the input velocity representation. + data_tf = data_tf.replace( + velocity_representation=data.velocity_representation, validate=False + ) return data_tf, integrator_state_tf diff --git a/src/jaxsim/api/ode.py b/src/jaxsim/api/ode.py index 5dc0da8bd..ea7616f18 100644 --- a/src/jaxsim/api/ode.py +++ b/src/jaxsim/api/ode.py @@ -131,7 +131,7 @@ def system_velocity_dynamics( # Initialize the 6D forces W_f ∈ ℝ^{n_L × 6} applied to links due to contact # with the terrain. - W_f_Li_terrain = jnp.zeros_like(O_f_L).astype(float) + W_f_L_terrain = jnp.zeros_like(O_f_L).astype(float) # Initialize a dictionary of auxiliary data. # This dictionary is used to store additional data computed by the contact model. @@ -139,66 +139,59 @@ def system_velocity_dynamics( if len(model.kin_dyn_parameters.contact_parameters.body) > 0: - # Note: the following code should be kept in sync with the function - # `jaxsim.api.model.link_contact_forces`. We cannot merge them since - # here we need to get also aux_data. - - # Compute the 6D forces W_f ∈ ℝ^{n_c × 6} applied to each collidable point - # along with contact-specific auxiliary states. with ( data.switch_velocity_representation(VelRepr.Inertial), references.switch_velocity_representation(VelRepr.Inertial), ): - W_f_Ci, aux_data = js.contact.collidable_point_dynamics( + + # Compute the 6D forces W_f ∈ ℝ^{n_c × 6} applied to each collidable point + # along with contact-specific auxiliary states. + W_f_C, aux_data = js.contact.collidable_point_dynamics( model=model, data=data, link_forces=references.link_forces(model=model, data=data), joint_force_references=references.joint_force_references(model=model), ) - # Construct the vector defining the parent link index of each collidable point. - # We use this vector to sum the 6D forces of all collidable points rigidly - # attached to the same link. - parent_link_index_of_collidable_points = jnp.array( - model.kin_dyn_parameters.contact_parameters.body, dtype=int - ) - - # Sum the forces of all collidable points rigidly attached to a body. - # Since the contact forces W_f_Ci are expressed in the world frame, - # we don't need any coordinate transformation. - mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange( - model.number_of_links() - ) - - W_f_Li_terrain = mask.T @ W_f_Ci + # Compute the 6D forces applied to the links equivalent to the forces applied + # to the frames associated to the collidable points. + W_f_L_terrain = model.contact_model.link_forces_from_contact_forces( + model=model, + data=data, + contact_forces=W_f_C, + ) # =========================== # Compute system acceleration # =========================== - # Compute the total link forces + # Compute the total link forces. with ( data.switch_velocity_representation(VelRepr.Inertial), references.switch_velocity_representation(VelRepr.Inertial), ): + + # Sum the contact forces just computed with the link forces applied by the user. references = references.apply_link_forces( model=model, data=data, - forces=W_f_Li_terrain, + forces=W_f_L_terrain, additive=True, ) - # Get the link forces in inertial representation + # Get the link forces in inertial-fixed representation. f_L_total = references.link_forces(model=model, data=data) - v̇_WB, s̈ = system_acceleration( + # Compute the system acceleration in inertial-fixed representation. + # This representation is useful for integration purpose. + W_v̇_WB, s̈ = system_acceleration( model=model, data=data, joint_force_references=joint_force_references, link_forces=f_L_total, ) - return v̇_WB, s̈, aux_data + return W_v̇_WB, s̈, aux_data def system_acceleration( @@ -390,17 +383,15 @@ def system_dynamics( case contacts.ViscoElasticContacts(): - extended_ode_state["contacts_state"] = { - "tangential_deformation": jnp.zeros_like( - data.state.extended["tangential_deformation"] - ) - } + extended_ode_state["tangential_deformation"] = jnp.zeros_like( + data.state.extended["tangential_deformation"] + ) case contacts.RigidContacts() | contacts.RelaxedRigidContacts(): pass case _: - raise ValueError(f"Invalid contact model {model.contact_model}") + raise ValueError(f"Invalid contact model: {model.contact_model}") # Extract the velocities. W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics( diff --git a/src/jaxsim/rbda/contacts/common.py b/src/jaxsim/rbda/contacts/common.py index e6892704a..517ecb483 100644 --- a/src/jaxsim/rbda/contacts/common.py +++ b/src/jaxsim/rbda/contacts/common.py @@ -2,7 +2,6 @@ import abc import functools -from typing import Any import jax import jax.numpy as jnp @@ -10,6 +9,7 @@ import jaxsim.api as js import jaxsim.terrain import jaxsim.typing as jtp +from jaxsim.api.common import ModelDataWithVelocityRepresentation from jaxsim.utils import JaxsimDataclass try: @@ -131,7 +131,7 @@ def compute_contact_forces( model: js.model.JaxSimModel, data: js.data.JaxSimModelData, **kwargs, - ) -> tuple[jtp.Matrix, tuple[Any, ...]]: + ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: """ Compute the contact forces. @@ -142,11 +142,145 @@ def compute_contact_forces( Returns: A tuple containing as first element the computed 6D contact force applied to the contact points and expressed in the world frame, and as second element - a tuple of optional additional information. + a dictionary of optional additional information. """ pass + def compute_link_contact_forces( + self, + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + **kwargs, + ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: + """ + Compute the link contact forces. + + Args: + model: The robot model considered by the contact model. + data: The data of the considered model. + + Returns: + A tuple containing as first element the 6D contact force applied to the + links and expressed in the frame of the velocity representation of data, + and as second element a dictionary of optional additional information. + """ + + # Compute the contact forces expressed in the inertial frame. + # This function, contrarily to `compute_contact_forces`, already handles how + # the optional kwargs should be passed to the specific contact models. + W_f_C, aux_dict = js.contact.collidable_point_dynamics( + model=model, data=data, **kwargs + ) + + # Compute the 6D forces applied to the links equivalent to the forces applied + # to the frames associated to the collidable points. + with data.switch_velocity_representation(jaxsim.VelRepr.Inertial): + + W_f_L = self.link_forces_from_contact_forces( + model=model, data=data, contact_forces=W_f_C + ) + + # Store the link forces in the references object for easy conversion. + references = js.references.JaxSimModelReferences.build( + model=model, + data=data, + link_forces=W_f_L, + velocity_representation=jaxsim.VelRepr.Inertial, + ) + + # Convert the link forces to the frame corresponding to the velocity + # representation of data. + with references.switch_velocity_representation(data.velocity_representation): + f_L = references.link_forces(model=model, data=data) + + return f_L, aux_dict + + @staticmethod + def link_forces_from_contact_forces( + model: js.model.JaxSimModel, + data: js.data.JaxSimModelData, + *, + contact_forces: jtp.MatrixLike, + ) -> jtp.Matrix: + """ + Compute the link forces from the contact forces. + + Args: + model: The robot model considered by the contact model. + data: The data of the considered model. + contact_forces: The contact forces computed by the contact model. + + Returns: + The 6D contact forces applied to the links and expressed in the frame of + the velocity representation of data. + """ + + # Convert the contact forces to a JAX array. + f_C = jnp.atleast_2d(jnp.array(contact_forces, dtype=float).squeeze()) + + # Get the pose of the enabled collidable points. + W_H_C = js.contact.transforms(model=model, data=data) + + # Convert the contact forces to inertial-fixed representation. + W_f_C = jax.vmap( + lambda f_C, W_H_C: ( + ModelDataWithVelocityRepresentation.other_representation_to_inertial( + array=f_C, + other_representation=data.velocity_representation, + transform=W_H_C, + is_force=True, + ) + ) + )(f_C, W_H_C) + + # Get the object storing the contact parameters of the model. + contact_parameters = model.kin_dyn_parameters.contact_parameters + + # Extract the indices corresponding to the enabled collidable points. + indices_of_enabled_collidable_points = ( + contact_parameters.indices_of_enabled_collidable_points + ) + + # Construct the vector defining the parent link index of each collidable point. + # We use this vector to sum the 6D forces of all collidable points rigidly + # attached to the same link. + parent_link_index_of_collidable_points = jnp.array( + contact_parameters.body, dtype=int + )[indices_of_enabled_collidable_points] + + # Create the mask that associate each collidable point to their parent link. + # We use this mask to sum the collidable points to the right link. + mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange( + model.number_of_links() + ) + + # Sum the forces of all collidable points rigidly attached to a body. + # Since the contact forces W_f_C are expressed in the world frame, + # we don't need any coordinate transformation. + W_f_L = mask.T @ W_f_C + + # Compute the link transforms. + W_H_L = ( + js.model.forward_kinematics(model=model, data=data) + if data.velocity_representation is not jaxsim.VelRepr.Inertial + else jnp.zeros(shape=(model.number_of_links(), 4, 4)) + ) + + # Convert the inertial-fixed link forces to the velocity representation of data. + f_L = jax.vmap( + lambda W_f_L, W_H_L: ( + ModelDataWithVelocityRepresentation.inertial_to_other_representation( + array=W_f_L, + other_representation=data.velocity_representation, + transform=W_H_L, + is_force=True, + ) + ) + )(W_f_L, W_H_L) + + return f_L + @classmethod def zero_state_variables(cls, model: js.model.JaxSimModel) -> dict[str, jtp.Array]: """ diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index 25fb35b5e..ee58790d6 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -120,19 +120,44 @@ def default(name: str): return cls( time_constant=jnp.array( - time_constant or default("time_constant"), dtype=float + ( + time_constant + if time_constant is not None + else default("time_constant") + ), + dtype=float, ), damping_coefficient=jnp.array( - damping_coefficient or default("damping_coefficient"), dtype=float + ( + damping_coefficient + if damping_coefficient is not None + else default("damping_coefficient") + ), + dtype=float, + ), + d_min=jnp.array( + d_min if d_min is not None else default("d_min"), dtype=float + ), + d_max=jnp.array( + d_max if d_max is not None else default("d_max"), dtype=float + ), + width=jnp.array( + width if width is not None else default("width"), dtype=float + ), + midpoint=jnp.array( + midpoint if midpoint is not None else default("midpoint"), dtype=float ), - d_min=jnp.array(d_min or default("d_min"), dtype=float), - d_max=jnp.array(d_max or default("d_max"), dtype=float), - width=jnp.array(width or default("width"), dtype=float), - midpoint=jnp.array(midpoint or default("midpoint"), dtype=float), - power=jnp.array(power or default("power"), dtype=float), - stiffness=jnp.array(stiffness or default("stiffness"), dtype=float), - damping=jnp.array(damping or default("damping"), dtype=float), - mu=jnp.array(mu or default("mu"), dtype=float), + power=jnp.array( + power if power is not None else default("power"), dtype=float + ), + stiffness=jnp.array( + stiffness if stiffness is not None else default("stiffness"), + dtype=float, + ), + damping=jnp.array( + damping if damping is not None else default("damping"), dtype=float + ), + mu=jnp.array(mu if mu is not None else default("mu"), dtype=float), ) def valid(self) -> jtp.BoolLike: @@ -210,7 +235,9 @@ def build( # Create the solver options to set by combining the default solver options # with the user-provided solver options. - solver_options = default_solver_options | (solver_options or {}) + solver_options = default_solver_options | ( + solver_options if solver_options is not None else {} + ) # Make sure that the solver options are hashable. # We need to check this because the solver options are static. @@ -223,9 +250,15 @@ def build( return cls( parameters=( - parameters or cls.__dataclass_fields__["parameters"].default_factory() + parameters + if parameters is not None + else cls.__dataclass_fields__["parameters"].default_factory() + ), + terrain=( + terrain + if terrain is not None + else cls.__dataclass_fields__["terrain"].default_factory() ), - terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(), _solver_options_keys=tuple(solver_options.keys()), _solver_options_values=tuple(solver_options.values()), ) @@ -238,7 +271,7 @@ def compute_contact_forces( *, link_forces: jtp.MatrixLike | None = None, joint_force_references: jtp.VectorLike | None = None, - ) -> tuple[jtp.Matrix, tuple]: + ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: """ Compute the contact forces. @@ -458,7 +491,7 @@ def continuing_criterion(carry: OptimizationCarry) -> jtp.Bool: ), )(CW_fl_C, W_H_C) - return W_f_C, () + return W_f_C, {} @staticmethod def _regularizers( diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index bdbbb2937..220d65722 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -66,9 +66,17 @@ def build( """Create a `RigidContactParams` instance""" return cls( - mu=mu or cls.__dataclass_fields__["mu"].default, - K=K or cls.__dataclass_fields__["K"].default, - D=D or cls.__dataclass_fields__["D"].default, + mu=jnp.array( + mu + if mu is not None + else cls.__dataclass_fields__["mu"].default_factory() + ).astype(float), + K=jnp.array( + K if K is not None else cls.__dataclass_fields__["K"].default_factory() + ).astype(float), + D=jnp.array( + D if D is not None else cls.__dataclass_fields__["D"].default_factory() + ).astype(float), ) def valid(self) -> jtp.BoolLike: @@ -147,7 +155,9 @@ def build( # Create the solver options to set by combining the default solver options # with the user-provided solver options. - solver_options = default_solver_options | (solver_options or {}) + solver_options = default_solver_options | ( + solver_options if solver_options is not None else {} + ) # Make sure that the solver options are hashable. # We need to check this because the solver options are static. @@ -160,12 +170,19 @@ def build( return cls( parameters=( - parameters or cls.__dataclass_fields__["parameters"].default_factory() + parameters + if parameters is not None + else cls.__dataclass_fields__["parameters"].default_factory() + ), + terrain=( + terrain + if terrain is not None + else cls.__dataclass_fields__["terrain"].default_factory() ), - terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(), regularization_delassus=float( regularization_delassus - or cls.__dataclass_fields__["regularization_delassus"].default + if regularization_delassus is not None + else cls.__dataclass_fields__["regularization_delassus"].default ), _solver_options_keys=tuple(solver_options.keys()), _solver_options_values=tuple(solver_options.values()), @@ -242,7 +259,7 @@ def compute_contact_forces( *, link_forces: jtp.MatrixLike | None = None, joint_force_references: jtp.VectorLike | None = None, - ) -> tuple[jtp.Matrix, tuple]: + ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: """ Compute the contact forces. @@ -402,7 +419,7 @@ def compute_contact_forces( ), )(CW_fl_C, W_H_C) - return W_f_C, () + return W_f_C, {} @staticmethod def _delassus_matrix( diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index 4af693527..e726df379 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -237,9 +237,13 @@ def build( else cls.__dataclass_fields__["parameters"].default_factory() ) - return SoftContacts( + return cls( parameters=parameters, - terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(), + terrain=( + terrain + if terrain is not None + else cls.__dataclass_fields__["terrain"].default_factory() + ), ) @classmethod @@ -423,7 +427,7 @@ def compute_contact_forces( self, model: js.model.JaxSimModel, data: js.data.JaxSimModelData, - ) -> tuple[jtp.Matrix, tuple[jtp.Matrix]]: + ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: """ Compute the contact forces. @@ -433,7 +437,7 @@ def compute_contact_forces( Returns: A tuple containing as first element the computed contact forces, and as - second element the derivative of the material deformation. + second element a dictionary with derivative of the material deformation. """ # Initialize the model and data this contact model is operating on. @@ -460,4 +464,4 @@ def compute_contact_forces( ) )(W_p_C, W_ṗ_C, m) - return W_f, (ṁ,) + return W_f, dict(m_dot=ṁ) diff --git a/src/jaxsim/rbda/contacts/visco_elastic.py b/src/jaxsim/rbda/contacts/visco_elastic.py index 25019fcfc..490122f98 100644 --- a/src/jaxsim/rbda/contacts/visco_elastic.py +++ b/src/jaxsim/rbda/contacts/visco_elastic.py @@ -13,6 +13,7 @@ import jaxsim.exceptions import jaxsim.typing as jtp from jaxsim import logging +from jaxsim.api.common import ModelDataWithVelocityRepresentation from jaxsim.math import StandardGravity from jaxsim.terrain import FlatTerrain, Terrain @@ -235,11 +236,17 @@ def build( else cls.__dataclass_fields__["parameters"].default_factory() ) - return ViscoElasticContacts( + return cls( parameters=parameters, - terrain=terrain or cls.__dataclass_fields__["terrain"].default_factory(), + terrain=( + terrain + if terrain is not None + else cls.__dataclass_fields__["terrain"].default_factory() + ), max_squarings=int( - max_squarings or cls.__dataclass_fields__["max_squarings"].default + max_squarings + if max_squarings is not None + else cls.__dataclass_fields__["max_squarings"].default ), ) @@ -266,7 +273,7 @@ def compute_contact_forces( dt: jtp.FloatLike | None = None, link_forces: jtp.MatrixLike | None = None, joint_force_references: jtp.VectorLike | None = None, - ) -> tuple[jtp.Matrix, tuple[jtp.Matrix, jtp.Matrix]]: + ) -> tuple[jtp.Matrix, dict[str, jtp.PyTree]]: """ Compute the contact forces. @@ -291,7 +298,7 @@ def compute_contact_forces( Returns: A tuple containing as first element the computed 6D contact force applied to the contact point and expressed in the world frame, and as second element - a tuple of optional additional information. + a dictionary of optional additional information. """ # Initialize the model and data this contact model is operating on. @@ -315,8 +322,8 @@ def compute_contact_forces( model=model, data=data, dt=jnp.array(dt).astype(float), - joint_force_references=joint_force_references, link_forces=link_forces, + joint_force_references=joint_force_references, indices_of_enabled_collidable_points=indices_of_enabled_collidable_points, max_squarings=self.max_squarings, ) @@ -334,11 +341,13 @@ def compute_contact_forces( # Vmapped transformation from mixed to inertial-fixed representation. compute_forces_inertial_fixed_vmap = jax.vmap( - lambda CW_fl_C, W_H_C: data.other_representation_to_inertial( - array=jnp.zeros(6).at[0:3].set(CW_fl_C), - other_representation=jaxsim.VelRepr.Mixed, - transform=W_H_C, - is_force=True, + lambda CW_fl_C, W_H_C: ( + ModelDataWithVelocityRepresentation.other_representation_to_inertial( + array=jnp.zeros(6).at[0:3].set(CW_fl_C), + other_representation=jaxsim.VelRepr.Mixed, + transform=W_H_C, + is_force=True, + ) ) ) @@ -347,7 +356,7 @@ def compute_contact_forces( lambda CW_fl: compute_forces_inertial_fixed_vmap(CW_fl, W_H_C) )(jnp.stack([CW_f̅l, CW_fl̿])) - return W_f̅_C, (W_f̿_C, m_tf) + return W_f̅_C, dict(W_f_avg2_C=W_f̿_C, m_tf=m_tf) @staticmethod @functools.partial(jax.jit, static_argnames=("max_squarings",)) @@ -407,8 +416,8 @@ def _compute_contact_forces_with_exponential_integration( A, b, A_sc, b_sc = ViscoElasticContacts._contact_points_dynamics( model=model, data=data, - joint_force_references=joint_force_references, link_forces=link_forces, + joint_force_references=joint_force_references, indices_of_enabled_collidable_points=indices, p_t0=p_t0, v_t0=v_t0, @@ -657,8 +666,8 @@ def _contact_points_dynamics( BW_v̇_free_WB, s̈_free = js.ode.system_acceleration( model=model, data=data, - joint_force_references=references.joint_force_references(model=model), link_forces=references.link_forces(model=model, data=data), + joint_force_references=references.joint_force_references(model=model), ) # Pack the free system acceleration in mixed representation. @@ -688,7 +697,20 @@ def _linearize_contact_model( parameters: ViscoElasticContactsParams, terrain: Terrain, ) -> tuple[jtp.Matrix, jtp.Vector]: - """""" + """ + Linearize the Hunt/Crossley contact model at the initial state. + + Args: + position: The position of the contact point. + velocity: The velocity of the contact point. + tangential_deformation: The tangential deformation of the contact point. + parameters: The parameters of the contact model. + terrain: The considered terrain. + + Returns: + A tuple containing the `A` matrix and the `b` vector of the linear system + corresponding to the contact dynamics linearized at the initial state. + """ # Initialize the state at which the model is linearized. p0 = jnp.array(position, dtype=float).squeeze() @@ -969,58 +991,67 @@ def step( assert isinstance(model.contact_model, ViscoElasticContacts) assert isinstance(data.contacts_params, ViscoElasticContactsParams) + # Compute the contact forces in inertial-fixed representation. + # TODO: understand what's wrong in other representations. + data_inertial_fixed = data.replace( + velocity_representation=jaxsim.VelRepr.Inertial, validate=False + ) + + # Create the references object. + references = js.references.JaxSimModelReferences.build( + model=model, + data=data, + link_forces=link_forces, + joint_force_references=joint_force_references, + velocity_representation=data.velocity_representation, + ) + # Initialize the time step. dt = dt if dt is not None else model.time_step # Compute the contact forces with the exponential integrator. - W_f̅_C, (W_f̿_C, m_tf) = model.contact_model.compute_contact_forces( + W_f̅_C, aux_data = model.contact_model.compute_contact_forces( model=model, - data=data, + data=data_inertial_fixed, dt=jnp.array(dt).astype(float), - link_forces=link_forces, - joint_force_references=joint_force_references, + link_forces=references.link_forces(model=model, data=data), + joint_force_references=references.joint_force_references(model=model), ) + # Extract the final material deformation and the average of average forces + # from the dictionary containing auxiliary data. + m_tf = aux_data["m_tf"] + W_f̿_C = aux_data["W_f_avg2_C"] + # =============================== # Compute the link contact forces # =============================== - # Extract the indices corresponding to the enabled collidable points. - # The visco-elastic contact model computed only their contact forces. - indices_of_enabled_collidable_points = ( - model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points - ) + # Get the link contact forces by summing the forces of contact points belonging + # to the same link. + W_f̅_L, W_f̿_L = jax.vmap( + lambda W_f_C: model.contact_model.link_forces_from_contact_forces( + model=model, data=data_inertial_fixed, contact_forces=W_f_C + ) + )(jnp.stack([W_f̅_C, W_f̿_C])) # Compute the link transforms. - W_H_L = js.model.forward_kinematics(model=model, data=data) - - # Construct the vector defining the parent link index of each collidable point. - # We use this vector to sum the 6D forces of all collidable points rigidly - # attached to the same link. - parent_link_index_of_collidable_points = jnp.array( - model.kin_dyn_parameters.contact_parameters.body, dtype=int - )[indices_of_enabled_collidable_points] - - # Create the mask that associate each collidable point to their parent link. - # We use this mask to sum the collidable points to the right link. - mask = parent_link_index_of_collidable_points[:, jnp.newaxis] == jnp.arange( - model.number_of_links() + W_H_L = ( + js.model.forward_kinematics(model=model, data=data) + if data.velocity_representation is not jaxsim.VelRepr.Inertial + else jnp.zeros(shape=(model.number_of_links(), 4, 4)) ) - # Sum the forces of all collidable points rigidly attached to a body. - # Since the contact forces W_f_C are expressed in the world frame, - # we don't need any coordinate transformation. - W_f̅_L = mask.T @ W_f̅_C - W_f̿_L = mask.T @ W_f̿_C - - # For integration purpose, we need these average of averages expressed in + # For integration purpose, we need the average of average forces expressed in # mixed representation. LW_f̿_L = jax.vmap( - lambda W_f_L, W_H_L: data.inertial_to_other_representation( - array=W_f_L, - other_representation=jaxsim.VelRepr.Mixed, - transform=W_H_L, - is_force=True, + lambda W_f_L, W_H_L: ( + ModelDataWithVelocityRepresentation.inertial_to_other_representation( + array=W_f_L, + other_representation=jaxsim.VelRepr.Mixed, + transform=W_H_L, + is_force=True, + ) ) )(W_f̿_L, W_H_L) @@ -1032,10 +1063,10 @@ def step( data_tf: js.data.JaxSimModelData = ( model.contact_model.integrate_data_with_average_contact_forces( model=model, - data=data, + data=data_inertial_fixed, dt=dt, - link_forces=link_forces, - joint_force_references=joint_force_references, + link_forces=references.link_forces(model=model, data=data), + joint_force_references=references.joint_force_references(model=model), average_link_contact_forces_inertial=W_f̅_L, average_of_average_link_contact_forces_mixed=LW_f̿_L, ) @@ -1046,10 +1077,21 @@ def step( # be much more accurate than the one computed with the discrete soft contacts. with data_tf.mutable_context(): + # Extract the indices corresponding to the enabled collidable points. + # The visco-elastic contact model computed only their contact forces. + indices_of_enabled_collidable_points = ( + model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points + ) + data_tf.state.extended |= { "tangential_deformation": data_tf.state.extended["tangential_deformation"] .at[indices_of_enabled_collidable_points] .set(m_tf) } + # Restore the original velocity representation. + data_tf = data_tf.replace( + velocity_representation=data.velocity_representation, validate=False + ) + return data_tf, {}