Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add helper to convert JaxSimModelData to mujoco.mjData #264

Merged
merged 4 commits into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/jaxsim/mujoco/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .loaders import RodModelToMjcf, SdfToMjcf, UrdfToMjcf
from .model import MujocoModelHelper
from .utils import mujoco_data_from_jaxsim
from .visualizer import MujocoVideoRecorder, MujocoVisualizer
2 changes: 1 addition & 1 deletion src/jaxsim/mujoco/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ class MujocoCamera:
def build(cls, **kwargs) -> MujocoCamera:

if not all(isinstance(value, str) for value in kwargs.values()):
raise ValueError("Values must be strings")
raise ValueError(f"Values must be strings: {kwargs}")

return cls(**kwargs)

Expand Down
7 changes: 4 additions & 3 deletions src/jaxsim/mujoco/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import functools
import pathlib
from collections.abc import Callable
from collections.abc import Callable, Sequence
from typing import Any

import mujoco as mj
Expand Down Expand Up @@ -107,7 +107,8 @@ def build_from_xml(
size = [float(el) for el in hfield_element["@size"].split(" ")]
size[0], size[1] = heightmap_radius_xy
size[2] = 1.0
size[3] = max(0, -min(hfield))
# The following could be zero but Mujoco complains if it's exactly zero.
size[3] = max(0.000_001, -min(hfield))

# Replace the 'size' attribute.
hfields_dict[heightmap_name]["@size"] = " ".join(str(el) for el in size)
Expand Down Expand Up @@ -315,7 +316,7 @@ def set_joint_position(
self.data.qpos[sl] = position

def set_joint_positions(
self, joint_names: list[str], positions: npt.NDArray | list[npt.NDArray]
self, joint_names: Sequence[str], positions: npt.NDArray | list[npt.NDArray]
) -> None:
"""Set the positions of multiple joints."""

Expand Down
101 changes: 101 additions & 0 deletions src/jaxsim/mujoco/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import mujoco as mj
import numpy as np

from . import MujocoModelHelper


def mujoco_data_from_jaxsim(
mujoco_model: mj.MjModel,
jaxsim_model,
jaxsim_data,
mujoco_data: mj.MjData | None = None,
update_removed_joints: bool = True,
) -> mj.MjData:
"""
Create a Mujoco data object from a JaxSim model and data objects.

Args:
mujoco_model: The Mujoco model object corresponding to the JaxSim model.
jaxsim_model: The JaxSim model object from which the Mujoco model was created.
jaxsim_data: The JaxSim data object containing the state of the model.
mujoco_data: An optional Mujoco data object. If None, a new one will be created.
update_removed_joints:
If True, the positions of the joints that have been removed during the
model reduction process will be set to their initial values.

Returns:
The Mujoco data object containing the state of the JaxSim model.

Note:
This method is useful to initialize a Mujoco data object used for visualization
with the state of a JaxSim model. In particular, this function takes care of
initializing the positions of the joints that have been removed during the
model reduction process. After the initial creation of the Mujoco data object,
it's faster to update the state using an external MujocoModelHelper object.
"""

# The package `jaxsim.mujoco` is supposed to be jax-independent.
# We import all the JaxSim resources privately.
import jaxsim.api as js

if not isinstance(jaxsim_model, js.model.JaxSimModel):
raise ValueError("The `jaxsim_model` argument must be a JaxSimModel object.")

if not isinstance(jaxsim_data, js.data.JaxSimModelData):
raise ValueError("The `jaxsim_data` argument must be a JaxSimModelData object.")

# Create the helper to operate on the Mujoco model and data.
model_helper = MujocoModelHelper(model=mujoco_model, data=mujoco_data)

# If the model is fixed-base, the Mujoco model won't have the joint corresponding
# to the floating base, and the helper would raise an exception.
if jaxsim_model.floating_base():

# Set the model position.
model_helper.set_base_position(position=np.array(jaxsim_data.base_position()))

# Set the model orientation.
model_helper.set_base_orientation(
orientation=np.array(jaxsim_data.base_orientation())
)

# Set the joint positions.
if jaxsim_model.dofs() > 0:

model_helper.set_joint_positions(
joint_names=list(jaxsim_model.joint_names()),
positions=np.array(
jaxsim_data.joint_positions(
model=jaxsim_model, joint_names=jaxsim_model.joint_names()
)
),
)

# Updating these joints is not necessary after the first time.
# Users can disable this update after initialization.
if update_removed_joints:

# Create a dictionary with the joints that have been removed for various reasons
# (like link lumping due to model reduction).
joints_removed_dict = {
j.name: j
for j in jaxsim_model.description._joints_removed
if j.name not in set(jaxsim_model.joint_names())
}

# Set the positions of the removed joints.
_ = [
model_helper.set_joint_position(
position=joints_removed_dict[joint_name].initial_position,
joint_name=joint_name,
)
# Select all original joint that have been removed from the JaxSim model
# that are still present in the Mujoco model.
for joint_name in joints_removed_dict
if joint_name in model_helper.joint_names()
]

# Return the mujoco data with updated kinematics.
mj.mj_forward(mujoco_model, model_helper.data)

return model_helper.data
2 changes: 1 addition & 1 deletion src/jaxsim/mujoco/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def write_video(self, path: pathlib.Path, exist_ok: bool = False) -> None:
if not exist_ok and path.is_file():
raise FileExistsError(f"The file '{path}' already exists.")

media.write_video(path=path, images=self.frames, fps=self.fps)
media.write_video(path=path, images=np.array(self.frames), fps=self.fps)

@staticmethod
def compute_down_sampling(original_fps: int, target_min_fps: int) -> int:
Expand Down