diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index f2082150c9..18f6f88dba 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -153,6 +153,7 @@ def _get_data(self): register_result_extension(ComputeTemplateSimilarity) compute_template_similarity = ComputeTemplateSimilarity.function_factory() + def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num_shifts, mask, method): num_templates = templates_array.shape[0] @@ -215,6 +216,7 @@ def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num if HAVE_NUMBA: from math import sqrt + @numba.jit(nopython=True, parallel=True, fastmath=True, nogil=True) def _compute_similarity_matrix_numba(templates_array, other_templates_array, num_shifts, mask, method): num_templates = templates_array.shape[0] @@ -234,11 +236,11 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num else: shift_loop = range(-num_shifts, num_shifts + 1) - if method == 'l1': + if method == "l1": metric = 0 - elif method == 'l2': + elif method == "l2": metric = 1 - elif method == 'cosine': + elif method == "cosine": metric = 2 for count, shift in enumerate(shift_loop): @@ -292,7 +294,7 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num if same_array and num_shifts != 0: distances[num_shifts_both_sides - count - 1] = distances[count].T - + return distances _compute_similarity_matrix = _compute_similarity_matrix_numba @@ -300,7 +302,6 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num _compute_similarity_matrix = _compute_similarity_matrix_numpy - def compute_similarity_with_templates_array( templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None ):