Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor SoftContacts algorithm #245

Merged
merged 9 commits into from
Sep 27, 2024
4 changes: 3 additions & 1 deletion src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ def collidable_point_dynamics(
# Note that the material deformation rate is always returned in the mixed frame
# C[W] = (W_p_C, [W]). This is convenient for integration purpose.
W_f_Ci, (CW_ṁ,) = jax.vmap(soft_contacts.compute_contact_forces)(
W_p_Ci, W_ṗ_Ci, data.state.contact.tangential_deformation
position=W_p_Ci,
velocity=W_ṗ_Ci,
tangential_deformation=data.state.contact.tangential_deformation,
)
aux_data = dict(m_dot=CW_ṁ)

Expand Down
8 changes: 8 additions & 0 deletions src/jaxsim/rbda/contacts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .common import ContactModel, ContactsParams, ContactsState
from .relaxed_rigid import (
RelaxedRigidContacts,
RelaxedRigidContactsParams,
RelaxedRigidContactsState,
)
from .rigid import RigidContacts, RigidContactsParams, RigidContactsState
from .soft import SoftContacts, SoftContactsParams, SoftContactsState
23 changes: 15 additions & 8 deletions src/jaxsim/rbda/contacts/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@
import jaxsim.typing as jtp
from jaxsim.utils import JaxsimDataclass

try:
from typing import Self
except ImportError:
from typing_extensions import Self

class ContactsState(abc.ABC):

class ContactsState(JaxsimDataclass):
"""
Abstract class storing the state of the contacts model.
"""

@classmethod
@abc.abstractmethod
def build(cls, **kwargs) -> ContactsState:
def build(cls: type[Self], **kwargs) -> Self:
"""
Build the contact state object.

Expand All @@ -26,7 +31,7 @@ def build(cls, **kwargs) -> ContactsState:

@classmethod
@abc.abstractmethod
def zero(cls, **kwargs) -> ContactsState:
def zero(cls: type[Self], **kwargs) -> Self:
"""
Build a zero contact state.

Expand All @@ -36,7 +41,7 @@ def zero(cls, **kwargs) -> ContactsState:
pass

@abc.abstractmethod
def valid(self, **kwargs) -> bool:
def valid(self, **kwargs) -> jtp.BoolLike:
"""
Check if the contacts state is valid.
"""
Expand All @@ -50,18 +55,20 @@ class ContactsParams(JaxsimDataclass):

@classmethod
@abc.abstractmethod
def build(cls) -> ContactsParams:
def build(cls: type[Self], **kwargs) -> Self:
"""
Create a `ContactsParams` instance with specified parameters.

Returns:
The `ContactsParams` instance.
"""
pass

@abc.abstractmethod
def valid(self, *args, **kwargs) -> bool:
def valid(self, **kwargs) -> jtp.BoolLike:
"""
Check if the parameters are valid.

Returns:
True if the parameters are valid, False otherwise.
"""
Expand All @@ -83,8 +90,8 @@ class ContactModel(JaxsimDataclass):
@abc.abstractmethod
def compute_contact_forces(
self,
position: jtp.Vector,
velocity: jtp.Vector,
position: jtp.VectorLike,
velocity: jtp.VectorLike,
**kwargs,
) -> tuple[jtp.Vector, tuple[Any, ...]]:
"""
Expand Down
34 changes: 22 additions & 12 deletions src/jaxsim/rbda/contacts/relaxed_rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@

from .common import ContactModel, ContactsParams, ContactsState

try:
from typing import Self
except ImportError:
from typing_extensions import Self


@jax_dataclasses.pytree_dataclass
class RelaxedRigidContactsParams(ContactsParams):
Expand Down Expand Up @@ -106,7 +111,8 @@ def __eq__(self, other: RelaxedRigidContactsParams) -> bool:

@classmethod
def build(
cls,
cls: type[Self],
*,
time_constant: jtp.FloatLike | None = None,
damping_coefficient: jtp.FloatLike | None = None,
d_min: jtp.FloatLike | None = None,
Expand All @@ -119,7 +125,7 @@ def build(
mu: jtp.FloatLike | None = None,
max_iterations: jtp.IntLike | None = None,
tolerance: jtp.FloatLike | None = None,
) -> RelaxedRigidContactsParams:
) -> Self:
"""Create a `RelaxedRigidContactsParams` instance"""

return cls(
Expand All @@ -132,7 +138,8 @@ def build(
}
)

def valid(self) -> bool:
def valid(self) -> jtp.BoolLike:

return bool(
jnp.all(self.time_constant >= 0.0)
and jnp.all(self.damping_coefficient > 0.0)
Expand All @@ -155,18 +162,19 @@ class RelaxedRigidContactsState(ContactsState):
def __eq__(self, other: RelaxedRigidContactsState) -> bool:
return hash(self) == hash(other)

@staticmethod
def build() -> RelaxedRigidContactsState:
@classmethod
def build(cls: type[Self]) -> Self:
"""Create a `RelaxedRigidContactsState` instance"""

return RelaxedRigidContactsState()
return cls()

@staticmethod
def zero(model: js.model.JaxSimModel) -> RelaxedRigidContactsState:
@classmethod
def zero(cls: type[Self]) -> Self:
"""Build a zero `RelaxedRigidContactsState` instance from a `JaxSimModel`."""
return RelaxedRigidContactsState.build()

def valid(self, model: js.model.JaxSimModel) -> bool:
return cls.build()

def valid(self, *, model: js.model.JaxSimModel) -> jtp.BoolLike:
return True


Expand All @@ -182,10 +190,12 @@ class RelaxedRigidContacts(ContactModel):
default_factory=FlatTerrain
)

@jax.jit
def compute_contact_forces(
self,
position: jtp.Vector,
velocity: jtp.Vector,
position: jtp.VectorLike,
velocity: jtp.VectorLike,
*,
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
link_forces: jtp.MatrixLike | None = None,
Expand Down
44 changes: 28 additions & 16 deletions src/jaxsim/rbda/contacts/rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@

from .common import ContactModel, ContactsParams, ContactsState

try:
from typing import Self
except ImportError:
from typing_extensions import Self


@jax_dataclasses.pytree_dataclass
class RigidContactsParams(ContactsParams):
Expand Down Expand Up @@ -50,19 +55,22 @@ def __eq__(self, other: RigidContactsParams) -> bool:

@classmethod
def build(
cls,
cls: type[Self],
*,
mu: jtp.FloatLike | None = None,
K: jtp.FloatLike | None = None,
D: jtp.FloatLike | None = None,
) -> RigidContactsParams:
) -> Self:
"""Create a `RigidContactParams` instance"""
return RigidContactsParams(

return cls(
mu=mu or cls.__dataclass_fields__["mu"].default,
K=K or cls.__dataclass_fields__["K"].default,
D=D or cls.__dataclass_fields__["D"].default,
)

def valid(self) -> bool:
def valid(self) -> jtp.BoolLike:

return bool(
jnp.all(self.mu >= 0.0)
and jnp.all(self.K >= 0.0)
Expand All @@ -77,18 +85,19 @@ class RigidContactsState(ContactsState):
def __eq__(self, other: RigidContactsState) -> bool:
return hash(self) == hash(other)

@staticmethod
def build(**kwargs) -> RigidContactsState:
@classmethod
def build(cls: type[Self]) -> Self:
"""Create a `RigidContactsState` instance"""

return RigidContactsState()
return cls()

@staticmethod
def zero(**kwargs) -> RigidContactsState:
@classmethod
def zero(cls: type[Self]) -> Self:
"""Build a zero `RigidContactsState` instance from a `JaxSimModel`."""
return RigidContactsState.build()

def valid(self, **kwargs) -> bool:
return cls.build()

def valid(self) -> jtp.BoolLike:
return True


Expand Down Expand Up @@ -117,7 +126,8 @@ def detect_contacts(
terrain_height: The height of the terrain at the collidable point position.

Returns:
A tuple containing the activation state of the collidable points and the contact penetration depth h.
A tuple containing the activation state of the collidable points
and the contact penetration depth h.
"""

# TODO: reduce code duplication with js.contact.in_contact
Expand Down Expand Up @@ -154,8 +164,8 @@ def compute_impact_velocity(

Args:
inactive_collidable_points: The activation state of the collidable points.
M: The mass matrix of the system.
J_WC: The Jacobian matrix of the collidable points.
M: The mass matrix of the system (in mixed representation).
J_WC: The Jacobian matrix of the collidable points (in mixed representation).
data: The `JaxSimModelData` instance.
"""

Expand Down Expand Up @@ -206,10 +216,12 @@ def impact_velocity(

return BW_ν_post_impact

@jax.jit
def compute_contact_forces(
self,
position: jtp.Vector,
velocity: jtp.Vector,
position: jtp.VectorLike,
velocity: jtp.VectorLike,
*,
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
link_forces: jtp.MatrixLike | None = None,
Expand Down
Loading