diff --git a/genesis/engine/solvers/rigid/constraint/solver.py b/genesis/engine/solvers/rigid/constraint/solver.py index 0283f7c20..6c2b981fb 100644 --- a/genesis/engine/solvers/rigid/constraint/solver.py +++ b/genesis/engine/solvers/rigid/constraint/solver.py @@ -3327,20 +3327,12 @@ def _func_update_efc_force( len_constraints = constraint_state.active.shape[0] _B = constraint_state.grad.shape[1] - if qd.static(static_rigid_sim_config.constraint_layout_transposed): - qd.loop_config( - name="update_constraint_forces", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL - ) - for i_b, i_c in qd.ndrange(_B, len_constraints): - if i_c < constraint_state.n_constraints[i_b]: - _func_update_efc_force_body(i_c, i_b, constraint_state, static_rigid_sim_config) - else: - qd.loop_config( - name="update_constraint_forces", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL - ) - for i_c, i_b in qd.ndrange(len_constraints, _B): - if i_c < constraint_state.n_constraints[i_b]: - _func_update_efc_force_body(i_c, i_b, constraint_state, static_rigid_sim_config) + qd.loop_config(name="update_constraint_forces", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_c, i_b in qd.ndrange( + len_constraints, _B, axes=qd.static((1, 0) if static_rigid_sim_config.constraint_layout_transposed else None) + ): + if i_c < constraint_state.n_constraints[i_b]: + _func_update_efc_force_body(i_c, i_b, constraint_state, static_rigid_sim_config) @qd.func @@ -3547,18 +3539,13 @@ def func_update_gradient_tiled( # Compute Mgrad = H^{-1} @ grad, s.t. grad = M @ acc - q_force_ext - q_force_const. # Under the DOF-vec flip, 3 of 4 in-loop accesses (grad, Ma, qfrc_constraint) are flipped and one (dofs_state.force) # is canonical — swap the ndrange so adjacent lanes vary i_d. - if qd.static(static_rigid_sim_config.constraint_layout_transposed): - qd.loop_config(name="update_gradient_tiled", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b, i_d in qd.ndrange(_B, n_dofs): - constraint_state.grad[i_d, i_b] = ( - constraint_state.Ma[i_d, i_b] - dofs_state.force[i_d, i_b] - constraint_state.qfrc_constraint[i_d, i_b] - ) - else: - qd.loop_config(name="update_gradient_tiled", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_d, i_b in qd.ndrange(n_dofs, _B): - constraint_state.grad[i_d, i_b] = ( - constraint_state.Ma[i_d, i_b] - dofs_state.force[i_d, i_b] - constraint_state.qfrc_constraint[i_d, i_b] - ) + qd.loop_config(name="update_gradient_tiled", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in qd.ndrange( + n_dofs, _B, axes=qd.static((1, 0) if static_rigid_sim_config.constraint_layout_transposed else None) + ): + constraint_state.grad[i_d, i_b] = ( + constraint_state.Ma[i_d, i_b] - dofs_state.force[i_d, i_b] - constraint_state.qfrc_constraint[i_d, i_b] + ) if qd.static(static_rigid_sim_config.solver_type == gs.constraint_solver.CG): qd.loop_config( @@ -3750,18 +3737,14 @@ def _initialize_Jaref_parallel( n_dofs = constraint_state.jac.shape[1] len_constraints = constraint_state.Jaref.shape[0] - if qd.static(static_rigid_sim_config.constraint_layout_transposed): - qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - # i_c innermost: matches stride-1 axis of flipped jac, jac loads coalesce. - for i_b, i_c in qd.ndrange(_B, len_constraints): - if i_c < constraint_state.n_constraints[i_b]: - _initialize_Jaref_body(i_c, i_b, n_dofs, qacc, constraint_state, static_rigid_sim_config) - else: - qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - # i_b innermost: matches stride-1 axis of canonical jac, jac loads coalesce. - for i_c, i_b in qd.ndrange(len_constraints, _B): - if i_c < constraint_state.n_constraints[i_b]: - _initialize_Jaref_body(i_c, i_b, n_dofs, qacc, constraint_state, static_rigid_sim_config) + # Innermost ndrange axis matches the stride-1 axis of jac so jac loads coalesce: i_c-innermost under the flipped + # layout, i_b-innermost under canonical. + qd.loop_config(serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_c, i_b in qd.ndrange( + len_constraints, _B, axes=qd.static((1, 0) if static_rigid_sim_config.constraint_layout_transposed else None) + ): + if i_c < constraint_state.n_constraints[i_b]: + _initialize_Jaref_body(i_c, i_b, n_dofs, qacc, constraint_state, static_rigid_sim_config) @qd.func @@ -3776,27 +3759,19 @@ def initialize_Ma( _B = rigid_global_info.mass_mat.shape[2] n_dofs = qacc.shape[0] - if qd.static(static_rigid_sim_config.constraint_layout_transposed): - # Flipped mass_mat layout=(2,1,0): physical (_B, n_dofs, n_dofs) with i_d1 stride-1. Make i_d1 the innermost - # ndrange axis so adjacent lanes vary i_d1 -> coalesced reads of mass_mat[i_d1, i_d2, i_b]. qacc[i_d2, i_b] is - # constant within the warp -> broadcast load. - qd.loop_config(name="init_ma", serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL)) - for i_b, i_d1 in qd.ndrange(_B, n_dofs): - I_d1 = [i_d1, i_b] if qd.static(static_rigid_sim_config.batch_dofs_info) else i_d1 - i_e = dofs_info.entity_idx[I_d1] - Ma_ = gs.qd_float(0.0) - for i_d2 in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): - Ma_ = Ma_ + rigid_global_info.mass_mat[i_d1, i_d2, i_b] * qacc[i_d2, i_b] - Ma[i_d1, i_b] = Ma_ - else: - qd.loop_config(name="init_ma", serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL)) - for i_d1, i_b in qd.ndrange(n_dofs, _B): - I_d1 = [i_d1, i_b] if qd.static(static_rigid_sim_config.batch_dofs_info) else i_d1 - i_e = dofs_info.entity_idx[I_d1] - Ma_ = gs.qd_float(0.0) - for i_d2 in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): - Ma_ = Ma_ + rigid_global_info.mass_mat[i_d1, i_d2, i_b] * qacc[i_d2, i_b] - Ma[i_d1, i_b] = Ma_ + # Flipped mass_mat layout=(2,1,0): physical (_B, n_dofs, n_dofs) with i_d1 stride-1. Make i_d1 the innermost + # ndrange axis so adjacent lanes vary i_d1 -> coalesced reads of mass_mat[i_d1, i_d2, i_b]. qacc[i_d2, i_b] is + # constant within the warp -> broadcast load. + qd.loop_config(name="init_ma", serialize=qd.static(static_rigid_sim_config.para_level < gs.PARA_LEVEL.PARTIAL)) + for i_d1, i_b in qd.ndrange( + n_dofs, _B, axes=qd.static((1, 0) if static_rigid_sim_config.constraint_layout_transposed else None) + ): + I_d1 = [i_d1, i_b] if qd.static(static_rigid_sim_config.batch_dofs_info) else i_d1 + i_e = dofs_info.entity_idx[I_d1] + Ma_ = gs.qd_float(0.0) + for i_d2 in range(entities_info.dof_start[i_e], entities_info.dof_end[i_e]): + Ma_ = Ma_ + rigid_global_info.mass_mat[i_d1, i_d2, i_b] * qacc[i_d2, i_b] + Ma[i_d1, i_b] = Ma_ # ======================================================= Core ======================================================== @@ -3876,20 +3851,14 @@ def func_solve_init( # Under the DOF-vec flip, both qacc and qacc_ws are env-leading; swap the ndrange so adjacent lanes vary i_d # to coalesce those writes/reads. The dofs_state.acc_smooth read remains canonical (small per-env working # set, dominated by the qacc write). - if qd.static(static_rigid_sim_config.constraint_layout_transposed): - qd.loop_config(name="from_warmstart", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b, i_d in qd.ndrange(_B, n_dofs): - if constraint_state.n_constraints[i_b] > 0 and constraint_state.is_warmstart[i_b]: - constraint_state.qacc[i_d, i_b] = constraint_state.qacc_ws[i_d, i_b] - else: - constraint_state.qacc[i_d, i_b] = dofs_state.acc_smooth[i_d, i_b] - else: - qd.loop_config(name="from_warmstart", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_d, i_b in qd.ndrange(n_dofs, _B): - if constraint_state.n_constraints[i_b] > 0 and constraint_state.is_warmstart[i_b]: - constraint_state.qacc[i_d, i_b] = constraint_state.qacc_ws[i_d, i_b] - else: - constraint_state.qacc[i_d, i_b] = dofs_state.acc_smooth[i_d, i_b] + qd.loop_config(name="from_warmstart", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in qd.ndrange( + n_dofs, _B, axes=qd.static((1, 0) if static_rigid_sim_config.constraint_layout_transposed else None) + ): + if constraint_state.n_constraints[i_b] > 0 and constraint_state.is_warmstart[i_b]: + constraint_state.qacc[i_d, i_b] = constraint_state.qacc_ws[i_d, i_b] + else: + constraint_state.qacc[i_d, i_b] = dofs_state.acc_smooth[i_d, i_b] initialize_Ma( Ma=constraint_state.Ma, @@ -3937,14 +3906,11 @@ def func_solve_init( static_rigid_sim_config=static_rigid_sim_config, ) - if qd.static(static_rigid_sim_config.constraint_layout_transposed): - qd.loop_config(name="assign_search", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_b, i_d in qd.ndrange(_B, n_dofs): - constraint_state.search[i_d, i_b] = -constraint_state.Mgrad[i_d, i_b] - else: - qd.loop_config(name="assign_search", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) - for i_d, i_b in qd.ndrange(n_dofs, _B): - constraint_state.search[i_d, i_b] = -constraint_state.Mgrad[i_d, i_b] + qd.loop_config(name="assign_search", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in qd.ndrange( + n_dofs, _B, axes=qd.static((1, 0) if static_rigid_sim_config.constraint_layout_transposed else None) + ): + constraint_state.search[i_d, i_b] = -constraint_state.Mgrad[i_d, i_b] @qd.func diff --git a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py index ec8ca8d6d..6fb67544a 100644 --- a/genesis/engine/solvers/rigid/constraint/solver_breakdown.py +++ b/genesis/engine/solvers/rigid/constraint/solver_breakdown.py @@ -562,14 +562,11 @@ def _func_update_constraint_forces( _B = constraint_state.grad.shape[1] qd.loop_config(name="update_constraint_forces") - if qd.static(static_rigid_sim_config.constraint_layout_transposed): - for i_b, i_c in qd.ndrange(_B, len_constraints): - if i_c < constraint_state.n_constraints[i_b] and constraint_state.improved[i_b]: - _func_update_constraint_forces_body(i_c, i_b, constraint_state, static_rigid_sim_config) - else: - for i_c, i_b in qd.ndrange(len_constraints, _B): - if i_c < constraint_state.n_constraints[i_b] and constraint_state.improved[i_b]: - _func_update_constraint_forces_body(i_c, i_b, constraint_state, static_rigid_sim_config) + for i_c, i_b in qd.ndrange( + len_constraints, _B, axes=qd.static((1, 0) if static_rigid_sim_config.constraint_layout_transposed else None) + ): + if i_c < constraint_state.n_constraints[i_b] and constraint_state.improved[i_b]: + _func_update_constraint_forces_body(i_c, i_b, constraint_state, static_rigid_sim_config) @qd.func @@ -586,24 +583,16 @@ def _func_update_qfrc_constraint_per_dof( n_dofs = constraint_state.qfrc_constraint.shape[0] _B = constraint_state.grad.shape[1] - if qd.static(static_rigid_sim_config.constraint_layout_transposed): - qd.loop_config(name="update_constraint_qfrc") - for i_b, i_d in qd.ndrange(_B, n_dofs): - if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]: - n_con = constraint_state.n_constraints[i_b] - qfrc = gs.qd_float(0.0) - for i_c in range(n_con): - qfrc += constraint_state.jac[i_c, i_d, i_b] * constraint_state.efc_force[i_c, i_b] - constraint_state.qfrc_constraint[i_d, i_b] = qfrc - else: - qd.loop_config(name="update_constraint_qfrc") - for i_d, i_b in qd.ndrange(n_dofs, _B): - if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]: - n_con = constraint_state.n_constraints[i_b] - qfrc = gs.qd_float(0.0) - for i_c in range(n_con): - qfrc += constraint_state.jac[i_c, i_d, i_b] * constraint_state.efc_force[i_c, i_b] - constraint_state.qfrc_constraint[i_d, i_b] = qfrc + qd.loop_config(name="update_constraint_qfrc") + for i_d, i_b in qd.ndrange( + n_dofs, _B, axes=qd.static((1, 0) if static_rigid_sim_config.constraint_layout_transposed else None) + ): + if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]: + n_con = constraint_state.n_constraints[i_b] + qfrc = gs.qd_float(0.0) + for i_c in range(n_con): + qfrc += constraint_state.jac[i_c, i_d, i_b] * constraint_state.efc_force[i_c, i_b] + constraint_state.qfrc_constraint[i_d, i_b] = qfrc @qd.func @@ -911,28 +900,14 @@ def _func_update_gradient_no_solve( """ _B = constraint_state.grad.shape[1] n_dofs = constraint_state.grad.shape[0] - if qd.static(static_rigid_sim_config.constraint_layout_transposed): - qd.loop_config( - name="update_gradient_no_solve", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL - ) - for i_b, i_d in qd.ndrange(_B, n_dofs): - if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]: - constraint_state.grad[i_d, i_b] = ( - constraint_state.Ma[i_d, i_b] - - dofs_state.force[i_d, i_b] - - constraint_state.qfrc_constraint[i_d, i_b] - ) - else: - qd.loop_config( - name="update_gradient_no_solve", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL - ) - for i_d, i_b in qd.ndrange(n_dofs, _B): - if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]: - constraint_state.grad[i_d, i_b] = ( - constraint_state.Ma[i_d, i_b] - - dofs_state.force[i_d, i_b] - - constraint_state.qfrc_constraint[i_d, i_b] - ) + qd.loop_config(name="update_gradient_no_solve", serialize=static_rigid_sim_config.para_level < gs.PARA_LEVEL.ALL) + for i_d, i_b in qd.ndrange( + n_dofs, _B, axes=qd.static((1, 0) if static_rigid_sim_config.constraint_layout_transposed else None) + ): + if constraint_state.n_constraints[i_b] > 0 and constraint_state.improved[i_b]: + constraint_state.grad[i_d, i_b] = ( + constraint_state.Ma[i_d, i_b] - dofs_state.force[i_d, i_b] - constraint_state.qfrc_constraint[i_d, i_b] + ) @qd.func