Skip to content

Speed up template similarity computing using numba#4343

Merged
samuelgarcia merged 20 commits into
SpikeInterface:mainfrom
tayheau:soft_merges_and_refactor
Mar 13, 2026
Merged

Speed up template similarity computing using numba#4343
samuelgarcia merged 20 commits into
SpikeInterface:mainfrom
tayheau:soft_merges_and_refactor

Conversation

@tayheau
Copy link
Copy Markdown
Contributor

@tayheau tayheau commented Jan 26, 2026

Following #4310, @chrishalcrow showed that given that the diff between two merged templates is considered small, we can approximate the distance of a megred template and a new one a as a linear function. This should allow us to speed up significantly template similarity computations for merged ones.

@alejoe91 alejoe91 added the postprocessing Related to postprocessing module label Jan 28, 2026
@tayheau tayheau force-pushed the soft_merges_and_refactor branch from ce491e6 to 99985d4 Compare February 16, 2026 16:20
@tayheau
Copy link
Copy Markdown
Contributor Author

tayheau commented Feb 24, 2026

So everything is supposed to work fine, i think i have some differences mostly due to casting but it's "minimal" in the context of a normalized distance.
image

Just to quickly sum up i moved the computing of the support matrix out of the loop, it will be more 'memory' costly in case of different template matrix but i think since most of the time we have the same ones it's a boost here. Also more leverage of numpy views instead of copies.

Im not this familiar with the numba/numpy efficiency stuff but for the union it was faster to do it the "dummy" way (mine lol) that the vectorised one. So if you guys have so rule of thumbs tips for Numba, im in ;) @samuelgarcia

image

@tayheau tayheau changed the title Speed up template similarity for soft merges Speed up template similarity computing using numba Feb 24, 2026
@tayheau tayheau marked this pull request as ready for review February 24, 2026 13:32
@chrishalcrow
Copy link
Copy Markdown
Member

Looks great as a user! I get a small ~30% speed up on initial compute, and 10x speed-up on recompute (=> trying out different methods in the gui is super fast, and hard merges are very fast).

And on some real data with 800 neurons, the biggest absmax in difference is...

np.max(np.abs(old_temps - new_temps))
>> np.float32(4.2915344e-06)

Nice!!!

Comment thread src/spikeinterface/postprocessing/template_similarity.py
@alejoe91 alejoe91 added this to the 0.104.0 milestone Feb 25, 2026
Copy link
Copy Markdown
Collaborator

@yger yger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would be carefull, because in between, we patched a bug in the similarity due to over optimizations. We can not compute only hald the times, and the upper part of the matrix and use symmetry everywhere, otherwise this is not complete. If we symmetrize in time, we need compute for all indices

Comment thread src/spikeinterface/postprocessing/template_similarity.py Outdated
Comment thread src/spikeinterface/postprocessing/template_similarity.py Outdated
@yger
Copy link
Copy Markdown
Collaborator

yger commented Feb 25, 2026

I would be carefull here, because we recently patched a bug in #4345 and I think this is not propagated here

@tayheau tayheau requested a review from yger March 4, 2026 09:51
@alejoe91
Copy link
Copy Markdown
Member

alejoe91 commented Mar 9, 2026

@yger I think that the symmetry issue is fixed. Can you double check?

Comment thread src/spikeinterface/postprocessing/template_similarity.py Outdated
@samuelgarcia
Copy link
Copy Markdown
Member

Salut Theo.
We check it with Pierre and we have push a small commit.
You did the same mistake pierre did some time ago. Now it is correct.

@samuelgarcia
Copy link
Copy Markdown
Member

This is now OK for me.
I am not sure that the speedup will be as big as @chrishalcrow has seen.
Could you check ?

@tayheau
Copy link
Copy Markdown
Contributor Author

tayheau commented Mar 11, 2026

salut Samuel,
but with the fix you did, won't the array only half populated in the case of same_array ?

@yger
Copy link
Copy Markdown
Collaborator

yger commented Mar 11, 2026

No because if this is not the same array, we'll explore all shifts, am I right?

@tayheau
Copy link
Copy Markdown
Contributor Author

tayheau commented Mar 11, 2026

