Skip to content

Commit

Permalink
Merge pull request #78 from ami-iit/flferretti-patch-2
Browse files Browse the repository at this point in the history
Enhance maintainability and performance
  • Loading branch information
flferretti authored Feb 5, 2024
2 parents 0cdfa79 + d3c65ac commit e2ea0e8
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 128 deletions.
19 changes: 5 additions & 14 deletions src/jaxsim/high_level/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,8 @@ def valid(self) -> jtp.Bool:
""""""

valid = True
valid = valid and all([l.valid() for l in self.links()])
valid = valid and all([j.valid() for j in self.joints()])
valid = valid and all(l.valid() for l in self.links())
valid = valid and all(j.valid() for j in self.joints())
return jnp.array(valid, dtype=bool)

@functools.partial(oop.jax_tf.method_ro, jit=False)
Expand Down Expand Up @@ -1414,27 +1414,18 @@ def integrate(
physics_model=self.data.model_input
)

if integrator_type is IntegratorType.EulerForward:
integrator_fn = ode_integration.ode_integration_euler

elif integrator_type is IntegratorType.EulerSemiImplicit:
integrator_fn = ode_integration.ode_integration_euler_semi_implicit

elif integrator_type is IntegratorType.RungeKutta4:
integrator_fn = ode_integration.ode_integration_rk4

else:
raise ValueError(integrator_type)
assert isinstance(integrator_type, IntegratorType)

