Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 22 additions & 27 deletions genesis/engine/solvers/rigid/constraint/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
import genesis.utils.array_class as array_class
import genesis.utils.geom as gu
from genesis.engine.solvers.rigid.abd import func_solve_mass_batch
from genesis.utils._tile16 import Tile16x16Cholesky
from genesis.utils._tile32 import Tile32x32Cholesky
from genesis.utils.misc import qd_to_torch, indices_to_mask, assign_indexed_tensor

from ..collider.contact_island import ContactIsland
Expand Down Expand Up @@ -1912,17 +1910,14 @@ def _cholesky_factor_direct_tiled_impl(

# Load diagonal tile H[k,k] (rows beyond n_dofs stay as identity from the .eye() init)
L_kk = TileCls.eye(dtype=gs.qd_float)
# FIXME: migrate back to using slice index, i.e. L_kk[:] = constraint_state.nt_H[i_b, k0:k1, k0:k1]
# and similar.
# We'll do this once we move _tile16.py changes back into Quadrants.
L_kk._load3d(constraint_state.nt_H, i_b, k0, k1, k0, k1)
L_kk[:] = constraint_state.nt_H[i_b, k0:k1, k0:k1]

# Subtract prior-column contributions: L_kk -= sum_j L[k,j] @ L[k,j]^T
for jb in range(kb):
j0 = jb * T
for t in range(T):
v = L_kk._resolve_vec3d(constraint_state.nt_H, i_b, k0, k1, j0 + t)
L_kk._ger_sub(v, v)
v = constraint_state.nt_H[i_b, k0:k1, j0 + t]
L_kk -= qd.outer(v, v)

# Factor diagonal tile in-place
L_kk.cholesky_(EPS)
Expand All @@ -1934,24 +1929,24 @@ def _cholesky_factor_direct_tiled_impl(

# Load off-diagonal tile H[i,k] (rows beyond n_dofs stay as zero from the .zeros() init)
L_ik = TileCls.zeros(dtype=gs.qd_float)
L_ik._load3d(constraint_state.nt_H, i_b, i0, i1, k0, k1)
L_ik[:] = constraint_state.nt_H[i_b, i0:i1, k0:k1]

# Subtract prior-column contributions: L_ik -= sum_j L[i,j] @ L[k,j]^T
for jb in range(kb):
j0 = jb * T
for t in range(T):
v_own = L_ik._resolve_vec3d(constraint_state.nt_H, i_b, i0, i1, j0 + t)
v_diag = L_ik._resolve_vec3d(constraint_state.nt_H, i_b, k0, k1, j0 + t)
L_ik._ger_sub(v_own, v_diag)
v_own = constraint_state.nt_H[i_b, i0:i1, j0 + t]
v_diag = constraint_state.nt_H[i_b, k0:k1, j0 + t]
L_ik -= qd.outer(v_own, v_diag)

# Triangular solve: L[i,k] = L_ik @ inv(L[k,k]^T)
L_kk.solve_triangular_(L_ik)

# Write L[i,k] back to global memory
L_ik._store3d(constraint_state.nt_H, i_b, i0, i1, k0, k1)
constraint_state.nt_H[i_b, i0:i1, k0:k1] = L_ik

# Write L[k,k] back to global memory
L_kk._store3d(constraint_state.nt_H, i_b, k0, k1, k0, k1)
constraint_state.nt_H[i_b, k0:k1, k0:k1] = L_kk


@qd.func
Expand Down Expand Up @@ -2008,14 +2003,14 @@ def _cholesky_and_solve_fused_tiled_impl(

# Load diagonal tile H[k,k] (rows beyond n_dofs stay as identity from the .eye() init)
L_kk = TileCls.eye(dtype=gs.qd_float)
L_kk._load3d(constraint_state.nt_H, i_b, k0, k1, k0, k1)
L_kk[:] = constraint_state.nt_H[i_b, k0:k1, k0:k1]

# Subtract prior-column contributions from shared memory
for jb in range(kb):
j0 = jb * T
for t in range(T):
v = L_kk._resolve_vec2d(L_sh, k0, k1, j0 + t)
L_kk._ger_sub(v, v)
v = L_sh[k0:k1, j0 + t]
L_kk -= qd.outer(v, v)

# Factor diagonal tile in-place
L_kk.cholesky_(EPS)
Expand All @@ -2027,24 +2022,24 @@ def _cholesky_and_solve_fused_tiled_impl(

# Load off-diagonal tile H[i,k] (rows beyond n_dofs stay as zero from the .zeros() init)
L_ik = TileCls.zeros(dtype=gs.qd_float)
L_ik._load3d(constraint_state.nt_H, i_b, i0, i1, k0, k1)
L_ik[:] = constraint_state.nt_H[i_b, i0:i1, k0:k1]

# Subtract prior-column contributions from shared memory
for jb in range(kb):
j0 = jb * T
for t in range(T):
v_own = L_ik._resolve_vec2d(L_sh, i0, i1, j0 + t)
v_diag = L_ik._resolve_vec2d(L_sh, k0, k1, j0 + t)
L_ik._ger_sub(v_own, v_diag)
v_own = L_sh[i0:i1, j0 + t]
v_diag = L_sh[k0:k1, j0 + t]
L_ik -= qd.outer(v_own, v_diag)

# Triangular solve: L[i,k] = L_ik @ inv(L[k,k]^T)
L_kk.solve_triangular_(L_ik)

# Write L[i,k] to shared memory
L_ik._store(L_sh, i0, i1, k0, k1)
L_sh[i0:i1, k0:k1] = L_ik

# Write L[k,k] to shared memory
L_kk._store(L_sh, k0, k1, k0, k1)
L_sh[k0:k1, k0:k1] = L_kk

# --- Scalar triangular solve using L from shared memory ---
# No longer using TxT tiles; the T threads parallelize each row's dot product by striping across columns,
Expand Down Expand Up @@ -2112,11 +2107,11 @@ def func_cholesky_factor_direct_tiled(
"""Tile-size dispatcher; see _cholesky_factor_direct_tiled_impl for the algorithm and dispatch rule."""
if qd.static(static_rigid_sim_config.cholesky_tile_size == 32):
_cholesky_factor_direct_tiled_impl(
constraint_state, rigid_global_info, static_rigid_sim_config, Tile32x32Cholesky
constraint_state, rigid_global_info, static_rigid_sim_config, qd.simt.Tile32x32
)
else:
_cholesky_factor_direct_tiled_impl(
constraint_state, rigid_global_info, static_rigid_sim_config, Tile16x16Cholesky
constraint_state, rigid_global_info, static_rigid_sim_config, qd.simt.Tile16x16
)


Expand All @@ -2130,11 +2125,11 @@ def func_cholesky_and_solve_fused_tiled(
"""Tile-size dispatcher; see _cholesky_and_solve_fused_tiled_impl for the algorithm and dispatch rule."""
if qd.static(static_rigid_sim_config.cholesky_tile_size == 32):
_cholesky_and_solve_fused_tiled_impl(
constraint_state, rigid_global_info, static_rigid_sim_config, Tile32x32Cholesky, write_L_to_nt_H
constraint_state, rigid_global_info, static_rigid_sim_config, qd.simt.Tile32x32, write_L_to_nt_H
)
else:
_cholesky_and_solve_fused_tiled_impl(
constraint_state, rigid_global_info, static_rigid_sim_config, Tile16x16Cholesky, write_L_to_nt_H
constraint_state, rigid_global_info, static_rigid_sim_config, qd.simt.Tile16x16, write_L_to_nt_H
)


Expand Down
Loading
Loading