to me there two symmetries : a "spatial" one where we can say dist(i, j) = dist(j, i) at lag t and a temporal one where dist(i, j) at lag t = dist(i, j) at lag -t (we do agree that both apply only in case of same_array) . So that's why we check twice the matrix similarity, and for num_shifts != 0 it's just so that it doesnt erase with the transpose but it should be shift  != 0. here i think this fix erease the 'spatial' symmetry optimisation

@yger
Copy link
Copy Markdown
Collaborator

yger commented Mar 12, 2026

You are right. Can we make that together with a small call ?

@yger
Copy link
Copy Markdown
Collaborator

yger commented Mar 13, 2026

@tayheau if you can finish today and propagate changes to compute_(...)_numpy this would be great, since we would like to merge that quickly for release. Otherwise let me know and I can have a go on the branch

Comment thread src/spikeinterface/postprocessing/template_similarity.py Outdated
@yger
Copy link
Copy Markdown
Collaborator

yger commented Mar 13, 2026

As told to @samuelgarcia this is good to me, expect that currently optimizations are only performed at the numba level, not yet in pure numpy (in case numba is not present)

@yger
Copy link
Copy Markdown
Collaborator

yger commented Mar 13, 2026

I'm not able to push in your branch, but here is a code to change _compute_similarity_matrix_numpy()

def _compute_similarity_matrix_numpy(
    templates_array, other_templates_array, num_shifts, method, sparsity_mask, other_sparsity_mask, support="union"
):

    num_templates = templates_array.shape[0]
    num_samples = templates_array.shape[1]
    other_num_templates = other_templates_array.shape[0]

    num_shifts_both_sides = 2 * num_shifts + 1
    distances = np.ones((num_shifts_both_sides, num_templates, other_num_templates), dtype=np.float32)
    same_array = np.array_equal(templates_array, other_templates_array)

    # We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t
    # So the matrix can be computed only for negative lags and be transposed

    if same_array:
        # optimisation when array are the same because of symetry in shift
        shift_loop = range(-num_shifts, 1)
    else:
        shift_loop = range(-num_shifts, num_shifts + 1)

    for count, shift in enumerate(shift_loop):
        src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts]
        tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift]
        for i in range(num_templates):
            src_template = src_sliced_templates[i]
            local_mask = get_overlapping_mask_for_one_template(i, sparsity_mask, other_sparsity_mask, support=support)
            overlapping_templates = np.flatnonzero(np.sum(local_mask, 1))
            tgt_templates = tgt_sliced_templates[overlapping_templates]
            for gcount, j in enumerate(overlapping_templates):
                if j < i and same_array:
                    continue
                src = src_template[:, local_mask[j]].reshape(1, -1)
                tgt = (tgt_templates[gcount][:, local_mask[j]]).reshape(1, -1)

                if method == "l1":
                    norm_i = np.sum(np.abs(src))
                    norm_j = np.sum(np.abs(tgt))
                    distances[count, i, j] = np.sum(np.abs(src - tgt))
                    distances[count, i, j] /= norm_i + norm_j
                elif method == "l2":
                    norm_i = np.linalg.norm(src, ord=2)
                    norm_j = np.linalg.norm(tgt, ord=2)
                    distances[count, i, j] = np.linalg.norm(src - tgt, ord=2)
                    distances[count, i, j] /= norm_i + norm_j
                elif method == "cosine":
                    norm_i = np.linalg.norm(src, ord=2)
                    norm_j = np.linalg.norm(tgt, ord=2)
                    distances[count, i, j] = np.sum(src * tgt)
                    distances[count, i, j] /= norm_i * norm_j
                    distances[count, i, j] = 1 - distances[count, i, j]

                if same_array:
                    distances[count, j, i] = distances[count, i, j]

        if same_array and shift != 0:
            distances[num_shifts_both_sides - count - 1] = distances[count].T

    return distances

@samuelgarcia
Copy link
Copy Markdown
Member

Merci Theo.
Désolé pour l'intrusion inutile!!
je merge.

@samuelgarcia samuelgarcia merged commit 3fa0779 into SpikeInterface:main Mar 13, 2026
15 checks passed
@tayheau
Copy link
Copy Markdown
Contributor Author

tayheau commented Mar 13, 2026

pas de soucis !
c'est fait pour ça aussi git hehe

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

postprocessing Related to postprocessing module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants