diff --git a/src/jaxsim/rbda/contacts/rigid.py b/src/jaxsim/rbda/contacts/rigid.py index 86b64e022..26fefa5b7 100644 --- a/src/jaxsim/rbda/contacts/rigid.py +++ b/src/jaxsim/rbda/contacts/rigid.py @@ -415,11 +415,16 @@ def _compute_ineq_constraint_matrix( return G -@jax.jit -@js.common.named_scope def _compute_ineq_bounds(n_collidable_points: int) -> jtp.Vector: + """ + Compute the inequality bounds for the contact forces. + + Note: + Do not JIT this function as the output shape depends on the number of collidable points. + """ n_constraints = 6 * n_collidable_points + return jnp.zeros(shape=(n_constraints,))