Skip to content

Commit

Permalink
Merge pull request #245 from ami-iit/refactor_soft_contacts
Browse files Browse the repository at this point in the history
Refactor `SoftContacts` algorithm
  • Loading branch information
diegoferigo authored Sep 27, 2024
2 parents 8080f3c + d59d2c2 commit a4d1476
Show file tree
Hide file tree
Showing 6 changed files with 281 additions and 217 deletions.
4 changes: 3 additions & 1 deletion src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ def collidable_point_dynamics(
# Note that the material deformation rate is always returned in the mixed frame
# C[W] = (W_p_C, [W]). This is convenient for integration purpose.
W_f_Ci, (CW_ṁ,) = jax.vmap(soft_contacts.compute_contact_forces)(
W_p_Ci, W_ṗ_Ci, data.state.contact.tangential_deformation
position=W_p_Ci,
velocity=W_ṗ_Ci,
tangential_deformation=data.state.contact.tangential_deformation,
)
aux_data = dict(m_dot=CW_ṁ)

Expand Down
8 changes: 8 additions & 0 deletions src/jaxsim/rbda/contacts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .common import ContactModel, ContactsParams, ContactsState
from .relaxed_rigid import (
RelaxedRigidContacts,
RelaxedRigidContactsParams,
RelaxedRigidContactsState,
)
from .rigid import RigidContacts, RigidContactsParams, RigidContactsState
from .soft import SoftContacts, SoftContactsParams, SoftContactsState
23 changes: 15 additions & 8 deletions src/jaxsim/rbda/contacts/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@
import jaxsim.typing as jtp
from jaxsim.utils import JaxsimDataclass

try:
from typing import Self
except ImportError:
from typing_extensions import Self

class ContactsState(abc.ABC):

class ContactsState(JaxsimDataclass):
"""
Abstract class storing the state of the contacts model.
"""

@classmethod
@abc.abstractmethod
def build(cls, **kwargs) -> ContactsState:
def build(cls: type[Self], **kwargs) -> Self:
"""
Build the contact state object.
Expand All @@ -26,7 +31,7 @@ def build(cls, **kwargs) -> ContactsState:

@classmethod
@abc.abstractmethod
def zero(cls, **kwargs) -> ContactsState:
def zero(cls: type[Self], **kwargs) -> Self:
"""
Build a zero contact state.
Expand All @@ -36,7 +41,7 @@ def zero(cls, **kwargs) -> ContactsState:
pass

@abc.abstractmethod
def valid(self, **kwargs) -> bool:
def valid(self, **kwargs) -> jtp.BoolLike:
"""
Check if the contacts state is valid.
"""
Expand All @@ -50,18 +55,20 @@ class ContactsParams(JaxsimDataclass):

@classmethod
@abc.abstractmethod
def build(cls) -> ContactsParams:
def build(cls: type[Self], **kwargs) -> Self:
"""
Create a `ContactsParams` instance with specified parameters.
Returns:
The `ContactsParams` instance.
"""
pass

@abc.abstractmethod
def valid(self, *args, **kwargs) -> bool:
def valid(self, **kwargs) -> jtp.BoolLike:
"""
Check if the parameters are valid.
Returns:
True if the parameters are valid, False otherwise.
"""
Expand All @@ -83,8 +90,8 @@ class ContactModel(JaxsimDataclass):
@abc.abstractmethod
def compute_contact_forces(
self,
position: jtp.Vector,
velocity: jtp.Vector,
position: jtp.VectorLike,
velocity: jtp.VectorLike,
**kwargs,
) -> tuple[jtp.Vector, tuple[Any, ...]]:
"""
Expand Down
34 changes: 22 additions & 12 deletions src/jaxsim/rbda/contacts/relaxed_rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@

from .common import ContactModel, ContactsParams, ContactsState

try:
from typing import Self
except ImportError:
from typing_extensions import Self


@jax_dataclasses.pytree_dataclass
class RelaxedRigidContactsParams(ContactsParams):
Expand Down Expand Up @@ -106,7 +111,8 @@ def __eq__(self, other: RelaxedRigidContactsParams) -> bool:

@classmethod
def build(
cls,
cls: type[Self],
*,
time_constant: jtp.FloatLike | None = None,
damping_coefficient: jtp.FloatLike | None = None,
d_min: jtp.FloatLike | None = None,
Expand All @@ -119,7 +125,7 @@ def build(
mu: jtp.FloatLike | None = None,
max_iterations: jtp.IntLike | None = None,
tolerance: jtp.FloatLike | None = None,
) -> RelaxedRigidContactsParams:
) -> Self:
"""Create a `RelaxedRigidContactsParams` instance"""

return cls(
Expand All @@ -132,7 +138,8 @@ def build(
}
)

def valid(self) -> bool:
def valid(self) -> jtp.BoolLike:

return bool(
jnp.all(self.time_constant >= 0.0)
and jnp.all(self.damping_coefficient > 0.0)
Expand All @@ -155,18 +162,19 @@ class RelaxedRigidContactsState(ContactsState):
def __eq__(self, other: RelaxedRigidContactsState) -> bool:
return hash(self) == hash(other)

@staticmethod
def build() -> RelaxedRigidContactsState:
@classmethod
def build(cls: type[Self]) -> Self:
"""Create a `RelaxedRigidContactsState` instance"""

return RelaxedRigidContactsState()
return cls()

@staticmethod
def zero(model: js.model.JaxSimModel) -> RelaxedRigidContactsState:
@classmethod
def zero(cls: type[Self]) -> Self:
"""Build a zero `RelaxedRigidContactsState` instance from a `JaxSimModel`."""
return RelaxedRigidContactsState.build()

def valid(self, model: js.model.JaxSimModel) -> bool:
return cls.build()

def valid(self, *, model: js.model.JaxSimModel) -> jtp.BoolLike:
return True


Expand All @@ -182,10 +190,12 @@ class RelaxedRigidContacts(ContactModel):
default_factory=FlatTerrain
)

@jax.jit
def compute_contact_forces(
self,
position: jtp.Vector,
velocity: jtp.Vector,
position: jtp.VectorLike,
velocity: jtp.VectorLike,
*,
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
link_forces: jtp.MatrixLike | None = None,
Expand Down
44 changes: 28 additions & 16 deletions src/jaxsim/rbda/contacts/rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@

from .common import ContactModel, ContactsParams, ContactsState

try:
from typing import Self
except ImportError:
from typing_extensions import Self


@jax_dataclasses.pytree_dataclass
class RigidContactsParams(ContactsParams):
Expand Down Expand Up @@ -50,19 +55,22 @@ def __eq__(self, other: RigidContactsParams) -> bool:

@classmethod
def build(
cls,
cls: type[Self],
*,
mu: jtp.FloatLike | None = None,
K: jtp.FloatLike | None = None,
D: jtp.FloatLike | None = None,
) -> RigidContactsParams:
) -> Self:
"""Create a `RigidContactParams` instance"""
return RigidContactsParams(

return cls(
mu=mu or cls.__dataclass_fields__["mu"].default,
K=K or cls.__dataclass_fields__["K"].default,
D=D or cls.__dataclass_fields__["D"].default,
)

def valid(self) -> bool:
def valid(self) -> jtp.BoolLike:

return bool(
jnp.all(self.mu >= 0.0)
and jnp.all(self.K >= 0.0)
Expand All @@ -77,18 +85,19 @@ class RigidContactsState(ContactsState):
def __eq__(self, other: RigidContactsState) -> bool:
return hash(self) == hash(other)

@staticmethod
def build(**kwargs) -> RigidContactsState:
@classmethod
def build(cls: type[Self]) -> Self:
"""Create a `RigidContactsState` instance"""

return RigidContactsState()
return cls()

@staticmethod
def zero(**kwargs) -> RigidContactsState:
@classmethod
def zero(cls: type[Self]) -> Self:
"""Build a zero `RigidContactsState` instance from a `JaxSimModel`."""
return RigidContactsState.build()

def valid(self, **kwargs) -> bool:
return cls.build()

def valid(self) -> jtp.BoolLike:
return True


Expand Down Expand Up @@ -117,7 +126,8 @@ def detect_contacts(
terrain_height: The height of the terrain at the collidable point position.
Returns:
A tuple containing the activation state of the collidable points and the contact penetration depth h.
A tuple containing the activation state of the collidable points
and the contact penetration depth h.
"""

# TODO: reduce code duplication with js.contact.in_contact
Expand Down Expand Up @@ -154,8 +164,8 @@ def compute_impact_velocity(
Args:
inactive_collidable_points: The activation state of the collidable points.
M: The mass matrix of the system.
J_WC: The Jacobian matrix of the collidable points.
M: The mass matrix of the system (in mixed representation).
J_WC: The Jacobian matrix of the collidable points (in mixed representation).
data: The `JaxSimModelData` instance.
"""

Expand Down Expand Up @@ -206,10 +216,12 @@ def impact_velocity(

return BW_ν_post_impact

@jax.jit
def compute_contact_forces(
self,
position: jtp.Vector,
velocity: jtp.Vector,
position: jtp.VectorLike,
velocity: jtp.VectorLike,
*,
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
link_forces: jtp.MatrixLike | None = None,
Expand Down
Loading

0 comments on commit a4d1476

Please sign in to comment.