Skip to content

Commit

Permalink
Merge pull request #3405 from yger/numba_similarity
Browse files Browse the repository at this point in the history
Optimizations for template_similarity (numba and dependencies)
  • Loading branch information
samuelgarcia authored Oct 7, 2024
2 parents 5e13593 + 07f893f commit 80cc888
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 51 deletions.
200 changes: 151 additions & 49 deletions src/spikeinterface/postprocessing/template_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
from ..core.template_tools import get_dense_templates_array
from ..core.sparsity import ChannelSparsity

try:
import numba

HAVE_NUMBA = True
except ImportError:
HAVE_NUMBA = False


class ComputeTemplateSimilarity(AnalyzerExtension):
"""Compute similarity between templates with several methods.
Expand Down Expand Up @@ -147,54 +154,15 @@ def _get_data(self):
compute_template_similarity = ComputeTemplateSimilarity.function_factory()


def compute_similarity_with_templates_array(
templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None
):
import sklearn.metrics.pairwise
def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num_shifts, mask, method):

if method == "cosine_similarity":
method = "cosine"

all_metrics = ["cosine", "l1", "l2"]

if method not in all_metrics:
raise ValueError(f"compute_template_similarity (method {method}) not exists")

assert (
templates_array.shape[1] == other_templates_array.shape[1]
), "The number of samples in the templates should be the same for both arrays"
assert (
templates_array.shape[2] == other_templates_array.shape[2]
), "The number of channels in the templates should be the same for both arrays"
num_templates = templates_array.shape[0]
num_samples = templates_array.shape[1]
num_channels = templates_array.shape[2]
other_num_templates = other_templates_array.shape[0]

same_array = np.array_equal(templates_array, other_templates_array)

mask = None
if sparsity is not None and other_sparsity is not None:
if support == "intersection":
mask = np.logical_and(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :])
elif support == "union":
mask = np.logical_and(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :])
units_overlaps = np.sum(mask, axis=2) > 0
mask = np.logical_or(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :])
mask[~units_overlaps] = False
if mask is not None:
units_overlaps = np.sum(mask, axis=2) > 0
overlapping_templates = {}
for i in range(num_templates):
overlapping_templates[i] = np.flatnonzero(units_overlaps[i])
else:
# here we make a dense mask and overlapping templates
overlapping_templates = {i: np.arange(other_num_templates) for i in range(num_templates)}
mask = np.ones((num_templates, other_num_templates, num_channels), dtype=bool)

assert num_shifts < num_samples, "max_lag is too large"
num_shifts_both_sides = 2 * num_shifts + 1
distances = np.ones((num_shifts_both_sides, num_templates, other_num_templates), dtype=np.float32)
same_array = np.array_equal(templates_array, other_templates_array)

# We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t
# So the matrix can be computed only for negative lags and be transposed
Expand All @@ -210,8 +178,9 @@ def compute_similarity_with_templates_array(
tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift]
for i in range(num_templates):
src_template = src_sliced_templates[i]
tgt_templates = tgt_sliced_templates[overlapping_templates[i]]
for gcount, j in enumerate(overlapping_templates[i]):
overlapping_templates = np.flatnonzero(np.sum(mask[i], 1))
tgt_templates = tgt_sliced_templates[overlapping_templates]
for gcount, j in enumerate(overlapping_templates):
# symmetric values are handled later
if same_array and j < i:
# no need exhaustive looping when same template
Expand All @@ -222,23 +191,156 @@ def compute_similarity_with_templates_array(
if method == "l1":
norm_i = np.sum(np.abs(src))
norm_j = np.sum(np.abs(tgt))
distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l1").item()
distances[count, i, j] = np.sum(np.abs(src - tgt))
distances[count, i, j] /= norm_i + norm_j
elif method == "l2":
norm_i = np.linalg.norm(src, ord=2)
norm_j = np.linalg.norm(tgt, ord=2)
distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l2").item()
distances[count, i, j] = np.linalg.norm(src - tgt, ord=2)
distances[count, i, j] /= norm_i + norm_j
else:
distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(
src, tgt, metric="cosine"
).item()
elif method == "cosine":
norm_i = np.linalg.norm(src, ord=2)
norm_j = np.linalg.norm(tgt, ord=2)
distances[count, i, j] = np.sum(src * tgt)
distances[count, i, j] /= norm_i * norm_j
distances[count, i, j] = 1 - distances[count, i, j]

if same_array:
distances[count, j, i] = distances[count, i, j]

if same_array and num_shifts != 0:
distances[num_shifts_both_sides - count - 1] = distances[count].T
return distances


if HAVE_NUMBA:

from math import sqrt

@numba.jit(nopython=True, parallel=True, fastmath=True, nogil=True)
def _compute_similarity_matrix_numba(templates_array, other_templates_array, num_shifts, mask, method):
num_templates = templates_array.shape[0]
num_samples = templates_array.shape[1]
other_num_templates = other_templates_array.shape[0]

num_shifts_both_sides = 2 * num_shifts + 1
distances = np.ones((num_shifts_both_sides, num_templates, other_num_templates), dtype=np.float32)
same_array = np.array_equal(templates_array, other_templates_array)

# We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t
# So the matrix can be computed only for negative lags and be transposed

if same_array:
# optimisation when array are the same because of symetry in shift
shift_loop = list(range(-num_shifts, 1))
else:
shift_loop = list(range(-num_shifts, num_shifts + 1))

if method == "l1":
metric = 0
elif method == "l2":
metric = 1
elif method == "cosine":
metric = 2

for count in range(len(shift_loop)):
shift = shift_loop[count]
src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts]
tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift]
for i in numba.prange(num_templates):
src_template = src_sliced_templates[i]
overlapping_templates = np.flatnonzero(np.sum(mask[i], 1))
tgt_templates = tgt_sliced_templates[overlapping_templates]
for gcount in range(len(overlapping_templates)):

j = overlapping_templates[gcount]
# symmetric values are handled later
if same_array and j < i:
# no need exhaustive looping when same template
continue
src = src_template[:, mask[i, j]].flatten()
tgt = (tgt_templates[gcount][:, mask[i, j]]).flatten()

norm_i = 0
norm_j = 0
distances[count, i, j] = 0

for k in range(len(src)):
if metric == 0:
norm_i += abs(src[k])
norm_j += abs(tgt[k])
distances[count, i, j] += abs(src[k] - tgt[k])
elif metric == 1:
norm_i += src[k] ** 2
norm_j += tgt[k] ** 2
distances[count, i, j] += (src[k] - tgt[k]) ** 2
elif metric == 2:
distances[count, i, j] += src[k] * tgt[k]
norm_i += src[k] ** 2
norm_j += tgt[k] ** 2

if metric == 0:
distances[count, i, j] /= norm_i + norm_j
elif metric == 1:
norm_i = sqrt(norm_i)
norm_j = sqrt(norm_j)
distances[count, i, j] = sqrt(distances[count, i, j])
distances[count, i, j] /= norm_i + norm_j
elif metric == 2:
norm_i = sqrt(norm_i)
norm_j = sqrt(norm_j)
distances[count, i, j] /= norm_i * norm_j
distances[count, i, j] = 1 - distances[count, i, j]

if same_array:
distances[count, j, i] = distances[count, i, j]

if same_array and num_shifts != 0:
distances[num_shifts_both_sides - count - 1] = distances[count].T

return distances

_compute_similarity_matrix = _compute_similarity_matrix_numba
else:
_compute_similarity_matrix = _compute_similarity_matrix_numpy


def compute_similarity_with_templates_array(
templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None
):

if method == "cosine_similarity":
method = "cosine"

all_metrics = ["cosine", "l1", "l2"]

if method not in all_metrics:
raise ValueError(f"compute_template_similarity (method {method}) not exists")

assert (
templates_array.shape[1] == other_templates_array.shape[1]
), "The number of samples in the templates should be the same for both arrays"
assert (
templates_array.shape[2] == other_templates_array.shape[2]
), "The number of channels in the templates should be the same for both arrays"
num_templates = templates_array.shape[0]
num_samples = templates_array.shape[1]
num_channels = templates_array.shape[2]
other_num_templates = other_templates_array.shape[0]

mask = np.ones((num_templates, other_num_templates, num_channels), dtype=bool)

if sparsity is not None and other_sparsity is not None:
if support == "intersection":
mask = np.logical_and(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :])
elif support == "union":
mask = np.logical_and(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :])
units_overlaps = np.sum(mask, axis=2) > 0
mask = np.logical_or(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :])
mask[~units_overlaps] = False

assert num_shifts < num_samples, "max_lag is too large"
distances = _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, mask, method)

distances = np.min(distances, axis=0)
similarity = 1 - distances
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def test_equal_results_correlograms(window_and_bin_ms):
)

assert np.array_equal(result_numpy, result_numba)
assert np.array_equal(result_numpy, result_numba)


@pytest.mark.parametrize("method", ["numpy", param("numba", marks=SKIP_NUMBA)])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,23 @@
)

from spikeinterface.postprocessing import check_equal_template_with_distribution_overlap, ComputeTemplateSimilarity
from spikeinterface.postprocessing.template_similarity import compute_similarity_with_templates_array
from spikeinterface.postprocessing.template_similarity import (
compute_similarity_with_templates_array,
_compute_similarity_matrix_numpy,
)

try:
import numba

HAVE_NUMBA = True
from spikeinterface.postprocessing.template_similarity import _compute_similarity_matrix_numba
except ModuleNotFoundError as err:
HAVE_NUMBA = False

import pytest
from pytest import param

SKIP_NUMBA = pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available")


class TestSimilarityExtension(AnalyzerExtensionCommonTestSuite):
Expand Down Expand Up @@ -72,6 +88,35 @@ def test_compute_similarity_with_templates_array(params):
print(similarity.shape)


pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available")


@pytest.mark.parametrize(
"params",
[
dict(method="cosine", num_shifts=8),
dict(method="l1", num_shifts=0),
dict(method="l2", num_shifts=0),
dict(method="cosine", num_shifts=0),
],
)
def test_equal_results_numba(params):
"""
Test that the 2 methods have same results with some varied time bins
that are not tested in other tests.
"""

rng = np.random.default_rng(seed=2205)
templates_array = rng.random(size=(4, 20, 5), dtype=np.float32)
other_templates_array = rng.random(size=(2, 20, 5), dtype=np.float32)
mask = np.ones((4, 2, 5), dtype=bool)

result_numpy = _compute_similarity_matrix_numba(templates_array, other_templates_array, mask=mask, **params)
result_numba = _compute_similarity_matrix_numpy(templates_array, other_templates_array, mask=mask, **params)

assert np.allclose(result_numpy, result_numba, 1e-3)


if __name__ == "__main__":
from spikeinterface.postprocessing.tests.common_extension_tests import get_dataset
from spikeinterface.core import estimate_sparsity
Expand Down

0 comments on commit 80cc888

Please sign in to comment.