Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove *ContactsState classes #256

Merged
merged 11 commits into from
Oct 7, 2024
15 changes: 3 additions & 12 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def collidable_point_dynamics(
The joint force references to apply to the joints.

Returns:
The 6D force applied to each collidable point and additional data based on the contact model configured:
The 6D force applied to each collidable point and additional data based
on the contact model configured:
- Soft: the material deformation rate.
- Rigid: no additional data.
- QuasiRigid: no additional data.
Expand All @@ -156,21 +157,13 @@ def collidable_point_dynamics(
"""

# Import privately the contacts classes.
from jaxsim.rbda.contacts import (
RelaxedRigidContacts,
RelaxedRigidContactsState,
RigidContacts,
RigidContactsState,
SoftContacts,
SoftContactsState,
)
from jaxsim.rbda.contacts import RelaxedRigidContacts, RigidContacts, SoftContacts

# Build the soft contact model.
match model.contact_model:

case SoftContacts():
assert isinstance(model.contact_model, SoftContacts)
assert isinstance(data.state.contact, SoftContactsState)

# Compute the 6D force expressed in the inertial frame and applied to each
# collidable point, and the corresponding material deformation rate.
Expand All @@ -187,7 +180,6 @@ def collidable_point_dynamics(

case RigidContacts():
assert isinstance(model.contact_model, RigidContacts)
assert isinstance(data.state.contact, RigidContactsState)

# Compute the 6D force expressed in the inertial frame and applied to each
# collidable point.
Expand All @@ -203,7 +195,6 @@ def collidable_point_dynamics(

case RelaxedRigidContacts():
assert isinstance(model.contact_model, RelaxedRigidContacts)
assert isinstance(data.state.contact, RelaxedRigidContactsState)

# Compute the 6D force expressed in the inertial frame and applied to each
# collidable point.
Expand Down
106 changes: 62 additions & 44 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import jaxsim.math
import jaxsim.rbda
import jaxsim.typing as jtp
from jaxsim.rbda.contacts import SoftContacts
from jaxsim.utils import Mutability
from jaxsim.utils.tracing import not_tracing

Expand Down Expand Up @@ -107,17 +106,17 @@ def zero(
@staticmethod
def build(
model: js.model.JaxSimModel,
base_position: jtp.Vector | None = None,
base_quaternion: jtp.Vector | None = None,
joint_positions: jtp.Vector | None = None,
base_linear_velocity: jtp.Vector | None = None,
base_angular_velocity: jtp.Vector | None = None,
joint_velocities: jtp.Vector | None = None,
base_position: jtp.VectorLike | None = None,
base_quaternion: jtp.VectorLike | None = None,
joint_positions: jtp.VectorLike | None = None,
base_linear_velocity: jtp.VectorLike | None = None,
base_angular_velocity: jtp.VectorLike | None = None,
joint_velocities: jtp.VectorLike | None = None,
standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
contact: jaxsim.rbda.contacts.ContactsState | None = None,
contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None,
velocity_representation: VelRepr = VelRepr.Inertial,
time: jtp.FloatLike | None = None,
extended_ode_state: dict[str, jtp.PyTree] | None = None,
) -> JaxSimModelData:
"""
Create a `JaxSimModelData` object with the given state.
Expand All @@ -133,56 +132,73 @@ def build(
The base angular velocity in the selected representation.
joint_velocities: The joint velocities.
standard_gravity: The standard gravity constant.
contact: The state of the soft contacts.
contacts_params: The parameters of the soft contacts.
velocity_representation: The velocity representation to use.
time: The time at which the state is created.
extended_ode_state:
Additional user-defined state variables that are not part of the
standard `ODEState` object. Useful to extend the system dynamics
considered by default in JaxSim.

Returns:
A `JaxSimModelData` object with the given state.
A `JaxSimModelData` initialized with the given state.
"""

base_position = jnp.array(
base_position if base_position is not None else jnp.zeros(3)
base_position if base_position is not None else jnp.zeros(3),
dtype=float,
).squeeze()

base_quaternion = jnp.array(
base_quaternion
if base_quaternion is not None
else jnp.array([1.0, 0, 0, 0])
(
base_quaternion
if base_quaternion is not None
else jnp.array([1.0, 0, 0, 0])
),
dtype=float,
).squeeze()

base_linear_velocity = jnp.array(
base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3)
base_linear_velocity if base_linear_velocity is not None else jnp.zeros(3),
dtype=float,
).squeeze()

base_angular_velocity = jnp.array(
base_angular_velocity if base_angular_velocity is not None else jnp.zeros(3)
(
base_angular_velocity
if base_angular_velocity is not None
else jnp.zeros(3)
),
dtype=float,
).squeeze()

gravity = jnp.zeros(3).at[2].set(-standard_gravity)

joint_positions = jnp.atleast_1d(
joint_positions.squeeze()
if joint_positions is not None
else jnp.zeros(model.dofs())
jnp.array(
(
joint_positions
if joint_positions is not None
else jnp.zeros(model.dofs())
),
dtype=float,
).squeeze()
)

joint_velocities = jnp.atleast_1d(
joint_velocities.squeeze()
if joint_velocities is not None
else jnp.zeros(model.dofs())
jnp.array(
(
joint_velocities
if joint_velocities is not None
else jnp.zeros(model.dofs())
),
dtype=float,
).squeeze()
)

time_ns = (
jnp.array(
time * 1e9,
dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32,
)
if time is not None
else jnp.array(
0, dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32
)
time_ns = jnp.array(
time * 1e9 if time is not None else 0.0,
dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32,
)

W_H_B = jaxsim.math.Transform.from_quaternion_and_translation(
Expand All @@ -194,21 +210,22 @@ def build(
other_representation=velocity_representation,
transform=W_H_B,
is_force=False,
)
).astype(float)

ode_state = ODEState.build_from_jaxsim_model(
model=model,
base_position=base_position.astype(float),
base_quaternion=base_quaternion.astype(float),
joint_positions=joint_positions.astype(float),
base_linear_velocity=v_WB[0:3].astype(float),
base_angular_velocity=v_WB[3:6].astype(float),
joint_velocities=joint_velocities.astype(float),
tangential_deformation=(
contact.tangential_deformation
if contact is not None and isinstance(model.contact_model, SoftContacts)
else None
),
base_position=base_position,
base_quaternion=base_quaternion,
joint_positions=joint_positions,
base_linear_velocity=v_WB[0:3],
base_angular_velocity=v_WB[3:6],
joint_velocities=joint_velocities,
# Unpack all the additional ODE states. If the contact model requires an
# additional state that is not explicitly passed to this builder, ODEState
# automatically populates that state with zeroed variables.
# This is not true for any other custom state that the user might want to
# pass to the integrator.
**(extended_ode_state if extended_ode_state else {}),
)

if not ode_state.valid(model=model):
Expand All @@ -220,13 +237,14 @@ def build(
contacts_params = js.contact.estimate_good_soft_contacts_parameters(
model=model, standard_gravity=standard_gravity
)

else:
contacts_params = model.contact_model.parameters

return JaxSimModelData(
time_ns=time_ns,
state=ode_state,
gravity=gravity.astype(float),
gravity=gravity,
contacts_params=contacts_params,
velocity_representation=velocity_representation,
)
Expand Down
45 changes: 28 additions & 17 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class JaxSimModel(JaxsimDataclass):
model_name: Static[str]

terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
default=jaxsim.terrain.FlatTerrain(), repr=False
default_factory=jaxsim.terrain.FlatTerrain.build, repr=False
)

contact_model: jaxsim.rbda.contacts.ContactModel | None = dataclasses.field(
Expand Down Expand Up @@ -101,13 +101,14 @@ def build_from_model_description(
A path to an SDF/URDF file, a string containing
its content, or a pre-parsed/pre-built rod model.
model_name:
The optional name of the model that overrides the one in
the description.
terrain:
The optional terrain to consider.
The name of the model. If not specified, it is read from the description.
terrain: The terrain to consider (the default is a flat infinite plane).
contact_model:
The contact model to consider.
If not specified, a soft contacts model is used.
is_urdf:
The optional flag to force the model description to be parsed as a
URDF or a SDF. This is otherwise automatically inferred.
The optional flag to force the model description to be parsed as a URDF.
This is usually automatically inferred.
considered_joints:
The list of joints to consider. If None, all joints are considered.

Expand All @@ -120,7 +121,7 @@ def build_from_model_description(
# Parse the input resource (either a path to file or a string with the URDF/SDF)
# and build the -intermediate- model description.
intermediate_description = jaxsim.parsers.rod.build_model_description(
model_description=model_description
model_description=model_description, is_urdf=is_urdf
)

# Lump links together if not all joints are considered.
Expand Down Expand Up @@ -160,11 +161,11 @@ def build(
The intermediate model description defining the kinematics and dynamics
of the model.
model_name:
The optional name of the model overriding the physics model name.
terrain:
The optional terrain to consider.
The name of the model. If not specified, it is read from the description.
terrain: The terrain to consider (the default is a flat infinite plane).
contact_model:
The optional contact model to consider. If None, the soft contact model is used.
The contact model to consider.
If not specified, a soft contacts model is used.

Returns:
The built Model object.
Expand All @@ -173,21 +174,31 @@ def build(
# Set the model name (if not provided, use the one from the model description).
model_name = model_name if model_name is not None else model_description.name

# Set the terrain (if not provided, use the default flat terrain).
terrain = terrain or JaxSimModel.__dataclass_fields__["terrain"].default
contact_model = contact_model or jaxsim.rbda.contacts.SoftContacts(
terrain=terrain
# Consider the default terrain (a flat infinite plane) if not specified.
terrain = (
terrain or JaxSimModel.__dataclass_fields__["terrain"].default_factory()
)

# Create the default contact model.
# It will be populated with an initial estimation of good parameters.
# While these might not be the best, they are a good starting point.
contact_model = contact_model or jaxsim.rbda.contacts.SoftContacts.build(
terrain=terrain, parameters=None
)

# Build the model.
model = JaxSimModel(
model_name=model_name,
_description=wrappers.HashlessObject(obj=model_description),
kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
model_description=model_description
),
terrain=terrain,
contact_model=contact_model,
# The following is wrapped as hashless since it's a static argument, and we
# don't want to trigger recompilation if it changes. All relevant parameters
# needed to compute kinematics and dynamics quantities are stored in the
# kin_dyn_parameters attribute.
_description=wrappers.HashlessObject(obj=model_description),
)

return model
Expand Down
16 changes: 9 additions & 7 deletions src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,9 +370,8 @@ def system_dynamics(
corresponding derivative, and the dictionary of auxiliary data returned
by the system dynamics evaluation.
"""
from jaxsim.rbda.contacts.relaxed_rigid import RelaxedRigidContacts
from jaxsim.rbda.contacts.rigid import RigidContacts
from jaxsim.rbda.contacts.soft import SoftContacts

from jaxsim.rbda.contacts import RelaxedRigidContacts, RigidContacts, SoftContacts

# Compute the accelerations and the material deformation rate.
W_v̇_WB, s̈, aux_dict = system_velocity_dynamics(
Expand All @@ -382,17 +381,20 @@ def system_dynamics(
link_forces=link_forces,
)

ode_state_kwargs = {}
# Initialize the dictionary storing the derivative of the additional state variables
# that extend the state vector of the integrated ODE system.
extended_ode_state = {}

match model.contact_model:

case SoftContacts():
ode_state_kwargs["tangential_deformation"] = aux_dict["m_dot"]
extended_ode_state["tangential_deformation"] = aux_dict["m_dot"]

case RigidContacts() | RelaxedRigidContacts():
pass

case _:
raise ValueError("Unable to determine contact state class prefix.")
raise ValueError(f"Invalid contact model {model.contact_model}")

# Extract the velocities.
W_ṗ_B, W_Q̇_B, ṡ = system_position_dynamics(
Expand All @@ -412,7 +414,7 @@ def system_dynamics(
base_linear_velocity=W_v̇_WB[0:3],
base_angular_velocity=W_v̇_WB[3:6],
joint_velocities=s̈,
**ode_state_kwargs,
**extended_ode_state,
)

return ode_state_derivative, aux_dict
Loading