Skip to content

Commit

Permalink
Merge pull request #108 from ami-iit/functional
Browse files Browse the repository at this point in the history
Finalize the functional APIs and replace the OOP classes
  • Loading branch information
diegoferigo authored Mar 29, 2024
2 parents 4fd2032 + 2b353d1 commit 83fcf58
Show file tree
Hide file tree
Showing 91 changed files with 7,640 additions and 8,742 deletions.
15 changes: 13 additions & 2 deletions .github/workflows/ci_cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,22 @@ jobs:
with:
fetch-depth: 0

- name: Install Gazebo Classic
# - name: Install Gazebo Classic
# if: contains(matrix.os, 'ubuntu')
# run: |
# sudo apt-get update
# sudo apt-get install gazebo

# https://gazebosim.org/docs/harmonic/install_ubuntu
- name: Install Gazebo Sim
if: contains(matrix.os, 'ubuntu')
run: |
sudo apt-get update
sudo apt-get install gazebo
sudo apt-get install lsb-release wget gnupg
sudo wget https://packages.osrfoundation.org/gazebo.gpg -O /usr/share/keyrings/pkgs-osrf-archive-keyring.gpg
echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/pkgs-osrf-archive-keyring.gpg] http://packages.osrfoundation.org/gazebo/ubuntu-stable $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/gazebo-stable.list > /dev/null
sudo apt-get update
sudo apt-get install gz-harmonic
- name: Run the Python tests
if: contains(matrix.os, 'ubuntu')
Expand Down
3 changes: 1 addition & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies:
- jaxlie >= 1.3.0
- jax-dataclasses >= 1.4.0
- pptree
- rod
- rod >= 0.2.0
- typing_extensions # python<3.12
# Optional dependencies from setup.cfg
# [style]
Expand All @@ -19,7 +19,6 @@ dependencies:
# [testing]
- idyntree
- pytest
- pytest-forked
- pytest-icdiff
- robot_descriptions
# [viz]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ multi_line_output = 3

[tool.pytest.ini_options]
minversion = "6.0"
addopts = "-rsxX -v --strict-markers --forked"
addopts = "-rsxX -v --strict-markers"
testpaths = [
"tests",
]
9 changes: 4 additions & 5 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ package_dir =
python_requires = >=3.11
install_requires =
coloredlogs
jax >= 0.4.13,< 0.4.25
jaxlib >= 0.4.13,< 0.4.25
jax >= 0.4.13
jaxlib >= 0.4.13
jaxlie >= 1.3.0
jax_dataclasses >= 1.4.0
pptree
rod
rod >= 0.2.0
typing_extensions ; python_version < '3.12'

[options.packages.find]
Expand All @@ -71,8 +71,7 @@ style =
pre-commit
testing =
idyntree
pytest >= 6.0
pytest-forked
pytest >=6.0
pytest-icdiff
robot-descriptions
viz =
Expand Down
7 changes: 3 additions & 4 deletions src/jaxsim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def _is_editable() -> bool:
del _np_options
del _is_editable

from . import high_level, logging, math, simulation, sixd
from .high_level.common import VelRepr
from .simulation.ode_integration import IntegratorType
from .simulation.simulator import JaxSim
from . import terrain # isort:skip
from . import api, integrators, logging, math, rbda
from .api.common import VelRepr
4 changes: 3 additions & 1 deletion src/jaxsim/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from . import contact, data, joint, link, model, ode
from . import common # isort:skip
from . import model, data # isort:skip
from . import com, contact, joint, kin_dyn_parameters, link, ode, ode_data, references
240 changes: 240 additions & 0 deletions src/jaxsim/api/com.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
import jax
import jax.numpy as jnp
import jaxlie

import jaxsim.api as js
import jaxsim.math
import jaxsim.typing as jtp

from .common import VelRepr


