Skip to content

Commit

Permalink
Merge pull request #370 from ami-iit/update_documentation
Browse files Browse the repository at this point in the history
Refactor documentation and example scripts in README
  • Loading branch information
flferretti authored Feb 10, 2025
2 parents 4cfb2d2 + 8f65639 commit 0020a58
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 59 deletions.
71 changes: 47 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
<br/>
<table>
<tr>
<th><img src="https://github.com/user-attachments/assets/f9661fae-9a85-41dd-9a58-218758ec8c9c" width="500"></th>
<th><img src="https://github.com/user-attachments/assets/62b88b9d-45ea-4d22-99d2-f24fc842dd29" width="500"></th>
<th><img src="https://github.com/user-attachments/assets/f9661fae-9a85-41dd-9a58-218758ec8c9c"></th>
<th><img src="https://github.com/user-attachments/assets/62b88b9d-45ea-4d22-99d2-f24fc842dd29"></th>
</tr>
</table>
<br/>
Expand All @@ -27,13 +27,16 @@


```python
import pathlib

import icub_models
import jax.numpy as jnp

import jaxsim.api as js
import icub_models
import pathlib

# Load the iCub model
model_path = icub_models.get_model_file("iCubGazeboV2_5")

joints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',
'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch',
'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch',
Expand All @@ -43,33 +46,45 @@ joints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',

# Build and reduce the model
model_description = pathlib.Path(model_path)

full_model = js.model.JaxSimModel.build_from_model_description(
model_description=model_description, time_step=0.0001, is_urdf=True
)

model = js.model.reduce(model=full_model, considered_joints=joints)

# Get the number of degrees of freedom
ndof = model.dofs()

# Initialize data and simulation
# Note that the default data representation is mixed velocity representation
data = js.data.JaxSimModelData.build(model=model,base_position=jnp.array([0.0, 0.0, 1.0]))
data = js.data.JaxSimModelData.build(
model=model, base_position=jnp.array([0.0, 0.0, 1.0])
)

T = jnp.arange(start=0, stop=1.0, step=model.time_step)

tau = jnp.zeros(ndof)

# Simulate
for t in T:
data = js.model.step(model=model, data=data, link_forces=None, joint_force_references=tau)

for _ in T:
data = js.model.step(
model=model, data=data, link_forces=None, joint_force_references=tau
)
```

### Using JaxSim as a multibody dynamics library
``` python
import pathlib

import icub_models
import jax.numpy as jnp

import jaxsim.api as js
import icub_models
import pathlib

# Load the iCub model
model_path = icub_models.get_model_file("iCubGazeboV2_5")

joints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',
'l_shoulder_roll', 'l_shoulder_yaw', 'l_elbow', 'r_shoulder_pitch',
'r_shoulder_roll', 'r_shoulder_yaw', 'r_elbow', 'l_hip_pitch',
Expand All @@ -79,21 +94,31 @@ joints = ('torso_pitch', 'torso_roll', 'torso_yaw', 'l_shoulder_pitch',

# Build and reduce the model
model_description = pathlib.Path(model_path)

full_model = js.model.JaxSimModel.build_from_model_description(
model_description=model_description, time_step=0.0001, is_urdf=True
)

model = js.model.reduce(model=full_model, considered_joints=joints)

# Initialize model data
data = js.data.JaxSimModelData.build(
model=model,
base_position=jnp.array([0.0, 0.0, 1.0],
base_position=jnp.array([0.0, 0.0, 1.0]),
)

# Frame and dynamics computations
frame_index = js.frame.name_to_idx(model=model, frame_name="l_foot")
W_H_F = js.frame.transform(model=model, data=data, frame_index=frame_index) # Frame transformation
W_J_F = js.frame.jacobian(model=model, data=data, frame_index=frame_index) # Frame Jacobian

# Frame transformation
W_H_F = js.frame.transform(
model=model, data=data, frame_index=frame_index
)

# Frame Jacobian
W_J_F = js.frame.jacobian(
model=model, data=data, frame_index=frame_index
)

# Dynamics properties
M = js.model.free_floating_mass_matrix(model=model, data=data) # Mass matrix
Expand All @@ -102,16 +127,15 @@ g = js.model.free_floating_gravity_forces(model=model, data=data) # Gravity for
C = js.model.free_floating_coriolis_matrix(model=model, data=data) # Coriolis matrix

# Print dynamics results
print(f"M: shape={M.shape}, h: shape={h.shape}, g: shape={g.shape}, C: shape={C.shape}")

print(f"{M.shape=} \n{h.shape=} \n{g.shape=} \n{C.shape=}")
```

### Additional features

- Full support for automatic differentiation of RBDAs (forward and reverse modes) with JAX.
- Support for automatically differentiating against kinematics and dynamics parameters.
- All fixed-step integrators are forward and reverse differentiable.
- All variable-step integrators are forward differentiable.
- Check the example folder for additional usecase !
- Check the example folder for additional use cases!

