From 68d143add708dc935cdeb67e2de9e86ead97ab28 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 13 Sep 2024 10:55:23 +0200 Subject: [PATCH 01/14] WIP --- .../postprocessing/template_similarity.py | 207 +++++++++++++----- 1 file changed, 152 insertions(+), 55 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 0e70b1f494..1f1325cf10 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -7,6 +7,13 @@ from ..core.template_tools import get_dense_templates_array from ..core.sparsity import ChannelSparsity +try: + import numba + + HAVE_NUMBA = True +except ImportError: + HAVE_NUMBA = False + class ComputeTemplateSimilarity(AnalyzerExtension): """Compute similarity between templates with several methods. @@ -147,10 +154,149 @@ def _get_data(self): compute_template_similarity = ComputeTemplateSimilarity.function_factory() +if HAVE_NUMBA: + + from numba import prange + + + @numba.jit(nopython=True, parallel=True, fastmath=True, cache=True, nogil=True) + def _compute_similarity_matrix( + templates_array, other_templates_array, num_shifts, mask, method + ): + num_templates = templates_array.shape[0] + num_samples = templates_array.shape[1] + other_num_templates = other_templates_array.shape[0] + + num_shifts_both_sides = 2 * num_shifts + 1 + distances = np.ones((num_shifts_both_sides, num_templates, other_num_templates), dtype=np.float32) + same_array = np.array_equal(templates_array, other_templates_array) + + # We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t + # So the matrix can be computed only for negative lags and be transposed + + if same_array: + # optimisation when array are the same because of symetry in shift + shift_loop = range(-num_shifts, 1) + else: + shift_loop = range(-num_shifts, num_shifts + 1) + + for count, shift in enumerate(shift_loop): + src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts] + tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] + for i in prange(num_templates): + src_template = src_sliced_templates[i] + overlapping_templates = np.flatnonzero(np.sum(mask[i], 1)) + tgt_templates = tgt_sliced_templates[overlapping_templates] + for gcount, j in enumerate(overlapping_templates): + + # symmetric values are handled later + if same_array and j < i: + # no need exhaustive looping when same template + continue + src = src_template[:, mask[i, j]].flatten() + tgt = (tgt_templates[gcount][:, mask[i, j]]).flatten() + + norm_i = 0 + norm_j = 0 + + for k in range(len(src)): + if method == "l1": + norm_i += abs(src[k]) + norm_j += abs(tgt[k]) + distances[count, i, j] += abs(src[k] - tgt[k]) + elif method == "l2": + norm_i += src[k]**2 + norm_j += tgt[k]**2 + distances[count, i, j] += (src[k] - tgt[k])**2 + elif method == "cosine": + distances[count, i, j] += src[k]*tgt[k] + norm_i += src[k]**2 + norm_j += tgt[k]**2 + + if method == "l1": + distances[count, i, j] /= (norm_i + norm_j) + elif method == "l2": + norm_i = np.sqrt(norm_i) + norm_j = np.sqrt(norm_j) + distances[count, i, j] = np.sqrt(distances[count, i, j]) + distances[count, i, j] /= (norm_i + norm_j) + elif method == "cosine": + norm_i = np.sqrt(norm_i) + norm_j = np.sqrt(norm_j) + distances[count, i, j] /= (norm_i*norm_j) + + if same_array: + distances[count, j, i] = distances[count, i, j] + + if same_array and num_shifts != 0: + distances[num_shifts_both_sides - count - 1] = distances[count].T + return distances + +else: + def _compute_similarity_matrix( + templates_array, other_templates_array, num_shifts, mask, method + ): + import sklearn.metrics.pairwise + num_templates = templates_array.shape[0] + num_samples = templates_array.shape[1] + other_num_templates = other_templates_array.shape[0] + + num_shifts_both_sides = 2 * num_shifts + 1 + distances = np.ones((num_shifts_both_sides, num_templates, other_num_templates), dtype=np.float32) + same_array = np.array_equal(templates_array, other_templates_array) + + # We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t + # So the matrix can be computed only for negative lags and be transposed + + if same_array: + # optimisation when array are the same because of symetry in shift + shift_loop = range(-num_shifts, 1) + else: + shift_loop = range(-num_shifts, num_shifts + 1) + + for count, shift in enumerate(shift_loop): + src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts] + tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] + for i in range(num_templates): + src_template = src_sliced_templates[i] + overlapping_templates = np.flatnonzero(np.sum(mask[i], 1)) + tgt_templates = tgt_sliced_templates[overlapping_templates] + for gcount, j in enumerate(overlapping_templates): + # symmetric values are handled later + if same_array and j < i: + # no need exhaustive looping when same template + continue + src = src_template[:, mask[i, j]].reshape(1, -1) + tgt = (tgt_templates[gcount][:, mask[i, j]]).reshape(1, -1) + + if method == "l1": + norm_i = np.sum(np.abs(src)) + norm_j = np.sum(np.abs(tgt)) + distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l1").item() + distances[count, i, j] /= norm_i + norm_j + elif method == "l2": + norm_i = np.linalg.norm(src, ord=2) + norm_j = np.linalg.norm(tgt, ord=2) + distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l2").item() + distances[count, i, j] /= norm_i + norm_j + elif method == "cosine": + distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances( + src, tgt, metric="cosine" + ).item() + + if same_array: + distances[count, j, i] = distances[count, i, j] + + if same_array and num_shifts != 0: + distances[num_shifts_both_sides - count - 1] = distances[count].T + return distances + + + def compute_similarity_with_templates_array( templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None ): - import sklearn.metrics.pairwise + if method == "cosine_similarity": method = "cosine" @@ -171,8 +317,6 @@ def compute_similarity_with_templates_array( num_channels = templates_array.shape[2] other_num_templates = other_templates_array.shape[0] - same_array = np.array_equal(templates_array, other_templates_array) - mask = None if sparsity is not None and other_sparsity is not None: if support == "intersection": @@ -182,63 +326,16 @@ def compute_similarity_with_templates_array( units_overlaps = np.sum(mask, axis=2) > 0 mask = np.logical_or(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :]) mask[~units_overlaps] = False - if mask is not None: - units_overlaps = np.sum(mask, axis=2) > 0 - overlapping_templates = {} - for i in range(num_templates): - overlapping_templates[i] = np.flatnonzero(units_overlaps[i]) else: # here we make a dense mask and overlapping templates - overlapping_templates = {i: np.arange(other_num_templates) for i in range(num_templates)} mask = np.ones((num_templates, other_num_templates, num_channels), dtype=bool) assert num_shifts < num_samples, "max_lag is too large" - num_shifts_both_sides = 2 * num_shifts + 1 - distances = np.ones((num_shifts_both_sides, num_templates, other_num_templates), dtype=np.float32) - - # We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t - # So the matrix can be computed only for negative lags and be transposed - - if same_array: - # optimisation when array are the same because of symetry in shift - shift_loop = range(-num_shifts, 1) - else: - shift_loop = range(-num_shifts, num_shifts + 1) - - for count, shift in enumerate(shift_loop): - src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts] - tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] - for i in range(num_templates): - src_template = src_sliced_templates[i] - tgt_templates = tgt_sliced_templates[overlapping_templates[i]] - for gcount, j in enumerate(overlapping_templates[i]): - # symmetric values are handled later - if same_array and j < i: - # no need exhaustive looping when same template - continue - src = src_template[:, mask[i, j]].reshape(1, -1) - tgt = (tgt_templates[gcount][:, mask[i, j]]).reshape(1, -1) - - if method == "l1": - norm_i = np.sum(np.abs(src)) - norm_j = np.sum(np.abs(tgt)) - distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l1").item() - distances[count, i, j] /= norm_i + norm_j - elif method == "l2": - norm_i = np.linalg.norm(src, ord=2) - norm_j = np.linalg.norm(tgt, ord=2) - distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l2").item() - distances[count, i, j] /= norm_i + norm_j - else: - distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances( - src, tgt, metric="cosine" - ).item() - - if same_array: - distances[count, j, i] = distances[count, i, j] - - if same_array and num_shifts != 0: - distances[num_shifts_both_sides - count - 1] = distances[count].T + distances = _compute_similarity_matrix(templates_array, + other_templates_array, + num_shifts, + mask, + method) distances = np.min(distances, axis=0) similarity = 1 - distances From 9184a34bfae5717699c7835a94acc5cda0c162f1 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 13 Sep 2024 11:11:43 +0200 Subject: [PATCH 02/14] prange and parallelism in numba --- src/spikeinterface/postprocessing/correlograms.py | 3 ++- src/spikeinterface/postprocessing/isi.py | 1 + src/spikeinterface/postprocessing/template_similarity.py | 5 +---- src/spikeinterface/sortingcomponents/matching/tdc.py | 2 +- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index ba12a5c462..88d664f059 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -439,6 +439,7 @@ def _compute_correlograms_numba(sorting, window_size, bin_size): nopython=True, nogil=True, cache=False, + parallel=True ) def _compute_correlograms_one_segment_numba( correlograms, spike_times, spike_unit_indices, window_size, bin_size, num_half_bins @@ -472,7 +473,7 @@ def _compute_correlograms_one_segment_numba( The size of which to bin lags, in samples. """ start_j = 0 - for i in range(spike_times.size): + for i in numba.prange(spike_times.size): for j in range(start_j, spike_times.size): if i == j: diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index 542f829f21..b526a54413 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -186,6 +186,7 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float nopython=True, nogil=True, cache=False, + parallel=True ) def _compute_isi_histograms_numba(ISIs, spike_trains, spike_clusters, bins): n_units = ISIs.shape[0] diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 1f1325cf10..99824f3f43 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -155,9 +155,6 @@ def _get_data(self): if HAVE_NUMBA: - - from numba import prange - @numba.jit(nopython=True, parallel=True, fastmath=True, cache=True, nogil=True) def _compute_similarity_matrix( @@ -183,7 +180,7 @@ def _compute_similarity_matrix( for count, shift in enumerate(shift_loop): src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts] tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] - for i in prange(num_templates): + for i in numba.prange(num_templates): src_template = src_sliced_templates[i] overlapping_templates = np.flatnonzero(np.sum(mask[i], 1)) tgt_templates = tgt_sliced_templates[overlapping_templates] diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index e66929e2b1..5c145d1f25 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -348,7 +348,7 @@ def _tdc_find_spikes(traces, d, level=0): if HAVE_NUMBA: - @jit(nopython=True) + @jit(nopython=True, parallel=True) def numba_sparse_dist(wf, templates, union_channels, possible_clusters): """ numba implementation that compute distance from template with sparsity From 90cebc34c985dcbff6ea7de14924454a2c252342 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 13 Sep 2024 11:22:07 +0200 Subject: [PATCH 03/14] WIP --- src/spikeinterface/postprocessing/template_similarity.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 99824f3f43..cf49010efd 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -221,6 +221,7 @@ def _compute_similarity_matrix( norm_i = np.sqrt(norm_i) norm_j = np.sqrt(norm_j) distances[count, i, j] /= (norm_i*norm_j) + distances[count, i, j] = 1 - distances[count, i, j] if same_array: distances[count, j, i] = distances[count, i, j] From 04c7b02480a363c365fc36987d8b08e429daa5a1 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 13 Sep 2024 11:40:22 +0200 Subject: [PATCH 04/14] Removing dependencies --- .../postprocessing/template_similarity.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index cf49010efd..25173b618c 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -10,7 +10,7 @@ try: import numba - HAVE_NUMBA = True + HAVE_NUMBA = False except ImportError: HAVE_NUMBA = False @@ -156,7 +156,7 @@ def _get_data(self): if HAVE_NUMBA: - @numba.jit(nopython=True, parallel=True, fastmath=True, cache=True, nogil=True) + @numba.jit(nopython=True, parallel=True, fastmath=True, nogil=True) def _compute_similarity_matrix( templates_array, other_templates_array, num_shifts, mask, method ): @@ -234,7 +234,7 @@ def _compute_similarity_matrix( def _compute_similarity_matrix( templates_array, other_templates_array, num_shifts, mask, method ): - import sklearn.metrics.pairwise + num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] other_num_templates = other_templates_array.shape[0] @@ -270,17 +270,19 @@ def _compute_similarity_matrix( if method == "l1": norm_i = np.sum(np.abs(src)) norm_j = np.sum(np.abs(tgt)) - distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l1").item() + distances[count, i, j] = np.sum(np.abs(src - tgt)) distances[count, i, j] /= norm_i + norm_j elif method == "l2": norm_i = np.linalg.norm(src, ord=2) norm_j = np.linalg.norm(tgt, ord=2) - distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l2").item() + distances[count, i, j] = np.linalg.norm(src - tgt, ord=2) distances[count, i, j] /= norm_i + norm_j elif method == "cosine": - distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances( - src, tgt, metric="cosine" - ).item() + norm_i = np.linalg.norm(src, ord=2) + norm_j = np.linalg.norm(tgt, ord=2) + distances[count, i, j] = np.sum(src*tgt) + distances[count, i, j] /= norm_i * norm_j + distances[count, i, j] = 1 - distances[count, i, j] if same_array: distances[count, j, i] = distances[count, i, j] From e6e6487d369135304d247760cab84b111c9faaa9 Mon Sep 17 00:00:00 2001 From: Sebastien Date: Fri, 13 Sep 2024 11:46:01 +0200 Subject: [PATCH 05/14] Default mask for dense case --- src/spikeinterface/postprocessing/template_similarity.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 25173b618c..8540896ac3 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -10,7 +10,7 @@ try: import numba - HAVE_NUMBA = False + HAVE_NUMBA = True except ImportError: HAVE_NUMBA = False @@ -317,7 +317,8 @@ def compute_similarity_with_templates_array( num_channels = templates_array.shape[2] other_num_templates = other_templates_array.shape[0] - mask = None + mask = np.ones((num_templates, other_num_templates, num_channels), dtype=bool) + if sparsity is not None and other_sparsity is not None: if support == "intersection": mask = np.logical_and(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :]) @@ -326,9 +327,7 @@ def compute_similarity_with_templates_array( units_overlaps = np.sum(mask, axis=2) > 0 mask = np.logical_or(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :]) mask[~units_overlaps] = False - else: - # here we make a dense mask and overlapping templates - mask = np.ones((num_templates, other_num_templates, num_channels), dtype=bool) + assert num_shifts < num_samples, "max_lag is too large" distances = _compute_similarity_matrix(templates_array, From 2e3a77b927691dc7418f6f72aa42d9ef0ea3a202 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Sep 2024 09:50:01 +0000 Subject: [PATCH 06/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../postprocessing/correlograms.py | 7 +--- src/spikeinterface/postprocessing/isi.py | 7 +--- .../postprocessing/template_similarity.py | 40 +++++++------------ 3 files changed, 17 insertions(+), 37 deletions(-) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 88d664f059..8a12e9b853 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -435,12 +435,7 @@ def _compute_correlograms_numba(sorting, window_size, bin_size): if HAVE_NUMBA: - @numba.jit( - nopython=True, - nogil=True, - cache=False, - parallel=True - ) + @numba.jit(nopython=True, nogil=True, cache=False, parallel=True) def _compute_correlograms_one_segment_numba( correlograms, spike_times, spike_unit_indices, window_size, bin_size, num_half_bins ): diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index b526a54413..5c9e5f0346 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -182,12 +182,7 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float if HAVE_NUMBA: - @numba.jit( - nopython=True, - nogil=True, - cache=False, - parallel=True - ) + @numba.jit(nopython=True, nogil=True, cache=False, parallel=True) def _compute_isi_histograms_numba(ISIs, spike_trains, spike_clusters, bins): n_units = ISIs.shape[0] diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 8540896ac3..2d94746ce7 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -157,9 +157,7 @@ def _get_data(self): if HAVE_NUMBA: @numba.jit(nopython=True, parallel=True, fastmath=True, nogil=True) - def _compute_similarity_matrix( - templates_array, other_templates_array, num_shifts, mask, method - ): + def _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, mask, method): num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] other_num_templates = other_templates_array.shape[0] @@ -176,7 +174,7 @@ def _compute_similarity_matrix( shift_loop = range(-num_shifts, 1) else: shift_loop = range(-num_shifts, num_shifts + 1) - + for count, shift in enumerate(shift_loop): src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts] tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] @@ -202,25 +200,25 @@ def _compute_similarity_matrix( norm_j += abs(tgt[k]) distances[count, i, j] += abs(src[k] - tgt[k]) elif method == "l2": - norm_i += src[k]**2 - norm_j += tgt[k]**2 - distances[count, i, j] += (src[k] - tgt[k])**2 + norm_i += src[k] ** 2 + norm_j += tgt[k] ** 2 + distances[count, i, j] += (src[k] - tgt[k]) ** 2 elif method == "cosine": - distances[count, i, j] += src[k]*tgt[k] - norm_i += src[k]**2 - norm_j += tgt[k]**2 + distances[count, i, j] += src[k] * tgt[k] + norm_i += src[k] ** 2 + norm_j += tgt[k] ** 2 if method == "l1": - distances[count, i, j] /= (norm_i + norm_j) + distances[count, i, j] /= norm_i + norm_j elif method == "l2": norm_i = np.sqrt(norm_i) norm_j = np.sqrt(norm_j) distances[count, i, j] = np.sqrt(distances[count, i, j]) - distances[count, i, j] /= (norm_i + norm_j) + distances[count, i, j] /= norm_i + norm_j elif method == "cosine": norm_i = np.sqrt(norm_i) norm_j = np.sqrt(norm_j) - distances[count, i, j] /= (norm_i*norm_j) + distances[count, i, j] /= norm_i * norm_j distances[count, i, j] = 1 - distances[count, i, j] if same_array: @@ -231,9 +229,8 @@ def _compute_similarity_matrix( return distances else: - def _compute_similarity_matrix( - templates_array, other_templates_array, num_shifts, mask, method - ): + + def _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, mask, method): num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] @@ -280,7 +277,7 @@ def _compute_similarity_matrix( elif method == "cosine": norm_i = np.linalg.norm(src, ord=2) norm_j = np.linalg.norm(tgt, ord=2) - distances[count, i, j] = np.sum(src*tgt) + distances[count, i, j] = np.sum(src * tgt) distances[count, i, j] /= norm_i * norm_j distances[count, i, j] = 1 - distances[count, i, j] @@ -292,11 +289,9 @@ def _compute_similarity_matrix( return distances - def compute_similarity_with_templates_array( templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None ): - if method == "cosine_similarity": method = "cosine" @@ -327,14 +322,9 @@ def compute_similarity_with_templates_array( units_overlaps = np.sum(mask, axis=2) > 0 mask = np.logical_or(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :]) mask[~units_overlaps] = False - assert num_shifts < num_samples, "max_lag is too large" - distances = _compute_similarity_matrix(templates_array, - other_templates_array, - num_shifts, - mask, - method) + distances = _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, mask, method) distances = np.min(distances, axis=0) similarity = 1 - distances From 1195022cc885458a0cf69de5539cdb437f3d418b Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 13 Sep 2024 14:11:04 +0200 Subject: [PATCH 07/14] WIP --- src/spikeinterface/postprocessing/correlograms.py | 4 ++-- src/spikeinterface/postprocessing/isi.py | 2 +- src/spikeinterface/sortingcomponents/matching/tdc.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 8a12e9b853..27d5703c9e 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -435,7 +435,7 @@ def _compute_correlograms_numba(sorting, window_size, bin_size): if HAVE_NUMBA: - @numba.jit(nopython=True, nogil=True, cache=False, parallel=True) + @numba.jit(nopython=True, nogil=True, cache=False) def _compute_correlograms_one_segment_numba( correlograms, spike_times, spike_unit_indices, window_size, bin_size, num_half_bins ): @@ -468,7 +468,7 @@ def _compute_correlograms_one_segment_numba( The size of which to bin lags, in samples. """ start_j = 0 - for i in numba.prange(spike_times.size): + for i in range(spike_times.size): for j in range(start_j, spike_times.size): if i == j: diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index 5c9e5f0346..1f635a8c84 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -182,7 +182,7 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float if HAVE_NUMBA: - @numba.jit(nopython=True, nogil=True, cache=False, parallel=True) + @numba.jit(nopython=True, nogil=True, cache=False) def _compute_isi_histograms_numba(ISIs, spike_trains, spike_clusters, bins): n_units = ISIs.shape[0] diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 5c145d1f25..e66929e2b1 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -348,7 +348,7 @@ def _tdc_find_spikes(traces, d, level=0): if HAVE_NUMBA: - @jit(nopython=True, parallel=True) + @jit(nopython=True) def numba_sparse_dist(wf, templates, union_channels, possible_clusters): """ numba implementation that compute distance from template with sparsity From f4ca8fafc162ab0539440ad66f29b430311cf50c Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Fri, 13 Sep 2024 14:11:04 +0200 Subject: [PATCH 08/14] WIP --- src/spikeinterface/postprocessing/correlograms.py | 8 ++++++-- src/spikeinterface/postprocessing/isi.py | 6 +++++- src/spikeinterface/sortingcomponents/matching/tdc.py | 2 +- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 8a12e9b853..ba12a5c462 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -435,7 +435,11 @@ def _compute_correlograms_numba(sorting, window_size, bin_size): if HAVE_NUMBA: - @numba.jit(nopython=True, nogil=True, cache=False, parallel=True) + @numba.jit( + nopython=True, + nogil=True, + cache=False, + ) def _compute_correlograms_one_segment_numba( correlograms, spike_times, spike_unit_indices, window_size, bin_size, num_half_bins ): @@ -468,7 +472,7 @@ def _compute_correlograms_one_segment_numba( The size of which to bin lags, in samples. """ start_j = 0 - for i in numba.prange(spike_times.size): + for i in range(spike_times.size): for j in range(start_j, spike_times.size): if i == j: diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index 5c9e5f0346..542f829f21 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -182,7 +182,11 @@ def compute_isi_histograms_numba(sorting, window_ms: float = 50.0, bin_ms: float if HAVE_NUMBA: - @numba.jit(nopython=True, nogil=True, cache=False, parallel=True) + @numba.jit( + nopython=True, + nogil=True, + cache=False, + ) def _compute_isi_histograms_numba(ISIs, spike_trains, spike_clusters, bins): n_units = ISIs.shape[0] diff --git a/src/spikeinterface/sortingcomponents/matching/tdc.py b/src/spikeinterface/sortingcomponents/matching/tdc.py index 5c145d1f25..e66929e2b1 100644 --- a/src/spikeinterface/sortingcomponents/matching/tdc.py +++ b/src/spikeinterface/sortingcomponents/matching/tdc.py @@ -348,7 +348,7 @@ def _tdc_find_spikes(traces, d, level=0): if HAVE_NUMBA: - @jit(nopython=True, parallel=True) + @jit(nopython=True) def numba_sparse_dist(wf, templates, union_channels, possible_clusters): """ numba implementation that compute distance from template with sparsity From 7d8ef31d927757fa1c36bbcf18fab16b01d827c4 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 24 Sep 2024 10:18:40 +0200 Subject: [PATCH 09/14] WIP --- .../postprocessing/template_similarity.py | 150 ++++++++++-------- 1 file changed, 81 insertions(+), 69 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 2d94746ce7..f2082150c9 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -153,11 +153,70 @@ def _get_data(self): register_result_extension(ComputeTemplateSimilarity) compute_template_similarity = ComputeTemplateSimilarity.function_factory() +def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num_shifts, mask, method): + + num_templates = templates_array.shape[0] + num_samples = templates_array.shape[1] + other_num_templates = other_templates_array.shape[0] + + num_shifts_both_sides = 2 * num_shifts + 1 + distances = np.ones((num_shifts_both_sides, num_templates, other_num_templates), dtype=np.float32) + same_array = np.array_equal(templates_array, other_templates_array) + + # We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t + # So the matrix can be computed only for negative lags and be transposed + + if same_array: + # optimisation when array are the same because of symetry in shift + shift_loop = range(-num_shifts, 1) + else: + shift_loop = range(-num_shifts, num_shifts + 1) + + for count, shift in enumerate(shift_loop): + src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts] + tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] + for i in range(num_templates): + src_template = src_sliced_templates[i] + overlapping_templates = np.flatnonzero(np.sum(mask[i], 1)) + tgt_templates = tgt_sliced_templates[overlapping_templates] + for gcount, j in enumerate(overlapping_templates): + # symmetric values are handled later + if same_array and j < i: + # no need exhaustive looping when same template + continue + src = src_template[:, mask[i, j]].reshape(1, -1) + tgt = (tgt_templates[gcount][:, mask[i, j]]).reshape(1, -1) + + if method == "l1": + norm_i = np.sum(np.abs(src)) + norm_j = np.sum(np.abs(tgt)) + distances[count, i, j] = np.sum(np.abs(src - tgt)) + distances[count, i, j] /= norm_i + norm_j + elif method == "l2": + norm_i = np.linalg.norm(src, ord=2) + norm_j = np.linalg.norm(tgt, ord=2) + distances[count, i, j] = np.linalg.norm(src - tgt, ord=2) + distances[count, i, j] /= norm_i + norm_j + elif method == "cosine": + norm_i = np.linalg.norm(src, ord=2) + norm_j = np.linalg.norm(tgt, ord=2) + distances[count, i, j] = np.sum(src * tgt) + distances[count, i, j] /= norm_i * norm_j + distances[count, i, j] = 1 - distances[count, i, j] + + if same_array: + distances[count, j, i] = distances[count, i, j] + + if same_array and num_shifts != 0: + distances[num_shifts_both_sides - count - 1] = distances[count].T + return distances + if HAVE_NUMBA: + from math import sqrt @numba.jit(nopython=True, parallel=True, fastmath=True, nogil=True) - def _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, mask, method): + def _compute_similarity_matrix_numba(templates_array, other_templates_array, num_shifts, mask, method): num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] other_num_templates = other_templates_array.shape[0] @@ -175,6 +234,13 @@ def _compute_similarity_matrix(templates_array, other_templates_array, num_shift else: shift_loop = range(-num_shifts, num_shifts + 1) + if method == 'l1': + metric = 0 + elif method == 'l2': + metric = 1 + elif method == 'cosine': + metric = 2 + for count, shift in enumerate(shift_loop): src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts] tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] @@ -195,29 +261,29 @@ def _compute_similarity_matrix(templates_array, other_templates_array, num_shift norm_j = 0 for k in range(len(src)): - if method == "l1": + if metric == 0: norm_i += abs(src[k]) norm_j += abs(tgt[k]) distances[count, i, j] += abs(src[k] - tgt[k]) - elif method == "l2": + elif metric == 1: norm_i += src[k] ** 2 norm_j += tgt[k] ** 2 distances[count, i, j] += (src[k] - tgt[k]) ** 2 - elif method == "cosine": + elif metric == 2: distances[count, i, j] += src[k] * tgt[k] norm_i += src[k] ** 2 norm_j += tgt[k] ** 2 - if method == "l1": + if metric == 0: distances[count, i, j] /= norm_i + norm_j - elif method == "l2": - norm_i = np.sqrt(norm_i) - norm_j = np.sqrt(norm_j) - distances[count, i, j] = np.sqrt(distances[count, i, j]) + elif metric == 1: + norm_i = sqrt(norm_i) + norm_j = sqrt(norm_j) + distances[count, i, j] = sqrt(distances[count, i, j]) distances[count, i, j] /= norm_i + norm_j - elif method == "cosine": - norm_i = np.sqrt(norm_i) - norm_j = np.sqrt(norm_j) + elif metric == 2: + norm_i = sqrt(norm_i) + norm_j = sqrt(norm_j) distances[count, i, j] /= norm_i * norm_j distances[count, i, j] = 1 - distances[count, i, j] @@ -226,67 +292,13 @@ def _compute_similarity_matrix(templates_array, other_templates_array, num_shift if same_array and num_shifts != 0: distances[num_shifts_both_sides - count - 1] = distances[count].T + return distances + _compute_similarity_matrix = _compute_similarity_matrix_numba else: + _compute_similarity_matrix = _compute_similarity_matrix_numpy - def _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, mask, method): - - num_templates = templates_array.shape[0] - num_samples = templates_array.shape[1] - other_num_templates = other_templates_array.shape[0] - - num_shifts_both_sides = 2 * num_shifts + 1 - distances = np.ones((num_shifts_both_sides, num_templates, other_num_templates), dtype=np.float32) - same_array = np.array_equal(templates_array, other_templates_array) - - # We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t - # So the matrix can be computed only for negative lags and be transposed - - if same_array: - # optimisation when array are the same because of symetry in shift - shift_loop = range(-num_shifts, 1) - else: - shift_loop = range(-num_shifts, num_shifts + 1) - - for count, shift in enumerate(shift_loop): - src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts] - tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] - for i in range(num_templates): - src_template = src_sliced_templates[i] - overlapping_templates = np.flatnonzero(np.sum(mask[i], 1)) - tgt_templates = tgt_sliced_templates[overlapping_templates] - for gcount, j in enumerate(overlapping_templates): - # symmetric values are handled later - if same_array and j < i: - # no need exhaustive looping when same template - continue - src = src_template[:, mask[i, j]].reshape(1, -1) - tgt = (tgt_templates[gcount][:, mask[i, j]]).reshape(1, -1) - - if method == "l1": - norm_i = np.sum(np.abs(src)) - norm_j = np.sum(np.abs(tgt)) - distances[count, i, j] = np.sum(np.abs(src - tgt)) - distances[count, i, j] /= norm_i + norm_j - elif method == "l2": - norm_i = np.linalg.norm(src, ord=2) - norm_j = np.linalg.norm(tgt, ord=2) - distances[count, i, j] = np.linalg.norm(src - tgt, ord=2) - distances[count, i, j] /= norm_i + norm_j - elif method == "cosine": - norm_i = np.linalg.norm(src, ord=2) - norm_j = np.linalg.norm(tgt, ord=2) - distances[count, i, j] = np.sum(src * tgt) - distances[count, i, j] /= norm_i * norm_j - distances[count, i, j] = 1 - distances[count, i, j] - - if same_array: - distances[count, j, i] = distances[count, i, j] - - if same_array and num_shifts != 0: - distances[num_shifts_both_sides - count - 1] = distances[count].T - return distances def compute_similarity_with_templates_array( From b77d0b088c545c8ed54a8511cf585a860735c276 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 24 Sep 2024 08:19:02 +0000 Subject: [PATCH 10/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../postprocessing/template_similarity.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index f2082150c9..18f6f88dba 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -153,6 +153,7 @@ def _get_data(self): register_result_extension(ComputeTemplateSimilarity) compute_template_similarity = ComputeTemplateSimilarity.function_factory() + def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num_shifts, mask, method): num_templates = templates_array.shape[0] @@ -215,6 +216,7 @@ def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num if HAVE_NUMBA: from math import sqrt + @numba.jit(nopython=True, parallel=True, fastmath=True, nogil=True) def _compute_similarity_matrix_numba(templates_array, other_templates_array, num_shifts, mask, method): num_templates = templates_array.shape[0] @@ -234,11 +236,11 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num else: shift_loop = range(-num_shifts, num_shifts + 1) - if method == 'l1': + if method == "l1": metric = 0 - elif method == 'l2': + elif method == "l2": metric = 1 - elif method == 'cosine': + elif method == "cosine": metric = 2 for count, shift in enumerate(shift_loop): @@ -292,7 +294,7 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num if same_array and num_shifts != 0: distances[num_shifts_both_sides - count - 1] = distances[count].T - + return distances _compute_similarity_matrix = _compute_similarity_matrix_numba @@ -300,7 +302,6 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num _compute_similarity_matrix = _compute_similarity_matrix_numpy - def compute_similarity_with_templates_array( templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None ): From e383292dfa2e1bf7d35d73369ad2c5457394febf Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Tue, 24 Sep 2024 10:42:11 +0200 Subject: [PATCH 11/14] WIP --- .../postprocessing/template_similarity.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 18f6f88dba..9bfe899840 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -232,9 +232,9 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num if same_array: # optimisation when array are the same because of symetry in shift - shift_loop = range(-num_shifts, 1) + shift_loop = list(range(-num_shifts, 1)) else: - shift_loop = range(-num_shifts, num_shifts + 1) + shift_loop = list(range(-num_shifts, num_shifts + 1)) if method == "l1": metric = 0 @@ -243,15 +243,17 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num elif method == "cosine": metric = 2 - for count, shift in enumerate(shift_loop): + for count in range(len(shift_loop)): + shift = shift_loop[count] src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts] tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] for i in numba.prange(num_templates): src_template = src_sliced_templates[i] overlapping_templates = np.flatnonzero(np.sum(mask[i], 1)) tgt_templates = tgt_sliced_templates[overlapping_templates] - for gcount, j in enumerate(overlapping_templates): + for gcount in range(len(overlapping_templates)): + j = overlapping_templates[gcount] # symmetric values are handled later if same_array and j < i: # no need exhaustive looping when same template From 1505a213d58fde4ee232adc57c4767654a4f8e32 Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 25 Sep 2024 15:30:11 +0200 Subject: [PATCH 12/14] Adding tests --- .../postprocessing/template_similarity.py | 1 + .../postprocessing/tests/test_correlograms.py | 1 - .../tests/test_template_similarity.py | 42 ++++++++++++++++++- 3 files changed, 42 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 9bfe899840..cfa9d89fea 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -263,6 +263,7 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num norm_i = 0 norm_j = 0 + distances[count, i, j] = 0 for k in range(len(src)): if metric == 0: diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index 66d84c9565..0431c8d675 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -93,7 +93,6 @@ def test_equal_results_correlograms(window_and_bin_ms): ) assert np.array_equal(result_numpy, result_numba) - assert np.array_equal(result_numpy, result_numba) @pytest.mark.parametrize("method", ["numpy", param("numba", marks=SKIP_NUMBA)]) diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index cc6797c262..364c54beea 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -7,7 +7,19 @@ ) from spikeinterface.postprocessing import check_equal_template_with_distribution_overlap, ComputeTemplateSimilarity -from spikeinterface.postprocessing.template_similarity import compute_similarity_with_templates_array +from spikeinterface.postprocessing.template_similarity import compute_similarity_with_templates_array, _compute_similarity_matrix_numba, _compute_similarity_matrix_numpy + +try: + import numba + + HAVE_NUMBA = True +except ModuleNotFoundError as err: + HAVE_NUMBA = False + +import pytest +from pytest import param + +SKIP_NUMBA = pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available") class TestSimilarityExtension(AnalyzerExtensionCommonTestSuite): @@ -71,6 +83,34 @@ def test_compute_similarity_with_templates_array(params): similarity = compute_similarity_with_templates_array(templates_array, other_templates_array, **params) print(similarity.shape) +pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available") +@pytest.mark.parametrize( + "params", + [ + dict(method="cosine", num_shifts=8), + dict(method="l1", num_shifts=0), + dict(method="l2", num_shifts=0), + dict(method="cosine", num_shifts=0), + ], +) +def test_equal_results_numba(params): + """ + Test that the 2 methods have same results with some varied time bins + that are not tested in other tests. + """ + + rng = np.random.default_rng(seed=2205) + templates_array = rng.random(size=(4, 20, 5), dtype=np.float32) + other_templates_array = rng.random(size=(2, 20, 5), dtype=np.float32) + mask = np.ones((4, 2, 5), dtype=bool) + + result_numpy = _compute_similarity_matrix_numba(templates_array, other_templates_array, mask=mask, **params) + result_numba = _compute_similarity_matrix_numpy(templates_array, other_templates_array, mask=mask, **params) + + assert np.allclose(result_numpy, result_numba, 1e-3) + + + if __name__ == "__main__": from spikeinterface.postprocessing.tests.common_extension_tests import get_dataset From 9caa8c9c12324d02afa80d94c617ee39f0b271c2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Sep 2024 13:43:55 +0000 Subject: [PATCH 13/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../postprocessing/tests/test_template_similarity.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index 364c54beea..2b6cd566b3 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -7,7 +7,11 @@ ) from spikeinterface.postprocessing import check_equal_template_with_distribution_overlap, ComputeTemplateSimilarity -from spikeinterface.postprocessing.template_similarity import compute_similarity_with_templates_array, _compute_similarity_matrix_numba, _compute_similarity_matrix_numpy +from spikeinterface.postprocessing.template_similarity import ( + compute_similarity_with_templates_array, + _compute_similarity_matrix_numba, + _compute_similarity_matrix_numpy, +) try: import numba @@ -83,7 +87,10 @@ def test_compute_similarity_with_templates_array(params): similarity = compute_similarity_with_templates_array(templates_array, other_templates_array, **params) print(similarity.shape) + pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available") + + @pytest.mark.parametrize( "params", [ @@ -110,8 +117,6 @@ def test_equal_results_numba(params): assert np.allclose(result_numpy, result_numba, 1e-3) - - if __name__ == "__main__": from spikeinterface.postprocessing.tests.common_extension_tests import get_dataset from spikeinterface.core import estimate_sparsity From 1f2e37a2fb5201627c5472c1840e32bbf3efcb0b Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 25 Sep 2024 15:48:34 +0200 Subject: [PATCH 14/14] Imports --- .../postprocessing/tests/test_template_similarity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index 2b6cd566b3..20d8373981 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -9,7 +9,6 @@ from spikeinterface.postprocessing import check_equal_template_with_distribution_overlap, ComputeTemplateSimilarity from spikeinterface.postprocessing.template_similarity import ( compute_similarity_with_templates_array, - _compute_similarity_matrix_numba, _compute_similarity_matrix_numpy, ) @@ -17,6 +16,7 @@ import numba HAVE_NUMBA = True + from spikeinterface.postprocessing.template_similarity import _compute_similarity_matrix_numba except ModuleNotFoundError as err: HAVE_NUMBA = False