Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Sprint] Avoid evaluation of branches leading to NaNs in the optimizer #362

Closed
wants to merge 103 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
103 commits
Select commit Hold shift + click to select a range
7911bbb
erase integrators
younik Jan 20, 2025
3642e9a
fix import
younik Jan 20, 2025
88395a4
remotveintegrators
younik Jan 20, 2025
55f5c4a
restore heun2 integrator
younik Jan 20, 2025
002ad30
tentative change
younik Jan 20, 2025
f019cba
adding quaternion integration from axis angle
CarlottaSartore Jan 20, 2025
054107b
improve semi implicit euler and add velocity representation
CarlottaSartore Jan 20, 2025
342f788
tentative fix
younik Jan 21, 2025
d7cde88
Pass link and joint forces to `js.ode.system_dynamics`
flferretti Jan 21, 2025
deb9378
fix tests but one
younik Jan 21, 2025
42e2c77
Fix position derivative extraction in semi-implicit Euler
flferretti Jan 21, 2025
2091bd0
Add `pytest-benchmark` in environment YAML
flferretti Jan 20, 2025
68560c7
Remove `rigid.py`
xela-95 Jan 20, 2025
9e22325
Remove `soft.py`
xela-95 Jan 20, 2025
f7a37d1
Remove `visco_elastic.py`
xela-95 Jan 20, 2025
67615ac
Remove commented note about default contact model in JaxSimModel
xela-95 Jan 20, 2025
ef8c4b6
Update `jaxsim.rbda.contacts.__init__.py`
xela-95 Jan 20, 2025
ebe90de
Add `_hunt_crossley_contact_model` to initialize RelaxedRigidContacts
xela-95 Jan 20, 2025
5f337dc
Refactor `contact.py` to use only RelaxedRigidContacts
xela-95 Jan 20, 2025
0a7bd9f
Update `JaxSimModel` to use only RelaxedRigidContacts
xela-95 Jan 20, 2025
20078a7
Update `JaxSimModelData` to use only RelaxedRigidContacts
xela-95 Jan 20, 2025
696bd55
Refactor `ode.py` to simplify `system_dynamics` and `system_velocity_…
xela-95 Jan 20, 2025
e40d796
Remove unused soft contacts test and related parameters from automati…
xela-95 Jan 20, 2025
f9627ad
Refactor simulation tests to remove soft contacts and streamline cont…
xela-95 Jan 20, 2025
42e86b8
Create `contact_model.py` to contain all functions related to contact…
xela-95 Jan 21, 2025
d38a1a8
Add `contact_model` __init__.py
xela-95 Jan 21, 2025
b48e7f8
Remove `collidable_point_dynamics` function from contact module
xela-95 Jan 21, 2025
f7aa6a7
Remove `link_contact_forces` function from model API
xela-95 Jan 21, 2025
2a0903b
Remove compute_link_contact_forces and link_forces_from_contact_force…
xela-95 Jan 21, 2025
7be7794
Update link_contact_forces reference to use contact_model module
xela-95 Jan 21, 2025
721db59
Remove collidable_point_forces function from contact.py
xela-95 Jan 21, 2025
15a7ff1
Refactor `link_contact_forces` to compute contact forces in inertial …
xela-95 Jan 21, 2025
3a61ac6
Refactor `link_forces_from_contact_forces` to work only in inertial
xela-95 Jan 21, 2025
bb2afa3
Update docstring to clarify that computed contact forces are in inert…
xela-95 Jan 21, 2025
d7494de
Refactor heun2_integration to remove aux_dict
xela-95 Jan 21, 2025
0370e26
Refactor system_velocity_dynamics and system_dynamics to use inertial…
xela-95 Jan 21, 2025
9f4111b
Fix semi-implicit integrators
flferretti Jan 21, 2025
e3af2db
Fix pre-commit
flferretti Jan 21, 2025
3441f67
Improve error message for incompatible PyTree structures
flferretti Jan 21, 2025
ff88446
Update AD test to include aux dict and fix gradient check modes
flferretti Jan 21, 2025
76ce235
Merge pull request #347 from ami-iit/sprint/contacts
flferretti Jan 21, 2025
4a28d96
Remove unnecessary `VelRepr` switch from `rbda.contacts`
flferretti Jan 20, 2025
f122b70
Use inertial-fixed representation in `system_dynamics`
flferretti Jan 20, 2025
73f1e63
Remove unnecessary `VelRepr` switch in `forward_dynamics_aba`
flferretti Jan 20, 2025
125250a
Merge pull request #346 from ami-iit/refactor/algorithms
flferretti Jan 21, 2025
8dce161
remove ode_data
younik Jan 22, 2025
a13f592
add comments
younik Jan 22, 2025
d08770e
remove hashing
younik Jan 22, 2025
3963fb5
fix tests
younik Jan 22, 2025
c9898a9
Enable cache in `JaxSimModelData` and extract W_v_L in FK
flferretti Jan 22, 2025
213a1f6
WIP
flferretti Jan 22, 2025
d73d93c
WIP fix tests
flferretti Jan 22, 2025
fcd5738
WIP fix tests round 2
flferretti Jan 23, 2025
c592e23
Fix tests
flferretti Jan 23, 2025
a909469
Fix `test_model_creation_and_reduction`
flferretti Jan 23, 2025
8c0a62f
Avoid extra call to contact Jacobian
flferretti Jan 23, 2025
6c55234
Remove leftover `jax.disable_jit()`
flferretti Jan 23, 2025
8ac7139
Use `JointType` instead of integers in motion subspaces
flferretti Jan 23, 2025
39630ce
Fix pre-commit
flferretti Jan 23, 2025
1735f63
Fix formatting in tests
flferretti Jan 23, 2025
20fd82f
Merge pull request #348 from ami-iit/sprint/data
CarlottaSartore Jan 23, 2025
2071f5d
WIP
flferretti Jan 23, 2025
e58f20e
Remove Heun's method integration function from integrators.py
xela-95 Jan 23, 2025
482fd39
removed usage of reference
xela-95 Jan 23, 2025
8be38ef
Remove commented-out reference handling in system_velocity_dynamics
xela-95 Jan 23, 2025
a559020
Add contact_params to JaxSimModel `build`
xela-95 Jan 23, 2025
2f04f8c
wip collidable_points.py
xela-95 Jan 23, 2025
ab4ea12
Revert "wip collidable_points.py"
xela-95 Jan 23, 2025
c9c05dd
Add jaxlie import to joint_model.py
xela-95 Jan 23, 2025
1515425
Fix attribute access for base_transform in `semi_implicit_euler_integ…
xela-95 Jan 23, 2025
0410673
Update cached data in `test_automatic_differentiation`
xela-95 Jan 23, 2025
cab088b
Merge pull request #349 from ami-iit/sprint/cleanup-xela95
CarlottaSartore Jan 23, 2025
dcceef1
Refactor `collidable_points_pos_vel` to use cached link transforms an…
xela-95 Jan 23, 2025
e241e5f
Refactor `collidable_point_kinematics` to use link transforms and vel…
xela-95 Jan 23, 2025
9ce6018
Add initialization for terrain contact forces in `system_velocity_dyn…
xela-95 Jan 23, 2025
273baf3
Add `actuation_model.py` to compute resultant joint torques
xela-95 Jan 23, 2025
18948ff
Add `actuation_model` import to API module
xela-95 Jan 23, 2025
32e6ccb
Refactor `system_velocity_dynamics` and `system_acceleration`
xela-95 Jan 23, 2025
ce815ed
Rename `joint_force_references` to `joint_torques` in `RelaxedRigidCo…
xela-95 Jan 23, 2025
19a27b9
Refactor `semi_implicit_euler_integration`
xela-95 Jan 23, 2025
8d6631c
Refactor `step` function
xela-95 Jan 23, 2025
e9ad2f1
Refactor `test_box_with_zero_gravity` to use only `jaxsim.VelRepr.Ine…
xela-95 Jan 23, 2025
12f2019
Rename `joint_force_references` to `joint_torques` in `link_contact_f…
xela-95 Jan 23, 2025
096a5f4
Update `step` function changing order of torques and link forces comp…
xela-95 Jan 23, 2025
cd2381a
Rename `link_forces` parameter to `link_forces_inertial` in `step` fu…
xela-95 Jan 24, 2025
976ff4a
Merge pull request #350 from ami-iit/sprint/day4-xela95
xela-95 Jan 24, 2025
6e23c12
Add data.switch in `semi_implicit_euler_integration`
xela-95 Jan 24, 2025
0f5f298
Refactor `link_contact_forces` function by removing unused kwargs and…
xela-95 Jan 24, 2025
563f33a
Fix call to `link_contact_forces` in `step`
xela-95 Jan 24, 2025
15e402e
Merge pull request #353 from ami-iit/sprint/fix-xela95
CarlottaSartore Jan 24, 2025
5af6b84
fix device transfers
younik Jan 23, 2025
15e31c7
Disable exceptions by default
flferretti Jan 24, 2025
e974a30
Change flag name to enable exceptions
flferretti Jan 24, 2025
56d080e
Enable exceptions in tests
flferretti Jan 24, 2025
478660b
Fix argument name in AD integration test
flferretti Jan 24, 2025
7aff901
Save `time_step` and reduced joint positions as float
flferretti Jan 24, 2025
ccb7f27
Merge pull request #352 from ami-iit/sprint/gpu-transfert
flferretti Jan 24, 2025
8cf5fb2
Fix missing switch to mixed in RelaxedRigidContacts `compute_contact_…
xela-95 Jan 28, 2025
17264a8
Update src/jaxsim/rbda/contacts/relaxed_rigid.py
xela-95 Jan 28, 2025
060f14b
Merge pull request #356 from ami-iit/sprint/fix-relaxed-contacts
xela-95 Jan 30, 2025
487b14b
Reintroduce soft, viscoelastic and rigid contact models
flferretti Jan 28, 2025
82e8d9f
Streamline new API changes to alternative contact models
flferretti Jan 28, 2025
1f24739
Avoid evaluation of branches leading to NaNs in the optimizer
flferretti Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/guide/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ The logging and exceptions configurations is controlled by the following environ

*Default:* ``DEBUG`` for development, ``WARNING`` for production.

- ``JAXSIM_DISABLE_EXCEPTIONS``: Disables the runtime checks and exceptions.
- ``JAXSIM_ENABLE_EXCEPTIONS``: Enables the runtime checks and exceptions. Note that enabling exceptions might lead to device-to-host transfer of data, increasing the computational time required.

*Default:* ``False``.

Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies:
# [testing]
- idyntree >= 12.2.1
- pytest
- pytest-benchmark
- pytest-icdiff
- robot_descriptions
- icub-models
Expand Down
5 changes: 2 additions & 3 deletions examples/jaxsim_as_multibody_dynamics_library.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@
"# Alternatively, the same transform can be extracted from the tensor of\n",
"# all link transforms.\n",
"assert jnp.allclose(\n",
" W_H_L, js.model.forward_kinematics(model=model, data=data)[link_index]\n",
" W_H_L, data.link_transforms[link_index]\n",
")\n",
"\n",
"print(f\"Transform of '{link_name}': shape={W_H_L.shape}\\n{W_H_L}\")"
Expand Down Expand Up @@ -782,12 +782,11 @@
")\n",
"\n",
"# Compute the 3D gravity vector and the total mass of the robot.\n",
"W_g = data.gravity\n",
"m = js.model.total_mass(model=model)\n",
"\n",
"# The centroidal dynamics can be computed as follows.\n",
"G_ḣ = 0\n",
"G_ḣ += m * jnp.hstack([W_g, jnp.zeros(3)])\n",
"G_ḣ += m * jnp.hstack([0, 0, model.gravity, 0, 0, 0])\n",
"G_ḣ += jnp.einsum(\"c66,c6->6\", G_Xf_C, jnp.hstack([C_fl, jnp.zeros_like(C_fl)]))\n",
"print(f\"G_ḣ: shape={G_ḣ.shape}\")"
]
Expand Down
6 changes: 3 additions & 3 deletions examples/jaxsim_as_physics_engine.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@
"\n",
"# Simulate\n",
"for _t in T:\n",
" data, _ = js.model.step(\n",
" data = js.model.step(\n",
" model=model,\n",
" data=data,\n",
" link_forces=None,\n",
Expand Down Expand Up @@ -239,7 +239,7 @@
"\n",
"data = data_batch_t0\n",
"for _t in T:\n",
" data, _ = step_parallel(model, data)"
" data = step_parallel(model, data)"
]
}
],
Expand All @@ -266,7 +266,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.0"
"version": "3.13.1"
}
},
"nbformat": 4,
Expand Down
8 changes: 3 additions & 5 deletions examples/jaxsim_as_physics_engine_advanced.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
"import jax\n",
"import jax.numpy as jnp\n",
"import jaxsim.api as js\n",
"import jaxsim\n",
"import rod\n",
"from jaxsim import logging\n",
"from rod.builder.primitives import SphereBuilder\n",
Expand Down Expand Up @@ -163,7 +162,6 @@
"model = js.model.JaxSimModel.build_from_model_description(\n",
" model_description=model_sdf_string,\n",
" time_step=0.001,\n",
" integrator=jaxsim.integrators.fixed_step.Heun2,\n",
")\n",
"\n",
"# Create the data of a single model.\n",
Expand Down Expand Up @@ -299,7 +297,7 @@
" )\n",
")(jnp.vstack(subkeys))\n",
"\n",
"print(\"W_p_B(t0)=\\n\", data_batch_t0.base_position()[0:10])"
"print(\"W_p_B(t0)=\\n\", data_batch_t0.base_position[0:10])"
]
},
{
Expand Down Expand Up @@ -398,7 +396,7 @@
"# This operation is called 'tree transpose' in JAX.\n",
"data_trajectory = jax.tree.map(lambda *leafs: jnp.stack(leafs), *data_trajectory_list)\n",
"\n",
"print(f\"W_p_B: shape={data_trajectory.base_position().shape}\")"
"print(f\"W_p_B: shape={data_trajectory.base_position.shape}\")"
]
},
{
Expand All @@ -412,7 +410,7 @@
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"plt.plot(T, data_trajectory.base_position()[:, 0:5, 2])\n",
"plt.plot(T, data_trajectory.base_position[:, 0:5, 2])\n",
"plt.grid(True)\n",
"plt.xlabel(\"Time [s]\")\n",
"plt.ylabel(\"Height [m]\")\n",
Expand Down
8 changes: 4 additions & 4 deletions examples/jaxsim_for_robot_controllers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@
"\n",
" # Update the MuJoCo data.\n",
" mj_model_helper.set_joint_positions(\n",
" positions=data.joint_positions(), joint_names=model.joint_names()\n",
" positions=data.joint_positions, joint_names=model.joint_names()\n",
" )\n",
"\n",
" # Record a new video frame.\n",
Expand Down Expand Up @@ -343,8 +343,8 @@
" Mss = js.model.free_floating_mass_matrix(model=model, data=data)[6:, 6:]\n",
"\n",
" # Get the current joint positions and velocities.\n",
" s = data.joint_positions()\n",
" ṡ = data.joint_velocities()\n",
" s = data.joint_positions\n",
" ṡ = data.joint_velocities\n",
"\n",
" # Compute the actuated joint torques.\n",
" s_star = -kp * (s - s_des) - kd * (ṡ - s_dot_des)\n",
Expand Down Expand Up @@ -403,7 +403,7 @@
"\n",
" # Update the MuJoCo data.\n",
" mj_model_helper.set_joint_positions(\n",
" positions=data.joint_positions(), joint_names=model.joint_names()\n",
" positions=data.joint_positions, joint_names=model.joint_names()\n",
" )\n",
"\n",
" # Record a new video frame.\n",
Expand Down
2 changes: 1 addition & 1 deletion src/jaxsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,5 @@ def _get_default_logging_level(env_var: str) -> logging.LoggingLevel:
del _is_editable

from . import terrain # isort:skip
from . import api, integrators, logging, math, rbda
from . import api, logging, math, rbda
from .api.common import VelRepr
4 changes: 3 additions & 1 deletion src/jaxsim/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from . import common # isort:skip
from . import model, data # isort:skip
from . import (
actuation_model,
com,
contact,
contact_model,
frame,
integrators,
joint,
kin_dyn_parameters,
link,
ode,
ode_data,
references,
)
96 changes: 96 additions & 0 deletions src/jaxsim/api/actuation_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import jax.numpy as jnp

import jaxsim.api as js
import jaxsim.typing as jtp


def compute_resultant_torques(
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
joint_force_references: jtp.Vector | None = None,
) -> jtp.Vector:
"""
Compute the resultant torques acting on the joints.

Args:
model: The model to consider.
data: The data of the considered model.
joint_force_references: The joint force references to apply.

Returns:
The resultant torques acting on the joints.
"""

# Build joint torques if not provided.
τ_references = (
jnp.atleast_1d(joint_force_references.squeeze())
if joint_force_references is not None
else jnp.zeros_like(data.joint_positions)
).astype(float)

# ====================
# Enforce joint limits
# ====================

τ_position_limit = jnp.zeros_like(τ_references).astype(float)

if model.dofs() > 0:

# Stiffness and damper parameters for the joint position limits.
k_j = jnp.array(
model.kin_dyn_parameters.joint_parameters.position_limit_spring
).astype(float)
d_j = jnp.array(
model.kin_dyn_parameters.joint_parameters.position_limit_damper
).astype(float)

# Compute the joint position limit violations.
lower_violation = jnp.clip(
data.joint_positions
- model.kin_dyn_parameters.joint_parameters.position_limits_min,
max=0.0,
)

upper_violation = jnp.clip(
data.joint_positions
- model.kin_dyn_parameters.joint_parameters.position_limits_max,
min=0.0,
)

# Compute the joint position limit torque.
τ_position_limit -= jnp.diag(k_j) @ (lower_violation + upper_violation)

τ_position_limit -= (
jnp.positive(τ_position_limit) * jnp.diag(d_j) @ data.joint_velocities
)

# ====================
# Joint friction model
# ====================

τ_friction = jnp.zeros_like(τ_references).astype(float)

if model.dofs() > 0:

# Static and viscous joint friction parameters
kc = jnp.array(
model.kin_dyn_parameters.joint_parameters.friction_static
).astype(float)
kv = jnp.array(
model.kin_dyn_parameters.joint_parameters.friction_viscous
).astype(float)

# Compute the joint friction torque.
τ_friction = -(
jnp.diag(kc) @ jnp.sign(data.joint_velocities)
+ jnp.diag(kv) @ data.joint_velocities
)

# ===============================
# Compute the total joint forces.
# ===============================

τ_total = τ_references + τ_friction + τ_position_limit

return τ_total
12 changes: 6 additions & 6 deletions src/jaxsim/api/com.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def com_position(

m = js.model.total_mass(model=model)

W_H_L = js.model.forward_kinematics(model=model, data=data)
W_H_B = data.base_transform()
W_H_L = data.link_transforms
W_H_B = data.base_transform
B_H_W = jaxsim.math.Transform.inverse(transform=W_H_B)

def B_p̃_LCoM(i) -> jtp.Vector:
Expand Down Expand Up @@ -134,7 +134,7 @@ def centroidal_momentum_jacobian(
model=model, data=data, output_vel_repr=VelRepr.Body
)

W_H_B = data.base_transform()
W_H_B = data.base_transform
B_H_W = jaxsim.math.Transform.inverse(W_H_B)

W_p_CoM = com_position(model=model, data=data)
Expand Down Expand Up @@ -172,7 +172,7 @@ def locked_centroidal_spatial_inertia(
with data.switch_velocity_representation(VelRepr.Body):
B_Mbb_B = js.model.locked_spatial_inertia(model=model, data=data)

W_H_B = data.base_transform()
W_H_B = data.base_transform
W_p_CoM = com_position(model=model, data=data)

match data.velocity_representation:
Expand Down Expand Up @@ -269,7 +269,7 @@ def bias_acceleration(
"""

# Compute the pose of all links with forward kinematics.
W_H_L = js.model.forward_kinematics(model=model, data=data)
W_H_L = data.link_transforms

# Compute the bias acceleration of all links by zeroing the generalized velocity
# in the active representation.
Expand Down Expand Up @@ -411,7 +411,7 @@ def bias_momentum_derivative_term(
case VelRepr.Body:

GB_Xf_W = jaxsim.math.Adjoint.from_transform(
transform=data.base_transform().at[0:3].set(W_p_CoM)
transform=data.base_transform.at[0:3].set(W_p_CoM)
).T

GB_ḣ_bias = GB_Xf_W @ W_ḣ_bias
Expand Down
Loading
Loading