@jax.jit
def com_position(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Vector:
"""
Compute the position of the center of mass of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The position of the center of mass of the model w.r.t. the world frame.
"""

m = js.model.total_mass(model=model)

W_H_L = js.model.forward_kinematics(model=model, data=data)
W_H_B = data.base_transform()
B_H_W = jaxlie.SE3.from_matrix(W_H_B).inverse().as_matrix()

def B_p̃_LCoM(i) -> jtp.Vector:
m = js.link.mass(model=model, link_index=i)
L_p_LCoM = js.link.com_position(
model=model, data=data, link_index=i, in_link_frame=True
)
return m * B_H_W @ W_H_L[i] @ jnp.hstack([L_p_LCoM, 1])

com_links = jax.vmap(B_p̃_LCoM)(jnp.arange(model.number_of_links()))

B_p̃_CoM = (1 / m) * com_links.sum(axis=0)
B_p̃_CoM = B_p̃_CoM.at[3].set(1)

return (W_H_B @ B_p̃_CoM)[0:3].astype(float)


@jax.jit
def com_linear_velocity(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Vector:
r"""
Compute the linear velocity of the center of mass of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The linear velocity of the center of mass of the model in the
active representation.
Note:
The linear velocity of the center of mass is expressed in the mixed frame
:math:`G = ({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`[C] = [W]` if the
active velocity representation is either inertial-fixed or mixed,
and :math:`[C] = [B]` if the active velocity representation is body-fixed.
"""

# Extract the linear component of the 6D average centroidal velocity.
# This is expressed in G[B] in body-fixed representation, and in G[W] in
# inertial-fixed or mixed representation.
G_vl_WG = average_centroidal_velocity(model=model, data=data)[0:3]

return G_vl_WG


@jax.jit
def centroidal_momentum(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Vector:
r"""
Compute the centroidal momentum of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The centroidal momentum of the model.
Note:
The centroidal momentum is expressed in the mixed frame
:math:`({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`C = W` if the
active velocity representation is either inertial-fixed or mixed,
and :math:`C = B` if the active velocity representation is body-fixed.
"""

ν = data.generalized_velocity()
G_J = centroidal_momentum_jacobian(model=model, data=data)

return G_J @ ν


@jax.jit
def centroidal_momentum_jacobian(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Matrix:
r"""
Compute the Jacobian of the centroidal momentum of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The Jacobian of the centroidal momentum of the model.
Note:
The frame corresponding to the output representation of this Jacobian is either
:math:`G[W]`, if the active velocity representation is inertial-fixed or mixed,
or :math:`G[B]`, if the active velocity representation is body-fixed.
Note:
This Jacobian is also known in the literature as Centroidal Momentum Matrix.
"""

# Compute the Jacobian of the total momentum with body-fixed output representation.
# We convert the output representation either to G[W] or G[B] below.
B_Jh = js.model.total_momentum_jacobian(
model=model, data=data, output_vel_repr=VelRepr.Body
)

W_H_B = data.base_transform()
B_H_W = jaxsim.math.Transform.inverse(W_H_B)

W_p_CoM = com_position(model=model, data=data)

match data.velocity_representation:
case VelRepr.Inertial | VelRepr.Mixed:
W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
case VelRepr.Body:
W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM)
case _:
raise ValueError(data.velocity_representation)

# Compute the transform for 6D forces.
G_Xf_B = jaxsim.math.Adjoint.from_transform(transform=B_H_W @ W_H_G).T

return G_Xf_B @ B_Jh


@jax.jit
def locked_centroidal_spatial_inertia(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
):
"""
Compute the locked centroidal spatial inertia of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The locked centroidal spatial inertia of the model.
"""

with data.switch_velocity_representation(VelRepr.Body):
B_Mbb_B = js.model.locked_spatial_inertia(model=model, data=data)

W_H_B = data.base_transform()
W_p_CoM = com_position(model=model, data=data)

match data.velocity_representation:
case VelRepr.Inertial | VelRepr.Mixed:
W_H_G = W_H_GW = jnp.eye(4).at[0:3, 3].set(W_p_CoM)
case VelRepr.Body:
W_H_G = W_H_GB = W_H_B.at[0:3, 3].set(W_p_CoM)
case _:
raise ValueError(data.velocity_representation)

B_H_G = jaxlie.SE3.from_matrix(jaxsim.math.Transform.inverse(W_H_B) @ W_H_G)

B_Xv_G = B_H_G.adjoint()
G_Xf_B = B_Xv_G.transpose()

return G_Xf_B @ B_Mbb_B @ B_Xv_G


@jax.jit
def average_centroidal_velocity(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Vector:
r"""
Compute the average centroidal velocity of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The average centroidal velocity of the model.
Note:
The average velocity is expressed in the mixed frame
:math:`G = ({}^W \mathbf{p}_{\text{CoM}}, [C])`, where :math:`[C] = [W]` if the
active velocity representation is either inertial-fixed or mixed,
and :math:`[C] = [B]` if the active velocity representation is body-fixed.
"""

ν = data.generalized_velocity()
G_J = average_centroidal_velocity_jacobian(model=model, data=data)

return G_J @ ν


@jax.jit
def average_centroidal_velocity_jacobian(
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Matrix:
r"""
Compute the Jacobian of the average centroidal velocity of the model.
Args:
model: The model to consider.
data: The data of the considered model.
Returns:
The Jacobian of the average centroidal velocity of the model.
Note:
The frame corresponding to the output representation of this Jacobian is either
:math:`G[W]`, if the active velocity representation is inertial-fixed or mixed,
or :math:`G[B]`, if the active velocity representation is body-fixed.
"""

G_J = centroidal_momentum_jacobian(model=model, data=data)
G_Mbb = locked_centroidal_spatial_inertia(model=model, data=data)

return jnp.linalg.inv(G_Mbb) @ G_J
15 changes: 13 additions & 2 deletions src/jaxsim/api/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import contextlib
import dataclasses
import enum
import functools
from typing import ContextManager

Expand All @@ -11,7 +12,6 @@
from jax_dataclasses import Static

import jaxsim.typing as jtp
from jaxsim.high_level.common import VelRepr
from jaxsim.utils import JaxsimDataclass, Mutability

try:
Expand All @@ -20,6 +20,17 @@
from typing_extensions import Self


@enum.unique
class VelRepr(enum.IntEnum):
"""
Enumeration of all supported 6D velocity representations.
"""

Body = enum.auto()
Mixed = enum.auto()
Inertial = enum.auto()


@jax_dataclasses.pytree_dataclass
class ModelDataWithVelocityRepresentation(JaxsimDataclass, abc.ABC):
"""
Expand Down Expand Up @@ -59,7 +70,7 @@ def switch_velocity_representation(
# We run this in a mutable context with restoration so that any exception
# occurring, we restore the original object in case it was modified.
with self.mutable_context(
mutability=self._mutability(), restore_after_exception=True
mutability=self.mutability(), restore_after_exception=True
):
yield self

Expand Down
Loading

0 comments on commit 83fcf58

Please sign in to comment.