Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 47 additions & 81 deletions genesis/engine/solvers/rigid/constraint/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3327,20 +3327,12 @@ def _func_update_efc_force(
len_constraints = constraint_state.active.shape[0]
_B = constraint_state.grad.shape[1]

if qd.static(static_rigid_sim_config.constraint_layout_transposed):
qd.loop_config(
name="update_constraint_forces", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL
)
for i_b, i_c in qd.ndrange(_B, len_constraints):
if i_c < constraint_state.n_constraints[i_b]:
_func_update_efc_force_body(i_c, i_b, constraint_state, static_rigid_sim_config)
else:
qd.loop_config(
name="update_constraint_forces", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL
)
for i_c, i_b in qd.ndrange(len_constraints, _B):
if i_c < constraint_state.n_constraints[i_b]:
_func_update_efc_force_body(i_c, i_b, constraint_state, static_rigid_sim_config)
qd.loop_config(name="update_constraint_forces", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_c, i_b in qd.ndrange(
len_constraints, _B, axes=qd.static((1, 0) if static_rigid_sim_config.constraint_layout_transposed else None)
Comment on lines +3331 to +3332
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid using unreleased ndrange axes API

In environments created from this repo's current dependency spec, quadrants is pinned to 0.8.0, while these new qd.ndrange(..., axes=...) calls require the unreleased Quadrants API noted in the commit message. Any rigid constraint kernel that reaches this helper with the pinned wheel will fail at Quadrants parsing/compilation with an unexpected axes keyword instead of running, so this needs either a dependency bump to a released version that supports axes or to keep the previous explicit branches until that release is available.

Useful? React with 👍 / 👎.

):
if i_c < constraint_state.n_constraints[i_b]:
_func_update_efc_force_body(i_c, i_b, constraint_state, static_rigid_sim_config)


@qd.func
Expand Down Expand Up @@ -3547,18 +3539,13 @@ def func_update_gradient_tiled(
# Compute Mgrad = H^{-1} @ grad, s.t. grad = M @ acc - q_force_ext - q_force_const.
# Under the DOF-vec flip, 3 of 4 in-loop accesses (grad, Ma, qfrc_constraint) are flipped and one (dofs_state.force)
# is canonical — swap the ndrange so adjacent lanes vary i_d.
if qd.static(static_rigid_sim_config.constraint_layout_transposed):
qd.loop_config(name="update_gradient_tiled", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_b, i_d in qd.ndrange(_B, n_dofs):
constraint_state.grad[i_d, i_b] = (
constraint_state.Ma[i_d, i_b] - dofs_state.force[i_d, i_b] - constraint_state.qfrc_constraint[i_d, i_b]
)
else:
qd.loop_config(name="update_gradient_tiled", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_d, i_b in qd.ndrange(n_dofs, _B):
constraint_state.grad[i_d, i_b] = (
constraint_state.Ma[i_d, i_b] - dofs_state.force[i_d, i_b] - constraint_state.qfrc_constraint[i_d, i_b]
)
qd.loop_config(name="update_gradient_tiled", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_d, i_b in qd.ndrange(
n_dofs, _B, axes=qd.static((1, 0) if static_rigid_sim_config.constraint_layout_transposed else None)
):
constraint_state.grad[i_d, i_b] = (
constraint_state.Ma[i_d, i_b] - dofs_state.force[i_d, i_b] - constraint_state.qfrc_constraint[i_d, i_b]
)

if qd.static(static_rigid_sim_config.solver_type == gs.constraint_solver.CG):
qd.loop_config(
Expand Down Expand Up @@ -3750,18 +3737,14 @@ def _initialize_Jaref_parallel(
n_dofs = constraint_state.jac.shape[1]
len_constraints = constraint_state.Jaref.shape[0]

if qd.static(static_rigid_sim_config.constraint_layout_transposed):
qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
# i_c innermost: matches stride-1 axis of flipped jac, jac loads coalesce.
for i_b, i_c in qd.ndrange(_B, len_constraints):
if i_c < constraint_state.n_constraints[i_b]:
_initialize_Jaref_body(i_c, i_b, n_dofs, qacc, constraint_state, static_rigid_sim_config)
else:
qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
# i_b innermost: matches stride-1 axis of canonical jac, jac loads coalesce.
for i_c, i_b in qd.ndrange(len_constraints, _B):
if i_c < constraint_state.n_constraints[i_b]:
_initialize_Jaref_body(i_c, i_b, n_dofs, qacc, constraint_state, static_rigid_sim_config)
# Innermost ndrange axis matches the stride-1 axis of jac so jac loads coalesce: i_c-innermost under the flipped
# layout, i_b-innermost under canonical.
qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_c, i_b in qd.ndrange(
len_constraints, _B, axes=qd.static((1, 0) if static_rigid_sim_config.constraint_layout_transposed else None)
):
if i_c < constraint_state.n_constraints[i_b]:
_initialize_Jaref_body(i_c, i_b, n_dofs, qacc, constraint_state, static_rigid_sim_config)


@qd.func
Expand All @@ -3776,27 +3759,19 @@ def initialize_Ma(
_B = rigid_global_info.mass_mat.shape[2]
n_dofs = qacc.shape[0]

if qd.static(static_rigid_sim_config.constraint_layout_transposed):
# Flipped mass_mat layout=(2,1,0): physical (_B, n_dofs, n_dofs) with i_d1 stride-1. Make i_d1 the innermost
# ndrange axis so adjacent lanes vary i_d1 -> coalesced reads of mass_mat[i_d1, i_d2, i_b]. qacc[i_d2, i_b] is
# constant within the warp -> broadcast load.
qd.loop_config(name="init_ma", serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL))
for i_b, i_d1 in qd.ndrange(_B, n_dofs):
I_d1 = [i_d1, i_b] if qd.static(static_rigid_sim_config.batch_dofs_info) else i_d1
i_e = dofs_info.entity_idx[I_d1]
Ma_ = gs.qd_float(0.0)
for i_d2 in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]):
Ma_ = Ma_ + rigid_global_info.mass_mat[i_d1, i_d2, i_b] * qacc[i_d2, i_b]
Ma[i_d1, i_b] = Ma_
else:
qd.loop_config(name="init_ma", serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL))
for i_d1, i_b in qd.ndrange(n_dofs, _B):
I_d1 = [i_d1, i_b] if qd.static(static_rigid_sim_config.batch_dofs_info) else i_d1
i_e = dofs_info.entity_idx[I_d1]
Ma_ = gs.qd_float(0.0)
for i_d2 in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]):
Ma_ = Ma_ + rigid_global_info.mass_mat[i_d1, i_d2, i_b] * qacc[i_d2, i_b]
Ma[i_d1, i_b] = Ma_
# Flipped mass_mat layout=(2,1,0): physical (_B, n_dofs, n_dofs) with i_d1 stride-1. Make i_d1 the innermost
# ndrange axis so adjacent lanes vary i_d1 -> coalesced reads of mass_mat[i_d1, i_d2, i_b]. qacc[i_d2, i_b] is
# constant within the warp -> broadcast load.
qd.loop_config(name="init_ma", serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL))
for i_d1, i_b in qd.ndrange(
n_dofs, _B, axes=qd.static((1, 0) if static_rigid_sim_config.constraint_layout_transposed else None)
):
I_d1 = [i_d1, i_b] if qd.static(static_rigid_sim_config.batch_dofs_info) else i_d1
i_e = dofs_info.entity_idx[I_d1]
Ma_ = gs.qd_float(0.0)
for i_d2 in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]):
Ma_ = Ma_ + rigid_global_info.mass_mat[i_d1, i_d2, i_b] * qacc[i_d2, i_b]
Ma[i_d1, i_b] = Ma_


