Skip to content

Commit

Permalink
Use jax.pure_callback to throw errors
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Jun 10, 2024
1 parent 13a377d commit 35e6811
Showing 1 changed file with 28 additions and 17 deletions.
45 changes: 28 additions & 17 deletions src/jaxsim/api/references.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,19 +179,22 @@ def link_forces(
# serialization.
if model is None:

def not_inertial():
def inertial():
if link_names is not None:
raise ValueError("Link names cannot be provided without a model")

return self.input.physics_model.f_ext

jax.lax.cond(
return jax.lax.cond(
pred=(self.velocity_representation == VelRepr.Inertial),
true_fun=not_inertial,
false_fun=lambda: (_ for _ in (None,)).throw(
ValueError(
"Missing model to use a representation different from `VelRepr.Inertial`"
)
true_fun=inertial,
false_fun=lambda: jax.pure_callback(
callback=lambda: (_ for _ in ()).throw(
ValueError(
"Missing model to use a representation different from `VelRepr.Inertial`"
)
),
result_shape_dtypes=self.input.physics_model.f_ext,
),
)

Expand Down Expand Up @@ -233,7 +236,10 @@ def convert(W_f_L: jtp.MatrixLike, W_H_L: jtp.ArrayLike) -> jtp.Matrix:
return jax.lax.cond(
pred=(self.velocity_representation == VelRepr.Inertial),
true_fun=lambda: W_f_L[link_idxs, :],
false_fun=not_inertial,
false_fun=lambda: jax.pure_callback(
callback=not_inertial,
result_shape_dtypes=W_f_L[link_idxs, :],
),
)

def joint_force_references(
Expand Down Expand Up @@ -382,7 +388,7 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences:
# using the implicit link serialization.
if model is None:

def not_inertial():
def inertial():
if link_names is not None:
raise ValueError("Link names cannot be provided without a model")

Expand All @@ -398,11 +404,14 @@ def not_inertial():

jax.lax.cond(
pred=(self.velocity_representation == VelRepr.Inertial),
true_fun=not_inertial,
false_fun=lambda: (_ for _ in (None,)).throw(
ValueError(
"Missing model to use a representation different from `VelRepr.Inertial`"
)
true_fun=inertial,
false_fun=lambda: jax.pure_callback(
callback=lambda: (_ for _ in ()).throw(
ValueError(
"Missing model to use a representation different from `VelRepr.Inertial`"
)
),
result_shape_dtypes=self,
),
)

Expand Down Expand Up @@ -435,8 +444,7 @@ def inertial():
)
)

def not_inertial():

def not_inertial(data):
if data is None:
raise ValueError(
"Missing model data to use a representation different from `VelRepr.Inertial`"
Expand Down Expand Up @@ -473,5 +481,8 @@ def convert_using_link_frame(
return jax.lax.cond(
pred=(self.velocity_representation == VelRepr.Inertial),
true_fun=inertial,
false_fun=not_inertial,
false_fun=lambda: jax.experimental.io_callback(
callback=not_inertial,
result_shape_dtypes=self,
),
)

0 comments on commit 35e6811

Please sign in to comment.