Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove absolute time from JaxSimModelData and introduce JaxSimModel.time_step #262

Merged
merged 7 commits into from
Oct 11, 2024
25 changes: 0 additions & 25 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,6 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):

contacts_params: jaxsim.rbda.contacts.ContactsParams = dataclasses.field(repr=False)

time_ns: jtp.Int = dataclasses.field(
default_factory=lambda: jnp.array(
0, dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32
),
)

def __hash__(self) -> int:

from jaxsim.utils.wrappers import HashedNumpyArray
Expand All @@ -52,7 +46,6 @@ def __hash__(self) -> int:
(
hash(self.state),
HashedNumpyArray.hash_of_array(self.gravity),
HashedNumpyArray.hash_of_array(self.time_ns),
hash(self.contacts_params),
)
)
Expand Down Expand Up @@ -115,7 +108,6 @@ def build(
standard_gravity: jtp.FloatLike = jaxsim.math.StandardGravity,
contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None,
velocity_representation: VelRepr = VelRepr.Inertial,
time: jtp.FloatLike | None = None,
extended_ode_state: dict[str, jtp.PyTree] | None = None,
) -> JaxSimModelData:
"""
Expand All @@ -134,7 +126,6 @@ def build(
standard_gravity: The standard gravity constant.
contacts_params: The parameters of the soft contacts.
velocity_representation: The velocity representation to use.
time: The time at which the state is created.
extended_ode_state:
Additional user-defined state variables that are not part of the
standard `ODEState` object. Useful to extend the system dynamics
Expand Down Expand Up @@ -196,11 +187,6 @@ def build(
).squeeze()
)

time_ns = jnp.array(
time * 1e9 if time is not None else 0.0,
dtype=jnp.uint64 if jax.config.read("jax_enable_x64") else jnp.uint32,
)

W_H_B = jaxsim.math.Transform.from_quaternion_and_translation(
translation=base_position, quaternion=base_quaternion
)
Expand Down Expand Up @@ -246,7 +232,6 @@ def build(
contacts_params = model.contact_model.parameters

return JaxSimModelData(
time_ns=time_ns,
state=ode_state,
gravity=gravity,
contacts_params=contacts_params,
Expand All @@ -257,16 +242,6 @@ def build(
# Extract quantities
# ==================

def time(self) -> jtp.Float:
"""
Get the simulated time.

Returns:
The simulated time in seconds.
"""

return self.time_ns.astype(float) / 1e9

def standard_gravity(self) -> jtp.Float:
"""
Get the standard gravity constant.
Expand Down
67 changes: 38 additions & 29 deletions src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class JaxSimModel(JaxsimDataclass):

model_name: Static[str]

time_step: jaxsim.integrators.TimeStep = dataclasses.field(
default_factory=lambda: jnp.array(0.001, dtype=float),
)

terrain: Static[jaxsim.terrain.Terrain] = dataclasses.field(
default_factory=jaxsim.terrain.FlatTerrain.build, repr=False
)
Expand Down Expand Up @@ -64,6 +68,9 @@ def __eq__(self, other: JaxSimModel) -> bool:
if self.model_name != other.model_name:
return False

if self.time_step != other.time_step:
return False

if self.kin_dyn_parameters != other.kin_dyn_parameters:
return False