[jax]: https://github.com/google/jax/
[sdformat]: https://github.com/gazebosim/sdformat
Expand Down Expand Up @@ -243,23 +267,21 @@ The JaxSim API documentation is available at [jaxsim.readthedocs.io][readthedocs
src/jaxsim
|-- api..........................# Package containing the main functional APIs.
| |-- com.py...................# |-- APIs for computing quantities related to the center of mass.
| |-- actuation_model.py.......# |-- APIs for computing quantities related to the actuation model.
| |-- common.py................# |-- Common utilities used in the current package.
| |-- com.py...................# |-- APIs for computing quantities related to the center of mass.
| |-- contact_model.py.........# |-- APIs for computing quantities related to the contact model.
| |-- contact.py...............# |-- APIs for computing quantities related to the collidable points.
| |-- data.py..................# |-- Class storing the data of a simulated model.
| |-- frame.py.................# |-- APIs for computing quantities related to additional frames.
| |-- integrators.py...........# |-- APIs for integrating the system dynamics.
| |-- joint.py.................# |-- APIs for computing quantities related to the joints.
| |-- kin_dyn_parameters.py....# |-- Class storing kinematic and dynamic parameters of a model.
| |-- link.py..................# |-- APIs for computing quantities related to the links.
| |-- model.py.................# |-- Class defining a simulated model and APIs for computing related quantities.
| |-- ode.py...................# |-- APIs for computing quantities related to the system dynamics.
| |-- ode_data.py..............# |-- Set of classes to store the data of the system dynamics.
| `-- references.py............# `-- Helper class to create references (link forces and joint torques).
|-- exceptions.py................# Module containing functions to raise exceptions from JIT-compiled functions.
|-- integrators..................# Package containing the integrators used to simulate the system dynamics.
| |-- common.py................# |-- Common utilities used in the current package.
| |-- fixed_step.py............# |-- Fixed-step integrators (explicit Runge-Kutta schemes).
| `-- variable_step.py.........# `-- Variable-step integrators (embedded Runge-Kutta schemes).
|-- logging.py...................# Module containing logging utilities.
|-- math.........................# Package containing mathematical utilities.
| |-- adjoint.py...............# |-- APIs for creating and manipulating 6D transformations.
Expand All @@ -269,7 +291,8 @@ src/jaxsim
| |-- quaternion.py............# |-- APIs for creating and manipulating quaternions.
| |-- rotation.py..............# |-- APIs for creating and manipulating rotation matrices.
| |-- skew.py..................# |-- APIs for creating and manipulating skew-symmetric matrices.
| `-- transform.py.............# `-- APIs for creating and manipulating homogeneous transformations.
| |-- transform.py.............# |-- APIs for creating and manipulating homogeneous transformations.
| |-- utils.py.................# |-- Common utilities used in the current package.
|-- mujoco.......................# Package containing utilities to interact with the Mujoco passive viewer.
| |-- loaders.py...............# |-- Utilities for converting JaxSim models to Mujoco models.
| |-- model.py.................# |-- Class providing high-level methods to compute quantities using Mujoco.
Expand Down
4 changes: 2 additions & 2 deletions docs/guide/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ Alternatively, you can use `pypa/pip`_, preferably in a `virtual environment`_:
pip install jaxsim
Have a look to `pyproject.toml`_ for a complete list of optional dependencies.
You can install all of them by specifying ``jaxsim[all]`.
You can install all by using ``pip install "jaxsim[all]"``.
.. note::

If you need GPU support, please follow the official `installation instruction`_ of JAX.

.. _conda: https://anaconda.org/
.. _pyproject.toml: https://github.com/ami-iit/jaxsim/blob/main/pyproject.toml
.. _pypa/pip: https://github.com/pypa/pip/
.. _virtual environment: https://docs.python.org/3.8/tutorial/venv.html
.. _installation instruction: https://github.com/google/jax/#installation
1 change: 0 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ Features
:caption: JAXsim API

modules/api
modules/integrators
modules/math
modules/mujoco
modules/parsers
Expand Down
17 changes: 13 additions & 4 deletions docs/modules/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@ Functional API
data
contact
kin_dyn_parameters
integrators
joint
link
frame
com
ode_data
ode
references
actuation_model
common


Expand All @@ -27,6 +28,10 @@ Model
:members:
:no-index:

.. automodule:: jaxsim.api.actuation_model
:members:
:no-index:

Data
~~~~

Expand Down Expand Up @@ -57,29 +62,33 @@ Joint

Link
~~~~~

.. automodule:: jaxsim.api.link
:members:
:no-index:

Frame
~~~~~

.. automodule:: jaxsim.api.frame
:members:
:no-index:

CoM
~~~

.. automodule:: jaxsim.api.com
:members:
:no-index:

ODE Data
~~~~~~~~
Integration
~~~~~~~~~~~

.. automodule:: jaxsim.api.ode_data
.. automodule:: jaxsim.api.integrators
:members:
:no-index:


.. automodule:: jaxsim.api.ode
:members:
:no-index:
Expand Down
23 changes: 0 additions & 23 deletions docs/modules/integrators.rst

This file was deleted.

10 changes: 10 additions & 0 deletions docs/modules/rbda.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ This module provides a set of algorithms for rigid body dynamics.
aba
collidable_points
contacts.soft
contacts.rigid
contacts.relaxed_rigid
crba
forward_kinematics
jacobian
Expand All @@ -30,6 +32,14 @@ Contact Models
:members:
:no-index:

.. automodule:: jaxsim.rbda.contacts.rigid
:members:
:no-index:

.. automodule:: jaxsim.rbda.contacts.relaxed_rigid
:members:
:no-index:

Utilities
~~~~~~~~~

Expand Down
11 changes: 6 additions & 5 deletions docs/modules/typing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ Typing
Int
Float
Vector
FloatJax
IntJax
ArrayJax
VectorJax
MatrixJax
BoolLike
FloatLike
IntLike
ArrayLike
VectorLike
MatrixLike

0 comments on commit 0020a58

Please sign in to comment.