From 1836c9a930e7345d402b4bf44292322673abec48 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 14 Mar 2024 12:48:27 +0100 Subject: [PATCH] Disable some exception while tracing --- src/jaxsim/api/data.py | 15 +++++++++++++-- src/jaxsim/api/references.py | 9 +++++---- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/jaxsim/api/data.py b/src/jaxsim/api/data.py index 926098b06..d67e2e305 100644 --- a/src/jaxsim/api/data.py +++ b/src/jaxsim/api/data.py @@ -14,6 +14,7 @@ import jaxsim.rbda import jaxsim.typing as jtp from jaxsim.utils import Mutability +from jaxsim.utils.tracing import not_tracing from . import common from .common import VelRepr @@ -260,9 +261,14 @@ def joint_positions( """ if model is None: + if joint_names is not None: + raise ValueError("Joint names cannot be provided without a model") + return self.state.physics_model.joint_positions - if not self.valid(model=model): + if not_tracing(self.state.physics_model.joint_positions) and not self.valid( + model=model + ): msg = "The data object is not compatible with the provided model" raise ValueError(msg) @@ -300,9 +306,14 @@ def joint_velocities( """ if model is None: + if joint_names is not None: + raise ValueError("Joint names cannot be provided without a model") + return self.state.physics_model.joint_velocities - if not self.valid(model=model): + if not_tracing(self.state.physics_model.joint_velocities) and not self.valid( + model=model + ): msg = "The data object is not compatible with the provided model" raise ValueError(msg) diff --git a/src/jaxsim/api/references.py b/src/jaxsim/api/references.py index c955d9135..e9c122bbf 100644 --- a/src/jaxsim/api/references.py +++ b/src/jaxsim/api/references.py @@ -9,6 +9,7 @@ import jaxsim.api as js import jaxsim.typing as jtp from jaxsim.simulation.ode_data import ODEInput +from jaxsim.utils.tracing import not_tracing from .common import VelRepr @@ -198,7 +199,7 @@ def link_forces( msg = "Missing model data to use a representation different from {}" raise ValueError(msg.format(VelRepr.Inertial.name)) - if not data.valid(model=model): + if not_tracing(self.input.physics_model.f_ext) and not data.valid(model=model): raise ValueError("The provided data is not valid for the model") # Helper function to convert a single 6D force to the active representation. @@ -252,7 +253,7 @@ def joint_force_references( return self.input.physics_model.tau - if not self.valid(model=model): + if not_tracing(self.input.physics_model.tau) and not self.valid(model=model): msg = "The actuation object is not compatible with the provided model" raise ValueError(msg) @@ -303,7 +304,7 @@ def replace(forces: jtp.VectorLike) -> JaxSimModelReferences: if model is None: return replace(forces=forces) - if not self.valid(model=model): + if not_tracing(forces) and not self.valid(model=model): msg = "The references object is not compatible with the provided model" raise ValueError(msg) @@ -401,7 +402,7 @@ def replace(forces: jtp.MatrixLike) -> JaxSimModelReferences: msg = "Missing model data to use a representation different from {}" raise ValueError(msg.format(VelRepr.Inertial.name)) - if not data.valid(model=model): + if not_tracing(forces) and not data.valid(model=model): raise ValueError("The provided data is not valid for the model") # Helper function to convert a single 6D force to the inertial representation.