Expand All @@ -74,6 +81,7 @@ def __hash__(self) -> int:
return hash(
(
hash(self.model_name),
hash(float(self.time_step)),
hash(self.kin_dyn_parameters),
hash(self.contact_model),
)
Expand All @@ -88,6 +96,7 @@ def build_from_model_description(
model_description: str | pathlib.Path | rod.Model,
model_name: str | None = None,
*,
time_step: jtp.FloatLike | None = None,
terrain: jaxsim.terrain.Terrain | None = None,
contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
is_urdf: bool | None = None,
Expand All @@ -102,6 +111,9 @@ def build_from_model_description(
its content, or a pre-parsed/pre-built rod model.
model_name:
The name of the model. If not specified, it is read from the description.
time_step:
The default time step to consider for the simulation. It can be
manually overridden in the function that steps the simulation.
terrain: The terrain to consider (the default is a flat infinite plane).
contact_model:
The contact model to consider.
Expand Down Expand Up @@ -135,6 +147,7 @@ def build_from_model_description(
model = JaxSimModel.build(
model_description=intermediate_description,
model_name=model_name,
time_step=time_step,
terrain=terrain,
contact_model=contact_model,
)
Expand All @@ -150,6 +163,7 @@ def build(
model_description: ModelDescription,
model_name: str | None = None,
*,
time_step: jtp.FloatLike | None = None,
terrain: jaxsim.terrain.Terrain | None = None,
contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
) -> JaxSimModel:
Expand All @@ -162,6 +176,9 @@ def build(
of the model.
model_name:
The name of the model. If not specified, it is read from the description.
time_step:
The default time step to consider for the simulation. It can be
manually overridden in the function that steps the simulation.
terrain: The terrain to consider (the default is a flat infinite plane).
contact_model:
The contact model to consider.
Expand All @@ -179,6 +196,11 @@ def build(
terrain or JaxSimModel.__dataclass_fields__["terrain"].default_factory()
)

# Consider the default time step if not specified.
time_step = (
time_step or JaxSimModel.__dataclass_fields__["time_step"].default_factory()
)

# Create the default contact model.
# It will be populated with an initial estimation of good parameters.
# While these might not be the best, they are a good starting point.
Expand All @@ -192,6 +214,7 @@ def build(
kin_dyn_parameters=js.kin_dyn_parameters.KynDynParameters.build(
model_description=model_description
),
time_step=time_step,
terrain=terrain,
contact_model=contact_model,
# The following is wrapped as hashless since it's a static argument, and we
Expand Down Expand Up @@ -1915,8 +1938,9 @@ def step(
model: JaxSimModel,
data: js.data.JaxSimModelData,
*,
dt: jtp.FloatLike,
integrator: jaxsim.integrators.Integrator,
t0: jtp.FloatLike = 0.0,
dt: jtp.FloatLike | None = None,
integrator_state: dict[str, Any] | None = None,
link_forces: jtp.MatrixLike | None = None,
joint_force_references: jtp.VectorLike | None = None,
Expand All @@ -1928,9 +1952,10 @@ def step(
Args:
model: The model to consider.
data: The data of the considered model.
dt: The time step to consider.
integrator: The integrator to use.
integrator_state: The state of the integrator.
t0: The initial time to consider. Only relevant for time-dependent dynamics.
dt: The time step to consider. If not specified, it is read from the model.
link_forces:
The 6D forces to apply to the links expressed in the frame corresponding to
the velocity representation of `data`.
Expand All @@ -1951,17 +1976,20 @@ def step(

integrator_state = integrator_state if integrator_state is not None else dict()

# Extract the initial resources.
t0_ns = data.time_ns
# Initialize the time-related variables.
state_t0 = data.state
integrator_state_x0 = integrator_state
t0 = jnp.array(t0, dtype=float)
dt = jnp.array(dt if dt is not None else model.time_step).astype(float)

# Rename the integrator state.
integrator_state_t0 = integrator_state

# Step the dynamics forward.
state_tf, integrator_state_tf = integrator.step(
x0=state_t0,
t0=jnp.array(t0_ns / 1e9).astype(float),
t0=t0,
dt=dt,
params=integrator_state_x0,
params=integrator_state_t0,
# Always inject the current (model, data) pair into the system dynamics
# considered by the integrator, and include the input variables represented
# by the pair (joint_force_references, link_forces).
Expand All @@ -1980,24 +2008,8 @@ def step(
),
)

tf_ns = t0_ns + jnp.array(dt * 1e9, dtype=t0_ns.dtype)
tf_ns = jnp.where(tf_ns >= t0_ns, tf_ns, jnp.array(0, dtype=t0_ns.dtype))

jax.lax.cond(
pred=tf_ns < t0_ns,
true_fun=lambda: jax.debug.print(
"The simulation time overflowed, resetting simulation time to 0."
),
false_fun=lambda: None,
)

data_tf = (
# Store the new state of the model and the new time.
data.replace(
state=state_tf,
time_ns=tf_ns,
)
)
# Store the new state of the model.
data_tf = data.replace(state=state_tf)

# Post process the simulation state, if needed.
match model.contact_model:
Expand Down Expand Up @@ -2064,7 +2076,4 @@ def step(
velocity_representation=data.velocity_representation, validate=False
)

return (
data_tf,
integrator_state_tf,
)
return data_tf, integrator_state_tf
1 change: 0 additions & 1 deletion src/jaxsim/api/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def f(x: ODEState, t: Time, **kwargs_f) -> tuple[ODEState, dict[str, Any]]:
# Update the state and time stored inside data.
with data_f.editable(validate=True) as data_rw:
data_rw.state = x
data_rw.time_ns = jnp.array(t * 1e9).astype(data_rw.time_ns.dtype)

# Evaluate the system dynamics, allowing to override the kwargs originally
# passed when the closure was created.
Expand Down
24 changes: 13 additions & 11 deletions src/jaxsim/rbda/contacts/visco_elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def compute_contact_forces(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
dt: jtp.FloatLike,
dt: jtp.FloatLike | None = None,
link_forces: jtp.MatrixLike | None = None,
joint_force_references: jtp.VectorLike | None = None,
) -> tuple[jtp.Vector, tuple[Any, ...]]:
Expand All @@ -273,7 +273,7 @@ def compute_contact_forces(
Args:
model: The robot model considered by the contact model.
data: The data of the considered model.
dt: The integration time step.
dt: The time step to consider. If not specified, it is read from the model.
link_forces:
The 6D forces to apply to the links expressed in the frame corresponding
to the velocity representation of `data`.
Expand Down Expand Up @@ -305,13 +305,16 @@ def compute_contact_forces(
model.kin_dyn_parameters.contact_parameters.indices_of_enabled_collidable_points
)

# Initialize the time step.
dt = dt if dt is not None else model.time_step

# Compute the average contact linear forces in mixed representation by
# integrating the contact dynamics in the continuous time domain.
CW_f̅l, CW_fl̿, m_tf = (
ViscoElasticContacts._compute_contact_forces_with_exponential_integration(
model=model,
data=data,
dt=dt,
dt=jnp.array(dt).astype(float),
joint_force_references=joint_force_references,
link_forces=link_forces,
indices_of_enabled_collidable_points=indices_of_enabled_collidable_points,
Expand Down Expand Up @@ -923,14 +926,10 @@ def integrate_data_with_average_contact_forces(
)

# Create the data at the final time.
with data.editable(validate=True) as data_tf:
data_tf: js.data.JaxSimModelData
data_tf.time_ns = data.time_ns + (dt * 1e9).astype(data.time_ns.dtype)

data_tf = data.copy()
data_tf = data_tf.reset_joint_positions(q_plus[7:])
data_tf = data_tf.reset_base_position(q_plus[0:3])
data_tf = data_tf.reset_base_quaternion(q_plus[3:7])

data_tf = data_tf.reset_joint_velocities(W_ν_plus[6:])
data_tf = data_tf.reset_base_velocity(
W_ν_plus[0:6], velocity_representation=jaxsim.VelRepr.Inertial
Expand All @@ -946,7 +945,7 @@ def step(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
dt: jtp.FloatLike,
dt: jtp.FloatLike | None = None,
link_forces: jtp.MatrixLike | None = None,
joint_force_references: jtp.VectorLike | None = None,
) -> tuple[js.data.JaxSimModelData, dict[str, Any]]:
Expand All @@ -956,7 +955,7 @@ def step(
Args:
model: The model to consider.
data: The data of the considered model.
dt: The time step to consider.
dt: The time step to consider. If not specified, it is read from the model.
link_forces:
The 6D forces to apply to the links expressed in the frame corresponding to
the velocity representation of `data`.
Expand All @@ -970,11 +969,14 @@ def step(
assert isinstance(model.contact_model, ViscoElasticContacts)
assert isinstance(data.contacts_params, ViscoElasticContactsParams)

# Initialize the time step.
dt = dt if dt is not None else model.time_step

# Compute the contact forces with the exponential integrator.
W_f̅_C, (W_f̿_C, m_tf) = model.contact_model.compute_contact_forces(
model=model,
data=data,
dt=dt,
dt=jnp.array(dt).astype(float),
link_forces=link_forces,
joint_force_references=joint_force_references,
)
Expand Down
14 changes: 4 additions & 10 deletions tests/test_simulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ def test_box_with_external_forces(
# Initialize the integrator.
tf = 0.5
dt = 0.001
T = jnp.arange(start=0, stop=tf * 1e9, step=dt * 1e9, dtype=int)
T_ns = jnp.arange(start=0, stop=tf * 1e9, step=dt * 1e9, dtype=int)
integrator_state = integrator.init(x0=data0.state, t0=0.0, dt=dt)

# Copy the initial data...
data = data0.copy()

# ... and step the simulation.
for t_ns in T:
for _ in T_ns:

data, integrator_state = js.model.step(
model=model,
Expand All @@ -89,7 +89,6 @@ def test_box_with_external_forces(
)

# Check that the box didn't move.
assert data.time() == t_ns / 1e9 + dt
assert data.base_position() == pytest.approx(data0.base_position())
assert data.base_orientation() == pytest.approx(data0.base_orientation())

Expand Down Expand Up @@ -158,16 +157,14 @@ def test_box_with_zero_gravity(

# Initialize the integrator.
tf, dt = 1.0, 0.010
T = jnp.arange(start=0, stop=tf * 1e9, step=dt * 1e9, dtype=int)
T_ns = jnp.arange(start=0, stop=tf * 1e9, step=dt * 1e9, dtype=int)
integrator_state = integrator.init(x0=data0.state, t0=0.0, dt=dt)

# Copy the initial data...
data = data0.copy()

# ... and step the simulation.
for t_ns in T:

assert data.time() == t_ns / 1e9
for _ in T_ns:

with (
data.switch_velocity_representation(velocity_representation),
Expand All @@ -183,9 +180,6 @@ def test_box_with_zero_gravity(
link_forces=references.link_forces(model=model, data=data),
)

# Check the final simulation time.
assert data.time() == T[-1] / 1e9 + dt

# Check that the box moved as expected.
assert data.base_position() == pytest.approx(
data0.base_position()
Expand Down