Skip to content
Merged
54 changes: 54 additions & 0 deletions genesis/engine/entities/rigid_entity/rigid_entity.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def __init__(
vgeom_start: int,
vvert_start: int,
vface_start: int,
custom_vvert_start: int,
custom_vface_start: int,
morph_heterogeneous: list[Morph] | None = None,
name: str | None = None,
):
Expand All @@ -97,6 +99,8 @@ def __init__(
self._vgeom_start = vgeom_start
self._vvert_start = vvert_start
self._vface_start = vface_start
self._custom_vvert_start = custom_vvert_start
self._custom_vface_start = custom_vface_start

self._is_built: bool = False
self._is_attached: bool = False
Expand Down Expand Up @@ -1916,6 +1920,52 @@ def vgeoms(self):
return self._vgeoms
return gs.List(vgeom for link in self._links for vgeom in link.vgeoms)

@gs.assert_built
def set_vverts(self, vverts, envs_idx=None):
"""Override this entity's visual vertex positions for rendering and sensors.

vverts is broadcast to (len(envs_idx), n_vverts, 3); scalars, (3,) and (n_vverts, 3) are accepted. vverts=None
re-runs FK over the entity's vgeoms and writes the result back into the custom buffer. Requires the entity's
morph to be created with enable_custom_vverts=True.
"""
if self._enable_heterogeneous:
gs.raise_exception("This method is not supported by heterogeneous entities.")
if not self._morph.enable_custom_vverts:
gs.raise_exception(
"'set_vverts' requires the entity's morph to be created with 'enable_custom_vverts=True'."
)
self._solver.set_vverts(
self._custom_vvert_start,
self._custom_vvert_start + self.n_vverts,
np.array([vg.idx for vg in self.vgeoms], dtype=gs.np_int),
vverts,
envs_idx,
)

@gs.assert_built
def get_vverts(self, envs_idx=None):
"""Return a copy of this entity's visual vertex positions in world space.

For entities created with enable_custom_vverts=True the positions are read from the engine custom buffer; for
other entities they are computed on the fly from each vgeom's current pose applied to its rest-pose init_vverts.
"""
if self._enable_heterogeneous:
gs.raise_exception("This method is not supported by heterogeneous entities.")
if self._morph.enable_custom_vverts:
return self._solver.get_vverts(self._custom_vvert_start, self._custom_vvert_start + self.n_vverts, envs_idx)

self._solver.update_vgeoms()
vgeoms_pos = qd_to_torch(self._solver.vgeoms_state.pos, envs_idx, transpose=True, copy=None)
vgeoms_quat = qd_to_torch(self._solver.vgeoms_state.quat, envs_idx, transpose=True, copy=None)
parts = []
for vgeom in self.vgeoms:
init = torch.as_tensor(vgeom.init_vverts, dtype=gs.tc_float, device=gs.device)
pos = vgeoms_pos[..., vgeom.idx, :].unsqueeze(-2)
quat = vgeoms_quat[..., vgeom.idx, :].unsqueeze(-2)
parts.append(gu.transform_by_trans_quat(init, pos, quat))
tensor = torch.cat(parts, dim=-2)
return tensor[0] if self._solver.n_envs == 0 else tensor

@property
def links(self) -> list[RigidLink]:
"""The list of links (`RigidLink`) in the entity."""
Expand Down Expand Up @@ -1977,6 +2027,8 @@ def __init__(
vgeom_start=0,
vvert_start=0,
vface_start=0,
custom_vvert_start=0,
custom_vface_start=0,
equality_start=0,
visualize_contact: bool = False,
morph_heterogeneous: list[Morph] | None = None,
Expand Down Expand Up @@ -2011,6 +2063,8 @@ def __init__(
vgeom_start,
vvert_start,
vface_start,
custom_vvert_start,
custom_vface_start,
morph_heterogeneous,
name,
)
Expand Down
42 changes: 42 additions & 0 deletions genesis/engine/entities/rigid_entity/rigid_geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,48 @@ def get_vAABB(self, envs_idx=None):
vverts_pos = pos[..., None, :] + gu.transform_by_quat(self._aabb_verts, quat[..., None, :])
return torch.stack((vverts_pos.min(dim=-2).values, vverts_pos.max(dim=-2).values), dim=-2)

@gs.assert_built
def set_vverts(self, vverts, envs_idx=None):
"""Override this vgeom's visual vertex positions for rendering and sensors. See
:meth:`KinematicEntity.set_vverts` for the full behavior; this method writes only this vgeom's slice.

Requires the owning entity's morph to be created with 'enable_custom_vverts=True'.
"""
if not self._entity._morph.enable_custom_vverts:
gs.raise_exception(
"'set_vverts' requires the entity's morph to be created with 'enable_custom_vverts=True'."
)
custom_offset = self._entity._custom_vvert_start - self._entity._vvert_start
self._entity._solver.set_vverts(
self.vvert_start + custom_offset,
self.vvert_end + custom_offset,
np.array([self.idx], dtype=gs.np_int),
vverts,
envs_idx,
)

@gs.assert_built
def get_vverts(self, envs_idx=None):
"""Return a copy of this vgeom's visual vertex positions in world space.

Entities created with 'morph.enable_custom_vverts=True' read from the engine-owned 'vverts_state.pos' buffer.
Other entities compute the positions on the fly from the vgeom's current world pose applied to 'init_vverts'.
"""
if self._entity._morph.enable_custom_vverts:
custom_offset = self._entity._custom_vvert_start - self._entity._vvert_start
return self._entity._solver.get_vverts(
self.vvert_start + custom_offset, self.vvert_end + custom_offset, envs_idx
)

self._solver.update_vgeoms()
vgeoms_pos = qd_to_torch(self._solver.vgeoms_state.pos, envs_idx, transpose=True, copy=None)
vgeoms_quat = qd_to_torch(self._solver.vgeoms_state.quat, envs_idx, transpose=True, copy=None)
init = torch.as_tensor(self.init_vverts, dtype=gs.tc_float, device=gs.device)
pos = vgeoms_pos[..., self.idx, :].unsqueeze(-2)
quat = vgeoms_quat[..., self.idx, :].unsqueeze(-2)
tensor = gu.transform_by_trans_quat(init, pos, quat)
return tensor[0] if self._solver.n_envs == 0 else tensor

# ------------------------------------------------------------------------------------
# ----------------------------------- properties -------------------------------------
# ------------------------------------------------------------------------------------
Expand Down
133 changes: 120 additions & 13 deletions genesis/engine/solvers/kinematic_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
kernel_masked_forward_kinematics,
kernel_masked_forward_velocity,
kernel_update_vgeoms,
kernel_update_vverts_for_vgeoms,
)
from .rigid.abd.accessor import (
kernel_get_kinematic_state,
Expand All @@ -48,6 +49,7 @@
kernel_set_links_pos,
kernel_set_links_quat,
kernel_set_qpos,
kernel_set_vverts,
kernel_get_links_vel,
)

Expand Down Expand Up @@ -126,6 +128,8 @@ def add_entity(self, idx, material, morph, surface, visualize_contact=False, nam
vgeom_start=self.n_vgeoms,
vvert_start=self.n_vverts,
vface_start=self.n_vfaces,
custom_vvert_start=self.n_custom_vverts,
custom_vface_start=self.n_custom_vfaces,
morph_heterogeneous=morph_heterogeneous,
name=name,
)
Expand Down Expand Up @@ -153,6 +157,8 @@ def build(self):
self._n_vgeoms = self.n_vgeoms
self._n_vfaces = self.n_vfaces
self._n_vverts = self.n_vverts
self._n_custom_vverts = self.n_custom_vverts
self._n_custom_vfaces = self.n_custom_vfaces
self._n_entities = self.n_entities

self._vgeoms = self.vgeoms
Expand All @@ -175,6 +181,8 @@ def build(self):
self.n_vgeoms_ = max(1, self.n_vgeoms)
self.n_vfaces_ = max(1, self.n_vfaces)
self.n_vverts_ = max(1, self.n_vverts)
self.n_custom_vverts_ = max(1, self.n_custom_vverts)
self.n_custom_vfaces_ = max(1, self.n_custom_vfaces)
self.n_entities_ = max(1, self.n_entities)

# batch_links_info is required when heterogeneous simulation is used.
Expand All @@ -192,6 +200,21 @@ def build(self):
self._init_entity_fields()

self._init_envs_offset()
self._init_vverts_state()

def _init_vverts_state(self):
# Seed vverts_state.pos with one FK pass so opt-in entities show their initial pose from frame 0 even before
# the user touches the buffer with set_vverts.
if self.n_custom_vverts == 0:
return
opt_in_vgeoms_idx = np.array(
[vgeom.idx for entity in self._entities if entity._morph.enable_custom_vverts for vgeom in entity.vgeoms],
dtype=gs.np_int,
)
if opt_in_vgeoms_idx.size == 0:
return
self.update_vgeoms()
self.update_vverts_for_vgeoms(opt_in_vgeoms_idx)

def _build_static_config(self):
# Static config with all physics disabled
Expand Down Expand Up @@ -372,20 +395,34 @@ def _dispatch_heterogeneous_vgeoms(self):

def _init_vvert_fields(self):
self.vverts_info = self.data_manager.vverts_info
self.vverts_state = self.data_manager.vverts_state
self.vfaces_info = self.data_manager.vfaces_info
if self.n_vverts > 0:
vgeoms = self.vgeoms
kernel_init_vvert_fields(
vverts=np.concatenate([vgeom.init_vverts for vgeom in vgeoms], dtype=gs.np_float),
vfaces=np.concatenate([vgeom.init_vfaces + vgeom.vvert_start for vgeom in vgeoms], dtype=gs.np_int),
vnormals=np.concatenate([vgeom.init_vnormals for vgeom in vgeoms], dtype=gs.np_float),
vverts_vgeom_idx=np.concatenate(
[np.full(vgeom.n_vverts, vgeom.idx) for vgeom in vgeoms], dtype=gs.np_int
),
vverts_info=self.vverts_info,
vfaces_info=self.vfaces_info,
static_rigid_sim_config=self._static_rigid_sim_config,
)
if self.n_vverts == 0:
return

vgeoms = self.vgeoms
vverts = np.concatenate([vg.init_vverts for vg in vgeoms], dtype=gs.np_float)
vnormals = np.concatenate([vg.init_vnormals for vg in vgeoms], dtype=gs.np_float)
vfaces = np.concatenate([vg.init_vfaces + vg._vvert_start for vg in vgeoms], dtype=gs.np_int)
vverts_vgeom_idx = np.concatenate([np.full(vg.n_vverts, vg.idx) for vg in vgeoms], dtype=gs.np_int)
vverts_state_idx = np.full(self.n_vverts, -1, dtype=gs.np_int)
for entity in self._entities:
if not entity._morph.enable_custom_vverts:
continue
entity_custom_offset = entity._custom_vvert_start - entity._vvert_start
for vgeom in entity.vgeoms:
local = np.arange(vgeom.n_vverts, dtype=gs.np_int)
vverts_state_idx[vgeom._vvert_start + local] = vgeom._vvert_start + entity_custom_offset + local
kernel_init_vvert_fields(
vverts=vverts,
vfaces=vfaces,
vnormals=vnormals,
vverts_vgeom_idx=vverts_vgeom_idx,
vverts_state_idx=vverts_state_idx,
vverts_info=self.vverts_info,
vfaces_info=self.vfaces_info,
static_rigid_sim_config=self._static_rigid_sim_config,
)

def _init_vgeom_fields(self):
self.vgeoms_info = self.data_manager.vgeoms_info
Expand Down Expand Up @@ -1015,6 +1052,64 @@ def get_dofs_limit(self, dofs_idx=None, envs_idx=None):
def update_vgeoms(self):
kernel_update_vgeoms(self.vgeoms_info, self.vgeoms_state, self.links_state, self._static_rigid_sim_config)

def update_vverts_for_vgeoms(self, vgeoms_idx):
"""Refresh the vverts_state.pos slice for the requested vgeoms by re-running FK.

Used by set_vverts(None, ...) and at scene-build time to initialize the custom buffer with a meaningful pose.
The kernel is a no-op for vgeoms whose vverts have no state slot.
"""
if self.n_custom_vverts == 0:
return
kernel_update_vverts_for_vgeoms(
vgeoms_idx,
self.vgeoms_info,
self.vgeoms_state,
self.vverts_info,
self.vverts_state,
self._static_rigid_sim_config,
)

def set_vverts(self, custom_vvert_start, custom_vvert_end, vgeoms_idx, vverts, envs_idx=None):
"""Write the slice [custom_vvert_start:custom_vvert_end] of vverts_state.pos.

vverts=None re-populates the slice from FK by running update_vverts_for_vgeoms over the vgeoms owning the slice.
Otherwise vverts is broadcast to the slice shape and written directly. vverts_state.pos is always batched;
envs_idx selects which envs to write to.
"""
if vverts is None:
self.update_vverts_for_vgeoms(vgeoms_idx)
return

if gs.use_zerocopy:
data = qd_to_torch(self.vverts_state.pos, transpose=True, copy=False)
if isinstance(envs_idx, torch.Tensor) and envs_idx.dtype == torch.bool:
pos_slice = data[:, custom_vvert_start:custom_vvert_end]
if vverts.ndim == 3 and len(vverts) != len(pos_slice):
pos_slice.masked_scatter_(envs_idx[:, None, None], vverts.view_as(vverts))
else:
vverts_b = broadcast_tensor(vverts, gs.tc_float, pos_slice.shape)
torch.where(envs_idx[:, None, None], vverts_b, pos_slice, out=pos_slice)
else:
vverts_mask = indices_to_mask(slice(custom_vvert_start, custom_vvert_end))
pos_mask = (0, *vverts_mask) if self.n_envs == 0 else indices_to_mask(envs_idx, *vverts_mask)
assign_indexed_tensor(data, pos_mask, vverts)
return

envs_idx = self._scene._sanitize_envs_idx(envs_idx)
target_shape = (envs_idx.shape[0], custom_vvert_end - custom_vvert_start, 3)
vverts = broadcast_tensor(vverts, gs.tc_float, target_shape, ("envs", "vverts", "xyz")).contiguous()
kernel_set_vverts(vverts, custom_vvert_start, envs_idx, self.vverts_state, self._static_rigid_sim_config)

def get_vverts(self, custom_vvert_start, custom_vvert_end, envs_idx=None):
"""Return a copy of the vverts_state.pos slice for the given custom-vvert range.

Shape: (len(envs_idx), custom_vvert_end - custom_vvert_start, 3). envs_idx=None returns every env.
"""
tensor = qd_to_torch(
self.vverts_state.pos, envs_idx, slice(custom_vvert_start, custom_vvert_end), transpose=True, copy=True
)
return tensor[0] if self.n_envs == 0 else tensor

# ------------------------------------------------------------------------------------
# ----------------------------------- properties -------------------------------------
# ------------------------------------------------------------------------------------
Expand Down Expand Up @@ -1073,6 +1168,18 @@ def n_vfaces(self):
return self._n_vfaces
return sum(entity.n_vfaces for entity in self._entities)

@property
def n_custom_vverts(self):
if self.is_built:
return self._n_custom_vverts
return sum(entity.n_vverts for entity in self._entities if entity._morph.enable_custom_vverts)

@property
def n_custom_vfaces(self):
if self.is_built:
return self._n_custom_vfaces
return sum(entity.n_vfaces for entity in self._entities if entity._morph.enable_custom_vverts)

@property
def n_qs(self):
if self.is_built:
Expand Down
19 changes: 19 additions & 0 deletions genesis/engine/solvers/rigid/abd/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,3 +1095,22 @@ def kernel_set_geoms_friction(
qd.loop_config(serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL))
for i_g_ in range(geoms_idx.shape[0]):
geoms_info.friction[geoms_idx[i_g_]] = friction[i_g_]


@qd.kernel(fastcache=True)
def kernel_set_vverts(
vverts: qd.types.ndarray(),
vvert_start: qd.i32,
envs_idx: qd.types.ndarray(),
vverts_state: array_class.VVertsState,
static_rigid_sim_config: qd.template(),
):
n_envs_in = envs_idx.shape[0]
n_vverts_in = vverts.shape[1]

qd.loop_config(serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL))
for i_b_, i_vv_ in qd.ndrange(n_envs_in, n_vverts_in):
i_b = envs_idx[i_b_]
i_vv = vvert_start + i_vv_
for j in qd.static(range(3)):
vverts_state.pos[i_vv, i_b][j] = vverts[i_b_, i_vv_, j]
Loading
Loading