Skip to content

Commit

Permalink
Bug with new template similarity
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed Jun 26, 2024
1 parent 531897e commit 84812d5
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 14 deletions.
3 changes: 3 additions & 0 deletions src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,9 @@ def _merge_extension_data(
weights[count] = counts[id]
weights /= weights.sum()
new_data[key][unit_ind] = (arr[keep_unit_indices, :, :] * weights[:, np.newaxis, np.newaxis]).sum(0)
chan_ids = new_sorting_analyzer.sparsity.unit_id_to_channel_indices[unit_id]
mask = ~np.isin(np.arange(arr.shape[2]), chan_ids)
new_data[key][unit_ind][:, mask] = 0

return new_data

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ def _save_or_select_or_merge(
from spikeinterface.core.sorting_tools import get_ids_after_merging

new_unit_ids = get_ids_after_merging(self.sorting, units_to_merge, new_unit_ids=unit_ids)

if self.has_recording():
recording = self._recording
elif self.has_temporary_recording():
Expand Down
27 changes: 14 additions & 13 deletions src/spikeinterface/postprocessing/template_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,14 @@ def _merge_extension_data(
self, units_to_merge, new_unit_ids, new_sorting_analyzer, kept_indices=None, verbose=False, **job_kwargs
):
num_shifts = int(self.params["max_lag_ms"] * self.sorting_analyzer.sampling_frequency / 1000)
templates_array = get_dense_templates_array(new_sorting_analyzer)
templates_array = get_dense_templates_array(new_sorting_analyzer, return_scaled=self.sorting_analyzer.return_scaled)
arr = self.data["similarity"]
sparsity = new_sorting_analyzer.sparsity
all_new_unit_ids = new_sorting_analyzer.unit_ids
new_similarity = np.zeros((len(all_new_unit_ids), len(all_new_unit_ids)), dtype=arr.dtype)

for unit_ind1, unit_id1 in enumerate(all_new_unit_ids):
for unit_id1 in all_new_unit_ids:
unit_ind1 = new_sorting_analyzer.sorting.id_to_index(unit_id1)
template1 = templates_array[unit_ind1][np.newaxis, :]
sparsity1 = ChannelSparsity(sparsity.mask[unit_ind1][np.newaxis, :], [unit_id1], sparsity.channel_ids)

Expand All @@ -84,9 +85,9 @@ def _merge_extension_data(
new_spk1 = False
i = self.sorting_analyzer.sorting.id_to_index(unit_id1)

for unit_ind2, unit_id2 in enumerate(all_new_unit_ids[unit_ind1:]):
position = unit_ind1 + unit_ind2
template2 = templates_array[position][np.newaxis, :]
for unit_id2 in all_new_unit_ids[unit_ind1:]:
unit_ind2 = new_sorting_analyzer.sorting.id_to_index(unit_id2)
template2 = templates_array[unit_ind2][np.newaxis, :]
sparsity2 = ChannelSparsity(sparsity.mask[unit_ind2][np.newaxis, :], [unit_id2], sparsity.channel_ids)
if unit_id2 in new_unit_ids:
new_spk2 = True
Expand All @@ -95,7 +96,7 @@ def _merge_extension_data(
j = self.sorting_analyzer.sorting.id_to_index(unit_id2)

if new_spk1 or new_spk2:
new_similarity[unit_ind1, position] = compute_similarity_with_templates_array(
new_similarity[unit_ind1, unit_ind2] = compute_similarity_with_templates_array(
template1,
template2,
method=self.params["method"],
Expand All @@ -105,9 +106,9 @@ def _merge_extension_data(
other_sparsity=sparsity2,
)
else:
new_similarity[unit_ind1, position] = arr[i, j]
new_similarity[unit_ind1, unit_ind2] = arr[i, j]

new_similarity[position, unit_ind1] = new_similarity[unit_ind1, position]
new_similarity[unit_ind2, unit_ind1] = new_similarity[unit_ind1, unit_ind2]

return dict(similarity=new_similarity)

Expand Down Expand Up @@ -195,7 +196,7 @@ def compute_similarity_with_templates_array(
tgt_templates = tgt_sliced_templates[overlapping_templates[i]]
for gcount, j in enumerate(overlapping_templates[i]):
# symmetric values are handled later
if num_templates == other_num_templates and j < i:
if j < i:
continue
src = src_template[:, mask[i, j]].reshape(1, -1)
tgt = (tgt_templates[gcount][:, mask[i, j]]).reshape(1, -1)
Expand All @@ -212,11 +213,11 @@ def compute_similarity_with_templates_array(
distances[count, i, j] /= norm_i + norm_j
else:
distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="cosine")
if num_templates == other_num_templates:
distances[count, j, i] = distances[count, i, j]

distances[count, j, i] = distances[count, i, j]

if num_shifts != 0:
distances[num_shifts_both_sides - count - 1] = distances[count].T
if num_shifts != 0:
distances[num_shifts_both_sides - count - 1] = distances[count].T

distances = np.min(distances, axis=0)
similarity = 1 - distances
Expand Down

0 comments on commit 84812d5

Please sign in to comment.