Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 42 additions & 37 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,48 +1111,53 @@ def to_reordered_spike_vector(
key = str(lexsort)

if key not in self._cached_lexsorted_spike_vector.keys():
spikes = self.to_spike_vector()
order = np.lexsort((spikes[lexsort[0]], spikes[lexsort[1]], spikes[lexsort[2]]))
ordered_spikes = spikes[order]
self._cached_lexsorted_spike_vector[key] = {}
self._cached_lexsorted_spike_vector[key]["ordered_spikes"] = ordered_spikes
self._cached_lexsorted_spike_vector[key]["order"] = order
from .sorting_tools import reorder_spike_vector_by_buckets

spikes = self.to_spike_vector()
num_units = len(self.unit_ids)
num_segments = self.get_num_segments()

# precompute the slices with nested search sorted
# Both supported `lexsort` keys are equivalent to grouping spikes into
# `num_units * num_segments` "buckets" while **preserving** ascending
# `sample_index`` within each bucket.
# (Within each bucket, samples are **already** sorted!)
#
# We can do this with a (stable) counting sort in O(N), if we know the
# bucket index for each spike. That's easy: the bucket index is just a
# straightforward linear combination of the `unit_index` and `segment_index`
# fields, with the order depending on the lexsort` key.
if lexsort == ("sample_index", "segment_index", "unit_index"):
# this case make spiketrain per unit compact in memory

slices = np.zeros((num_units, num_segments, 2), dtype=np.int64)
unit_slices = np.searchsorted(ordered_spikes["unit_index"], np.arange(num_units + 1), side="left")
for unit_index, unit_id in enumerate(self.unit_ids):
u0 = unit_slices[unit_index]
u1 = unit_slices[unit_index + 1]
seg_slices = np.searchsorted(
ordered_spikes[u0:u1]["segment_index"], np.arange(num_segments + 1), side="left"
)
for segment_index in range(num_segments):
s0 = seg_slices[segment_index]
s1 = seg_slices[segment_index + 1]
slices[unit_index, segment_index, :] = [u0 + s0, u0 + s1]

elif lexsort == ("sample_index", "unit_index", "segment_index"):
slices = np.zeros((num_segments, num_units, 2), dtype=np.int64)
seg_slices = np.searchsorted(ordered_spikes["segment_index"], np.arange(num_segments + 1), side="left")
for segment_index in range(self.get_num_segments()):
s0 = seg_slices[segment_index]
s1 = seg_slices[segment_index + 1]
unit_slices = np.searchsorted(
ordered_spikes[s0:s1]["unit_index"], np.arange(num_units + 1), side="left"
)
for unit_index, unit_id in enumerate(self.unit_ids):
u0 = unit_slices[unit_index]
u1 = unit_slices[unit_index + 1]
slices[segment_index, unit_index, :] = [s0 + u0, s0 + u1]

self._cached_lexsorted_spike_vector[key]["slices"] = slices
# primary key unit_index, then segment_index, then sample_index
bucket_index = spikes["unit_index"].astype(np.int64, copy=False) * num_segments + spikes[
"segment_index"
].astype(np.int64, copy=False)
num_buckets = num_units * num_segments
ordered_spikes, order, counts = reorder_spike_vector_by_buckets(spikes, bucket_index, num_buckets)
# counts is laid out as [unit_index * num_segments + segment_index]
counts_2d = counts.reshape(num_units, num_segments)

else: # ("sample_index", "unit_index", "segment_index")
# primary key segment_index, then unit_index, then sample_index
bucket_index = spikes["segment_index"].astype(np.int64, copy=False) * num_units + spikes[
"unit_index"
].astype(np.int64, copy=False)
num_buckets = num_segments * num_units
ordered_spikes, order, counts = reorder_spike_vector_by_buckets(spikes, bucket_index, num_buckets)
# counts is laid out as [segment_index * num_units + unit_index]
counts_2d = counts.reshape(num_segments, num_units)

# Build slices from cumulative counts. Stops are exclusive cumulative sums
# (aka "prefix sums" in the language of counting sort) shifted by one;
# starts are the same prefix without the last element.
ends = np.cumsum(counts_2d.ravel()).reshape(counts_2d.shape)
starts = ends - counts_2d
slices = np.stack([starts, ends], axis=-1).astype(np.int64, copy=False)

self._cached_lexsorted_spike_vector[key] = {
"ordered_spikes": ordered_spikes,
"order": order,
"slices": slices,
}

ordered_spikes = self._cached_lexsorted_spike_vector[key]["ordered_spikes"]
out = (ordered_spikes,)
Expand Down
Loading