Skip to content

Commit

Permalink
Merge pull request #101 from ami-iit/feature/kin_dyn_parameters
Browse files Browse the repository at this point in the history
Initial support of parametric hardware models
  • Loading branch information
diegoferigo authored Mar 14, 2024
2 parents 795df2b + 7b6a88b commit 1dbd1c8
Show file tree
Hide file tree
Showing 14 changed files with 1,096 additions and 169 deletions.
2 changes: 1 addition & 1 deletion src/jaxsim/api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from . import model, data # isort:skip
from . import common, contact, joint, link, ode, references
from . import common, contact, joint, kin_dyn_parameters, link, ode, references
22 changes: 10 additions & 12 deletions src/jaxsim/api/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@
import jax
import jax.numpy as jnp

import jaxsim.api as js
import jaxsim.typing as jtp
from jaxsim.physics.algos import soft_contacts

from . import data as Data
from . import model as Model


@jax.jit
def collidable_point_kinematics(
model: Model.JaxSimModel, data: Data.JaxSimModelData
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> tuple[jtp.Matrix, jtp.Matrix]:
"""
Compute the position and 3D velocity of the collidable points in the world frame.
Expand Down Expand Up @@ -44,7 +42,7 @@ def collidable_point_kinematics(

@jax.jit
def collidable_point_positions(
model: Model.JaxSimModel, data: Data.JaxSimModelData
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Matrix:
"""
Compute the position of the collidable points in the world frame.
Expand All @@ -62,7 +60,7 @@ def collidable_point_positions(

@jax.jit
def collidable_point_velocities(
model: Model.JaxSimModel, data: Data.JaxSimModelData
model: js.model.JaxSimModel, data: js.data.JaxSimModelData
) -> jtp.Matrix:
"""
Compute the 3D velocity of the collidable points in the world frame.
Expand All @@ -80,8 +78,8 @@ def collidable_point_velocities(

@functools.partial(jax.jit, static_argnames=["link_names"])
def in_contact(
model: Model.JaxSimModel,
data: Data.JaxSimModelData,
model: js.model.JaxSimModel,
data: js.data.JaxSimModelData,
*,
link_names: tuple[str, ...] | None = None,
) -> jtp.Vector:
Expand Down Expand Up @@ -131,7 +129,7 @@ def in_contact(

@jax.jit
def estimate_good_soft_contacts_parameters(
model: Model.JaxSimModel,
model: js.model.JaxSimModel,
static_friction_coefficient: jtp.FloatLike = 0.5,
number_of_active_collidable_points_steady_state: jtp.IntLike = 1,
damping_ratio: jtp.FloatLike = 1.0,
Expand Down Expand Up @@ -160,14 +158,14 @@ def estimate_good_soft_contacts_parameters(
specific application.
"""

def estimate_model_height(model: Model.JaxSimModel) -> jtp.Float:
def estimate_model_height(model: js.model.JaxSimModel) -> jtp.Float:
""""""

zero_data = Data.JaxSimModelData.build(
zero_data = js.data.JaxSimModelData.build(
model=model, soft_contacts_params=soft_contacts.SoftContactsParams()
)

W_pz_CoM = Model.com_position(model=model, data=zero_data)[2]
W_pz_CoM = js.model.com_position(model=model, data=zero_data)[2]

if model.physics_model.is_floating_base:
W_pz_C = collidable_point_positions(model=model, data=zero_data)[:, -1]
Expand Down
32 changes: 16 additions & 16 deletions src/jaxsim/api/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import jaxlie
import numpy as np

import jaxsim.api
import jaxsim.api as js
import jaxsim.physics.algos.aba
import jaxsim.physics.algos.crba
import jaxsim.physics.algos.forward_kinematics
Expand Down Expand Up @@ -48,7 +48,7 @@ class JaxSimModelData(common.ModelDataWithVelocityRepresentation):
default_factory=lambda: jnp.array(0, dtype=jnp.uint64)
)

def valid(self, model: jaxsim.api.model.JaxSimModel | None = None) -> bool:
def valid(self, model: js.model.JaxSimModel | None = None) -> bool:
"""
Check if the current state is valid for the given model.
Expand All @@ -68,7 +68,7 @@ def valid(self, model: jaxsim.api.model.JaxSimModel | None = None) -> bool:

@staticmethod
def zero(
model: jaxsim.api.model.JaxSimModel,
model: js.model.JaxSimModel,
velocity_representation: VelRepr = VelRepr.Inertial,
) -> JaxSimModelData:
"""
Expand All @@ -88,7 +88,7 @@ def zero(

@staticmethod
def build(
model: jaxsim.api.model.JaxSimModel,
model: js.model.JaxSimModel,
base_position: jtp.Vector | None = None,
base_quaternion: jtp.Vector | None = None,
joint_positions: jtp.Vector | None = None,
Expand Down Expand Up @@ -167,7 +167,7 @@ def build(
soft_contacts_params = (
soft_contacts_params
if soft_contacts_params is not None
else jaxsim.api.contact.estimate_good_soft_contacts_parameters(model=model)
else js.contact.estimate_good_soft_contacts_parameters(model=model)
)

W_H_B = jaxlie.SE3.from_rotation_and_translation(
Expand Down Expand Up @@ -225,7 +225,7 @@ def time(self) -> jtp.Float:
@functools.partial(jax.jit, static_argnames=["joint_names"])
def joint_positions(
self,
model: jaxsim.api.model.JaxSimModel | None = None,
model: js.model.JaxSimModel | None = None,
joint_names: tuple[str, ...] | None = None,
) -> jtp.Vector:
"""
Expand Down Expand Up @@ -259,13 +259,13 @@ def joint_positions(
joint_names = joint_names if joint_names is not None else model.joint_names()

return self.state.physics_model.joint_positions[
jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
js.joint.names_to_idxs(joint_names=joint_names, model=model)
]

@functools.partial(jax.jit, static_argnames=["joint_names"])
def joint_velocities(
self,
model: jaxsim.api.model.JaxSimModel | None = None,
model: js.model.JaxSimModel | None = None,
joint_names: tuple[str, ...] | None = None,
) -> jtp.Vector:
"""
Expand Down Expand Up @@ -299,7 +299,7 @@ def joint_velocities(
joint_names = joint_names if joint_names is not None else model.joint_names()

return self.state.physics_model.joint_velocities[
jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
js.joint.names_to_idxs(joint_names=joint_names, model=model)
]

@jax.jit
Expand Down Expand Up @@ -430,7 +430,7 @@ def generalized_velocity(self) -> jtp.Vector:
def reset_joint_positions(
self,
positions: jtp.VectorLike,
model: jaxsim.api.model.JaxSimModel | None = None,
model: js.model.JaxSimModel | None = None,
joint_names: tuple[str, ...] | None = None,
) -> Self:
"""
Expand Down Expand Up @@ -468,15 +468,15 @@ def replace(s: jtp.VectorLike) -> JaxSimModelData:

return replace(
s=self.state.physics_model.joint_positions.at[
jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
js.joint.names_to_idxs(joint_names=joint_names, model=model)
].set(positions)
)

@functools.partial(jax.jit, static_argnames=["joint_names"])
def reset_joint_velocities(
self,
velocities: jtp.VectorLike,
model: jaxsim.api.model.JaxSimModel | None = None,
model: js.model.JaxSimModel | None = None,
joint_names: tuple[str, ...] | None = None,
) -> Self:
"""
Expand Down Expand Up @@ -514,7 +514,7 @@ def replace(ṡ: jtp.VectorLike) -> JaxSimModelData:

return replace(
=self.state.physics_model.joint_velocities.at[
jaxsim.api.joint.names_to_idxs(joint_names=joint_names, model=model)
js.joint.names_to_idxs(joint_names=joint_names, model=model)
].set(velocities)
)

Expand Down Expand Up @@ -692,7 +692,7 @@ def reset_base_velocity(


def random_model_data(
model: jaxsim.api.model.JaxSimModel,
model: js.model.JaxSimModel,
*,
key: jax.Array | None = None,
velocity_representation: VelRepr | None = None,
Expand Down Expand Up @@ -762,8 +762,8 @@ def random_model_data(
).as_quaternion_xyzw()[np.array([3, 0, 1, 2])]

if model.number_of_joints() > 0:
physics_model_state.joint_positions = (
jaxsim.api.joint.random_joint_positions(model=model, key=k3)
physics_model_state.joint_positions = js.joint.random_joint_positions(
model=model, key=k3
)

physics_model_state.joint_velocities = jax.random.uniform(
Expand Down
Loading

0 comments on commit 1dbd1c8

Please sign in to comment.