Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimizations for template_similarity (numba and dependencies) #3405

Merged
merged 22 commits into from
Oct 7, 2024
Merged
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 144 additions & 58 deletions src/spikeinterface/postprocessing/template_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
from ..core.template_tools import get_dense_templates_array
from ..core.sparsity import ChannelSparsity

try:
import numba

HAVE_NUMBA = True
except ImportError:
HAVE_NUMBA = False


class ComputeTemplateSimilarity(AnalyzerExtension):
"""Compute similarity between templates with several methods.
Expand Down Expand Up @@ -147,10 +154,144 @@ def _get_data(self):
compute_template_similarity = ComputeTemplateSimilarity.function_factory()


if HAVE_NUMBA:
yger marked this conversation as resolved.
Show resolved Hide resolved

@numba.jit(nopython=True, parallel=True, fastmath=True, nogil=True)
def _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, mask, method):
num_templates = templates_array.shape[0]
num_samples = templates_array.shape[1]
other_num_templates = other_templates_array.shape[0]

num_shifts_both_sides = 2 * num_shifts + 1
distances = np.ones((num_shifts_both_sides, num_templates, other_num_templates), dtype=np.float32)
same_array = np.array_equal(templates_array, other_templates_array)

# We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t
# So the matrix can be computed only for negative lags and be transposed

if same_array:
# optimisation when array are the same because of symetry in shift
shift_loop = range(-num_shifts, 1)
else:
shift_loop = range(-num_shifts, num_shifts + 1)

for count, shift in enumerate(shift_loop):
src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts]
tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift]
for i in numba.prange(num_templates):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would put the prange the very first loop.
I do not known if the thread are spawn for every shift loop.

In short I would invert the prange on template and the shift loop.

Could you try this and make benchmark ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tried to invert the loops, and it does not make any difference

src_template = src_sliced_templates[i]
overlapping_templates = np.flatnonzero(np.sum(mask[i], 1))
tgt_templates = tgt_sliced_templates[overlapping_templates]
for gcount, j in enumerate(overlapping_templates):
yger marked this conversation as resolved.
Show resolved Hide resolved

# symmetric values are handled later
if same_array and j < i:
# no need exhaustive looping when same template
continue
src = src_template[:, mask[i, j]].flatten()
tgt = (tgt_templates[gcount][:, mask[i, j]]).flatten()
yger marked this conversation as resolved.
Show resolved Hide resolved

norm_i = 0
norm_j = 0

for k in range(len(src)):
if method == "l1":
yger marked this conversation as resolved.
Show resolved Hide resolved
norm_i += abs(src[k])
norm_j += abs(tgt[k])
distances[count, i, j] += abs(src[k] - tgt[k])
elif method == "l2":
norm_i += src[k] ** 2
norm_j += tgt[k] ** 2
distances[count, i, j] += (src[k] - tgt[k]) ** 2
elif method == "cosine":
distances[count, i, j] += src[k] * tgt[k]
norm_i += src[k] ** 2
norm_j += tgt[k] ** 2

if method == "l1":
distances[count, i, j] /= norm_i + norm_j
elif method == "l2":
norm_i = np.sqrt(norm_i)
yger marked this conversation as resolved.
Show resolved Hide resolved
norm_j = np.sqrt(norm_j)
distances[count, i, j] = np.sqrt(distances[count, i, j])
distances[count, i, j] /= norm_i + norm_j
elif method == "cosine":
norm_i = np.sqrt(norm_i)
norm_j = np.sqrt(norm_j)
distances[count, i, j] /= norm_i * norm_j
distances[count, i, j] = 1 - distances[count, i, j]

if same_array:
distances[count, j, i] = distances[count, i, j]

if same_array and num_shifts != 0:
distances[num_shifts_both_sides - count - 1] = distances[count].T
return distances

else:

def _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, mask, method):

num_templates = templates_array.shape[0]
num_samples = templates_array.shape[1]
other_num_templates = other_templates_array.shape[0]

num_shifts_both_sides = 2 * num_shifts + 1
distances = np.ones((num_shifts_both_sides, num_templates, other_num_templates), dtype=np.float32)
same_array = np.array_equal(templates_array, other_templates_array)

# We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t
# So the matrix can be computed only for negative lags and be transposed

if same_array:
# optimisation when array are the same because of symetry in shift
shift_loop = range(-num_shifts, 1)
else:
shift_loop = range(-num_shifts, num_shifts + 1)

for count, shift in enumerate(shift_loop):
src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts]
tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift]
for i in range(num_templates):
src_template = src_sliced_templates[i]
overlapping_templates = np.flatnonzero(np.sum(mask[i], 1))
tgt_templates = tgt_sliced_templates[overlapping_templates]
for gcount, j in enumerate(overlapping_templates):
# symmetric values are handled later
if same_array and j < i:
# no need exhaustive looping when same template
continue
src = src_template[:, mask[i, j]].reshape(1, -1)
tgt = (tgt_templates[gcount][:, mask[i, j]]).reshape(1, -1)

if method == "l1":
norm_i = np.sum(np.abs(src))
norm_j = np.sum(np.abs(tgt))
distances[count, i, j] = np.sum(np.abs(src - tgt))
distances[count, i, j] /= norm_i + norm_j
elif method == "l2":
norm_i = np.linalg.norm(src, ord=2)
norm_j = np.linalg.norm(tgt, ord=2)
distances[count, i, j] = np.linalg.norm(src - tgt, ord=2)
distances[count, i, j] /= norm_i + norm_j
elif method == "cosine":
norm_i = np.linalg.norm(src, ord=2)
norm_j = np.linalg.norm(tgt, ord=2)
distances[count, i, j] = np.sum(src * tgt)
distances[count, i, j] /= norm_i * norm_j
distances[count, i, j] = 1 - distances[count, i, j]

if same_array:
distances[count, j, i] = distances[count, i, j]

if same_array and num_shifts != 0:
distances[num_shifts_both_sides - count - 1] = distances[count].T
return distances


def compute_similarity_with_templates_array(
templates_array, other_templates_array, method, support="union", num_shifts=0, sparsity=None, other_sparsity=None
):
import sklearn.metrics.pairwise

if method == "cosine_similarity":
method = "cosine"
Expand All @@ -171,9 +312,8 @@ def compute_similarity_with_templates_array(
num_channels = templates_array.shape[2]
other_num_templates = other_templates_array.shape[0]

same_array = np.array_equal(templates_array, other_templates_array)
mask = np.ones((num_templates, other_num_templates, num_channels), dtype=bool)

mask = None
if sparsity is not None and other_sparsity is not None:
if support == "intersection":
mask = np.logical_and(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :])
Expand All @@ -182,63 +322,9 @@ def compute_similarity_with_templates_array(
units_overlaps = np.sum(mask, axis=2) > 0
mask = np.logical_or(sparsity.mask[:, np.newaxis, :], other_sparsity.mask[np.newaxis, :, :])
mask[~units_overlaps] = False
if mask is not None:
units_overlaps = np.sum(mask, axis=2) > 0
overlapping_templates = {}
for i in range(num_templates):
overlapping_templates[i] = np.flatnonzero(units_overlaps[i])
else:
# here we make a dense mask and overlapping templates
overlapping_templates = {i: np.arange(other_num_templates) for i in range(num_templates)}
mask = np.ones((num_templates, other_num_templates, num_channels), dtype=bool)

assert num_shifts < num_samples, "max_lag is too large"
num_shifts_both_sides = 2 * num_shifts + 1
distances = np.ones((num_shifts_both_sides, num_templates, other_num_templates), dtype=np.float32)

# We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t
# So the matrix can be computed only for negative lags and be transposed

if same_array:
# optimisation when array are the same because of symetry in shift
shift_loop = range(-num_shifts, 1)
else:
shift_loop = range(-num_shifts, num_shifts + 1)

for count, shift in enumerate(shift_loop):
src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts]
tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift]
for i in range(num_templates):
src_template = src_sliced_templates[i]
tgt_templates = tgt_sliced_templates[overlapping_templates[i]]
for gcount, j in enumerate(overlapping_templates[i]):
# symmetric values are handled later
if same_array and j < i:
# no need exhaustive looping when same template
continue
src = src_template[:, mask[i, j]].reshape(1, -1)
tgt = (tgt_templates[gcount][:, mask[i, j]]).reshape(1, -1)

if method == "l1":
norm_i = np.sum(np.abs(src))
norm_j = np.sum(np.abs(tgt))
distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l1").item()
distances[count, i, j] /= norm_i + norm_j
elif method == "l2":
norm_i = np.linalg.norm(src, ord=2)
norm_j = np.linalg.norm(tgt, ord=2)
distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="l2").item()
distances[count, i, j] /= norm_i + norm_j
else:
distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(
src, tgt, metric="cosine"
).item()

if same_array:
distances[count, j, i] = distances[count, i, j]

if same_array and num_shifts != 0:
distances[num_shifts_both_sides - count - 1] = distances[count].T
distances = _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, mask, method)

distances = np.min(distances, axis=0)
similarity = 1 - distances
Expand Down