# Integrate the model dynamics
ode_states, aux = integrator_fn(
ode_states, aux = ode_integration.ode_integration_fixed_step(
x0=x0,
t=jnp.array([t0, tf], dtype=float),
ode_input=ode_input,
physics_model=self.physics_model,
soft_contacts_params=contact_parameters,
num_sub_steps=sub_steps,
terrain=terrain,
integrator_type=integrator_type,
return_aux=True,
)

Expand Down
4 changes: 2 additions & 2 deletions src/jaxsim/parsers/kinematic_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __post_init__(self):

# Check that joint indices are unique
assert len([j.index for j in self.joints]) == len(
set([j.index for j in self.joints])
{j.index for j in self.joints}
)

# Order joints with their indices
Expand Down Expand Up @@ -268,7 +268,7 @@ def reduce(self, considered_joints: List[str]) -> "KinematicGraph":

# Check if all considered joints are part of the full kinematic graph
if len(set(considered_joints) - set(j.name for j in full_graph.joints)) != 0:
extra_j = set(considered_joints) - set([j.name for j in full_graph.joints])
extra_j = set(considered_joints) - {j.name for j in full_graph.joints}
msg = f"Not all joints to consider are part of the graph ({{{extra_j}}})"
raise ValueError(msg)

Expand Down
4 changes: 2 additions & 2 deletions src/jaxsim/parsers/rod/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from jaxsim.math.quaternion import Quaternion
from jaxsim.parsers import descriptions, kinematic_graph

from . import utils as utils
from . import utils


class SDFData(NamedTuple):
Expand Down Expand Up @@ -357,6 +357,6 @@ def build_model_description(
)

# Store the parsed SDF tree as extra info
model = dataclasses.replace(model, extra_info=dict(sdf_model=sdf_data.sdf_model))
model = dataclasses.replace(model, extra_info={"sdf_model": sdf_data.sdf_model})

return model
18 changes: 9 additions & 9 deletions src/jaxsim/simulation/ode.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Tuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -277,14 +277,14 @@ def fix_one_dof(vector: jtp.Vector) -> jtp.Vector | None:
W_nud_WB = jnp.hstack([W_a_WB.squeeze(), qdd.squeeze()])

# Build the auxiliary data
aux_dict = dict(
model_acceleration=W_nud_WB,
ode_input=ode_input,
ode_input_real=ode_input_real,
contact_forces_links=contact_forces_links,
contact_forces_points=contact_forces_points,
tangential_deformation_dot=tangential_deformation_dot,
)
aux_dict = {
"model_acceleration": W_nud_WB,
"ode_input": ode_input,
"ode_input_real": ode_input_real,
"contact_forces_links": contact_forces_links,
"contact_forces_points": contact_forces_points,
"tangential_deformation_dot": tangential_deformation_dot,
}

# Return the state derivative as a generic PyTree, and the dict with auxiliary info
return state_derivative, aux_dict
110 changes: 14 additions & 96 deletions src/jaxsim/simulation/ode_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@ class IntegratorType(enum.IntEnum):
EulerSemiImplicitManifold = enum.auto()


_integrator_registry = {
IntegratorType.RungeKutta4: integrators.odeint_rk4,
IntegratorType.EulerForward: integrators.odeint_euler,
IntegratorType.EulerSemiImplicit: integrators.odeint_euler_semi_implicit,
IntegratorType.EulerSemiImplicitManifold: integrators.odeint_euler_semi_implicit_manifold_one_step,
}


@jax.jit
def ode_integration_rk4_adaptive(
x0: jtp.Array,
Expand All @@ -33,118 +41,28 @@ def ode_integration_rk4_adaptive(
return odeint(dx_dt_closure, x0, t, **kwargs)


@functools.partial(jax.jit, static_argnames=["num_sub_steps", "return_aux"])
def ode_integration_euler(
@functools.partial(
jax.jit, static_argnames=["num_sub_steps", "integrator_type", "return_aux"]
)
def ode_integration_fixed_step(
x0: ode.ode_data.ODEState,
t: integrators.TimeHorizon,
physics_model: PhysicsModel,
integrator_type: IntegratorType,
soft_contacts_params: SoftContactsParams = SoftContactsParams(),
terrain: Terrain = FlatTerrain(),
ode_input: ode.ode_data.ODEInput | None = None,
*args,
num_sub_steps: int = 1,
return_aux: bool = False,
) -> Union[ode.ode_data.ODEState, Tuple[ode.ode_data.ODEState, Dict[str, Any]]]:
# Close func over additional inputs and parameters
dx_dt_closure = lambda x, ts: ode.dx_dt(
x, ts, physics_model, soft_contacts_params, ode_input, terrain, *args
)

# Integrate over the horizon
out = integrators.odeint_euler(
func=dx_dt_closure,
y0=x0,
t=t,
num_sub_steps=num_sub_steps,
return_aux=return_aux,
)

# Return output pytree and, optionally, the aux dict
state = out if not return_aux else out[0]
return (state, out[1]) if return_aux else state


@functools.partial(jax.jit, static_argnames=["num_sub_steps", "return_aux"])
def ode_integration_euler_semi_implicit(
x0: ode.ode_data.ODEState,
t: integrators.TimeHorizon,
physics_model: PhysicsModel,
soft_contacts_params: SoftContactsParams = SoftContactsParams(),
terrain: Terrain = FlatTerrain(),
ode_input: ode.ode_data.ODEInput | None = None,
*args,
num_sub_steps: int = 1,
return_aux: bool = False,
) -> Union[ode.ode_data.ODEState, Tuple[ode.ode_data.ODEState, Dict[str, Any]]]:
# Close func over additional inputs and parameters
dx_dt_closure = lambda x, ts: ode.dx_dt(
x, ts, physics_model, soft_contacts_params, ode_input, terrain, *args
)

# Integrate over the horizon
out = integrators.odeint_euler_semi_implicit(
func=dx_dt_closure,
y0=x0,
t=t,
num_sub_steps=num_sub_steps,
return_aux=return_aux,
)

# Return output pytree and, optionally, the aux dict
state = out if not return_aux else out[0]
return (state, out[1]) if return_aux else state


@functools.partial(jax.jit, static_argnames=["num_sub_steps", "return_aux"])
def ode_integration_euler_semi_implicit_manifold(
x0: ode.ode_data.ODEState,
t: integrators.TimeHorizon,
physics_model: PhysicsModel,
soft_contacts_params: SoftContactsParams = SoftContactsParams(),
terrain: Terrain = FlatTerrain(),
ode_input: ode.ode_data.ODEInput = None,
*args,
num_sub_steps: int = 1,
return_aux: bool = False,
) -> Union[ode.ode_data.ODEState, Tuple[ode.ode_data.ODEState, Dict[str, Any]]]:
# Close func over additional inputs and parameters
dx_dt_closure = lambda x, ts: ode.dx_dt(
x, ts, physics_model, soft_contacts_params, ode_input, terrain, *args
)

# Integrate over the horizon
out = integrators.odeint_euler_semi_implicit_manifold(
func=dx_dt_closure,
y0=x0,
t=t,
num_sub_steps=num_sub_steps,
return_aux=return_aux,
)

# Return output pytree and, optionally, the aux dict
state = out if not return_aux else out[0]
return (state, out[1]) if return_aux else state


@functools.partial(jax.jit, static_argnames=["num_sub_steps", "return_aux"])
def ode_integration_rk4(
x0: ode.ode_data.ODEState,
t: integrators.TimeHorizon,
physics_model: PhysicsModel,
soft_contacts_params: SoftContactsParams = SoftContactsParams(),
terrain: Terrain = FlatTerrain(),
ode_input: ode.ode_data.ODEInput | None = None,
*args,
num_sub_steps=1,
return_aux: bool = False,
) -> Union[ode.ode_data.ODEState, Tuple[ode.ode_data.ODEState, Dict]]:
# Close func over additional inputs and parameters
dx_dt_closure = lambda x, ts: ode.dx_dt(
x, ts, physics_model, soft_contacts_params, ode_input, terrain, *args
)

# Integrate over the horizon
out = integrators.odeint_rk4(
out = _integrator_registry[integrator_type](
func=dx_dt_closure,
y0=x0,
t=t,
Expand Down
10 changes: 5 additions & 5 deletions src/jaxsim/simulation/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
from jax_dataclasses import Static

import jaxsim.high_level
import jaxsim.parsers.descriptions as descriptions
import jaxsim.physics
import jaxsim.typing as jtp
from jaxsim import logging
from jaxsim.high_level.common import VelRepr
from jaxsim.high_level.model import Model, StepData
from jaxsim.parsers import descriptions
from jaxsim.physics.algos.soft_contacts import SoftContactsParams
from jaxsim.physics.algos.terrain import FlatTerrain, Terrain
from jaxsim.physics.model.physics_model import PhysicsModel
Expand Down Expand Up @@ -129,7 +129,7 @@ def reset(self, remove_models: bool = True) -> None:
self.data.time_ns = jnp.zeros_like(self.data.time_ns)

if remove_models:
self.data.models = dict()
self.data.models = {}
else:
_ = [m.zero() for m in self.models()]

Expand Down Expand Up @@ -213,7 +213,7 @@ def get_model(self, model_name: str) -> Model:
The model with the given name.
"""

if model_name not in self.data.models.keys():
if model_name not in self.data.models:
raise ValueError(f"Failed to find model '{model_name}'")

return self.data.models[model_name]
Expand Down Expand Up @@ -250,7 +250,7 @@ def set_gravity(self, gravity: jtp.Vector) -> None:

self.data.gravity = gravity

for model_name, model in self.data.models.items():
for model in self.data.models.values():
model.physics_model.set_gravity(gravity=gravity)

@functools.partial(oop.jax_tf.method_rw, jit=False, vmap=False, validate=False)
Expand Down Expand Up @@ -400,7 +400,7 @@ def step(self, clear_inputs: bool = False) -> Dict[str, StepData]:
tf_ns = t0_ns + dt_ns

# We collect the StepData of all models
step_data = dict()
step_data = {}

for model in self.models():
# Integrate individually all models and collect their StepData.
Expand Down

0 comments on commit e2ea0e8

Please sign in to comment.