From 84812d54e667594f77327bc7c1dc7347a04d5d4e Mon Sep 17 00:00:00 2001 From: Pierre Yger Date: Wed, 26 Jun 2024 22:32:57 +0200 Subject: [PATCH] Bug with new template similarity --- .../core/analyzer_extension_core.py | 3 +++ src/spikeinterface/core/sortinganalyzer.py | 2 +- .../postprocessing/template_similarity.py | 27 ++++++++++--------- 3 files changed, 18 insertions(+), 14 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 91382d1d77..523e048621 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -597,6 +597,9 @@ def _merge_extension_data( weights[count] = counts[id] weights /= weights.sum() new_data[key][unit_ind] = (arr[keep_unit_indices, :, :] * weights[:, np.newaxis, np.newaxis]).sum(0) + chan_ids = new_sorting_analyzer.sparsity.unit_id_to_channel_indices[unit_id] + mask = ~np.isin(np.arange(arr.shape[2]), chan_ids) + new_data[key][unit_ind][:, mask] = 0 return new_data diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index fc47955814..d3514a7d21 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -650,7 +650,7 @@ def _save_or_select_or_merge( from spikeinterface.core.sorting_tools import get_ids_after_merging new_unit_ids = get_ids_after_merging(self.sorting, units_to_merge, new_unit_ids=unit_ids) - + if self.has_recording(): recording = self._recording elif self.has_temporary_recording(): diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 3d778c98b8..2ace0945d5 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -68,13 +68,14 @@ def _merge_extension_data( self, units_to_merge, new_unit_ids, new_sorting_analyzer, kept_indices=None, verbose=False, **job_kwargs ): num_shifts = int(self.params["max_lag_ms"] * self.sorting_analyzer.sampling_frequency / 1000) - templates_array = get_dense_templates_array(new_sorting_analyzer) + templates_array = get_dense_templates_array(new_sorting_analyzer, return_scaled=self.sorting_analyzer.return_scaled) arr = self.data["similarity"] sparsity = new_sorting_analyzer.sparsity all_new_unit_ids = new_sorting_analyzer.unit_ids new_similarity = np.zeros((len(all_new_unit_ids), len(all_new_unit_ids)), dtype=arr.dtype) - for unit_ind1, unit_id1 in enumerate(all_new_unit_ids): + for unit_id1 in all_new_unit_ids: + unit_ind1 = new_sorting_analyzer.sorting.id_to_index(unit_id1) template1 = templates_array[unit_ind1][np.newaxis, :] sparsity1 = ChannelSparsity(sparsity.mask[unit_ind1][np.newaxis, :], [unit_id1], sparsity.channel_ids) @@ -84,9 +85,9 @@ def _merge_extension_data( new_spk1 = False i = self.sorting_analyzer.sorting.id_to_index(unit_id1) - for unit_ind2, unit_id2 in enumerate(all_new_unit_ids[unit_ind1:]): - position = unit_ind1 + unit_ind2 - template2 = templates_array[position][np.newaxis, :] + for unit_id2 in all_new_unit_ids[unit_ind1:]: + unit_ind2 = new_sorting_analyzer.sorting.id_to_index(unit_id2) + template2 = templates_array[unit_ind2][np.newaxis, :] sparsity2 = ChannelSparsity(sparsity.mask[unit_ind2][np.newaxis, :], [unit_id2], sparsity.channel_ids) if unit_id2 in new_unit_ids: new_spk2 = True @@ -95,7 +96,7 @@ def _merge_extension_data( j = self.sorting_analyzer.sorting.id_to_index(unit_id2) if new_spk1 or new_spk2: - new_similarity[unit_ind1, position] = compute_similarity_with_templates_array( + new_similarity[unit_ind1, unit_ind2] = compute_similarity_with_templates_array( template1, template2, method=self.params["method"], @@ -105,9 +106,9 @@ def _merge_extension_data( other_sparsity=sparsity2, ) else: - new_similarity[unit_ind1, position] = arr[i, j] + new_similarity[unit_ind1, unit_ind2] = arr[i, j] - new_similarity[position, unit_ind1] = new_similarity[unit_ind1, position] + new_similarity[unit_ind2, unit_ind1] = new_similarity[unit_ind1, unit_ind2] return dict(similarity=new_similarity) @@ -195,7 +196,7 @@ def compute_similarity_with_templates_array( tgt_templates = tgt_sliced_templates[overlapping_templates[i]] for gcount, j in enumerate(overlapping_templates[i]): # symmetric values are handled later - if num_templates == other_num_templates and j < i: + if j < i: continue src = src_template[:, mask[i, j]].reshape(1, -1) tgt = (tgt_templates[gcount][:, mask[i, j]]).reshape(1, -1) @@ -212,11 +213,11 @@ def compute_similarity_with_templates_array( distances[count, i, j] /= norm_i + norm_j else: distances[count, i, j] = sklearn.metrics.pairwise.pairwise_distances(src, tgt, metric="cosine") - if num_templates == other_num_templates: - distances[count, j, i] = distances[count, i, j] + + distances[count, j, i] = distances[count, i, j] - if num_shifts != 0: - distances[num_shifts_both_sides - count - 1] = distances[count].T + if num_shifts != 0: + distances[num_shifts_both_sides - count - 1] = distances[count].T distances = np.min(distances, axis=0) similarity = 1 - distances