Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Filippo Luca Ferretti <[email protected]>
  • Loading branch information
diegoferigo and flferretti committed Sep 27, 2024
1 parent cc7b010 commit 33e0900
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
6 changes: 3 additions & 3 deletions src/jaxsim/rbda/contacts/relaxed_rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,13 @@ def __eq__(self, other: RelaxedRigidContactsState) -> bool:
def build(cls: type[Self]) -> Self:
"""Create a `RelaxedRigidContactsState` instance"""

return RelaxedRigidContactsState()
return cls()

@classmethod
def zero(cls: type[Self], *, model: js.model.JaxSimModel) -> Self:
def zero(cls: type[Self]) -> Self:
"""Build a zero `RelaxedRigidContactsState` instance from a `JaxSimModel`."""

return RelaxedRigidContactsState.build()
return cls.build()

def valid(self, *, model: js.model.JaxSimModel) -> jtp.BoolLike:
return True
Expand Down
6 changes: 3 additions & 3 deletions src/jaxsim/rbda/contacts/rigid.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def build(
) -> Self:
"""Create a `RigidContactParams` instance"""

return RigidContactsParams(
return cls(
mu=mu or cls.__dataclass_fields__["mu"].default,
K=K or cls.__dataclass_fields__["K"].default,
D=D or cls.__dataclass_fields__["D"].default,
Expand All @@ -89,13 +89,13 @@ def __eq__(self, other: RigidContactsState) -> bool:
def build(cls: type[Self]) -> Self:
"""Create a `RigidContactsState` instance"""

return RigidContactsState()
return cls()

@classmethod
def zero(cls: type[Self]) -> Self:
"""Build a zero `RigidContactsState` instance from a `JaxSimModel`."""

return RigidContactsState.build()
return cls.build()

def valid(self) -> jtp.BoolLike:
return True
Expand Down
10 changes: 6 additions & 4 deletions src/jaxsim/rbda/contacts/soft.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def valid(self) -> jtp.BoolLike:
return jnp.hstack(
[
self.K >= 0.0,
self.D >= 0,
self.D >= 0.0,
self.mu >= 0.0,
self.p >= 0.0,
self.q >= 0.0,
Expand Down Expand Up @@ -310,9 +310,11 @@ def compute_contact_forces(
contact_status += (δ <= 0).astype(int)

# Select the right material deformation rate depending on the contact status.
= jax.lax.switch(
index=contact_status,
branches=(lambda: ṁ_slipping, lambda: ṁ_sticking, lambda: ṁ_no_contact),
= jax.lax.select_n(
ṁ_slipping,
ṁ_sticking,
ṁ_no_contact,
which=contact_status,
)

# ==========================================
Expand Down

0 comments on commit 33e0900

Please sign in to comment.