Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed Sep 24, 2024
1 parent 5bac5ed commit 7d8ef31
Showing 1 changed file with 81 additions and 69 deletions.
150 changes: 81 additions & 69 deletions src/spikeinterface/postprocessing/template_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,70 @@ def _get_data(self):
register_result_extension(ComputeTemplateSimilarity)
compute_template_similarity = ComputeTemplateSimilarity.function_factory()

def _compute_similarity_matrix_numpy(templates_array, other_templates_array, num_shifts, mask, method):

num_templates = templates_array.shape[0]
num_samples = templates_array.shape[1]
other_num_templates = other_templates_array.shape[0]

num_shifts_both_sides = 2 * num_shifts + 1
distances = np.ones((num_shifts_both_sides, num_templates, other_num_templates), dtype=np.float32)
same_array = np.array_equal(templates_array, other_templates_array)

# We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t
# So the matrix can be computed only for negative lags and be transposed

if same_array:
# optimisation when array are the same because of symetry in shift
shift_loop = range(-num_shifts, 1)
else:
shift_loop = range(-num_shifts, num_shifts + 1)

for count, shift in enumerate(shift_loop):
src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts]
tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift]
for i in range(num_templates):
src_template = src_sliced_templates[i]
overlapping_templates = np.flatnonzero(np.sum(mask[i], 1))
tgt_templates = tgt_sliced_templates[overlapping_templates]
for gcount, j in enumerate(overlapping_templates):
# symmetric values are handled later
if same_array and j < i:
# no need exhaustive looping when same template
continue
src = src_template[:, mask[i, j]].reshape(1, -1)
tgt = (tgt_templates[gcount][:, mask[i, j]]).reshape(1, -1)

if method == "l1":
norm_i = np.sum(np.abs(src))
norm_j = np.sum(np.abs(tgt))
distances[count, i, j] = np.sum(np.abs(src - tgt))
distances[count, i, j] /= norm_i + norm_j
elif method == "l2":
norm_i = np.linalg.norm(src, ord=2)
norm_j = np.linalg.norm(tgt, ord=2)
distances[count, i, j] = np.linalg.norm(src - tgt, ord=2)
distances[count, i, j] /= norm_i + norm_j
elif method == "cosine":
norm_i = np.linalg.norm(src, ord=2)
norm_j = np.linalg.norm(tgt, ord=2)
distances[count, i, j] = np.sum(src * tgt)
distances[count, i, j] /= norm_i * norm_j
distances[count, i, j] = 1 - distances[count, i, j]

if same_array:
distances[count, j, i] = distances[count, i, j]

if same_array and num_shifts != 0:
distances[num_shifts_both_sides - count - 1] = distances[count].T
return distances


if HAVE_NUMBA:

