Skip to content

Commit

Permalink
Add Integrator option in JaxSimModel
Browse files Browse the repository at this point in the history
  • Loading branch information
flferretti committed Feb 14, 2025
1 parent 938d9ce commit 7a76ec7
Showing 1 changed file with 36 additions and 1 deletion.
37 changes: 36 additions & 1 deletion src/jaxsim/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import copy
import dataclasses
import enum
import functools
import pathlib
from collections.abc import Sequence
Expand All @@ -23,6 +24,13 @@
from .common import VelRepr


class Integrator(enum.IntEnum):
"""The integrators available for the simulation."""

SemiImplicitEuler = enum.auto()
RungeKutta4 = enum.auto()


@jax_dataclasses.pytree_dataclass(eq=False, unsafe_hash=False)
class JaxSimModel(JaxsimDataclass):
"""
Expand Down Expand Up @@ -55,6 +63,10 @@ class JaxSimModel(JaxsimDataclass):
dataclasses.field(default=None, repr=False)
)

integrator: Static[Integrator] = dataclasses.field(
default=Integrator.SemiImplicitEuler, repr=False
)

built_from: Static[str | pathlib.Path | rod.Model | None] = dataclasses.field(
default=None, repr=False
)
Expand Down Expand Up @@ -111,6 +123,7 @@ def build_from_model_description(
terrain: jaxsim.terrain.Terrain | None = None,
contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
contact_params: jaxsim.rbda.contacts.ContactsParams | None = None,
integrator: Integrator | None = None,
is_urdf: bool | None = None,
considered_joints: Sequence[str] | None = None,
) -> JaxSimModel:
Expand All @@ -131,6 +144,7 @@ def build_from_model_description(
The contact model to consider.
If not specified, a soft contacts model is used.
contact_params: The parameters of the contact model.
integrator: The integrator to use for the simulation.
is_urdf:
The optional flag to force the model description to be parsed as a URDF.
This is usually automatically inferred.
Expand Down Expand Up @@ -164,6 +178,7 @@ def build_from_model_description(
terrain=terrain,
contact_model=contact_model,
contacts_params=contact_params,
integrator=integrator,
)

# Store the origin of the model, in case downstream logic needs it.
Expand All @@ -182,6 +197,7 @@ def build(
terrain: jaxsim.terrain.Terrain | None = None,
contact_model: jaxsim.rbda.contacts.ContactModel | None = None,
contacts_params: jaxsim.rbda.contacts.ContactsParams | None = None,
integrator: Integrator | None = None,
gravity: jtp.FloatLike = jaxsim.math.STANDARD_GRAVITY,
) -> JaxSimModel:
"""
Expand All @@ -202,6 +218,7 @@ def build(
The contact model to consider.
If not specified, a soft contacts model is used.
contacts_params: The parameters of the soft contacts.
integrator: The integrator to use for the simulation.
gravity: The gravity constant.
Returns:
Expand Down Expand Up @@ -237,6 +254,13 @@ def build(
if contacts_params is None:
contacts_params = contact_model._parameters_class()

# Consider the default integrator if not specified.
integrator = (
integrator
if integrator is not None
else JaxSimModel.__dataclass_fields__["integrator"].default
)

# Build the model.
model = cls(
model_name=model_name,
Expand All @@ -247,6 +271,7 @@ def build(
terrain=terrain,
contact_model=contact_model,
contacts_params=contacts_params,
integrator=integrator,
gravity=gravity,
# The following is wrapped as hashless since it's a static argument, and we
# don't want to trigger recompilation if it changes. All relevant parameters
Expand Down Expand Up @@ -449,6 +474,7 @@ def reduce(
contact_model=model.contact_model,
contacts_params=model.contacts_params,
gravity=model.gravity,
integrator=model.integrator,
)

# Store the origin of the model, in case downstream logic needs it.
Expand Down Expand Up @@ -2075,12 +2101,21 @@ def step(
# =============================
# Advance the simulation state
# =============================
from .integrators import _INTEGRATORS_MAP

integrator_fn = _INTEGRATORS_MAP[model.integrator]

data_tf = js.integrators.semi_implicit_euler_integration(
data_tf = integrator_fn(
model=model,
data=data,
base_acceleration_inertial=W_v̇_WB,
joint_accelerations=,
# Pass link_forces and joint_torques if the integrator is rk4
**(
{"link_forces": W_f_L_total, "joint_torques": τ_total}
if model.integrator == js.integrators.rk4_integration
else {}
),
)

return data_tf

0 comments on commit 7a76ec7

Please sign in to comment.