Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizations for template_similarity (numba and dependencies) #3405

Merged
merged 22 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 151 additions & 49 deletions src/spikeinterface/postprocessing/template_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
from ..core.template_tools import get_dense_templates_array
from ..core.sparsity import ChannelSparsity

try:
import numba

HAVE_NUMBA = True
except ImportError:
HAVE_NUMBA = False


class ComputeTemplateSimilarity(AnalyzerExtension):
"""Compute similarity between templates with several methods.
Expand Down Expand Up @@ -147,54 +154,15 @@ def _get_data(self):
compute_template_similarity = ComputeTemplateSimilarity.function_factory()


def compute_similarity_with_templates_array(
templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None
):
import sklearn.metrics.pairwise
def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num_shifts, mask, method):

if method == "cosine_similarity":
method = "cosine"

all_metrics = ["cosine", "l1", "l2"]

if method not in all_metrics:
raise ValueError(f"compute_template_similarity (method {method}) not exists")

assert (
templates_array.shape[1] == other_templates_array.shape[1]
), "The number of samples in the templates should be the same for both arrays"
assert (
templates_array.shape[2] == other_templates_array.shape[2]
), "The number of channels in the templates should be the same for both arrays"
num_templates = templates_array.shape[0]
num_samples = templates_array.shape[1]
num_channels = templates_array.shape[2]
other_num_templates = other_templates_array.shape[0]

same_array = np.array_equal(templates_array, other_templates_array)

mask = None
if sparsity is not None and other_sparsity is not None:
if support == "intersection":
mask = np.logical_and(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :])
elif support == "union":
mask = np.logical_and(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :])
units_overlaps = np.sum(mask, axis=2) > 0
mask = np.logical_or(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :])
mask[~units_overlaps] = False
if mask is not None:
units_overlaps = np.sum(mask, axis=2) > 0
overlapping_templates = {}
for i in range(num_templates):
overlapping_templates[i] = np.flatnonzero(units_overlaps[i])
else:
# here we make a dense mask and overlapping templates
overlapping_templates = {i: np.arange(other_num_templates) for i in range(num_templates)}
mask = np.ones((num_templates, other_num_templates, num_channels), dtype=bool)

assert num_shifts < num_samples, "max_lag is too large"
num_shifts_both_sides = 2 * num_shifts + 1
distances = np.ones((num_shifts_both_sides, num_templates, other_num_templates), dtype=np.float32)
same_array = np.array_equal(templates_array, other_templates_array)

# We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t
# So the matrix can be computed only for negative lags and be transposed
Expand All @@ -210,8 +178,9 @@ def compute_similarity_with_templates_array(
tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift]
for i in range(num_templates):
src_template = src_sliced_templates[i]
tgt_templates = tgt_sliced_templates[overlapping_templates[i]]
for gcount, j in enumerate(overlapping_templates[i]):
overlapping_templates = np.flatnonzero(np.sum(mask[i], 1))
tgt_templates = tgt_sliced_templates[overlapping_templates]
for gcount, j in enumerate(overlapping_templates):
# symmetric values are handled later
if same_array and j < i:
# no need exhaustive looping when same template
Expand All @@ -222,23 +191,156 @@ def compute_similarity_with_templates_array(
if method == "l1":
norm_i = np.sum(np.abs(src))
norm_j = np.sum(np.abs(tgt))
distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l1").item()
distances[count, i, j] = np.sum(np.abs(src - tgt))
distances[count, i, j] /= norm_i + norm_j
elif method == "l2":
norm_i = np.linalg.norm(src, ord=2)
norm_j = np.linalg.norm(tgt, ord=2)
distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l2").item()
distances[count, i, j] = np.linalg.norm(src - tgt, ord=2)
distances[count, i, j] /= norm_i + norm_j
else:
distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(
src, tgt, metric="cosine"
).item()
elif method == "cosine":
norm_i = np.linalg.norm(src, ord=2)
norm_j = np.linalg.norm(tgt, ord=2)
distances[count, i, j] = np.sum(src * tgt)
distances[count, i, j] /= norm_i * norm_j
distances[count, i, j] = 1 - distances[count, i, j]

if same_array:
distances[count, j, i] = distances[count, i, j]

if same_array and num_shifts != 0:
distances[num_shifts_both_sides - count - 1] = distances[count].T
return distances


if HAVE_NUMBA:
yger marked this conversation as resolved.
Show resolved Hide resolved

from math import sqrt

@numba.jit(nopython=True, parallel=True, fastmath=True, nogil=True)
def _compute_similarity_matrix_numba(templates_array, other_templates_array, num_shifts, mask, method):
num_templates = templates_array.shape[0]
num_samples = templates_array.shape[1]
other_num_templates = other_templates_array.shape[0]

num_shifts_both_sides = 2 * num_shifts + 1
distances = np.ones((num_shifts_both_sides, num_templates, other_num_templates), dtype=np.float32)
same_array = np.array_equal(templates_array, other_templates_array)

# We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t
# So the matrix can be computed only for negative lags and be transposed

if same_array:
# optimisation when array are the same because of symetry in shift
shift_loop = list(range(-num_shifts, 1))
else:
shift_loop = list(range(-num_shifts, num_shifts + 1))

if method == "l1":
metric = 0
elif method == "l2":
metric = 1
elif method == "cosine":
metric = 2

for count in range(len(shift_loop)):
shift = shift_loop[count]
src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts]
tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift]
for i in numba.prange(num_templates):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would put the prange the very first loop.
I do not known if the thread are spawn for every shift loop.

In short I would invert the prange on template and the shift loop.

Could you try this and make benchmark ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried to invert the loops, and it does not make any difference

src_template = src_sliced_templates[i]
overlapping_templates = np.flatnonzero(np.sum(mask[i], 1))
tgt_templates = tgt_sliced_templates[overlapping_templates]
for gcount in range(len(overlapping_templates)):

j = overlapping_templates[gcount]
# symmetric values are handled later
if same_array and j < i:
# no need exhaustive looping when same template
continue
src = src_template[:, mask[i, j]].flatten()
tgt = (tgt_templates[gcount][:, mask[i, j]]).flatten()
yger marked this conversation as resolved.
Show resolved Hide resolved

norm_i = 0
norm_j = 0
distances[count, i, j] = 0

for k in range(len(src)):
if metric == 0:
norm_i += abs(src[k])
norm_j += abs(tgt[k])
distances[count, i, j] += abs(src[k] - tgt[k])
elif metric == 1:
norm_i += src[k] ** 2
norm_j += tgt[k] ** 2
distances[count, i, j] += (src[k] - tgt[k]) ** 2
elif metric == 2:
distances[count, i, j] += src[k] * tgt[k]
norm_i += src[k] ** 2
norm_j += tgt[k] ** 2

if metric == 0:
distances[count, i, j] /= norm_i + norm_j
elif metric == 1:
norm_i = sqrt(norm_i)
norm_j = sqrt(norm_j)
distances[count, i, j] = sqrt(distances[count, i, j])
distances[count, i, j] /= norm_i + norm_j
elif metric == 2:
norm_i = sqrt(norm_i)
norm_j = sqrt(norm_j)
distances[count, i, j] /= norm_i * norm_j
distances[count, i, j] = 1 - distances[count, i, j]

if same_array:
distances[count, j, i] = distances[count, i, j]

if same_array and num_shifts != 0:
distances[num_shifts_both_sides - count - 1] = distances[count].T

return distances

_compute_similarity_matrix = _compute_similarity_matrix_numba
else:
_compute_similarity_matrix = _compute_similarity_matrix_numpy


def compute_similarity_with_templates_array(
templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None
):

if method == "cosine_similarity":
method = "cosine"

all_metrics = ["cosine", "l1", "l2"]

if method not in all_metrics:
raise ValueError(f"compute_template_similarity (method {method}) not exists")

assert (
templates_array.shape[1] == other_templates_array.shape[1]
), "The number of samples in the templates should be the same for both arrays"
assert (
templates_array.shape[2] == other_templates_array.shape[2]
), "The number of channels in the templates should be the same for both arrays"
num_templates = templates_array.shape[0]
num_samples = templates_array.shape[1]
num_channels = templates_array.shape[2]
other_num_templates = other_templates_array.shape[0]

mask = np.ones((num_templates, other_num_templates, num_channels), dtype=bool)

if sparsity is not None and other_sparsity is not None:
if support == "intersection":
mask = np.logical_and(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :])
elif support == "union":
mask = np.logical_and(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :])
units_overlaps = np.sum(mask, axis=2) > 0
mask = np.logical_or(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :])
mask[~units_overlaps] = False

assert num_shifts < num_samples, "max_lag is too large"
distances = _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, mask, method)

distances = np.min(distances, axis=0)
similarity = 1 - distances
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def test_equal_results_correlograms(window_and_bin_ms):
)

assert np.array_equal(result_numpy, result_numba)
assert np.array_equal(result_numpy, result_numba)


@pytest.mark.parametrize("method", ["numpy", param("numba", marks=SKIP_NUMBA)])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,23 @@
)

from spikeinterface.postprocessing import check_equal_template_with_distribution_overlap, ComputeTemplateSimilarity
from spikeinterface.postprocessing.template_similarity import compute_similarity_with_templates_array
from spikeinterface.postprocessing.template_similarity import (
compute_similarity_with_templates_array,
_compute_similarity_matrix_numpy,
)

try:
import numba

HAVE_NUMBA = True
from spikeinterface.postprocessing.template_similarity import _compute_similarity_matrix_numba
except ModuleNotFoundError as err:
HAVE_NUMBA = False

import pytest
from pytest import param

SKIP_NUMBA = pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available")


class TestSimilarityExtension(AnalyzerExtensionCommonTestSuite):
Expand Down Expand Up @@ -72,6 +88,35 @@ def test_compute_similarity_with_templates_array(params):
print(similarity.shape)


pytest.mark.skipif(not HAVE_NUMBA, reason="Numba not available")


@pytest.mark.parametrize(
"params",
[
dict(method="cosine", num_shifts=8),
dict(method="l1", num_shifts=0),
dict(method="l2", num_shifts=0),
dict(method="cosine", num_shifts=0),
],
)
def test_equal_results_numba(params):
"""
Test that the 2 methods have same results with some varied time bins
that are not tested in other tests.
"""

rng = np.random.default_rng(seed=2205)
templates_array = rng.random(size=(4, 20, 5), dtype=np.float32)
other_templates_array = rng.random(size=(2, 20, 5), dtype=np.float32)
mask = np.ones((4, 2, 5), dtype=bool)

result_numpy = _compute_similarity_matrix_numba(templates_array, other_templates_array, mask=mask, **params)
result_numba = _compute_similarity_matrix_numpy(templates_array, other_templates_array, mask=mask, **params)

assert np.allclose(result_numpy, result_numba, 1e-3)


if __name__ == "__main__":
from spikeinterface.postprocessing.tests.common_extension_tests import get_dataset
from spikeinterface.core import estimate_sparsity
Expand Down