Skip to content

Commit

Permalink
Disable some exception while tracing
Browse files Browse the repository at this point in the history
  • Loading branch information
diegoferigo committed Mar 14, 2024
1 parent 0ae0806 commit 1836c9a
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 6 deletions.
15 changes: 13 additions & 2 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import jaxsim.rbda
import jaxsim.typing as jtp
from jaxsim.utils import Mutability
from jaxsim.utils.tracing import not_tracing

from . import common
from .common import VelRepr
Expand Down Expand Up @@ -260,9 +261,14 @@ def joint_positions(
"""

if model is None:
if joint_names is not None:
raise ValueError("Joint names cannot be provided without a model")

return self.state.physics_model.joint_positions

if not self.valid(model=model):
if not_tracing(self.state.physics_model.joint_positions) and not self.valid(
model=model
):
msg = "The data object is not compatible with the provided model"
raise ValueError(msg)

Expand Down Expand Up @@ -300,9 +306,14 @@ def joint_velocities(
"""

if model is None:
if joint_names is not None:
raise ValueError("Joint names cannot be provided without a model")

return self.state.physics_model.joint_velocities

if not self.valid(model=model):
if not_tracing(self.state.physics_model.joint_velocities) and not self.valid(
model=model
):
msg = "The data object is not compatible with the provided model"
raise ValueError(msg)

Expand Down
9 changes: 5 additions & 4 deletions src/jaxsim/api/references.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.simulation.ode_data import ODEInput
from jaxsim.utils.tracing import not_tracing

from .common import VelRepr

Expand Down Expand Up @@ -198,7 +199,7 @@ def link_forces(
msg = "Missing model data to use a representation different from {}"
raise ValueError(msg.format(VelRepr.Inertial.name))

if not data.valid(model=model):
if not_tracing(self.input.physics_model.f_ext) and not data.valid(model=model):
raise ValueError("The provided data is not valid for the model")

# Helper function to convert a single 6D force to the active representation.
Expand Down Expand Up @@ -252,7 +253,7 @@ def joint_force_references(

return self.input.physics_model.tau

if not self.valid(model=model):
if not_tracing(self.input.physics_model.tau) and not self.valid(model=model):
msg = "The actuation object is not compatible with the provided model"
raise ValueError(msg)

Expand Down Expand Up @@ -303,7 +304,7 @@ def replace(forces: jtp.VectorLike) -> JaxSimModelReferences:
if model is None:
return replace(forces=forces)

if not self.valid(model=model):
if not_tracing(forces) and not self.valid(model=model):
msg = "The references object is not compatible with the provided model"
raise ValueError(msg)

Expand Down Expand Up @@ -401,7 +402,7 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences:
msg = "Missing model data to use a representation different from {}"
raise ValueError(msg.format(VelRepr.Inertial.name))

if not data.valid(model=model):
if not_tracing(forces) and not data.valid(model=model):
raise ValueError("The provided data is not valid for the model")

# Helper function to convert a single 6D force to the inertial representation.
Expand Down

0 comments on commit 1836c9a

Please sign in to comment.