Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed Sep 24, 2024
1 parent b77d0b0 commit e383292
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/spikeinterface/postprocessing/template_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,9 +232,9 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num

if same_array:
# optimisation when array are the same because of symetry in shift
shift_loop = range(-num_shifts, 1)
shift_loop = list(range(-num_shifts, 1))
else:
shift_loop = range(-num_shifts, num_shifts + 1)
shift_loop = list(range(-num_shifts, num_shifts + 1))

if method == "l1":
metric = 0
Expand All @@ -243,15 +243,17 @@ def _compute_similarity_matrix_numba(templates_array, other_templates_array, num
elif method == "cosine":
metric = 2

for count, shift in enumerate(shift_loop):
for count in range(len(shift_loop)):
shift = shift_loop[count]
src_sliced_templates = templates_array[:, num_shifts : num_samples - num_shifts]
tgt_sliced_templates = other_templates_array[:, num_shifts + shift : num_samples - num_shifts + shift]
for i in numba.prange(num_templates):
src_template = src_sliced_templates[i]
overlapping_templates = np.flatnonzero(np.sum(mask[i], 1))
tgt_templates = tgt_sliced_templates[overlapping_templates]
for gcount, j in enumerate(overlapping_templates):
for gcount in range(len(overlapping_templates)):

j = overlapping_templates[gcount]
# symmetric values are handled later
if same_array and j < i:
# no need exhaustive looping when same template
Expand Down

0 comments on commit e383292

Please sign in to comment.