From 1ea1d832f7f2aad849495dd8b1b9e3851ec0a87b Mon Sep 17 00:00:00 2001 From: Graham Findlay Date: Fri, 17 Apr 2026 12:47:47 -0500 Subject: [PATCH 01/10] Add tests for `sorting_tools.is_spike_vector_sorted()` --- .../core/tests/test_sorting_tools.py | 28 ++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_sorting_tools.py b/src/spikeinterface/core/tests/test_sorting_tools.py index 4194f459b3..84b6591582 100644 --- a/src/spikeinterface/core/tests/test_sorting_tools.py +++ b/src/spikeinterface/core/tests/test_sorting_tools.py @@ -13,6 +13,7 @@ _get_ids_after_merging, generate_unit_ids_for_merge_group, remap_unit_indices_in_vector, + is_spike_vector_sorted, ) from spikeinterface.core.base import minimum_spike_dtype @@ -45,6 +46,32 @@ def test_spike_vector_to_indices(): ) +def test_is_spike_vector_sorted(): + spikes = np.zeros(5, dtype=minimum_spike_dtype) + spikes["segment_index"] = [0, 0, 1, 1, 1] + spikes["sample_index"] = [100, 200, 0, 100, 100] + spikes["unit_index"] = [0, 1, 0, 0, 1] + assert is_spike_vector_sorted(spikes) + + segment_unsorted = spikes.copy() + segment_unsorted["segment_index"] = [0, 1, 0, 1, 1] + segment_unsorted["sample_index"] = [0, 100, 200, 300, 400] + segment_unsorted["unit_index"] = [0, 0, 0, 0, 0] + assert not is_spike_vector_sorted(segment_unsorted) + + sample_unsorted = spikes.copy() + sample_unsorted["segment_index"] = 0 + sample_unsorted["sample_index"] = [0, 100, 50, 200, 300] + sample_unsorted["unit_index"] = [0, 0, 0, 0, 0] + assert not is_spike_vector_sorted(sample_unsorted) + + tie_unsorted = spikes.copy() + tie_unsorted["segment_index"] = 0 + tie_unsorted["sample_index"] = [0, 100, 100, 200, 300] + tie_unsorted["unit_index"] = [0, 1, 0, 0, 0] + assert not is_spike_vector_sorted(tie_unsorted) + + def test_random_spikes_selection(): recording, sorting = generate_ground_truth_recording( durations=[20.0, 10.0], @@ -61,7 +88,6 @@ def test_random_spikes_selection(): random_spikes_indices = random_spikes_selection( sorting, num_samples, method="uniform", max_spikes_per_unit=max_spikes_per_unit, margin_size=None, seed=2205 ) - random_spikes_indices1 = random_spikes_indices spikes = sorting.to_spike_vector() some_spikes = spikes[random_spikes_indices] for unit_index, unit_id in enumerate(sorting.unit_ids): From 98fa0044f51bfcae3b977a8fd87a0c91a056b145 Mon Sep 17 00:00:00 2001 From: Graham Findlay Date: Fri, 17 Apr 2026 12:54:20 -0500 Subject: [PATCH 02/10] Leverage single-segment nature of BasePhyKilosortSortingExtractor to improve _compute_and_cache_spike_vector() Gains come from: - Avoiding an unnecessary `np.concatenate()` - Dropping `segment_index` from lexsort keys. Another gain, not related to the single-segment thing, is that `np.empty()` can replace `np.zeros()`. --- .../extractors/phykilosortextractors.py | 56 ++++++++--------- .../tests/test_phykilosortextractors.py | 61 +++++++++++++++++-- 2 files changed, 81 insertions(+), 36 deletions(-) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 1530072288..98ca881774 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -17,7 +17,7 @@ from spikeinterface.core.base import minimum_spike_dtype from spikeinterface.core.core_tools import define_function_from_class -from spikeinterface.postprocessing import ComputeSpikeAmplitudes, ComputeSpikeLocations +from spikeinterface.postprocessing import ComputeSpikeLocations from probeinterface import read_prb, Probe HAVE_NUMBA = importlib.util.find_spec("numba") is not None @@ -229,44 +229,38 @@ def __init__( self.add_sorting_segment(PhySortingSegment(spike_times_clean, spike_clusters_clean)) def _compute_and_cache_spike_vector(self) -> None: - """Build the spike vector directly from the flat per-segment arrays. + """Build the spike vector directly from the flat single-segment arrays. - Since Phy/Kilosort segments already hold the full spike_times and + Since Phy/Kilosort segment already holds the full spike_times and spike_clusters arrays in memory, we can construct the spike vector in one shot. """ + assert self.get_num_segments() == 1 + unit_ids = np.asarray(self.unit_ids) sorter = np.argsort(unit_ids) sorted_unit_ids = unit_ids[sorter] - num_seg = self.get_num_segments() - spikes_list = [] - segment_slices = np.zeros((num_seg, 2), dtype="int64") - pos = 0 - - for seg_idx in range(num_seg): - seg = self.segments[seg_idx] - all_spikes = seg._all_spikes - all_clusters = seg._all_clusters - - # Map cluster ids -> unit indices. `spike_clusters_clean` is guaranteed - # to only contain ids present in `self.unit_ids` (filtered in __init__), - # so searchsorted always returns a valid position. - unit_indices = sorter[np.searchsorted(sorted_unit_ids, all_clusters)] - - n = all_spikes.size - segment_slices[seg_idx] = [pos, pos + n] - pos += n - - seg_spikes = np.zeros(n, dtype=minimum_spike_dtype) - seg_spikes["sample_index"] = all_spikes - seg_spikes["unit_index"] = unit_indices - seg_spikes["segment_index"] = seg_idx - spikes_list.append(seg_spikes) - - spikes = np.concatenate(spikes_list) if spikes_list else np.zeros(0, dtype=minimum_spike_dtype) - # Canonical order: (segment_index, sample_index, unit_index). - order = np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"])) + seg = self.segments[0] + all_spikes = seg._all_spikes + all_clusters = seg._all_clusters + n = all_spikes.size + + # Map cluster ids -> unit indices. `spike_clusters_clean` is guaranteed + # to only contain ids present in `self.unit_ids` (filtered in __init__), + # so searchsorted always returns a valid position. + unit_indices = sorter[np.searchsorted(sorted_unit_ids, all_clusters)] + + segment_slices = np.array([[0, n]], dtype="int64") + spikes = np.empty(n, dtype=minimum_spike_dtype) + spikes["sample_index"] = all_spikes + spikes["unit_index"] = unit_indices + spikes["segment_index"] = 0 + + # Kilosort and Phy seem to always output spikes sorted by sample_index, but + # they DO NOT sort cluster_ids within a sample_index. + # No need to sort by segment_index since we know there's only one segment. + order = np.lexsort((spikes["unit_index"], spikes["sample_index"])) spikes = spikes[order] self._cached_spike_vector = spikes diff --git a/src/spikeinterface/extractors/tests/test_phykilosortextractors.py b/src/spikeinterface/extractors/tests/test_phykilosortextractors.py index 095e79cac1..3be84a341c 100644 --- a/src/spikeinterface/extractors/tests/test_phykilosortextractors.py +++ b/src/spikeinterface/extractors/tests/test_phykilosortextractors.py @@ -2,6 +2,7 @@ import numpy as np from spikeinterface.extractors.phykilosortextractors import PhySortingSegment +from spikeinterface.core.sorting_tools import is_spike_vector_sorted import spikeinterface.extractors.phykilosortextractors as phymod # Sorted spike times with known cluster assignments. @@ -45,29 +46,75 @@ def test_phy_sorting_segment_get_unit_spike_trains(monkeypatch, force_numpy_fall assert seg.get_unit_spike_trains([], start_frame=None, end_frame=None) == {} -def _make_phy_folder(tmp_path): +def _make_phy_folder(tmp_path, spike_times=None, spike_clusters=None, cluster_ids=None): """Create a minimal Phy output folder for testing.""" - spike_times = np.array([100, 100, 200, 300, 300, 300, 400, 500], dtype=np.int64) - spike_clusters = np.array([10, 20, 30, 10, 20, 30, 10, 20], dtype=np.int64) + if spike_times is None: + spike_times = np.array([100, 100, 200, 300, 300, 300, 400, 500], dtype=np.int64) + if spike_clusters is None: + spike_clusters = np.array([10, 20, 30, 10, 20, 30, 10, 20], dtype=np.int64) np.save(tmp_path / "spike_times.npy", spike_times) np.save(tmp_path / "spike_clusters.npy", spike_clusters) (tmp_path / "params.py").write_text("sample_rate = 30000.0\n") + if cluster_ids is not None: + cluster_lines = "\n".join(str(cluster_id) for cluster_id in cluster_ids) + (tmp_path / "cluster_info.tsv").write_text(f"cluster_id\n{cluster_lines}\n") return tmp_path -def test_phy_compute_and_cache_spike_vector(tmp_path): +@pytest.mark.parametrize( + ("spike_times", "spike_clusters", "cluster_ids"), + [ + pytest.param( + np.array([100, 200, 300, 400], dtype=np.int64), + np.array([20, 10, 30, 20], dtype=np.int64), + None, + id="canonical-no-cotemporal-ties", + ), + pytest.param( + np.array([100, 100, 200, 300, 300], dtype=np.int64), + np.array([10, 20, 30, 10, 30], dtype=np.int64), + None, + id="canonical-cotemporal-ties", + ), + pytest.param( + np.array([100, 100, 200, 300], dtype=np.int64), + np.array([20, 10, 30, 10], dtype=np.int64), + None, + id="cotemporal-ties-require-lexsort", + ), + pytest.param( + np.array([200, 100, 300, 100], dtype=np.int64), + np.array([10, 20, 30, 10], dtype=np.int64), + None, + id="sample-times-require-lexsort", + ), + pytest.param( + np.array([], dtype=np.int64), + np.array([], dtype=np.int64), + [10, 20], + id="empty-spike-vector", + ), + ], +) +def test_phy_compute_and_cache_spike_vector(tmp_path, spike_times, spike_clusters, cluster_ids): """Phy override of _compute_and_cache_spike_vector must produce the same spike vector as the base class (per-unit) implementation.""" from spikeinterface.core.basesorting import BaseSorting from spikeinterface.extractors.phykilosortextractors import BasePhyKilosortSortingExtractor - phy_folder = _make_phy_folder(tmp_path) + phy_folder = _make_phy_folder( + tmp_path, + spike_times=spike_times, + spike_clusters=spike_clusters, + cluster_ids=cluster_ids, + ) sorting = BasePhyKilosortSortingExtractor(phy_folder) # Phy override path sorting._compute_and_cache_spike_vector() phy_vector = sorting._cached_spike_vector.copy() + phy_segment_slices = sorting._cached_spike_vector_segment_slices.copy() # Base class (per-unit) path sorting._cached_spike_vector = None @@ -76,3 +123,7 @@ def test_phy_compute_and_cache_spike_vector(tmp_path): base_vector = sorting._cached_spike_vector assert np.array_equal(phy_vector, base_vector) + assert np.array_equal(phy_segment_slices, np.array([[0, len(phy_vector)]], dtype="int64")) + assert len(phy_vector) == len(spike_times) + assert np.all(phy_vector["segment_index"] == 0) + assert is_spike_vector_sorted(phy_vector) From 72d7395412f8caf79055fe77eaa985de82afd4db Mon Sep 17 00:00:00 2001 From: Graham Findlay Date: Fri, 17 Apr 2026 16:25:12 -0500 Subject: [PATCH 03/10] Make `is_spike_vector_sorted()` chunked, add early stopping, add `assume_single_segment` shortcut. Just checking whether the spike vector is sorted was allocating quite a bit of memory (for the diff arrays; ~3.25GB for 100M spikes). This chunked version only requires a constant amount of memory (~40MB). This version also adds early stopping, of sorts. That was accidental consequence of the chunking tbh, but it makes sense that it beats diff'ing every element in the spike vector (the 1M spike benchmark below is a good indicator). We also don't have to waste time/space checking the segment index if we know it is a single-segment vector. Depending on how pathological the spike vector is, and whether or not it is single-segment, the speedup is ~1.5x-2.0x for 50M+ spikes, and a bit larger (~1.8-2.5x) for 1M spikes. --- src/spikeinterface/core/sorting_tools.py | 97 +++++++++++++++---- .../core/tests/test_sorting_tools.py | 45 +++++++++ 2 files changed, 124 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index e94ebe2950..4580837b34 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -236,14 +236,14 @@ def random_spikes_selection( elif method == "percentage": if percentage is None or not (0 < percentage <= 1): - raise ValueError(f"percentage must be in the interval (0, 1]") + raise ValueError("percentage must be in the interval (0, 1]") rng_size = min(max_spikes_per_unit, int(all_unit_indices.size * percentage)) selected_unit_indices = rng.choice(all_unit_indices, size=rng_size, replace=False, shuffle=False) elif method == "maximum_rate": if maximum_rate is None: - raise ValueError(f"maximum_rate must be defined") + raise ValueError("maximum_rate must be defined") t_duration = np.sum(get_segment_durations(sorting)) rng_size = min(int(t_duration * maximum_rate), max_spikes_per_unit, all_unit_indices.size) @@ -998,26 +998,87 @@ def remap_unit_indices_in_vector(vector, all_old_unit_ids, all_new_unit_ids, kee return new_vector, keep_mask_vector -def is_spike_vector_sorted(spike_vector: np.ndarray) -> bool: +def is_spike_vector_sorted( + spike_vector: np.ndarray, *, chunk_size: int | None = 10_000_000, assume_single_segment: bool = False +) -> bool: """Return True iff the spike vector is sorted by (segment_index, sample_index, unit_index). - O(n) sequential scan. Used to avoid an O(n log n) lexsort when the vector already - happens to be in canonical order. + This is an O(n) sequential scan used to avoid an O(n log n) lexsort when the + vector already happens to be in canonical order. + + The strategy is: compare pairs of adjacent spikes in chunks to avoid allocating + (possibly big) temporary arrays for diffs. + + Each adjacent pair has to be fully "lexsorted": + + * segment_index is nondecreasing; + * within the same segment, sample_index is nondecreasing; + * within the same segment and same sample, unit_index is nondecreasing. + + Parameters + ---------- + spike_vector : np.ndarray + Spike vector with fields "sample_index", "unit_index", and + "segment_index". + chunk_size : int | None, default 10_000_000 + Number of adjacent pairs to check per chunk. None checks the full vector + in one chunk. + assume_single_segment : bool, default False + If True, skip segment_index checks and require only sample_index/unit_index + ordering. """ n = len(spike_vector) if n <= 1: return True - seg = spike_vector["segment_index"] - samp = spike_vector["sample_index"] - unit = spike_vector["unit_index"] - d_seg = np.diff(seg) - if np.any(d_seg < 0): - return False - seg_eq = d_seg == 0 - d_samp = np.diff(samp) - if np.any(d_samp[seg_eq] < 0): - return False - samp_eq = seg_eq & (d_samp == 0) - if np.any(np.diff(unit)[samp_eq] < 0): - return False + + if chunk_size is None: + chunk_size = n - 1 + elif chunk_size < 1: + raise ValueError("chunk_size must be >= 1 or None") + + sample_index = spike_vector["sample_index"] + unit_index = spike_vector["unit_index"] + + if assume_single_segment: + for start in range(0, n - 1, chunk_size): + stop = min(start + chunk_size, n - 1) + + # Compare each sample_index value to the following one. The shifted + # slices have equal length and represent adjacent spike pairs. + sample0 = sample_index[start:stop] + sample1 = sample_index[start + 1 : stop + 1] + if np.any(sample1 < sample0): + return False + + # Unit order only matters for cotemporal (same sample) spikes + same_sample = sample1 == sample0 + if np.any((unit_index[start + 1 : stop + 1] < unit_index[start:stop]) & same_sample): + return False + + return True + + segment_index = spike_vector["segment_index"] + + for start in range(0, n - 1, chunk_size): + stop = min(start + chunk_size, n - 1) + + # First enforce segment ordering. Later checks are masked to adjacent + # pairs in the same segment because sample/unit ordering is segment-local. + segment0 = segment_index[start:stop] + segment1 = segment_index[start + 1 : stop + 1] + if np.any(segment1 < segment0): + return False + + same_segment = segment1 == segment0 + + sample0 = sample_index[start:stop] + sample1 = sample_index[start + 1 : stop + 1] + if np.any((sample1 < sample0) & same_segment): + return False + + # Unit order is only part of canonical order for cotemporal spikes. + same_sample = same_segment & (sample1 == sample0) + if np.any((unit_index[start + 1 : stop + 1] < unit_index[start:stop]) & same_sample): + return False + return True diff --git a/src/spikeinterface/core/tests/test_sorting_tools.py b/src/spikeinterface/core/tests/test_sorting_tools.py index 84b6591582..dcc3386b55 100644 --- a/src/spikeinterface/core/tests/test_sorting_tools.py +++ b/src/spikeinterface/core/tests/test_sorting_tools.py @@ -47,11 +47,19 @@ def test_spike_vector_to_indices(): def test_is_spike_vector_sorted(): + empty_spikes = np.zeros(0, dtype=minimum_spike_dtype) + assert is_spike_vector_sorted(empty_spikes) + + one_spike = np.zeros(1, dtype=minimum_spike_dtype) + assert is_spike_vector_sorted(one_spike) + spikes = np.zeros(5, dtype=minimum_spike_dtype) spikes["segment_index"] = [0, 0, 1, 1, 1] spikes["sample_index"] = [100, 200, 0, 100, 100] spikes["unit_index"] = [0, 1, 0, 0, 1] assert is_spike_vector_sorted(spikes) + assert is_spike_vector_sorted(spikes, chunk_size=None) + assert is_spike_vector_sorted(spikes, chunk_size=1) segment_unsorted = spikes.copy() segment_unsorted["segment_index"] = [0, 1, 0, 1, 1] @@ -71,6 +79,43 @@ def test_is_spike_vector_sorted(): tie_unsorted["unit_index"] = [0, 1, 0, 0, 0] assert not is_spike_vector_sorted(tie_unsorted) + with pytest.raises(ValueError, match="chunk_size"): + is_spike_vector_sorted(spikes, chunk_size=0) + + +def test_is_spike_vector_sorted_chunk_boundaries(): + spikes = np.zeros(6, dtype=minimum_spike_dtype) + spikes["segment_index"] = [0, 0, 1, 0, 1, 1] + spikes["sample_index"] = [0, 100, 200, 300, 400, 500] + spikes["unit_index"] = 0 + assert not is_spike_vector_sorted(spikes, chunk_size=3) + + spikes["segment_index"] = 0 + spikes["sample_index"] = [0, 100, 300, 200, 400, 500] + assert not is_spike_vector_sorted(spikes, chunk_size=3) + + spikes["sample_index"] = [0, 100, 200, 200, 400, 500] + spikes["unit_index"] = [0, 0, 1, 0, 0, 0] + assert not is_spike_vector_sorted(spikes, chunk_size=3) + + +def test_is_spike_vector_sorted_assume_single_segment(): + spikes = np.zeros(5, dtype=minimum_spike_dtype) + spikes["segment_index"] = [0, 1, 0, 1, 1] + spikes["sample_index"] = [0, 100, 200, 300, 400] + spikes["unit_index"] = [0, 0, 0, 0, 0] + assert not is_spike_vector_sorted(spikes) + assert is_spike_vector_sorted(spikes, assume_single_segment=True) + + sample_unsorted = spikes.copy() + sample_unsorted["sample_index"] = [0, 100, 50, 200, 300] + assert not is_spike_vector_sorted(sample_unsorted, assume_single_segment=True) + + tie_unsorted = spikes.copy() + tie_unsorted["sample_index"] = [0, 100, 100, 200, 300] + tie_unsorted["unit_index"] = [0, 1, 0, 0, 0] + assert not is_spike_vector_sorted(tie_unsorted, assume_single_segment=True) + def test_random_spikes_selection(): recording, sorting = generate_ground_truth_recording( From 79285fb0883dbd8c13cfa56c75ddd6e623c5bb4e Mon Sep 17 00:00:00 2001 From: Graham Findlay Date: Fri, 17 Apr 2026 16:30:53 -0500 Subject: [PATCH 04/10] Leverage single-segment shortcuts when possible in `UnitSelectionSorting._compute_and_cache_spike_vector()`. We can easily check whether the USS is single-segement, and improve both the sortedness check and the lexsorting if so. --- .../core/tests/test_unitsselectionsorting.py | 28 +++++++++---------- .../core/unitsselectionsorting.py | 14 +++++++--- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/core/tests/test_unitsselectionsorting.py b/src/spikeinterface/core/tests/test_unitsselectionsorting.py index b5a7965ca5..6096cd73d3 100644 --- a/src/spikeinterface/core/tests/test_unitsselectionsorting.py +++ b/src/spikeinterface/core/tests/test_unitsselectionsorting.py @@ -1,9 +1,9 @@ import pytest import numpy as np -from pathlib import Path from spikeinterface.core import UnitsSelectionSorting from spikeinterface.core.numpyextractors import NumpySorting +from spikeinterface.core.sorting_tools import is_spike_vector_sorted from spikeinterface.core.generate import generate_sorting @@ -41,7 +41,7 @@ def test_failure_with_non_unique_unit_ids(): seed = 10 sorting = generate_sorting(num_units=3, durations=[0.100], sampling_frequency=30000.0, seed=seed) with pytest.raises(AssertionError): - sorting2 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"], renamed_unit_ids=["a", "a"]) + UnitsSelectionSorting(sorting, unit_ids=["0", "2"], renamed_unit_ids=["a", "a"]) def test_compute_and_cache_spike_vector(): @@ -89,6 +89,8 @@ def test_uss_get_unit_spike_trains_with_renamed_ids(use_cache): def test_spike_vector_sorted_after_reorder_with_cotemporal_spikes(): """USS spike vector must be correctly sorted even when selection reverses unit order and co-temporal spikes exist (same sample_index, different units).""" + from spikeinterface.core.basesorting import BaseSorting + # Build a sorting with guaranteed co-temporal spikes: # units 0, 1, 2 all fire at sample 100 and 200 samples = np.array([100, 100, 100, 200, 200, 200, 300, 400], dtype=np.int64) @@ -102,19 +104,15 @@ def test_spike_vector_sorted_after_reorder_with_cotemporal_spikes(): spike_vector = sub.to_spike_vector() - # Spike vector must be sorted by (segment_index, sample_index, unit_index) - n = len(spike_vector) - if n > 1: - seg = spike_vector["segment_index"] - samp = spike_vector["sample_index"] - unit = spike_vector["unit_index"] - d_seg = np.diff(seg) - assert np.all(d_seg >= 0), "segment_index not non-decreasing" - seg_eq = d_seg == 0 - d_samp = np.diff(samp) - assert np.all(d_samp[seg_eq] >= 0), "sample_index not non-decreasing within segment" - samp_eq = seg_eq & (d_samp == 0) - assert np.all(np.diff(unit)[samp_eq] >= 0), "unit_index not non-decreasing within same sample" + sub._cached_spike_vector = None + sub._cached_spike_vector_segment_slices = None + BaseSorting._compute_and_cache_spike_vector(sub) + base_vector = sub._cached_spike_vector + + assert np.array_equal(spike_vector, base_vector) + assert np.all(spike_vector["segment_index"] == 0) + assert is_spike_vector_sorted(spike_vector) + if __name__ == "__main__": diff --git a/src/spikeinterface/core/unitsselectionsorting.py b/src/spikeinterface/core/unitsselectionsorting.py index 2e9e4df5d8..bf27b38733 100644 --- a/src/spikeinterface/core/unitsselectionsorting.py +++ b/src/spikeinterface/core/unitsselectionsorting.py @@ -67,10 +67,16 @@ def _compute_and_cache_spike_vector(self) -> None: # relative order as in the parent (an O(k) check). If not, the vector may still # happen to be sorted -- verify with an O(n) scan before falling back to O(n log n) # lexsort. - if not self._is_order_preserving_selection() and not is_spike_vector_sorted(spike_vector): - sort_indices = np.lexsort( - (spike_vector["unit_index"], spike_vector["sample_index"], spike_vector["segment_index"]) - ) + assume_single_segment = self.get_num_segments() == 1 + if not self._is_order_preserving_selection() and not is_spike_vector_sorted( + spike_vector, assume_single_segment=assume_single_segment + ): + if assume_single_segment: + sort_indices = np.lexsort((spike_vector["unit_index"], spike_vector["sample_index"])) + else: + sort_indices = np.lexsort( + (spike_vector["unit_index"], spike_vector["sample_index"], spike_vector["segment_index"]) + ) spike_vector = spike_vector[sort_indices] self._cached_spike_vector = spike_vector From c495e3ae231d6235ea07940307953b7da8e08545 Mon Sep 17 00:00:00 2001 From: Graham Findlay Date: Wed, 20 May 2026 11:32:28 -0500 Subject: [PATCH 05/10] Optimize Phy/Kilosort `_compute_and_cache_spike_vector()`. Two main improvements: 1. The cluster_id -> unit_index mapping now uses a dense `cluster_to_unit` lookup table: O(N). Replaces `np.searchsorted`, which was O(N log M). On 392M spikes and 342 units, this is 27s -> 3s. 2. A numba kernel builds the spike vector in one pass. This is possible because (a) the input arrays are already sorted by sample_index, and the only remaining tie-breaking is by unit_index within each sample_index run, and (b) the runs are short (single-digit number of spikes ) and rare (single-digit percentage of total spikes). On 392M spikes and 342 units, ~170s -> ~6-10s. --- src/spikeinterface/core/sorting_tools.py | 156 +++++++++++++++++- .../core/tests/test_sorting_tools.py | 107 ++++++++++++ .../extractors/phykilosortextractors.py | 40 ++--- 3 files changed, 280 insertions(+), 23 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 4580837b34..69a77cc114 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -5,7 +5,7 @@ import numpy as np -from spikeinterface.core.base import BaseExtractor, unit_period_dtype +from spikeinterface.core.base import BaseExtractor, minimum_spike_dtype, unit_period_dtype from spikeinterface.core.basesorting import BaseSorting from spikeinterface.core.numpyextractors import NumpySorting @@ -1006,8 +1006,8 @@ def is_spike_vector_sorted( This is an O(n) sequential scan used to avoid an O(n log n) lexsort when the vector already happens to be in canonical order. - The strategy is: compare pairs of adjacent spikes in chunks to avoid allocating - (possibly big) temporary arrays for diffs. + The strategy is: compare pairs of adjacent spikes in chunks to avoid allocating + (possibly big) temporary arrays for diffs. Each adjacent pair has to be fully "lexsorted": @@ -1068,7 +1068,7 @@ def is_spike_vector_sorted( segment1 = segment_index[start + 1 : stop + 1] if np.any(segment1 < segment0): return False - + same_segment = segment1 == segment0 sample0 = sample_index[start:stop] @@ -1082,3 +1082,151 @@ def is_spike_vector_sorted( return False return True + + +def build_spike_vector_from_sorted_arrays( + sample_indices: np.ndarray, + unit_indices: np.ndarray, + segment_index: int = 0, +) -> np.ndarray: + """Build a `minimum_spike_dtype` spike vector when sample_indices is already sorted. + + Some sorting extractors (notably Phy/Kilosort) hold their spike samples in + a flat array that is already monotonic non-decreasing in `sample_index`. + Building the spike vector then only requires sorting `unit_index` *within* + runs of equal `sample_index` — a single O(N) pass, instead of an O(N log N) + global `np.lexsort`. + + Parameters + ---------- + sample_indices : np.ndarray + 1-D integer array of spike sample positions. Expected to be monotonic + non-decreasing; if a violation is detected the function falls back to + a global lexsort so the result is still correct. + unit_indices : np.ndarray + 1-D integer array, parallel to `sample_indices`, giving each spike's + unit index (position in the parent sorting's `unit_ids`). + segment_index : int, default 0 + Value to broadcast into the output `segment_index` field. + + Returns + ------- + spikes : np.ndarray + Structured array of length `sample_indices.size` with dtype + `minimum_spike_dtype`. The ordering is identical to what you would + get by building the structured array from the inputs and then + applying ``np.lexsort((unit_indices, sample_indices))`` — i.e. + primary key `sample_index` ascending, secondary key `unit_index` + ascending within ties. + """ + n = sample_indices.size + if unit_indices.size != n: + raise ValueError(f"sample_indices and unit_indices must have the same length; got {n} and {unit_indices.size}.") + + if n == 0: + return np.empty(0, dtype=minimum_spike_dtype) + + # Since the numba kernel is compiled for int64, this ensures we don't re-JIT if, + # for examples, the caller passes unit ids as int32. More importantly, this allows + # the kernel to index with a constant stride no matter what (e.g. if the caller + # passes a non-contiguous view like `arr[::2]`), and costs nothing if no-op. + sample_arr = np.ascontiguousarray(sample_indices, dtype=np.int64) + unit_arr = np.ascontiguousarray(unit_indices, dtype=np.int64) + + if HAVE_NUMBA: + # Allocate the output as a flat (N, 3) int64 buffer and let one numba + # kernel pass do everything: monotonicity check, unit-index + # tie resolution, and writing all three fields. + flat = np.empty((n, 3), dtype=np.int64) + is_monotonic = _build_spike_vector_kernel(sample_arr, unit_arr, int(segment_index), flat) + if is_monotonic: + # NB: This is zero-copy, becuase the (N, 3) int64 layout matches + # `minimum_spike_dtype` exactly. + return flat.view(minimum_spike_dtype).reshape(n) + + # Fallback: caller's sample_indices invariant did not hold (or numba + # is unavailable). Do a global lexsort. ='( + spikes = np.empty(n, dtype=minimum_spike_dtype) + spikes["segment_index"] = segment_index + order = np.lexsort((unit_arr, sample_arr)) + spikes["sample_index"] = sample_arr[order] + spikes["unit_index"] = unit_arr[order] + return spikes + + +if HAVE_NUMBA: + import numba + + @numba.jit(nopython=True, nogil=True, cache=True) + def _build_spike_vector_kernel(sample_indices, unit_indices, segment_index, flat_out): + """Single-pass build of a `minimum_spike_dtype` spike vector. + + Walks `sample_indices` once and, for each spike, writes all three + fields (sample_index, unit_index, segment_index) into `flat_out` + — a contiguous (N, 3) int64 buffer whose memory layout matches + `minimum_spike_dtype` exactly. + + While walking, the kernel also: + * verifies `sample_indices` is monotonic non-decreasing — returns + False at the first violation (caller falls back to a global + lexsort), + * insertion-sorts `unit_indices` within each run of equal + `sample_indices` before emitting that run. Runs of length 1 + (the common case for Kilosort/Phy output) skip the sort + entirely. + """ + n = sample_indices.shape[0] + i = 0 + # Walk one tie-run at a time: [i, j) is the next run of equal sample_indices. + while i < n: + # Monotonicity guard — bail out and lexsort instead of it fails + if i > 0 and sample_indices[i] < sample_indices[i - 1]: + return False + + # Find the end of the current tie-run. + j = i + 1 + while j < n and sample_indices[j] == sample_indices[i]: + j += 1 + + sample = sample_indices[i] + if j - i == 1: + # Fast path: no co-temporal spikes. + # Column order (0, 1, 2) matches minimum_spike_dtype field order + # (sample_index, unit_index, segment_index). Essential! + flat_out[i, 0] = sample + flat_out[i, 1] = unit_indices[i] + flat_out[i, 2] = segment_index + else: + # Tied run: sort unit_indices within the run. + # In practice, runs are short (single-digit numbers of spikes) and rare + # (single-digit percentage of total spikes), so the per-run allocation + # + insertion sort are cheap. + run_len = j - i + + # Stage the tied unit_indices into a small working buffer. + # This _could_ be expensive if the runs were long, but I tested + # with/without, and this _always_ wins. + buf = np.empty(run_len, dtype=np.int64) + for k in range(run_len): + buf[k] = unit_indices[i + k] + + # Insertion-sort the buffer in place. Beats anything fancier on + # tiny arrays (maybe because there is zero setup cost). + for k in range(1, run_len): + key = buf[k] + m = k - 1 + while m >= 0 and buf[m] > key: + buf[m + 1] = buf[m] + m -= 1 + buf[m + 1] = key + + # Emit the run with sorted unit_indices. + for k in range(run_len): + flat_out[i + k, 0] = sample + flat_out[i + k, 1] = buf[k] + flat_out[i + k, 2] = segment_index + + # Advance past the run we just emitted. + i = j + + return True diff --git a/src/spikeinterface/core/tests/test_sorting_tools.py b/src/spikeinterface/core/tests/test_sorting_tools.py index dcc3386b55..6d197d9e47 100644 --- a/src/spikeinterface/core/tests/test_sorting_tools.py +++ b/src/spikeinterface/core/tests/test_sorting_tools.py @@ -14,6 +14,7 @@ generate_unit_ids_for_merge_group, remap_unit_indices_in_vector, is_spike_vector_sorted, + build_spike_vector_from_sorted_arrays, ) from spikeinterface.core.base import minimum_spike_dtype @@ -117,6 +118,112 @@ def test_is_spike_vector_sorted_assume_single_segment(): assert not is_spike_vector_sorted(tie_unsorted, assume_single_segment=True) +def _reference_spike_vector(sample_indices, unit_indices, segment_index=0): + """Reference implementation: global lexsort, used as ground truth in tests.""" + n = sample_indices.size + spikes = np.empty(n, dtype=minimum_spike_dtype) + spikes["sample_index"] = sample_indices + spikes["unit_index"] = unit_indices + spikes["segment_index"] = segment_index + order = np.lexsort((spikes["unit_index"], spikes["sample_index"])) + return spikes[order] + + +@pytest.fixture(params=[True, False], ids=["numba", "numpy"]) +def force_numba(request, monkeypatch): + """Run each test once with numba enabled (if installed) and once with the fallback.""" + if request.param and importlib.util.find_spec("numba") is None: + pytest.skip("numba not installed") + monkeypatch.setattr("spikeinterface.core.sorting_tools.HAVE_NUMBA", request.param) + return request.param + + +def test_build_spike_vector_no_ties(force_numba): + sample_indices = np.array([10, 20, 30, 40, 50], dtype=np.int64) + unit_indices = np.array([3, 1, 4, 1, 5], dtype=np.int64) + out = build_spike_vector_from_sorted_arrays(sample_indices, unit_indices) + assert out.dtype == np.dtype(minimum_spike_dtype) + assert np.array_equal(out["sample_index"], sample_indices) + assert np.array_equal(out["unit_index"], unit_indices) + assert np.all(out["segment_index"] == 0) + + +def test_build_spike_vector_with_ties(force_numba): + # Three runs of ties (lengths 3, 2, 1, 4) with shuffled unit_indices + sample_indices = np.array( + [10, 10, 10, 20, 20, 30, 40, 40, 40, 40, 50], + dtype=np.int64, + ) + unit_indices = np.array([7, 2, 5, 9, 1, 3, 4, 0, 8, 2, 6], dtype=np.int64) + out = build_spike_vector_from_sorted_arrays(sample_indices, unit_indices) + ref = _reference_spike_vector(sample_indices, unit_indices) + assert np.array_equal(out, ref) + + +def test_build_spike_vector_all_same_sample_index(force_numba): + n = 64 + sample_indices = np.full(n, 42, dtype=np.int64) + rng = np.random.default_rng(0) + unit_indices = rng.permutation(n).astype(np.int64) + out = build_spike_vector_from_sorted_arrays(sample_indices, unit_indices) + ref = _reference_spike_vector(sample_indices, unit_indices) + assert np.array_equal(out, ref) + + +def test_build_spike_vector_ties_at_edges(force_numba): + # Ties at the very start, the very end, and an isolated single in between. + sample_indices = np.array([5, 5, 5, 9, 12, 12, 12], dtype=np.int64) + unit_indices = np.array([2, 0, 1, 7, 3, 1, 2], dtype=np.int64) + out = build_spike_vector_from_sorted_arrays(sample_indices, unit_indices) + ref = _reference_spike_vector(sample_indices, unit_indices) + assert np.array_equal(out, ref) + + +def test_build_spike_vector_empty(force_numba): + out = build_spike_vector_from_sorted_arrays( + np.array([], dtype=np.int64), + np.array([], dtype=np.int64), + ) + assert out.size == 0 + assert out.dtype == np.dtype(minimum_spike_dtype) + + +def test_build_spike_vector_segment_index(force_numba): + sample_indices = np.array([0, 1, 2], dtype=np.int64) + unit_indices = np.array([0, 0, 0], dtype=np.int64) + out = build_spike_vector_from_sorted_arrays(sample_indices, unit_indices, segment_index=3) + assert np.all(out["segment_index"] == 3) + + +def test_build_spike_vector_length_mismatch(): + with pytest.raises(ValueError): + build_spike_vector_from_sorted_arrays( + np.array([1, 2, 3], dtype=np.int64), + np.array([1, 2], dtype=np.int64), + ) + + +def test_build_spike_vector_randomized_against_lexsort(force_numba): + rng = np.random.default_rng(1234) + n = 10_000 + # Build ~30% ties by drawing sample positions from a small space. + sample_indices = np.sort(rng.integers(0, n // 3, size=n).astype(np.int64)) + unit_indices = rng.integers(0, 200, size=n).astype(np.int64) + out = build_spike_vector_from_sorted_arrays(sample_indices, unit_indices) + ref = _reference_spike_vector(sample_indices, unit_indices) + assert np.array_equal(out, ref) + + +def test_build_spike_vector_unsorted_falls_back(force_numba): + # Caller violates the "sample_indices is sorted" invariant; helper must + # still return a globally lexsorted vector via the fallback. + sample_indices = np.array([200, 100, 300, 100], dtype=np.int64) + unit_indices = np.array([0, 1, 2, 0], dtype=np.int64) + out = build_spike_vector_from_sorted_arrays(sample_indices, unit_indices) + ref = _reference_spike_vector(sample_indices, unit_indices) + assert np.array_equal(out, ref) + + def test_random_spikes_selection(): recording, sorting = generate_ground_truth_recording( durations=[20.0, 10.0], diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 98ca881774..34d332b8d4 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -14,8 +14,8 @@ create_sorting_analyzer, SortingAnalyzer, ) -from spikeinterface.core.base import minimum_spike_dtype from spikeinterface.core.core_tools import define_function_from_class +from spikeinterface.core.sorting_tools import build_spike_vector_from_sorted_arrays from spikeinterface.postprocessing import ComputeSpikeLocations from probeinterface import read_prb, Probe @@ -238,30 +238,32 @@ def _compute_and_cache_spike_vector(self) -> None: assert self.get_num_segments() == 1 unit_ids = np.asarray(self.unit_ids) - sorter = np.argsort(unit_ids) - sorted_unit_ids = unit_ids[sorter] - seg = self.segments[0] all_spikes = seg._all_spikes all_clusters = seg._all_clusters n = all_spikes.size - # Map cluster ids -> unit indices. `spike_clusters_clean` is guaranteed - # to only contain ids present in `self.unit_ids` (filtered in __init__), - # so searchsorted always returns a valid position. - unit_indices = sorter[np.searchsorted(sorted_unit_ids, all_clusters)] - + # Map cluster ids -> unit indices via a direct lookup table. + # cluster_ids are non-negative integers (Phy/Kilosort convention) and + # the max id is small (one per neural unit), so a "dense" table of size + # max_id + 1 is cheap (kilobytes), even though it reserves space for unit ids + # that don't exist, and lets the mapping run in a single O(N) gather. + # This is ~10x faster than `sorter[searchsorted(sorted_unit_ids, all_clusters)]` + # on large N. + max_id = int(max(unit_ids.max() if unit_ids.size else -1, all_clusters.max() if n else -1)) + cluster_to_unit = np.empty(max_id + 1, dtype=np.int64) + cluster_to_unit[unit_ids] = np.arange(unit_ids.size, dtype=np.int64) + unit_indices = cluster_to_unit[all_clusters] + + # Kilosort/Phy always emit spikes ascending in sample_index but DO NOT + # order cluster_ids within a sample_index. The helper sorts unit_index + # within tied sample_index runs in O(N), avoiding a global lexsort. + spikes = build_spike_vector_from_sorted_arrays( + sample_indices=all_spikes, + unit_indices=unit_indices, + segment_index=0, + ) segment_slices = np.array([[0, n]], dtype="int64") - spikes = np.empty(n, dtype=minimum_spike_dtype) - spikes["sample_index"] = all_spikes - spikes["unit_index"] = unit_indices - spikes["segment_index"] = 0 - - # Kilosort and Phy seem to always output spikes sorted by sample_index, but - # they DO NOT sort cluster_ids within a sample_index. - # No need to sort by segment_index since we know there's only one segment. - order = np.lexsort((spikes["unit_index"], spikes["sample_index"])) - spikes = spikes[order] self._cached_spike_vector = spikes self._cached_spike_vector_segment_slices = segment_slices From 9a728708846b6a9c034e65a0de10c8a0f1997020 Mon Sep 17 00:00:00 2001 From: Graham Findlay Date: Wed, 20 May 2026 12:19:55 -0500 Subject: [PATCH 06/10] Further optimize Phy/Kilosort `get_unit_spike_trains()` This fixes a performance regression (relative to my prototype, not relative to a previous commit) in Phy/Kilosort's `get_unit_spike_trains()`. On 392M spikes and 342 clusters, this takes the numba implementation from ~35s down to ~5s, and the numpy implementation from ~110s down to ~80s. --- .../extractors/phykilosortextractors.py | 39 ++++++++++--------- 1 file changed, 21 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 34d332b8d4..e1b5c98414 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -306,33 +306,36 @@ def get_unit_spike_trains( clusters = self._all_clusters[start:end] unit_ids_arr = np.asarray(unit_ids) - num_units = len(unit_ids_arr) + num_units = unit_ids_arr.size if num_units == 0: return {} - # Map each spike's cluster id to a destination index in the caller-supplied - # unit_ids order. -1 means "this spike's cluster is not in unit_ids, skip it". - sorter = np.argsort(unit_ids_arr, kind="stable") - sorted_unit_ids = unit_ids_arr[sorter] - idx_in_sorted = np.searchsorted(sorted_unit_ids, clusters, side="left") - idx_clamped = np.minimum(idx_in_sorted, num_units - 1) - matches = (idx_in_sorted < num_units) & (sorted_unit_ids[idx_clamped] == clusters) - dest = np.where(matches, sorter[idx_clamped], -1).astype(np.int64) - - spikes_i64 = np.ascontiguousarray(spikes, dtype=np.int64) + # Map cluster ids -> unit indices via a direct lookup table. + # See `_compute_and_cache_spike_vector()`. + max_id = int(max(unit_ids_arr.max(), clusters.max() if clusters.size else -1)) + cluster_to_dest = np.full(max_id + 1, -1, dtype=np.int64) + cluster_to_dest[unit_ids_arr] = np.arange(num_units, dtype=np.int64) + dest = cluster_to_dest[clusters] if HAVE_NUMBA: - offsets, flat_out = _counting_sort_spikes_by_unit(spikes_i64, dest, num_units) + offsets, flat_out = _counting_sort_spikes_by_unit(spikes, dest, num_units) else: # NumPy fallback: stable argsort by destination index, then split on offsets. # Stable sort preserves the input order of spikes within each unit group, # and since _all_spikes is sorted by sample_index, so is each group. - valid = dest >= 0 - valid_spikes = spikes_i64[valid] - valid_dest = dest[valid] - order = np.argsort(valid_dest, kind="stable") - flat_out = valid_spikes[order] - counts = np.bincount(valid_dest, minlength=num_units) + if dest.size and dest.min() >= 0: + # Trick: Every cluster in `clusters` is in `unit_ids`, so no + # boolean-mask filtering is needed. Skips two N-sized copies. + order = np.argsort(dest, kind="stable") + flat_out = spikes[order] + counts = np.bincount(dest, minlength=num_units) + else: + valid = dest >= 0 + valid_spikes = spikes[valid] + valid_dest = dest[valid] + order = np.argsort(valid_dest, kind="stable") + flat_out = valid_spikes[order] + counts = np.bincount(valid_dest, minlength=num_units) offsets = np.empty(num_units + 1, dtype=np.int64) offsets[0] = 0 np.cumsum(counts, out=offsets[1:]) From 74d871aadff7649543b8c4518c99e63e07539b86 Mon Sep 17 00:00:00 2001 From: Graham Findlay Date: Wed, 20 May 2026 12:43:26 -0500 Subject: [PATCH 07/10] Phy/Kilosort skips bad cluster removal if there are no bad clusters. Use Samuel's trick from PR 4579: `__init__()` doesn't waste time trying to remove spikes and unit ids from "bad clusters" from the full flat (`.npy``) arrays on load (`read_phy()``), if there aren't any bad clusters to begin with. --- .../extractors/phykilosortextractors.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index e1b5c98414..1a5e610a44 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -156,9 +156,14 @@ def __init__( # update spike clusters and times values bad_clusters = [clust for clust in clust_id if clust not in cluster_info["cluster_id"].values] - spike_clusters_clean_idxs = ~np.isin(spike_clusters, bad_clusters) - spike_clusters_clean = spike_clusters[spike_clusters_clean_idxs] - spike_times_clean = spike_times[spike_clusters_clean_idxs] + if len(bad_clusters) > 0: + spike_clusters_clean_idxs = ~np.isin(spike_clusters, bad_clusters) + spike_clusters_clean = spike_clusters[spike_clusters_clean_idxs] + spike_times_clean = spike_times[spike_clusters_clean_idxs] + else: + # No bad clusters — skip the O(N) isin mask and two N-sized copies. + spike_clusters_clean = spike_clusters + spike_times_clean = spike_times if "si_unit_id" in cluster_info.columns: unit_ids = cluster_info["si_unit_id"].values @@ -311,7 +316,7 @@ def get_unit_spike_trains( return {} # Map cluster ids -> unit indices via a direct lookup table. - # See `_compute_and_cache_spike_vector()`. + # See `_compute_and_cache_spike_vector()`. max_id = int(max(unit_ids_arr.max(), clusters.max() if clusters.size else -1)) cluster_to_dest = np.full(max_id + 1, -1, dtype=np.int64) cluster_to_dest[unit_ids_arr] = np.arange(num_units, dtype=np.int64) From 8d288def9bc5d62305703ec19d99837ec596de90 Mon Sep 17 00:00:00 2001 From: Graham Findlay Date: Wed, 20 May 2026 19:22:12 -0500 Subject: [PATCH 08/10] Optimize UnitSelectionSorting.to_spike_vector() This uses essentially the same dense LUT + flat-view buffer trick as the Phy/Kilosort `_compute_and_cache_spike_vector` method, in order to allow remapping and filtering the parent's spike vector in a single pass, without needing to create any intermediate allocations. On 392M spikes, selecting 258 out of 342 units in a single pass, starting from a cached parent spike vector, this reduces the additional time needed to get the USS spike vector to ~5.5s. --- src/spikeinterface/core/sorting_tools.py | 92 ++++++++++++++++++- .../core/tests/test_sorting_tools.py | 76 +++++++++++++++ .../core/unitsselectionsorting.py | 27 +++--- 3 files changed, 182 insertions(+), 13 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 69a77cc114..4409c72b8d 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -949,6 +949,9 @@ def remap_unit_indices_in_vector(vector, all_old_unit_ids, all_new_unit_ids, kee * select unit and recompute quickly the "unit_index" in the spike vector * merging/spliting periods or spikes and update the "unit_index" in the vector + Do not use this if you are operating on `minimum_spike_dtype` inputs! In such cases, + it is much more efficient to use a dense LUT + `filter_and_remap_spike_vector()` + (see `UnitSelectionSorting._compute_and_cache_spike_vector()`). Parameters ---------- @@ -1140,7 +1143,7 @@ def build_spike_vector_from_sorted_arrays( flat = np.empty((n, 3), dtype=np.int64) is_monotonic = _build_spike_vector_kernel(sample_arr, unit_arr, int(segment_index), flat) if is_monotonic: - # NB: This is zero-copy, becuase the (N, 3) int64 layout matches + # NB: This is zero-copy, because the (N, 3) int64 layout matches # `minimum_spike_dtype` exactly. return flat.view(minimum_spike_dtype).reshape(n) @@ -1154,6 +1157,63 @@ def build_spike_vector_from_sorted_arrays( return spikes +def filter_and_remap_spike_vector( + spike_vector: np.ndarray, + unit_mapping: np.ndarray, +) -> np.ndarray: + """Filter a `minimum_spike_dtype` spike vector by unit and remap unit_index in one pass. + + For each spike `i` in `spike_vector`: + * look up ``new_idx = unit_mapping[spike_vector[i]["unit_index"]]``, + * if ``new_idx >= 0``, copy the spike to the output with that new unit_index; + * otherwise drop it. + + Spikes are written in their original order in `spike_vector`, so if the input + is sorted by ``(segment_index, sample_index, parent_unit_index)`` and + `unit_mapping`, restricted to its kept entries, is monotonic increasing, + then the output is also sorted by ``(segment_index, sample_index, new_unit_index)``. + + If `unit_mapping` is not order-preserving, the caller is responsible for re-sorting + cotemporal spike groups (e.g. by calling `is_spike_vector_sorted` + `np.lexsort`). + + Parameters + ---------- + spike_vector : np.ndarray + Structured array with dtype `minimum_spike_dtype`. + unit_mapping : np.ndarray + 1-D int64 array of length ``parent_num_units``. ``unit_mapping[old]`` + gives the new unit_index, or any negative value to drop the spike. + + Returns + ------- + out : np.ndarray + Structured array of `minimum_spike_dtype`, length `n_kept`. + """ + n_parent = spike_vector.size + if n_parent == 0: + return np.empty(0, dtype=minimum_spike_dtype) + + # See implementation note in `build_spike_vector_from_sorted_arrays()` + mapping_arr = np.ascontiguousarray(unit_mapping, dtype=np.int64) + + if HAVE_NUMBA: + # Same trick as `build_spike_vector_from_sorted_arrays()`: + # These flat-buffier views are zero-copy because the (N, 3) int64 layouts match + # `minimum_spike_dtype` exactly. + parent_flat = spike_vector.view(np.int64).reshape(n_parent, 3) + out_flat = np.empty((n_parent, 3), dtype=np.int64) + n_kept = _filter_and_remap_kernel(parent_flat, mapping_arr, out_flat) + return out_flat[:n_kept].view(minimum_spike_dtype).reshape(n_kept) + + # NumPy fallback: bool-mask + remap. + old_unit_idx = spike_vector["unit_index"] + new_unit_idx_full = mapping_arr[old_unit_idx] + keep = new_unit_idx_full >= 0 + out = spike_vector[keep].copy() + out["unit_index"] = new_unit_idx_full[keep] + return out + + if HAVE_NUMBA: import numba @@ -1230,3 +1290,33 @@ def _build_spike_vector_kernel(sample_indices, unit_indices, segment_index, flat i = j return True + + @numba.jit(nopython=True, nogil=True, cache=True) + def _filter_and_remap_kernel(parent_flat, unit_mapping, out_flat): + """Single-pass filter + remap for a (N, 3) int64 spike-vector view. + + For each row i of `parent_flat` (columns 0/1/2 = sample, unit, segment): + look up ``new_unit = unit_mapping[parent_flat[i, 1]]``; if it is + non-negative, copy (sample, new_unit, segment) to ``out_flat[write_pos]`` + and advance `write_pos`. + + Returns the number of spikes written. The caller is expected to slice + `out_flat[:n_kept]` and view as `minimum_spike_dtype`. + + Spikes are emitted in the order they appear in `parent_flat`, so + ordering on (segment_index, sample_index) is preserved automatically; + unit_index ordering within tied sample_index groups follows whatever + `unit_mapping` does to the parent's unit_index values. + """ + n = parent_flat.shape[0] + write_pos = 0 + for i in range(n): + new_unit = unit_mapping[parent_flat[i, 1]] + if new_unit >= 0: + # Column order (0, 1, 2) is coupled to minimum_spike_dtype field order + # (sample_index, unit_index, segment_index) — keep them in sync. + out_flat[write_pos, 0] = parent_flat[i, 0] + out_flat[write_pos, 1] = new_unit + out_flat[write_pos, 2] = parent_flat[i, 2] + write_pos += 1 + return write_pos diff --git a/src/spikeinterface/core/tests/test_sorting_tools.py b/src/spikeinterface/core/tests/test_sorting_tools.py index 6d197d9e47..407cd18334 100644 --- a/src/spikeinterface/core/tests/test_sorting_tools.py +++ b/src/spikeinterface/core/tests/test_sorting_tools.py @@ -15,6 +15,7 @@ remap_unit_indices_in_vector, is_spike_vector_sorted, build_spike_vector_from_sorted_arrays, + filter_and_remap_spike_vector, ) from spikeinterface.core.base import minimum_spike_dtype @@ -224,6 +225,81 @@ def test_build_spike_vector_unsorted_falls_back(force_numba): assert np.array_equal(out, ref) +def _make_spike_vector(samples, units, segments=None): + """Build a minimum_spike_dtype array from parallel arrays. Test helper.""" + n = len(samples) + sv = np.empty(n, dtype=minimum_spike_dtype) + sv["sample_index"] = samples + sv["unit_index"] = units + sv["segment_index"] = segments if segments is not None else 0 + return sv + + +def test_filter_and_remap_keep_all(force_numba): + # Identity mapping: every parent unit_index maps to itself. + sv = _make_spike_vector([10, 20, 30, 40], [0, 1, 2, 0]) + mapping = np.arange(3, dtype=np.int64) + out = filter_and_remap_spike_vector(sv, mapping) + assert np.array_equal(out, sv) + + +def test_filter_and_remap_drop_some(force_numba): + # Drop unit 1 entirely; keep 0 and 2 with new indices [0, 1]. + sv = _make_spike_vector([10, 20, 30, 40, 50], [0, 1, 2, 0, 1]) + mapping = np.array([0, -1, 1], dtype=np.int64) + out = filter_and_remap_spike_vector(sv, mapping) + expected = _make_spike_vector([10, 30, 40], [0, 1, 0]) + assert np.array_equal(out, expected) + + +def test_filter_and_remap_renamed_only(force_numba): + # Selection is full but unit indices are permuted: 0->2, 1->0, 2->1. + sv = _make_spike_vector([10, 20, 30], [0, 1, 2]) + mapping = np.array([2, 0, 1], dtype=np.int64) + out = filter_and_remap_spike_vector(sv, mapping) + expected = _make_spike_vector([10, 20, 30], [2, 0, 1]) + assert np.array_equal(out, expected) + + +def test_filter_and_remap_empty_selection(force_numba): + sv = _make_spike_vector([10, 20, 30], [0, 1, 2]) + mapping = np.full(3, -1, dtype=np.int64) + out = filter_and_remap_spike_vector(sv, mapping) + assert out.size == 0 + assert out.dtype == np.dtype(minimum_spike_dtype) + + +def test_filter_and_remap_empty_input(force_numba): + sv = np.empty(0, dtype=minimum_spike_dtype) + mapping = np.array([0, 1, 2], dtype=np.int64) + out = filter_and_remap_spike_vector(sv, mapping) + assert out.size == 0 + assert out.dtype == np.dtype(minimum_spike_dtype) + + +def test_filter_and_remap_preserves_tie_order(force_numba): + # Two cotemporal spikes at sample 100 (units 1 and 2). After dropping unit 0, + # the two cotemporals must appear in their original relative order — the kernel + # never reorders within ties. + sv = _make_spike_vector( + [50, 100, 100, 200], + [0, 1, 2, 1], + ) + mapping = np.array([-1, 0, 1], dtype=np.int64) + out = filter_and_remap_spike_vector(sv, mapping) + expected = _make_spike_vector([100, 100, 200], [0, 1, 0]) + assert np.array_equal(out, expected) + + +def test_filter_and_remap_segment_index_preserved(force_numba): + sv = _make_spike_vector([10, 20, 30, 40], [0, 1, 0, 1], segments=[0, 0, 1, 1]) + mapping = np.array([0, 1], dtype=np.int64) + out = filter_and_remap_spike_vector(sv, mapping) + assert np.array_equal(out["segment_index"], [0, 0, 1, 1]) + assert np.array_equal(out["sample_index"], [10, 20, 30, 40]) + assert np.array_equal(out["unit_index"], [0, 1, 0, 1]) + + def test_random_spikes_selection(): recording, sorting = generate_ground_truth_recording( durations=[20.0, 10.0], diff --git a/src/spikeinterface/core/unitsselectionsorting.py b/src/spikeinterface/core/unitsselectionsorting.py index bf27b38733..c6c2096eb7 100644 --- a/src/spikeinterface/core/unitsselectionsorting.py +++ b/src/spikeinterface/core/unitsselectionsorting.py @@ -1,7 +1,7 @@ import numpy as np from .basesorting import BaseSorting, BaseSortingSegment -from .sorting_tools import is_spike_vector_sorted +from .sorting_tools import filter_and_remap_spike_vector, is_spike_vector_sorted class UnitsSelectionSorting(BaseSorting): @@ -47,26 +47,29 @@ def __init__(self, parent_sorting, unit_ids=None, renamed_unit_ids=None): self._kwargs = dict(parent_sorting=parent_sorting, unit_ids=unit_ids, renamed_unit_ids=renamed_unit_ids) def _compute_and_cache_spike_vector(self) -> None: - from spikeinterface.core.sorting_tools import remap_unit_indices_in_vector - if self._parent_sorting._cached_spike_vector is None: self._parent_sorting._compute_and_cache_spike_vector() if self._parent_sorting._cached_spike_vector is None: return - spike_vector, _ = remap_unit_indices_in_vector( - vector=self._parent_sorting._cached_spike_vector, - all_old_unit_ids=self._parent_sorting.unit_ids, - all_new_unit_ids=self._unit_ids, + # Build a dense LUT from parent unit_index -> new unit_index (-1 = drop). + parent_unit_ids = self._parent_sorting.unit_ids + parent_id_to_pos = {uid: i for i, uid in enumerate(parent_unit_ids)} + unit_mapping = np.full(parent_unit_ids.size, -1, dtype=np.int64) + for new_idx, uid in enumerate(self._unit_ids): + unit_mapping[parent_id_to_pos[uid]] = new_idx + + spike_vector = filter_and_remap_spike_vector( + spike_vector=self._parent_sorting._cached_spike_vector, + unit_mapping=unit_mapping, ) # The parent's spike vector is sorted by (segment_index, sample_index, unit_index). - # Boolean filtering by unit preserves that order; the remap only changes unit_index - # values. The result stays sorted iff the selected unit_ids appear in the same - # relative order as in the parent (an O(k) check). If not, the vector may still - # happen to be sorted -- verify with an O(n) scan before falling back to O(n log n) - # lexsort. + # Filtering preserves that order and the remap only changes unit_index values. + # The result stays sorted iff the selected unit_ids appear in the same relative + # order as in the parent (an O(k) check). If not, the vector may still happen to + # be sorted -- verify with an O(n) scan before falling back to O(n log n) lexsort. assume_single_segment = self.get_num_segments() == 1 if not self._is_order_preserving_selection() and not is_spike_vector_sorted( spike_vector, assume_single_segment=assume_single_segment From 9d4be479c4cdd42528e20c5b40f727ef25195096 Mon Sep 17 00:00:00 2001 From: Graham Findlay Date: Wed, 20 May 2026 20:03:38 -0500 Subject: [PATCH 09/10] Shortcut handling of "identity selection" in UnitSelectionSorting.to_spike_vector() Check if the user requested an "identity selection": all parent units, in parent order, possibly renamed (the spike vector uses unit _index_, and renaming doesn't affect that). If so, the cached parent spike vector is identical to the one we want, so just share the reference and skip the rest. --- .../core/tests/test_unitsselectionsorting.py | 29 +++++++++++++++++++ .../core/unitsselectionsorting.py | 14 ++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/tests/test_unitsselectionsorting.py b/src/spikeinterface/core/tests/test_unitsselectionsorting.py index 6096cd73d3..6f23fb5fcb 100644 --- a/src/spikeinterface/core/tests/test_unitsselectionsorting.py +++ b/src/spikeinterface/core/tests/test_unitsselectionsorting.py @@ -115,5 +115,34 @@ def test_spike_vector_sorted_after_reorder_with_cotemporal_spikes(): +def test_compute_and_cache_spike_vector_identity_selection_shares_parent_cache(): + """A USS that selects all of its parent's units in parent order should reuse the + parent's cached spike vector by reference, not rebuild it.""" + from spikeinterface.core.basesorting import BaseSorting + + sorting = generate_sorting(num_units=4, durations=[0.100, 0.100], sampling_frequency=30000.0) + + # First USS: identity selection over `sorting`. Force its cache. + uss1 = UnitsSelectionSorting(sorting, unit_ids=list(sorting.unit_ids)) + uss1._compute_and_cache_spike_vector() + assert uss1._cached_spike_vector is not None + + # Second USS: identity selection over uss1, with renamed ids to exercise the + # rename-only path. The cached spike vector must be the same Python object. + renamed = [f"r{uid}" for uid in uss1.unit_ids] + uss2 = UnitsSelectionSorting(uss1, unit_ids=list(uss1.unit_ids), renamed_unit_ids=renamed) + uss2._compute_and_cache_spike_vector() + assert uss2._cached_spike_vector is uss1._cached_spike_vector + if uss1._cached_spike_vector_segment_slices is not None: + assert uss2._cached_spike_vector_segment_slices is uss1._cached_spike_vector_segment_slices + + # Belt-and-suspenders: the shared vector must still match the slow base-class path. + uss2._cached_spike_vector = None + uss2._cached_spike_vector_segment_slices = None + BaseSorting._compute_and_cache_spike_vector(uss2) + base_vector = uss2._cached_spike_vector + assert np.array_equal(uss1._cached_spike_vector, base_vector) + + if __name__ == "__main__": test_basic_functions() diff --git a/src/spikeinterface/core/unitsselectionsorting.py b/src/spikeinterface/core/unitsselectionsorting.py index c6c2096eb7..489d33d88a 100644 --- a/src/spikeinterface/core/unitsselectionsorting.py +++ b/src/spikeinterface/core/unitsselectionsorting.py @@ -53,8 +53,20 @@ def _compute_and_cache_spike_vector(self) -> None: if self._parent_sorting._cached_spike_vector is None: return - # Build a dense LUT from parent unit_index -> new unit_index (-1 = drop). parent_unit_ids = self._parent_sorting.unit_ids + + # Check if the user requested an "identity selection": all parent units, in + # parent order, possibly renamed (the spike vector uses unit _index_, and + # renaming doesn't affect that). If so, the cached parent spike vector is + # identical to the one we want, so just share the reference and skip the rest. + if self._unit_ids.size == parent_unit_ids.size and np.array_equal(self._unit_ids, parent_unit_ids): + self._cached_spike_vector = self._parent_sorting._cached_spike_vector + parent_slices = self._parent_sorting._cached_spike_vector_segment_slices + if parent_slices is not None: + self._cached_spike_vector_segment_slices = parent_slices + return + + # Build a dense LUT from parent unit_index -> new unit_index (-1 = drop). parent_id_to_pos = {uid: i for i, uid in enumerate(parent_unit_ids)} unit_mapping = np.full(parent_unit_ids.size, -1, dtype=np.int64) for new_idx, uid in enumerate(self._unit_ids): From 375fb23caf6f008432d07e7f056ad6fe17ed7c09 Mon Sep 17 00:00:00 2001 From: Graham Findlay Date: Wed, 20 May 2026 23:05:33 -0500 Subject: [PATCH 10/10] Optimize `to_reordered_spike_vector` `to_spike_vector()` already returns a (canonical) vector whose sample_index is ascending within each segment. The two supported lexsort keys only differ from the canonical order by how they grouping spikes into (unit, segment) buckets. Sample order within each bucket is already correct in the canonical vector. No re-sorting necessary! Perfect for a (linear time!) counting sort. A numba kernel based on textbook implementation was fast enough that I didn't try alternatives. Plus, the same array striding tricks used to avoid intermediate copies are reused from recent optimization commits, before passing to the numba kernel. --- src/spikeinterface/core/basesorting.py | 79 ++++++----- src/spikeinterface/core/sorting_tools.py | 128 ++++++++++++++++-- .../core/tests/test_basesorting.py | 125 +++++++++++++++++ .../core/tests/test_unitsselectionsorting.py | 35 ++++- .../core/unitsselectionsorting.py | 39 +++++- 5 files changed, 355 insertions(+), 51 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 26de57c24d..db5d99b5c3 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -1111,48 +1111,53 @@ def to_reordered_spike_vector( key = str(lexsort) if key not in self._cached_lexsorted_spike_vector.keys(): - spikes = self.to_spike_vector() - order = np.lexsort((spikes[lexsort[0]], spikes[lexsort[1]], spikes[lexsort[2]])) - ordered_spikes = spikes[order] - self._cached_lexsorted_spike_vector[key] = {} - self._cached_lexsorted_spike_vector[key]["ordered_spikes"] = ordered_spikes - self._cached_lexsorted_spike_vector[key]["order"] = order + from .sorting_tools import reorder_spike_vector_by_buckets + spikes = self.to_spike_vector() num_units = len(self.unit_ids) num_segments = self.get_num_segments() - # precompute the slices with nested search sorted + # Both supported `lexsort` keys are equivalent to grouping spikes into + # `num_units * num_segments` "buckets" while **preserving** ascending + # `sample_index`` within each bucket. + # (Within each bucket, samples are **already** sorted!) + # + # We can do this with a (stable) counting sort in O(N), if we know the + # bucket index for each spike. That's easy: the bucket index is just a + # straightforward linear combination of the `unit_index` and `segment_index` + # fields, with the order depending on the lexsort` key. if lexsort == ("sample_index", "segment_index", "unit_index"): - # this case make spiketrain per unit compact in memory - - slices = np.zeros((num_units, num_segments, 2), dtype=np.int64) - unit_slices = np.searchsorted(ordered_spikes["unit_index"], np.arange(num_units + 1), side="left") - for unit_index, unit_id in enumerate(self.unit_ids): - u0 = unit_slices[unit_index] - u1 = unit_slices[unit_index + 1] - seg_slices = np.searchsorted( - ordered_spikes[u0:u1]["segment_index"], np.arange(num_segments + 1), side="left" - ) - for segment_index in range(num_segments): - s0 = seg_slices[segment_index] - s1 = seg_slices[segment_index + 1] - slices[unit_index, segment_index, :] = [u0 + s0, u0 + s1] - - elif lexsort == ("sample_index", "unit_index", "segment_index"): - slices = np.zeros((num_segments, num_units, 2), dtype=np.int64) - seg_slices = np.searchsorted(ordered_spikes["segment_index"], np.arange(num_segments + 1), side="left") - for segment_index in range(self.get_num_segments()): - s0 = seg_slices[segment_index] - s1 = seg_slices[segment_index + 1] - unit_slices = np.searchsorted( - ordered_spikes[s0:s1]["unit_index"], np.arange(num_units + 1), side="left" - ) - for unit_index, unit_id in enumerate(self.unit_ids): - u0 = unit_slices[unit_index] - u1 = unit_slices[unit_index + 1] - slices[segment_index, unit_index, :] = [s0 + u0, s0 + u1] - - self._cached_lexsorted_spike_vector[key]["slices"] = slices + # primary key unit_index, then segment_index, then sample_index + bucket_index = spikes["unit_index"].astype(np.int64, copy=False) * num_segments + spikes[ + "segment_index" + ].astype(np.int64, copy=False) + num_buckets = num_units * num_segments + ordered_spikes, order, counts = reorder_spike_vector_by_buckets(spikes, bucket_index, num_buckets) + # counts is laid out as [unit_index * num_segments + segment_index] + counts_2d = counts.reshape(num_units, num_segments) + + else: # ("sample_index", "unit_index", "segment_index") + # primary key segment_index, then unit_index, then sample_index + bucket_index = spikes["segment_index"].astype(np.int64, copy=False) * num_units + spikes[ + "unit_index" + ].astype(np.int64, copy=False) + num_buckets = num_segments * num_units + ordered_spikes, order, counts = reorder_spike_vector_by_buckets(spikes, bucket_index, num_buckets) + # counts is laid out as [segment_index * num_units + unit_index] + counts_2d = counts.reshape(num_segments, num_units) + + # Build slices from cumulative counts. Stops are exclusive cumulative sums + # (aka "prefix sums" in the language of counting sort) shifted by one; + # starts are the same prefix without the last element. + ends = np.cumsum(counts_2d.ravel()).reshape(counts_2d.shape) + starts = ends - counts_2d + slices = np.stack([starts, ends], axis=-1).astype(np.int64, copy=False) + + self._cached_lexsorted_spike_vector[key] = { + "ordered_spikes": ordered_spikes, + "order": order, + "slices": slices, + } ordered_spikes = self._cached_lexsorted_spike_vector[key]["ordered_spikes"] out = (ordered_spikes,) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 4409c72b8d..5b904def19 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -950,7 +950,7 @@ def remap_unit_indices_in_vector(vector, all_old_unit_ids, all_new_unit_ids, kee * merging/spliting periods or spikes and update the "unit_index" in the vector Do not use this if you are operating on `minimum_spike_dtype` inputs! In such cases, - it is much more efficient to use a dense LUT + `filter_and_remap_spike_vector()` + it is much more efficient to use a dense LUT + `filter_and_remap_spike_vector()` (see `UnitSelectionSorting._compute_and_cache_spike_vector()`). Parameters @@ -1171,15 +1171,15 @@ def filter_and_remap_spike_vector( Spikes are written in their original order in `spike_vector`, so if the input is sorted by ``(segment_index, sample_index, parent_unit_index)`` and `unit_mapping`, restricted to its kept entries, is monotonic increasing, - then the output is also sorted by ``(segment_index, sample_index, new_unit_index)``. - - If `unit_mapping` is not order-preserving, the caller is responsible for re-sorting + then the output is also sorted by ``(segment_index, sample_index, new_unit_index)``. + + If `unit_mapping` is not order-preserving, the caller is responsible for re-sorting cotemporal spike groups (e.g. by calling `is_spike_vector_sorted` + `np.lexsort`). Parameters ---------- spike_vector : np.ndarray - Structured array with dtype `minimum_spike_dtype`. + Structured array with dtype `minimum_spike_dtype`. unit_mapping : np.ndarray 1-D int64 array of length ``parent_num_units``. ``unit_mapping[old]`` gives the new unit_index, or any negative value to drop the spike. @@ -1198,14 +1198,14 @@ def filter_and_remap_spike_vector( if HAVE_NUMBA: # Same trick as `build_spike_vector_from_sorted_arrays()`: - # These flat-buffier views are zero-copy because the (N, 3) int64 layouts match - # `minimum_spike_dtype` exactly. + # These flat-buffier views are zero-copy because the (N, 3) int64 layouts match + # `minimum_spike_dtype` exactly. parent_flat = spike_vector.view(np.int64).reshape(n_parent, 3) out_flat = np.empty((n_parent, 3), dtype=np.int64) n_kept = _filter_and_remap_kernel(parent_flat, mapping_arr, out_flat) return out_flat[:n_kept].view(minimum_spike_dtype).reshape(n_kept) - # NumPy fallback: bool-mask + remap. + # NumPy fallback: bool-mask + remap. old_unit_idx = spike_vector["unit_index"] new_unit_idx_full = mapping_arr[old_unit_idx] keep = new_unit_idx_full >= 0 @@ -1214,6 +1214,77 @@ def filter_and_remap_spike_vector( return out +def reorder_spike_vector_by_buckets( + spike_vector: np.ndarray, + bucket_index: np.ndarray, + num_buckets: int, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Stable counting-sort of a `minimum_spike_dtype` spike vector by precomputed + bucket index. + + If numba is available, runs in O(N) using a kernel based on the algorithm described + in Cormen, Leiserson, Rivest, and Stein (CLRS) Chapter 8.2. Uses the same + flat-buffer view trick as `build_spike_vector_from_sorted_arrays()` and + `filter_and_remap_spike_vector()` to avoid intermediate copies. + + Stability means rows are emitted in input order within each bucket, + so any pre-existing sort within a bucket + (e.g. ascending `sample_index` within a (segment, unit) group) is preserved. + + Parameters + ---------- + spike_vector : np.ndarray + Structured array with dtype `minimum_spike_dtype`. + bucket_index : np.ndarray + 1-D integer array of length ``spike_vector.size`` giving the + destination bucket for each spike. Values must be in + ``[0, num_buckets)``. + num_buckets : int + Total number of buckets. + + Returns + ------- + ordered_spikes : np.ndarray + Structured array of `minimum_spike_dtype`, same length as input, + with rows grouped by `bucket_index`. + order : np.ndarray + 1-D int64 array such that ``spike_vector[order] == ordered_spikes``. + counts : np.ndarray + 1-D int64 array of length `num_buckets`, the number of spikes in + each bucket. + """ + n = spike_vector.size + if n == 0: + return ( + np.empty(0, dtype=minimum_spike_dtype), + np.empty(0, dtype=np.int64), + np.zeros(int(num_buckets), dtype=np.int64), + ) + + # See implementation note in `build_spike_vector_from_sorted_arrays()`. + bucket_arr = np.ascontiguousarray(bucket_index, dtype=np.int64) + if bucket_arr.size != n: + raise ValueError(f"bucket_index and spike_vector must have the same length; got {bucket_arr.size} and {n}.") + + if HAVE_NUMBA: + # Same trick as `build_spike_vector_from_sorted_arrays()`: + # These flat-buffier views are zero-copy because the (N, 3) int64 layouts match + # `minimum_spike_dtype` exactly. + in_flat = spike_vector.view(np.int64).reshape(n, 3) + out_flat = np.empty((n, 3), dtype=np.int64) + order = np.empty(n, dtype=np.int64) + counts = np.empty(int(num_buckets), dtype=np.int64) + _reorder_spike_vector_kernel(in_flat, bucket_arr, int(num_buckets), out_flat, order, counts) + ordered_spikes = out_flat.view(minimum_spike_dtype).reshape(n) + return ordered_spikes, order, counts + + # NumPy fallback: stable argsort by bucket, then fancy-index the structured array. + order = np.argsort(bucket_arr, kind="stable") + ordered_spikes = spike_vector[order] + counts = np.bincount(bucket_arr, minlength=int(num_buckets)).astype(np.int64, copy=False) + return ordered_spikes, order, counts + + if HAVE_NUMBA: import numba @@ -1320,3 +1391,44 @@ def _filter_and_remap_kernel(parent_flat, unit_mapping, out_flat): out_flat[write_pos, 2] = parent_flat[i, 2] write_pos += 1 return write_pos + + @numba.jit(nopython=True, nogil=True, cache=True) + def _reorder_spike_vector_kernel(in_flat, bucket_index, num_buckets, out_flat, order, counts): + """Stable counting-sort of a (N, 3) int64 spike-vector flat-buffer view by bucket. + + Adapted from Cormen, Leiserson, Rivest, and Stein (CLRS) Chapter 8.2. + + Two O(N) passes: + 1. histogram `bucket_index` into `counts`, + 2. cumulative-sum to per-bucket write positions, then scatter each + row of `in_flat` to its destination in `out_flat` and record the + source index in `order` so that ``in[order] == out``. + + Stability: within each bucket, rows keep their input order, so any + ordering already present in `in_flat` (e.g. ascending sample_index + within a (segment, unit) group) carries over to `out_flat`. + """ + n = in_flat.shape[0] + + for b in range(num_buckets): + counts[b] = 0 + for i in range(n): + counts[bucket_index[i]] += 1 + + # Exclusive prefix sum into a local write_pos buffer. `counts` keeps + # the per-bucket sizes (the caller uses them to derive slices). + # (If slices weren't needed, maybe this could be done in-place in `counts`?) + write_pos = np.empty(num_buckets, dtype=np.int64) + running = 0 + for b in range(num_buckets): + write_pos[b] = running + running += counts[b] + + for i in range(n): + b = bucket_index[i] + pos = write_pos[b] + out_flat[pos, 0] = in_flat[i, 0] + out_flat[pos, 1] = in_flat[i, 1] + out_flat[pos, 2] = in_flat[i, 2] + order[pos] = i + write_pos[b] = pos + 1 diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index 10874a8362..5617b881d7 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -155,6 +155,131 @@ def test_BaseSorting(create_cache_folder): assert sorting.get_annotation(annotation_name) == sorting_zarr_loaded.get_annotation(annotation_name) +def _reference_reordered_spike_vector(spikes, lexsort, num_units, num_segments): + """Pre-optimization reference: np.lexsort + nested searchsorted. + + Mirrors the implementation that lived in `to_reordered_spike_vector` + before the counting-sort rewrite. Used to assert byte-for-byte parity + of the new implementation. + """ + order = np.lexsort((spikes[lexsort[0]], spikes[lexsort[1]], spikes[lexsort[2]])) + ordered_spikes = spikes[order] + + if lexsort == ("sample_index", "segment_index", "unit_index"): + slices = np.zeros((num_units, num_segments, 2), dtype=np.int64) + unit_slices = np.searchsorted(ordered_spikes["unit_index"], np.arange(num_units + 1), side="left") + for unit_index in range(num_units): + u0 = unit_slices[unit_index] + u1 = unit_slices[unit_index + 1] + seg_slices = np.searchsorted( + ordered_spikes[u0:u1]["segment_index"], np.arange(num_segments + 1), side="left" + ) + for segment_index in range(num_segments): + s0 = seg_slices[segment_index] + s1 = seg_slices[segment_index + 1] + slices[unit_index, segment_index, :] = [u0 + s0, u0 + s1] + elif lexsort == ("sample_index", "unit_index", "segment_index"): + slices = np.zeros((num_segments, num_units, 2), dtype=np.int64) + seg_slices = np.searchsorted(ordered_spikes["segment_index"], np.arange(num_segments + 1), side="left") + for segment_index in range(num_segments): + s0 = seg_slices[segment_index] + s1 = seg_slices[segment_index + 1] + unit_slices = np.searchsorted(ordered_spikes[s0:s1]["unit_index"], np.arange(num_units + 1), side="left") + for unit_index in range(num_units): + u0 = unit_slices[unit_index] + u1 = unit_slices[unit_index + 1] + slices[segment_index, unit_index, :] = [s0 + u0, s0 + u1] + else: + raise ValueError(lexsort) + + return ordered_spikes, order, slices + + +def test_to_reordered_spike_vector_parity(): + """The counting-sort rewrite must match the prior np.lexsort implementation.""" + rng = np.random.default_rng(42) + num_units = 6 + num_segments = 3 + sampling_frequency = 30_000.0 + + # Build per-segment, per-unit spike trains with deliberate cotemporal spikes + # (multiple units firing at the same sample_index) so the unit-index tiebreaker + # is exercised. + spike_dicts = [] + for seg in range(num_segments): + seg_dict = {} + for u in range(num_units): + n = int(rng.integers(50, 200)) + times = np.sort(rng.integers(0, 10_000, size=n)) + # Inject a handful of cotemporal spikes that collide with the unit-0 train. + if u > 0 and n > 5: + times[:5] = np.array([100, 200, 300, 400, 500]) + seg * 10 + times = np.sort(times) + seg_dict[str(u)] = times.astype("int64") + spike_dicts.append(seg_dict) + + sorting = NumpySorting.from_unit_dict(spike_dicts, sampling_frequency) + spikes = sorting.to_spike_vector() + + for lexsort in [ + ("sample_index", "segment_index", "unit_index"), + ("sample_index", "unit_index", "segment_index"), + ]: + # Clear the cache between iterations so each call exercises the fresh build. + sorting._cached_lexsorted_spike_vector = {} + + ordered_spikes, order, slices = sorting.to_reordered_spike_vector( + lexsort=lexsort, return_order=True, return_slices=True + ) + + ref_ordered, ref_order, ref_slices = _reference_reordered_spike_vector(spikes, lexsort, num_units, num_segments) + + # ordered_spikes must agree with the reference exactly (cotemporal spikes + # are now ordered by unit_index — stable counting sort by bucket preserves + # the canonical unit-index ordering within each tied sample_index). + assert np.array_equal(ordered_spikes, ref_ordered), f"ordered mismatch for {lexsort}" + # The invariant `spikes[order] == ordered_spikes` must hold; the exact + # `order` permutation can differ from np.lexsort's because stable counting + # sort and np.lexsort may pick different tie-break orderings of source rows + # that map to the same destination (different source rows can carry the + # same (sample, unit, segment) triple). + assert np.array_equal(spikes[order], ordered_spikes) + assert np.array_equal(slices, ref_slices), f"slices mismatch for {lexsort}" + + # Each (unit, segment) — or (segment, unit) — slice must yield exactly the + # spikes for that group, with monotonic sample_index. + if lexsort == ("sample_index", "segment_index", "unit_index"): + for u in range(num_units): + for s in range(num_segments): + s0, s1 = slices[u, s] + block = ordered_spikes[s0:s1] + assert np.all(block["unit_index"] == u) + assert np.all(block["segment_index"] == s) + assert np.all(np.diff(block["sample_index"]) >= 0) + else: + for s in range(num_segments): + for u in range(num_units): + s0, s1 = slices[s, u] + block = ordered_spikes[s0:s1] + assert np.all(block["unit_index"] == u) + assert np.all(block["segment_index"] == s) + assert np.all(np.diff(block["sample_index"]) >= 0) + + +def test_to_reordered_spike_vector_empty(): + """An empty sorting must round-trip through the counting-sort path.""" + sorting = NumpySorting.from_unit_dict({"0": np.array([], dtype="int64")}, 30_000.0) + ordered_spikes, order, slices = sorting.to_reordered_spike_vector( + lexsort=("sample_index", "segment_index", "unit_index"), + return_order=True, + return_slices=True, + ) + assert ordered_spikes.size == 0 + assert order.size == 0 + assert slices.shape == (1, 1, 2) + assert np.array_equal(slices, np.zeros((1, 1, 2), dtype=np.int64)) + + def test_npy_sorting(): sfreq = 10 spike_times_0 = { diff --git a/src/spikeinterface/core/tests/test_unitsselectionsorting.py b/src/spikeinterface/core/tests/test_unitsselectionsorting.py index 6f23fb5fcb..d2b24d90ab 100644 --- a/src/spikeinterface/core/tests/test_unitsselectionsorting.py +++ b/src/spikeinterface/core/tests/test_unitsselectionsorting.py @@ -114,7 +114,6 @@ def test_spike_vector_sorted_after_reorder_with_cotemporal_spikes(): assert is_spike_vector_sorted(spike_vector) - def test_compute_and_cache_spike_vector_identity_selection_shares_parent_cache(): """A USS that selects all of its parent's units in parent order should reuse the parent's cached spike vector by reference, not rebuild it.""" @@ -144,5 +143,39 @@ def test_compute_and_cache_spike_vector_identity_selection_shares_parent_cache() assert np.array_equal(uss1._cached_spike_vector, base_vector) +def test_to_reordered_spike_vector_identity_selection_shares_parent_cache(): + """A USS that selects all of its parent's units in parent order should reuse the + parent's lexsorted spike vector cache by reference, not re-run the counting sort.""" + sorting = generate_sorting(num_units=5, durations=[0.200, 0.200], sampling_frequency=30000.0) + + # Identity selection, with renamed ids to also exercise the rename-only path. + renamed = [f"r{uid}" for uid in sorting.unit_ids] + uss = UnitsSelectionSorting(sorting, unit_ids=list(sorting.unit_ids), renamed_unit_ids=renamed) + + for lexsort in [ + ("sample_index", "segment_index", "unit_index"), + ("sample_index", "unit_index", "segment_index"), + ]: + # Force the parent to build the lexsorted cache. + parent_ordered, _, parent_slices = sorting.to_reordered_spike_vector( + lexsort=lexsort, return_order=True, return_slices=True + ) + key = str(lexsort) + assert key in sorting._cached_lexsorted_spike_vector + + # Reset USS cache and force a build through the override. + uss._cached_lexsorted_spike_vector = {} + uss_ordered, _, uss_slices = uss.to_reordered_spike_vector( + lexsort=lexsort, return_order=True, return_slices=True + ) + + # The cache entry must be the *same* dict object as the parent's. + assert ( + uss._cached_lexsorted_spike_vector[key] is sorting._cached_lexsorted_spike_vector[key] + ), f"identity USS did not share parent lexsorted cache for {lexsort}" + assert uss_ordered is parent_ordered + assert uss_slices is parent_slices + + if __name__ == "__main__": test_basic_functions() diff --git a/src/spikeinterface/core/unitsselectionsorting.py b/src/spikeinterface/core/unitsselectionsorting.py index 489d33d88a..823bc477c7 100644 --- a/src/spikeinterface/core/unitsselectionsorting.py +++ b/src/spikeinterface/core/unitsselectionsorting.py @@ -55,11 +55,11 @@ def _compute_and_cache_spike_vector(self) -> None: parent_unit_ids = self._parent_sorting.unit_ids - # Check if the user requested an "identity selection": all parent units, in - # parent order, possibly renamed (the spike vector uses unit _index_, and - # renaming doesn't affect that). If so, the cached parent spike vector is - # identical to the one we want, so just share the reference and skip the rest. - if self._unit_ids.size == parent_unit_ids.size and np.array_equal(self._unit_ids, parent_unit_ids): + # If the user requested an "identity selection" (all parent units, in + # parent order, possibly renamed), the cached parent spike vector is + # identical to the one we want — share the reference and skip the rest. + # See `_is_identity_selection` for the definition. + if self._is_identity_selection(): self._cached_spike_vector = self._parent_sorting._cached_spike_vector parent_slices = self._parent_sorting._cached_spike_vector_segment_slices if parent_slices is not None: @@ -96,6 +96,35 @@ def _compute_and_cache_spike_vector(self) -> None: self._cached_spike_vector = spike_vector + def _is_identity_selection(self) -> bool: + """Return True if self._unit_ids are exactly the parent's unit_ids, in parent order. + + Renaming via ``renamed_unit_ids`` does not affect this — the spike vector + carries unit *indices*, not ids. When True, every cached form of the + parent's spike vector (canonical, lexsorted, etc.) can be shared with + ``self`` by reference. + """ + parent_unit_ids = self._parent_sorting.unit_ids + return self._unit_ids.size == parent_unit_ids.size and np.array_equal(self._unit_ids, parent_unit_ids) + + def to_reordered_spike_vector( + self, lexsort=("sample_index", "segment_index", "unit_index"), return_order=True, return_slices=True + ): + # On an identity selection, the parent's lexsorted cache is exactly + # what we'd compute — just reference it so we don't re-run the counting sort! + if self._is_identity_selection(): + key = str(tuple(lexsort)) + if key not in self._cached_lexsorted_spike_vector: + # Force the parent to populate its own cache (a no-op if already + # cached) before we share the entry. + self._parent_sorting.to_reordered_spike_vector(lexsort=lexsort, return_order=True, return_slices=True) + parent_entry = self._parent_sorting._cached_lexsorted_spike_vector.get(key) + if parent_entry is not None: + self._cached_lexsorted_spike_vector[key] = parent_entry + return super().to_reordered_spike_vector( + lexsort=lexsort, return_order=return_order, return_slices=return_slices + ) + def _is_order_preserving_selection(self) -> bool: """Return True if self._unit_ids appear in the same relative order as in the parent.