From 1d21b68619605151d1571402fa89d5c71bcc1c05 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 9 Oct 2023 11:16:20 +0200 Subject: [PATCH 1/3] wip tdc2 merge with template. --- .../sorters/internal/tridesclous2.py | 39 +++--- .../sortingcomponents/clustering/merge.py | 115 +++++++++++++++++- 2 files changed, 135 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 054596e9b3..11be2c3580 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -229,24 +229,23 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording, features_folder, radius_um=merge_radius_um, - method="project_distribution", + # method="project_distribution", + # method_kwargs=dict( + # waveforms_sparse_mask=sparse_mask, + # feature_name="sparse_wfs", + # projection="centroid", + # criteria="distrib_overlap", + # threshold_overlap=0.3, + # min_cluster_size=min_cluster_size + 1, + # num_shift=5, + # ), + method="normalized_template_diff", method_kwargs=dict( - # neighbours_mask=neighbours_mask, waveforms_sparse_mask=sparse_mask, - # feature_name="sparse_tsvd", - feature_name="sparse_wfs", - # projection='lda', - projection="centroid", - # criteria='diptest', - # threshold_diptest=0.5, - # criteria="percentile", - # threshold_percentile=80., - criteria="distrib_overlap", - threshold_overlap=0.3, + threshold_diff=0.2, min_cluster_size=min_cluster_size + 1, - # num_shift=0 num_shift=5, - ), + ), **job_kwargs, ) @@ -255,10 +254,20 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): new_peaks = peaks.copy() new_peaks["sample_index"] -= peak_shifts + # clean very small cluster before peeler + minimum_cluster_size = 25 + labels_set, count = np.unique(post_merge_label, return_counts=True) + to_remove = labels_set[count < minimum_cluster_size] + print(to_remove) + mask = np.isin(post_merge_label, to_remove) + post_merge_label[mask] = -1 + + # final label sets labels_set = np.unique(post_merge_label) labels_set = labels_set[labels_set >= 0] - mask = post_merge_label >= 0 + + mask = post_merge_label >= 0 sorting_temp = NumpySorting.from_times_labels( new_peaks["sample_index"][mask], post_merge_label[mask], diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 45090452dc..24cbedfb8c 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -256,7 +256,7 @@ def find_merge_pairs( sparse_wfs, sparse_mask, radius_um=70, - method="waveforms_lda", + method="project_distribution", method_kwargs={}, **job_kwargs # n_jobs=1, @@ -308,7 +308,8 @@ def find_merge_pairs( max_workers=n_jobs, initializer=find_pair_worker_init, mp_context=get_context(mp_context), - initargs=(recording, features_dict_or_folder, peak_labels, method, method_kwargs, max_threads_per_process), + initargs=(recording, features_dict_or_folder, peak_labels, labels_set, templates, + method, method_kwargs, max_threads_per_process), ) as pool: jobs = [] for ind0, ind1 in zip(indices0, indices1): @@ -338,13 +339,16 @@ def find_merge_pairs( def find_pair_worker_init( - recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process + recording, features_dict_or_folder, original_labels, + labels_set, templates, method, method_kwargs, max_threads_per_process ): global _ctx _ctx = {} _ctx["recording"] = recording _ctx["original_labels"] = original_labels + _ctx["labels_set"] = labels_set + _ctx["templates"] = templates _ctx["method"] = method _ctx["method_kwargs"] = method_kwargs _ctx["method_class"] = find_pair_method_dict[method] @@ -364,8 +368,10 @@ def find_pair_function_wrapper(label0, label1): global _ctx with threadpool_limits(limits=_ctx["max_threads_per_process"]): is_merge, label0, label1, shift, merge_value = _ctx["method_class"].merge( - label0, label1, _ctx["original_labels"], _ctx["peaks"], _ctx["features"], **_ctx["method_kwargs"] + label0, label1, _ctx["labels_set"], _ctx["templates"], + _ctx["original_labels"], _ctx["peaks"], _ctx["features"], **_ctx["method_kwargs"] ) + return is_merge, label0, label1, shift, merge_value @@ -388,6 +394,8 @@ class ProjectDistribution: def merge( label0, label1, + labels_set, + templates, original_labels, peaks, features, @@ -578,7 +586,106 @@ def merge( return is_merge, label0, label1, final_shift, merge_value +class NormalizedTemplateDiff: + """ + Compute the normalized (some kind of) template differences. + And merge if below a threhold. + Do this at several shift. + + """ + + name = "normalized_template_diff" + + @staticmethod + def merge( + label0, + label1, + labels_set, + templates, + original_labels, + peaks, + features, + waveforms_sparse_mask=None, + threshold_diff=0.05, + min_cluster_size=50, + num_shift=5, + ): + + assert waveforms_sparse_mask is not None + + (inds0,) = np.nonzero(original_labels == label0) + chans0 = np.unique(peaks["channel_index"][inds0]) + target_chans0 = np.flatnonzero(np.all(waveforms_sparse_mask[chans0, :], axis=0)) + + (inds1,) = np.nonzero(original_labels == label1) + chans1 = np.unique(peaks["channel_index"][inds1]) + target_chans1 = np.flatnonzero(np.all(waveforms_sparse_mask[chans1, :], axis=0)) + + # if inds0.size < min_cluster_size or inds1.size < min_cluster_size: + # is_merge = False + # merge_value = 0 + # final_shift = 0 + # return is_merge, label0, label1, final_shift, merge_value + + target_chans = np.intersect1d(target_chans0, target_chans1) + union_chans = np.union1d(target_chans0, target_chans1) + + ind0 = list(labels_set).index(label0) + template0 = templates[ind0, :, target_chans] + + ind1 = list(labels_set).index(label1) + template1 = templates[ind1, :, target_chans] + + + num_samples = template0.shape[0] + # norm = np.mean(np.abs(template0)) + np.mean(np.abs(template1)) + norm = np.mean(np.abs(template0) + np.abs(template1)) + all_shift_diff = [] + for shift in range(-num_shift, num_shift + 1): + temp0 = template0[num_shift : num_samples - num_shift, :] + temp1 = template1[num_shift + shift : num_samples - num_shift + shift, :] + d = np.mean(np.abs(temp0 - temp1)) / (norm) + all_shift_diff.append(d) + normed_diff = np.min(all_shift_diff) + + is_merge = normed_diff < threshold_diff + if is_merge: + merge_value = normed_diff + final_shift = np.argmin(all_shift_diff) - num_shift + else: + final_shift = 0 + merge_value = np.nan + + + # DEBUG = False + DEBUG = True + if DEBUG and normed_diff < 0.2: + # if DEBUG: + + import matplotlib.pyplot as plt + + fig, ax = plt.subplots() + + m0 = template0.flatten() + m1 = template1.flatten() + + ax.plot(m0, color="C0", label=f"{label0} {inds0.size}") + ax.plot(m1, color="C1", label=f"{label1} {inds1.size}") + + ax.set_title(f"union{union_chans.size} intersect{target_chans.size} \n {normed_diff:.3f} {final_shift} {is_merge}") + ax.legend() + plt.show() + + + + + + return is_merge, label0, label1, final_shift, merge_value + + + find_pair_method_list = [ ProjectDistribution, + NormalizedTemplateDiff, ] find_pair_method_dict = {e.name: e for e in find_pair_method_list} From 64d507c7374a609955c69ef61df4e9cde5a7a04d Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Mon, 9 Oct 2023 22:19:45 +0200 Subject: [PATCH 2/3] remove print --- src/spikeinterface/sorters/internal/tridesclous2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 11be2c3580..ddabd46657 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -258,7 +258,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): minimum_cluster_size = 25 labels_set, count = np.unique(post_merge_label, return_counts=True) to_remove = labels_set[count < minimum_cluster_size] - print(to_remove) + mask = np.isin(post_merge_label, to_remove) post_merge_label[mask] = -1 From 0fd84922dd9d4ae54bcc0183a98d7a50a1e9f50c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Oct 2023 07:21:34 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../sorters/internal/tridesclous2.py | 3 +- .../sortingcomponents/clustering/merge.py | 46 ++++++++++++------- 2 files changed, 31 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index ddabd46657..e256915fa6 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -245,7 +245,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): threshold_diff=0.2, min_cluster_size=min_cluster_size + 1, num_shift=5, - ), + ), **job_kwargs, ) @@ -266,7 +266,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): labels_set = np.unique(post_merge_label) labels_set = labels_set[labels_set >= 0] - mask = post_merge_label >= 0 sorting_temp = NumpySorting.from_times_labels( new_peaks["sample_index"][mask], diff --git a/src/spikeinterface/sortingcomponents/clustering/merge.py b/src/spikeinterface/sortingcomponents/clustering/merge.py index 24cbedfb8c..c46f214192 100644 --- a/src/spikeinterface/sortingcomponents/clustering/merge.py +++ b/src/spikeinterface/sortingcomponents/clustering/merge.py @@ -308,8 +308,16 @@ def find_merge_pairs( max_workers=n_jobs, initializer=find_pair_worker_init, mp_context=get_context(mp_context), - initargs=(recording, features_dict_or_folder, peak_labels, labels_set, templates, - method, method_kwargs, max_threads_per_process), + initargs=( + recording, + features_dict_or_folder, + peak_labels, + labels_set, + templates, + method, + method_kwargs, + max_threads_per_process, + ), ) as pool: jobs = [] for ind0, ind1 in zip(indices0, indices1): @@ -339,8 +347,14 @@ def find_merge_pairs( def find_pair_worker_init( - recording, features_dict_or_folder, original_labels, - labels_set, templates, method, method_kwargs, max_threads_per_process + recording, + features_dict_or_folder, + original_labels, + labels_set, + templates, + method, + method_kwargs, + max_threads_per_process, ): global _ctx _ctx = {} @@ -368,8 +382,14 @@ def find_pair_function_wrapper(label0, label1): global _ctx with threadpool_limits(limits=_ctx["max_threads_per_process"]): is_merge, label0, label1, shift, merge_value = _ctx["method_class"].merge( - label0, label1, _ctx["labels_set"], _ctx["templates"], - _ctx["original_labels"], _ctx["peaks"], _ctx["features"], **_ctx["method_kwargs"] + label0, + label1, + _ctx["labels_set"], + _ctx["templates"], + _ctx["original_labels"], + _ctx["peaks"], + _ctx["features"], + **_ctx["method_kwargs"], ) return is_merge, label0, label1, shift, merge_value @@ -610,7 +630,6 @@ def merge( min_cluster_size=50, num_shift=5, ): - assert waveforms_sparse_mask is not None (inds0,) = np.nonzero(original_labels == label0) @@ -636,7 +655,6 @@ def merge( ind1 = list(labels_set).index(label1) template1 = templates[ind1, :, target_chans] - num_samples = template0.shape[0] # norm = np.mean(np.abs(template0)) + np.mean(np.abs(template1)) norm = np.mean(np.abs(template0) + np.abs(template1)) @@ -656,11 +674,10 @@ def merge( final_shift = 0 merge_value = np.nan - # DEBUG = False DEBUG = True if DEBUG and normed_diff < 0.2: - # if DEBUG: + # if DEBUG: import matplotlib.pyplot as plt @@ -672,18 +689,15 @@ def merge( ax.plot(m0, color="C0", label=f"{label0} {inds0.size}") ax.plot(m1, color="C1", label=f"{label1} {inds1.size}") - ax.set_title(f"union{union_chans.size} intersect{target_chans.size} \n {normed_diff:.3f} {final_shift} {is_merge}") + ax.set_title( + f"union{union_chans.size} intersect{target_chans.size} \n {normed_diff:.3f} {final_shift} {is_merge}" + ) ax.legend() plt.show() - - - - return is_merge, label0, label1, final_shift, merge_value - find_pair_method_list = [ ProjectDistribution, NormalizedTemplateDiff,