From 8b9a69539ed16ee253052bf32aa5bf64161063f5 Mon Sep 17 00:00:00 2001 From: Diego Ferigo Date: Fri, 27 Sep 2024 09:13:48 +0200 Subject: [PATCH] Apply suggestions from code review Co-authored-by: Filippo Luca Ferretti <102977828+flferretti@users.noreply.github.com> --- src/jaxsim/rbda/contacts/relaxed_rigid.py | 6 +++--- src/jaxsim/rbda/contacts/rigid.py | 6 +++--- src/jaxsim/rbda/contacts/soft.py | 8 ++++---- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/jaxsim/rbda/contacts/relaxed_rigid.py b/src/jaxsim/rbda/contacts/relaxed_rigid.py index 4225aaa79..c1fb3da9e 100644 --- a/src/jaxsim/rbda/contacts/relaxed_rigid.py +++ b/src/jaxsim/rbda/contacts/relaxed_rigid.py @@ -166,13 +166,13 @@ def __eq__(self, other: RelaxedRigidContactsState) -> bool: def build(cls: type[Self]) -> Self: """Create a `RelaxedRigidContactsState` instance""" - return RelaxedRigidContactsState() + return cls() @classmethod - def zero(cls: type[Self], *, model: js.model.JaxSimModel) -> Self: + def zero(cls: type[Self]) -> Self: """Build a zero `RelaxedRigidContactsState` instance from a `JaxSimModel`.""" - return RelaxedRigidContactsState.build() + return cls.build() def valid(self, *, model: js.model.JaxSimModel) -> jtp.BoolLike: return True diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index b0b3acb5e..428d500b8 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -63,7 +63,7 @@ def build( ) -> Self: """Create a `RigidContactParams` instance""" - return RigidContactsParams( + return cls( mu=mu or cls.__dataclass_fields__["mu"].default, K=K or cls.__dataclass_fields__["K"].default, D=D or cls.__dataclass_fields__["D"].default, @@ -89,13 +89,13 @@ def __eq__(self, other: RigidContactsState) -> bool: def build(cls: type[Self]) -> Self: """Create a `RigidContactsState` instance""" - return RigidContactsState() + return cls() @classmethod def zero(cls: type[Self]) -> Self: """Build a zero `RigidContactsState` instance from a `JaxSimModel`.""" - return RigidContactsState.build() + return cls.build() def valid(self) -> jtp.BoolLike: return True diff --git a/src/jaxsim/rbda/contacts/soft.py b/src/jaxsim/rbda/contacts/soft.py index 7ff186532..c7318df79 100644 --- a/src/jaxsim/rbda/contacts/soft.py +++ b/src/jaxsim/rbda/contacts/soft.py @@ -178,7 +178,7 @@ def valid(self) -> jtp.BoolLike: return jnp.hstack( [ self.K >= 0.0, - self.D >= 0, + self.D >= 0.0, self.mu >= 0.0, self.p >= 0.0, self.q >= 0.0, @@ -310,9 +310,9 @@ def compute_contact_forces( contact_status += (δ <= 0).astype(int) # Select the right material deformation rate depending on the contact status. - ṁ = jax.lax.switch( - index=contact_status, - branches=(lambda: ṁ_slipping, lambda: ṁ_sticking, lambda: ṁ_no_contact), + ṁ = jax.lax.select_n( + ṁ_slipping, ṁ_sticking, ṁ_no_contact, + which=contact_status, ) # ==========================================