diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 0e70b1f494..cfa9d89fea 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -7,6 +7,13 @@ from ..core.template_tools import get_dense_templates_array from ..core.sparsity import ChannelSparsity +try: + import numba + + HAVE_NUMBA = True +except ImportError: + HAVE_NUMBA = False + class ComputeTemplateSimilarity(AnalyzerExtension): """Compute similarity between templates with several methods. @@ -147,54 +154,15 @@ def _get_data(self): compute_template_similarity = ComputeTemplateSimilarity.function_factory() -def compute_similarity_with_templates_array( - templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None -): - import sklearn.metrics.pairwise +def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num_shifts, mask, method): - if method == "cosine_similarity": - method = "cosine" - - all_metrics = ["cosine", "l1", "l2"] - - if method not in all_metrics: - raise ValueError(f"compute_template_similarity (method {method}) not exists") - - assert ( - templates_array.shape[1] == other_templates_array.shape[1] - ), "The number of samples in the templates should be the same for both arrays" - assert ( - templates_array.shape[2] == other_templates_array.shape[2] - ), "The number of channels in the templates should be the same for both arrays" num_templates = templates_array.shape[0] num_samples = templates_array.shape[1] - num_channels = templates_array.shape[2] other_num_templates = other_templates_array.shape[0] - same_array = np.array_equal(templates_array, other_templates_array) - - mask = None - if sparsity is not None and other_sparsity is not None: - if support == "intersection": - mask = np.logical_and(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :]) - elif support == "union": - mask = np.logical_and(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :]) - units_overlaps = np.sum(mask, axis=2) > 0 - mask = np.logical_or(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :]) - mask[~units_overlaps] = False - if mask is not None: - units_overlaps = np.sum(mask, axis=2) > 0 - overlapping_templates = {} - for i in range(num_templates): - overlapping_templates[i] = np.flatnonzero(units_overlaps[i]) - else: - # here we make a dense mask and overlapping templates - overlapping_templates = {i: np.arange(other_num_templates) for i in range(num_templates)} - mask = np.ones((num_templates, other_num_templates, num_channels), dtype=bool) - - assert num_shifts < num_samples, "max_lag is too large" num_shifts_both_sides = 2 * num_shifts + 1 distances = np.ones((num_shifts_both_sides, num_templates, other_num_templates), dtype=np.float32) + same_array = np.array_equal(templates_array, other_templates_array) # We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t # So the matrix can be computed only for negative lags and be transposed @@ -210,8 +178,9 @@ def compute_similarity_with_templates_array( tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] for i in range(num_templates): src_template = src_sliced_templates[i] - tgt_templates = tgt_sliced_templates[overlapping_templates[i]] - for gcount, j in enumerate(overlapping_templates[i]): + overlapping_templates = np.flatnonzero(np.sum(mask[i], 1)) + tgt_templates = tgt_sliced_templates[overlapping_templates] + for gcount, j in enumerate(overlapping_templates): # symmetric values are handled later if same_array and j < i: # no need exhaustive looping when same template @@ -222,23 +191,156 @@ def compute_similarity_with_templates_array( if method == "l1": norm_i = np.sum(np.abs(src)) norm_j = np.sum(np.abs(tgt)) - distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l1").item() + distances[count, i, j] = np.sum(np.abs(src - tgt)) distances[count, i, j] /= norm_i + norm_j elif method == "l2": norm_i = np.linalg.norm(src, ord=2) norm_j = np.linalg.norm(tgt, ord=2) - distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l2").item() + distances[count, i, j] = np.linalg.norm(src - tgt, ord=2) distances[count, i, j] /= norm_i + norm_j - else: - distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances( - src, tgt, metric="cosine" - ).item() + elif method == "cosine": + norm_i = np.linalg.norm(src, ord=2) + norm_j = np.linalg.norm(tgt, ord=2) + distances[count, i, j] = np.sum(src * tgt) + distances[count, i, j] /= norm_i * norm_j + distances[count, i, j] = 1 - distances[count, i, j] if same_array: distances[count, j, i] = distances[count, i, j] if same_array and num_shifts != 0: distances[num_shifts_both_sides - count - 1] = distances[count].T + return distances + + +if HAVE_NUMBA: + + from math import sqrt + + @numba.jit(nopython=True, parallel=True, fastmath=True, nogil=True) + def _compute_similarity_matrix_numba(templates_array, other_templates_array, num_shifts, mask, method): + num_templates = templates_array.shape[0] + num_samples = templates_array.shape[1] + other_num_templates = other_templates_array.shape[0] + + num_shifts_both_sides = 2 * num_shifts + 1 + distances = np.ones((num_shifts_both_sides, num_templates, other_num_templates), dtype=np.float32) + same_array = np.array_equal(templates_array, other_templates_array) + + # We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t + # So the matrix can be computed only for negative lags and be transposed + + if same_array: + # optimisation when array are the same because of symetry in shift + shift_loop = list(range(-num_shifts, 1)) + else: + shift_loop = list(range(-num_shifts, num_shifts + 1)) + + if method == "l1": + metric = 0 + elif method == "l2": + metric = 1 + elif method == "cosine": + metric = 2 + + for count in range(len(shift_loop)): + shift = shift_loop[count] + src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts] + tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift] + for i in numba.prange(num_templates): + src_template = src_sliced_templates[i] + overlapping_templates = np.flatnonzero(np.sum(mask[i], 1)) + tgt_templates = tgt_sliced_templates[overlapping_templates] + for gcount in range(len(overlapping_templates)): + + j = overlapping_templates[gcount] + # symmetric values are handled later + if same_array and j < i: + # no need exhaustive looping when same template + continue + src = src_template[:, mask[i, j]].flatten() + tgt = (tgt_templates[gcount][:, mask[i, j]]).flatten() + + norm_i = 0 + norm_j = 0 + distances[count, i, j] = 0 + + for k in range(len(src)): + if metric == 0: + norm_i += abs(src[k]) + norm_j += abs(tgt[k]) + distances[count, i, j] += abs(src[k] - tgt[k]) + elif metric == 1: + norm_i += src[k] ** 2 + norm_j += tgt[k] ** 2 + distances[count, i, j] += (src[k] - tgt[k]) ** 2 + elif metric == 2: + distances[count, i, j] += src[k] * tgt[k] + norm_i += src[k] ** 2 + norm_j += tgt[k] ** 2 + + if metric == 0: + distances[count, i, j] /= norm_i + norm_j + elif metric == 1: + norm_i = sqrt(norm_i) + norm_j = sqrt(norm_j) + distances[count, i, j] = sqrt(distances[count, i, j]) + distances[count, i, j] /= norm_i + norm_j + elif metric == 2: + norm_i = sqrt(norm_i) + norm_j = sqrt(norm_j) + distances[count, i, j] /= norm_i * norm_j + distances[count, i, j] = 1 - distances[count, i, j] + + if same_array: + distances[count, j, i] = distances[count, i, j] + + if same_array and num_shifts != 0: + distances[num_shifts_both_sides - count - 1] = distances[count].T + + return distances + + _compute_similarity_matrix = _compute_similarity_matrix_numba +else: + _compute_similarity_matrix = _compute_similarity_matrix_numpy + + +def compute_similarity_with_templates_array( + templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None +): + + if method == "cosine_similarity": + method = "cosine" + + all_metrics = ["cosine", "l1", "l2"] + + if method not in all_metrics: + raise ValueError(f"compute_template_similarity (method {method}) not exists") + + assert ( + templates_array.shape[1] == other_templates_array.shape[1] + ), "The number of samples in the templates should be the same for both arrays" + assert ( + templates_array.shape[2] == other_templates_array.shape[2] + ), "The number of channels in the templates should be the same for both arrays" + num_templates = templates_array.shape[0] + num_samples = templates_array.shape[1] + num_channels = templates_array.shape[2] + other_num_templates = other_templates_array.shape[0] + + mask = np.ones((num_templates, other_num_templates, num_channels), dtype=bool) + + if sparsity is not None and other_sparsity is not None: + if support == "intersection": + mask = np.logical_and(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :]) + elif support == "union": + mask = np.logical_and(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :]) + units_overlaps = np.sum(mask, axis=2) > 0 + mask = np.logical_or(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :]) + mask[~units_overlaps] = False + + assert num_shifts < num_samples, "max_lag is too large" + distances = _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, mask, method) distances = np.min(distances, axis=0) similarity = 1 - distances diff --git a/src/spikeinterface/postprocessing/tests/test_correlograms.py b/src/spikeinterface/postprocessing/tests/test_correlograms.py index 66d84c9565..0431c8d675 100644 --- a/src/spikeinterface/postprocessing/tests/test_correlograms.py +++ b/src/spikeinterface/postprocessing/tests/test_correlograms.py @@ -93,7 +93,6 @@ def test_equal_results_correlograms(window_and_bin_ms): ) assert np.array_equal(result_numpy, result_numba) - assert np.array_equal(result_numpy, result_numba) @pytest.mark.parametrize("method", ["numpy", param("numba", marks=SKIP_NUMBA)]) diff --git a/src/spikeinterface/postprocessing/tests/test_template_similarity.py b/src/spikeinterface/postprocessing/tests/test_template_similarity.py index cc6797c262..20d8373981 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_similarity.py +++ b/src/spikeinterface/postprocessing/tests/test_template_similarity.py @@ -7,7 +7,23 @@ ) from spikeinterface.postprocessing import check_equal_template_with_distribution_overlap, ComputeTemplateSimilarity -from spikeinterface.postprocessing.template_similarity import compute_similarity_with_templates_array +from spikeinterface.postprocessing.template_similarity import ( + compute_similarity_with_templates_array, + _compute_similarity_matrix_numpy, +) + +try: + import numba + + HAVE_NUMBA = True + from spikeinterface.postprocessing.template_similarity import _compute_similarity_matrix_numba +except ModuleNotFoundError as err: + HAVE_NUMBA = False + +import pytest +from pytest import param + +SKIP_NUMBA = pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available") class TestSimilarityExtension(AnalyzerExtensionCommonTestSuite): @@ -72,6 +88,35 @@ def test_compute_similarity_with_templates_array(params): print(similarity.shape) +pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available") + + +@pytest.mark.parametrize( + "params", + [ + dict(method="cosine", num_shifts=8), + dict(method="l1", num_shifts=0), + dict(method="l2", num_shifts=0), + dict(method="cosine", num_shifts=0), + ], +) +def test_equal_results_numba(params): + """ + Test that the 2 methods have same results with some varied time bins + that are not tested in other tests. + """ + + rng = np.random.default_rng(seed=2205) + templates_array = rng.random(size=(4, 20, 5), dtype=np.float32) + other_templates_array = rng.random(size=(2, 20, 5), dtype=np.float32) + mask = np.ones((4, 2, 5), dtype=bool) + + result_numpy = _compute_similarity_matrix_numba(templates_array, other_templates_array, mask=mask, **params) + result_numba = _compute_similarity_matrix_numpy(templates_array, other_templates_array, mask=mask, **params) + + assert np.allclose(result_numpy, result_numba, 1e-3) + + if __name__ == "__main__": from spikeinterface.postprocessing.tests.common_extension_tests import get_dataset from spikeinterface.core import estimate_sparsity