Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 5 additions & 24 deletions genesis/engine/solvers/rigid/collider/contact.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,8 @@ def func_clamp_prune_and_sort_contacts_coop(
LP_KEY_STRIDE = gs.qd_float(1.0e7)
EPS = rigid_global_info.EPS[None]

_K = qd.static(32)
_K = 32
_LOG2_K = _K.bit_length() - 1 # = log2(_K), assuming _K is a power of two.
qd.loop_config(name="clamp_prune_and_sort_contacts_coop", block_dim=_K)
for i_flat in range(_B * _K):
tid = i_flat % _K
Expand Down Expand Up @@ -980,9 +981,8 @@ def func_clamp_prune_and_sort_contacts_coop(
)
ii += _K

# Phase 1a sort: parallel bitonic sort across 32 lanes when n_con <= 32; fall back to serial-on-lane-0
# insertion sort otherwise. Bitonic is 15 compare-exchange stages (k=2..32, j=k/2..1), each a single
# subgroup shuffle + compare, replacing the O(n^2/2) lane-0 insertion sort.
# Phase 1a sort: bitonic sort across _K lanes when n_con <= _K, serial-on-lane-0 insertion sort
# otherwise.
if n_con <= _K:
# Load with sentinel for out-of-range lanes (pushes them to the end of ascending sort).
my_key = qd.cast(gs.qd_float(1.0e30), gs.qd_float)
Expand All @@ -991,26 +991,7 @@ def func_clamp_prune_and_sort_contacts_coop(
my_key = collider_state.contact_sort_key[tid, i_b]
my_idx = collider_state.contact_sort_idx[tid, i_b]

# 15 bitonic stages: (k, j) pairs walking the standard schedule. Stable compare (tiebreak on idx).
for k_log2 in qd.static(range(1, 6)):
k_mask = qd.static(1 << k_log2)
for j_log2 in qd.static(range(k_log2 - 1, -1, -1)):
j = qd.static(1 << j_log2)
partner = qd.u32(tid ^ j)
their_key = qd.simt.subgroup.shuffle(my_key, partner)
their_idx = qd.simt.subgroup.shuffle(my_idx, partner)
i_am_low = (tid & j) == 0
asc = (tid & k_mask) == 0
take_min = i_am_low == asc
their_lt_mine = (their_key < my_key) or (their_key == my_key and their_idx < my_idx)
if take_min:
if their_lt_mine:
my_key = their_key
my_idx = their_idx
else:
if not their_lt_mine and (their_key != my_key or their_idx != my_idx):
my_key = their_key
my_idx = their_idx
my_key, my_idx = qd.simt.subgroup.bitonic_sort_kv_tiled(my_key, my_idx, _LOG2_K)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we have an API that does not require specifying _LOG2_K? Like this could be the default value, because I guess block_dim can be queried from context no? if not, then we should probably expose it. It should not be that hard, since it is already possible to set it from python.


# Write back the sorted values for the real range.
if tid < n_con:
Expand Down
Loading