# ======================================================= Core ========================================================
Expand Down Expand Up @@ -3876,20 +3851,14 @@ def func_solve_init(
# Under the DOF-vec flip, both qacc and qacc_ws are env-leading; swap the ndrange so adjacent lanes vary i_d
# to coalesce those writes/reads. The dofs_state.acc_smooth read remains canonical (small per-env working
# set, dominated by the qacc write).
if qd.static(static_rigid_sim_config.constraint_layout_transposed):
qd.loop_config(name="from_warmstart", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_b, i_d in qd.ndrange(_B, n_dofs):
if constraint_state.n_constraints[i_b] > 0 and constraint_state.is_warmstart[i_b]:
constraint_state.qacc[i_d, i_b] = constraint_state.qacc_ws[i_d, i_b]
else:
constraint_state.qacc[i_d, i_b] = dofs_state.acc_smooth[i_d, i_b]
else:
qd.loop_config(name="from_warmstart", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_d, i_b in qd.ndrange(n_dofs, _B):
if constraint_state.n_constraints[i_b] > 0 and constraint_state.is_warmstart[i_b]:
constraint_state.qacc[i_d, i_b] = constraint_state.qacc_ws[i_d, i_b]
else:
constraint_state.qacc[i_d, i_b] = dofs_state.acc_smooth[i_d, i_b]
qd.loop_config(name="from_warmstart", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_d, i_b in qd.ndrange(
n_dofs, _B, axes=qd.static((1, 0) if static_rigid_sim_config.constraint_layout_transposed else None)
):
if constraint_state.n_constraints[i_b] > 0 and constraint_state.is_warmstart[i_b]:
constraint_state.qacc[i_d, i_b] = constraint_state.qacc_ws[i_d, i_b]
else:
constraint_state.qacc[i_d, i_b] = dofs_state.acc_smooth[i_d, i_b]

initialize_Ma(
Ma=constraint_state.Ma,
Expand Down Expand Up @@ -3937,14 +3906,11 @@ def func_solve_init(
static_rigid_sim_config=static_rigid_sim_config,
)

if qd.static(static_rigid_sim_config.constraint_layout_transposed):
qd.loop_config(name="assign_search", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_b, i_d in qd.ndrange(_B, n_dofs):
constraint_state.search[i_d, i_b] = -constraint_state.Mgrad[i_d, i_b]
else:
qd.loop_config(name="assign_search", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_d, i_b in qd.ndrange(n_dofs, _B):
constraint_state.search[i_d, i_b] = -constraint_state.Mgrad[i_d, i_b]
qd.loop_config(name="assign_search", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_d, i_b in qd.ndrange(
n_dofs, _B, axes=qd.static((1, 0) if static_rigid_sim_config.constraint_layout_transposed else None)
):
constraint_state.search[i_d, i_b] = -constraint_state.Mgrad[i_d, i_b]


@qd.func
Expand Down
71 changes: 23 additions & 48 deletions genesis/engine/solvers/rigid/constraint/solver_breakdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,14 +562,11 @@ def _func_update_constraint_forces(
_B = constraint_state.grad.shape[1]

qd.loop_config(name="update_constraint_forces")
if qd.static(static_rigid_sim_config.constraint_layout_transposed):
for i_b, i_c in qd.ndrange(_B, len_constraints):
if i_c < constraint_state.n_constraints[i_b] and constraint_state.improved[i_b]:
_func_update_constraint_forces_body(i_c, i_b, constraint_state, static_rigid_sim_config)
else:
for i_c, i_b in qd.ndrange(len_constraints, _B):
if i_c < constraint_state.n_constraints[i_b] and constraint_state.improved[i_b]:
_func_update_constraint_forces_body(i_c, i_b, constraint_state, static_rigid_sim_config)
for i_c, i_b in qd.ndrange(
len_constraints, _B, axes=qd.static((1, 0) if static_rigid_sim_config.constraint_layout_transposed else None)
):
if i_c < constraint_state.n_constraints[i_b] and constraint_state.improved[i_b]:
_func_update_constraint_forces_body(i_c, i_b, constraint_state, static_rigid_sim_config)


@qd.func
Expand All @@ -586,24 +583,16 @@ def _func_update_qfrc_constraint_per_dof(
n_dofs = constraint_state.qfrc_constraint.shape[0]
_B = constraint_state.grad.shape[1]

if qd.static(static_rigid_sim_config.constraint_layout_transposed):
qd.loop_config(name="update_constraint_qfrc")
for i_b, i_d in qd.ndrange(_B, n_dofs):
if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]:
n_con = constraint_state.n_constraints[i_b]
qfrc = gs.qd_float(0.0)
for i_c in range(n_con):
qfrc += constraint_state.jac[i_c, i_d, i_b] * constraint_state.efc_force[i_c, i_b]
constraint_state.qfrc_constraint[i_d, i_b] = qfrc
else:
qd.loop_config(name="update_constraint_qfrc")
for i_d, i_b in qd.ndrange(n_dofs, _B):
if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]:
n_con = constraint_state.n_constraints[i_b]
qfrc = gs.qd_float(0.0)
for i_c in range(n_con):
qfrc += constraint_state.jac[i_c, i_d, i_b] * constraint_state.efc_force[i_c, i_b]
constraint_state.qfrc_constraint[i_d, i_b] = qfrc
qd.loop_config(name="update_constraint_qfrc")
for i_d, i_b in qd.ndrange(
n_dofs, _B, axes=qd.static((1, 0) if static_rigid_sim_config.constraint_layout_transposed else None)
):
if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]:
n_con = constraint_state.n_constraints[i_b]
qfrc = gs.qd_float(0.0)
for i_c in range(n_con):
qfrc += constraint_state.jac[i_c, i_d, i_b] * constraint_state.efc_force[i_c, i_b]
constraint_state.qfrc_constraint[i_d, i_b] = qfrc


@qd.func
Expand Down Expand Up @@ -911,28 +900,14 @@ def _func_update_gradient_no_solve(
"""
_B = constraint_state.grad.shape[1]
n_dofs = constraint_state.grad.shape[0]
if qd.static(static_rigid_sim_config.constraint_layout_transposed):
qd.loop_config(
name="update_gradient_no_solve", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL
)
for i_b, i_d in qd.ndrange(_B, n_dofs):
if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]:
constraint_state.grad[i_d, i_b] = (
constraint_state.Ma[i_d, i_b]
- dofs_state.force[i_d, i_b]
- constraint_state.qfrc_constraint[i_d, i_b]
)
else:
qd.loop_config(
name="update_gradient_no_solve", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL
)
for i_d, i_b in qd.ndrange(n_dofs, _B):
if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]:
constraint_state.grad[i_d, i_b] = (
constraint_state.Ma[i_d, i_b]
- dofs_state.force[i_d, i_b]
- constraint_state.qfrc_constraint[i_d, i_b]
)
qd.loop_config(name="update_gradient_no_solve", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL)
for i_d, i_b in qd.ndrange(
n_dofs, _B, axes=qd.static((1, 0) if static_rigid_sim_config.constraint_layout_transposed else None)
):
if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]:
constraint_state.grad[i_d, i_b] = (
constraint_state.Ma[i_d, i_b] - dofs_state.force[i_d, i_b] - constraint_state.qfrc_constraint[i_d, i_b]
)


@qd.func
Expand Down
Loading