from math import sqrt
@numba.jit(nopython=True, parallel=True, fastmath=True, nogil=True)
def _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, mask, method):
def _compute_similarity_matrix_numba(templates_array, other_templates_array, num_shifts, mask, method):
num_templates = templates_array.shape[0]
num_samples = templates_array.shape[1]
other_num_templates = other_templates_array.shape[0]
Expand All @@ -175,6 +234,13 @@ def _compute_similarity_matrix(templates_array, other_templates_array, num_shift
else:
shift_loop = range(-num_shifts, num_shifts + 1)

if method == 'l1':
metric = 0
elif method == 'l2':
metric = 1
elif method == 'cosine':
metric = 2

for count, shift in enumerate(shift_loop):
src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts]
tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift]
Expand All @@ -195,29 +261,29 @@ def _compute_similarity_matrix(templates_array, other_templates_array, num_shift
norm_j = 0

for k in range(len(src)):
if method == "l1":
if metric == 0:
norm_i += abs(src[k])
norm_j += abs(tgt[k])
distances[count, i, j] += abs(src[k] - tgt[k])
elif method == "l2":
elif metric == 1:
norm_i += src[k] ** 2
norm_j += tgt[k] ** 2
distances[count, i, j] += (src[k] - tgt[k]) ** 2
elif method == "cosine":
elif metric == 2:
distances[count, i, j] += src[k] * tgt[k]
norm_i += src[k] ** 2
norm_j += tgt[k] ** 2

if method == "l1":
if metric == 0:
distances[count, i, j] /= norm_i + norm_j
elif method == "l2":
norm_i = np.sqrt(norm_i)
norm_j = np.sqrt(norm_j)
distances[count, i, j] = np.sqrt(distances[count, i, j])
elif metric == 1:
norm_i = sqrt(norm_i)
norm_j = sqrt(norm_j)
distances[count, i, j] = sqrt(distances[count, i, j])
distances[count, i, j] /= norm_i + norm_j
elif method == "cosine":
norm_i = np.sqrt(norm_i)
norm_j = np.sqrt(norm_j)
elif metric == 2:
norm_i = sqrt(norm_i)
norm_j = sqrt(norm_j)
distances[count, i, j] /= norm_i * norm_j
distances[count, i, j] = 1 - distances[count, i, j]

Expand All @@ -226,67 +292,13 @@ def _compute_similarity_matrix(templates_array, other_templates_array, num_shift

if same_array and num_shifts != 0:
distances[num_shifts_both_sides - count - 1] = distances[count].T

return distances

_compute_similarity_matrix = _compute_similarity_matrix_numba
else:
_compute_similarity_matrix = _compute_similarity_matrix_numpy

def _compute_similarity_matrix(templates_array, other_templates_array, num_shifts, mask, method):

num_templates = templates_array.shape[0]
num_samples = templates_array.shape[1]
other_num_templates = other_templates_array.shape[0]

num_shifts_both_sides = 2 * num_shifts + 1
distances = np.ones((num_shifts_both_sides, num_templates, other_num_templates), dtype=np.float32)
same_array = np.array_equal(templates_array, other_templates_array)

# We can use the fact that dist[i,j] at lag t is equal to dist[j,i] at time -t
# So the matrix can be computed only for negative lags and be transposed

if same_array:
# optimisation when array are the same because of symetry in shift
shift_loop = range(-num_shifts, 1)
else:
shift_loop = range(-num_shifts, num_shifts + 1)

for count, shift in enumerate(shift_loop):
src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts]
tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift]
for i in range(num_templates):
src_template = src_sliced_templates[i]
overlapping_templates = np.flatnonzero(np.sum(mask[i], 1))
tgt_templates = tgt_sliced_templates[overlapping_templates]
for gcount, j in enumerate(overlapping_templates):
# symmetric values are handled later
if same_array and j < i:
# no need exhaustive looping when same template
continue
src = src_template[:, mask[i, j]].reshape(1, -1)
tgt = (tgt_templates[gcount][:, mask[i, j]]).reshape(1, -1)

if method == "l1":
norm_i = np.sum(np.abs(src))
norm_j = np.sum(np.abs(tgt))
distances[count, i, j] = np.sum(np.abs(src - tgt))
distances[count, i, j] /= norm_i + norm_j
elif method == "l2":
norm_i = np.linalg.norm(src, ord=2)
norm_j = np.linalg.norm(tgt, ord=2)
distances[count, i, j] = np.linalg.norm(src - tgt, ord=2)
distances[count, i, j] /= norm_i + norm_j
elif method == "cosine":
norm_i = np.linalg.norm(src, ord=2)
norm_j = np.linalg.norm(tgt, ord=2)
distances[count, i, j] = np.sum(src * tgt)
distances[count, i, j] /= norm_i * norm_j
distances[count, i, j] = 1 - distances[count, i, j]

if same_array:
distances[count, j, i] = distances[count, i, j]

if same_array and num_shifts != 0:
distances[num_shifts_both_sides - count - 1] = distances[count].T
return distances


def compute_similarity_with_templates_array(
Expand Down

0 comments on commit 7d8ef31

Please sign in to comment.