diff --git a/src/jaxsim/api/kin_dyn_parameters.py b/src/jaxsim/api/kin_dyn_parameters.py index 41764823a..6055558d3 100644 --- a/src/jaxsim/api/kin_dyn_parameters.py +++ b/src/jaxsim/api/kin_dyn_parameters.py @@ -805,11 +805,14 @@ def build_from(model_description: ModelDescription) -> ContactParameters: "false", "0", }: - points, idxs = jnp.unique(all_points, axis=0, return_index=True) - selected_points = map(collidable_points.__getitem__, idxs) + _, idxs = jnp.unique(all_points, axis=0, return_index=True) + idxs = jnp.sort(idxs) else: - points = all_points - selected_points = iter(collidable_points) + idxs = jnp.arange(len(collidable_points)) + + # Select unique points and corresponding link indices. + points = all_points[idxs] + selected_points = map(collidable_points.__getitem__, idxs) # Extract the indices of the links to which the collidable points are rigidly attached. link_index_of_points = tuple(