Skip to content

Commit

Permalink
Optimize for-loop merging of cohorts. (#378)
Browse files Browse the repository at this point in the history
* Optimize for-loop merging of cohorts.

Do this by skipping perfect cohorts that we already know about.

* Add new benchmark

* Fix

* Cleanup print statements

* minimize diff

* cleanup

* Update snapshot
  • Loading branch information
dcherian authored Aug 2, 2024
1 parent cb3fc1f commit f8f34b9
Show file tree
Hide file tree
Showing 3 changed files with 7,945 additions and 15,876 deletions.
6 changes: 5 additions & 1 deletion asv_bench/benchmarks/cohorts.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def time_find_group_cohorts(self):
except AttributeError:
pass

def track_num_cohorts(self):
return len(self.chunks_cohorts())

def time_graph_construct(self):
flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis)

Expand All @@ -60,10 +63,11 @@ def track_num_tasks_optimized(self):
def track_num_layers(self):
return len(self.result.dask.layers)

track_num_cohorts.unit = "cohorts" # type: ignore[attr-defined] # Lazy
track_num_tasks.unit = "tasks" # type: ignore[attr-defined] # Lazy
track_num_tasks_optimized.unit = "tasks" # type: ignore[attr-defined] # Lazy
track_num_layers.unit = "layers" # type: ignore[attr-defined] # Lazy
for f in [track_num_tasks, track_num_tasks_optimized, track_num_layers]:
for f in [track_num_tasks, track_num_tasks_optimized, track_num_layers, track_num_cohorts]:
f.repeat = 1 # type: ignore[attr-defined] # Lazy
f.rounds = 1 # type: ignore[attr-defined] # Lazy
f.number = 1 # type: ignore[attr-defined] # Lazy
Expand Down
32 changes: 24 additions & 8 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def find_group_cohorts(
# Invert the label_chunks mapping so we know which labels occur together.
def invert(x) -> tuple[np.ndarray, ...]:
arr = label_chunks[x]
return tuple(arr)
return tuple(arr.tolist())

chunks_cohorts = tlz.groupby(invert, label_chunks.keys())

Expand Down Expand Up @@ -477,22 +477,37 @@ def invert(x) -> tuple[np.ndarray, ...]:
containment.nnz / math.prod(containment.shape)
)
)
# Use a threshold to force some merging. We do not use the filtered
# containment matrix for estimating "sparsity" because it is a bit
# hard to reason about.

# Next we for-loop over groups and merge those that are quite similar.
# Use a threshold on containment to always force some merging.
# Note that we do not use the filtered containment matrix for estimating "sparsity"
# because it is a bit hard to reason about.
MIN_CONTAINMENT = 0.75 # arbitrary
mask = containment.data < MIN_CONTAINMENT

# Now we also know "exact cohorts" -- cohorts whose constituent groups
# occur in exactly the same chunks. We only need examine one member of each group.
# Skip the others by first looping over the exact cohorts, and zero out those rows.
repeated = np.concatenate([v[1:] for v in chunks_cohorts.values()]).astype(int)
repeated_idx = np.searchsorted(present_labels, repeated)
for i in repeated_idx:
mask[containment.indptr[i] : containment.indptr[i + 1]] = True
containment.data[mask] = 0
containment.eliminate_zeros()

# Iterate over labels, beginning with those with most chunks
# Figure out all the labels we need to loop over later
n_overlapping_labels = containment.astype(bool).sum(axis=1)
order = np.argsort(n_overlapping_labels, kind="stable")[::-1]
# Order is such that we iterate over labels, beginning with those with most overlaps
# Also filter out any "exact" cohorts
order = order[n_overlapping_labels[order] > 0]

logger.debug("find_group_cohorts: merging cohorts")
order = np.argsort(containment.sum(axis=LABEL_AXIS), kind="stable")[::-1]
merged_cohorts = {}
merged_keys = set()
# TODO: we can optimize this to loop over chunk_cohorts instead
# by zeroing out rows that are already in a cohort
for rowidx in order:
if present_labels[rowidx] in merged_keys:
continue
cohidx = containment.indices[
slice(containment.indptr[rowidx], containment.indptr[rowidx + 1])
]
Expand All @@ -507,6 +522,7 @@ def invert(x) -> tuple[np.ndarray, ...]:

actual_ngroups = np.concatenate(tuple(merged_cohorts.values())).size
expected_ngroups = present_labels.size
assert len(merged_keys) == actual_ngroups
assert expected_ngroups == actual_ngroups, (expected_ngroups, actual_ngroups)

# sort by first label in cohort
Expand Down
Loading

0 comments on commit f8f34b9

Please sign in to comment.