Skip to content

Commit

Permalink
Merge pull request #2087 from samuelgarcia/tdc_2
Browse files Browse the repository at this point in the history
tridesclous2 update
  • Loading branch information
samuelgarcia authored Oct 12, 2023
2 parents 5400abb + 51fd0e8 commit 728351a
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 18 deletions.
36 changes: 22 additions & 14 deletions src/spikeinterface/sorters/internal/tridesclous2.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,22 +229,21 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
recording,
features_folder,
radius_um=merge_radius_um,
method="project_distribution",
# method="project_distribution",
# method_kwargs=dict(
# waveforms_sparse_mask=sparse_mask,
# feature_name="sparse_wfs",
# projection="centroid",
# criteria="distrib_overlap",
# threshold_overlap=0.3,
# min_cluster_size=min_cluster_size + 1,
# num_shift=5,
# ),
method="normalized_template_diff",
method_kwargs=dict(
# neighbours_mask=neighbours_mask,
waveforms_sparse_mask=sparse_mask,
# feature_name="sparse_tsvd",
feature_name="sparse_wfs",
# projection='lda',
projection="centroid",
# criteria='diptest',
# threshold_diptest=0.5,
# criteria="percentile",
# threshold_percentile=80.,
criteria="distrib_overlap",
threshold_overlap=0.3,
threshold_diff=0.2,
min_cluster_size=min_cluster_size + 1,
# num_shift=0
num_shift=5,
),
**job_kwargs,
Expand All @@ -255,10 +254,19 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
new_peaks = peaks.copy()
new_peaks["sample_index"] -= peak_shifts

# clean very small cluster before peeler
minimum_cluster_size = 25
labels_set, count = np.unique(post_merge_label, return_counts=True)
to_remove = labels_set[count < minimum_cluster_size]

mask = np.isin(post_merge_label, to_remove)
post_merge_label[mask] = -1

# final label sets
labels_set = np.unique(post_merge_label)
labels_set = labels_set[labels_set >= 0]
mask = post_merge_label >= 0

mask = post_merge_label >= 0
sorting_temp = NumpySorting.from_times_labels(
new_peaks["sample_index"][mask],
post_merge_label[mask],
Expand Down
129 changes: 125 additions & 4 deletions src/spikeinterface/sortingcomponents/clustering/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def find_merge_pairs(
sparse_wfs,
sparse_mask,
radius_um=70,
method="waveforms_lda",
method="project_distribution",
method_kwargs={},
**job_kwargs
# n_jobs=1,
Expand Down Expand Up @@ -308,7 +308,16 @@ def find_merge_pairs(
max_workers=n_jobs,
initializer=find_pair_worker_init,
mp_context=get_context(mp_context),
initargs=(recording, features_dict_or_folder, peak_labels, method, method_kwargs, max_threads_per_process),
initargs=(
recording,
features_dict_or_folder,
peak_labels,
labels_set,
templates,
method,
method_kwargs,
max_threads_per_process,
),
) as pool:
jobs = []
for ind0, ind1 in zip(indices0, indices1):
Expand Down Expand Up @@ -338,13 +347,22 @@ def find_merge_pairs(


def find_pair_worker_init(
recording, features_dict_or_folder, original_labels, method, method_kwargs, max_threads_per_process
recording,
features_dict_or_folder,
original_labels,
labels_set,
templates,
method,
method_kwargs,
max_threads_per_process,
):
global _ctx
_ctx = {}

_ctx["recording"] = recording
_ctx["original_labels"] = original_labels
_ctx["labels_set"] = labels_set
_ctx["templates"] = templates
_ctx["method"] = method
_ctx["method_kwargs"] = method_kwargs
_ctx["method_class"] = find_pair_method_dict[method]
Expand All @@ -364,8 +382,16 @@ def find_pair_function_wrapper(label0, label1):
global _ctx
with threadpool_limits(limits=_ctx["max_threads_per_process"]):
is_merge, label0, label1, shift, merge_value = _ctx["method_class"].merge(
label0, label1, _ctx["original_labels"], _ctx["peaks"], _ctx["features"], **_ctx["method_kwargs"]
label0,
label1,
_ctx["labels_set"],
_ctx["templates"],
_ctx["original_labels"],
_ctx["peaks"],
_ctx["features"],
**_ctx["method_kwargs"],
)

return is_merge, label0, label1, shift, merge_value


Expand All @@ -388,6 +414,8 @@ class ProjectDistribution:
def merge(
label0,
label1,
labels_set,
templates,
original_labels,
peaks,
features,
Expand Down Expand Up @@ -578,7 +606,100 @@ def merge(
return is_merge, label0, label1, final_shift, merge_value


class NormalizedTemplateDiff:
"""
Compute the normalized (some kind of) template differences.
And merge if below a threhold.
Do this at several shift.
"""

name = "normalized_template_diff"

@staticmethod
def merge(
label0,
label1,
labels_set,
templates,
original_labels,
peaks,
features,
waveforms_sparse_mask=None,
threshold_diff=0.05,
min_cluster_size=50,
num_shift=5,
):
assert waveforms_sparse_mask is not None

(inds0,) = np.nonzero(original_labels == label0)
chans0 = np.unique(peaks["channel_index"][inds0])
target_chans0 = np.flatnonzero(np.all(waveforms_sparse_mask[chans0, :], axis=0))

(inds1,) = np.nonzero(original_labels == label1)
chans1 = np.unique(peaks["channel_index"][inds1])
target_chans1 = np.flatnonzero(np.all(waveforms_sparse_mask[chans1, :], axis=0))

# if inds0.size < min_cluster_size or inds1.size < min_cluster_size:
# is_merge = False
# merge_value = 0
# final_shift = 0
# return is_merge, label0, label1, final_shift, merge_value

target_chans = np.intersect1d(target_chans0, target_chans1)
union_chans = np.union1d(target_chans0, target_chans1)

ind0 = list(labels_set).index(label0)
template0 = templates[ind0, :, target_chans]

ind1 = list(labels_set).index(label1)
template1 = templates[ind1, :, target_chans]

num_samples = template0.shape[0]
# norm = np.mean(np.abs(template0)) + np.mean(np.abs(template1))
norm = np.mean(np.abs(template0) + np.abs(template1))
all_shift_diff = []
for shift in range(-num_shift, num_shift + 1):
temp0 = template0[num_shift : num_samples - num_shift, :]
temp1 = template1[num_shift + shift : num_samples - num_shift + shift, :]
d = np.mean(np.abs(temp0 - temp1)) / (norm)
all_shift_diff.append(d)
normed_diff = np.min(all_shift_diff)

is_merge = normed_diff < threshold_diff
if is_merge:
merge_value = normed_diff
final_shift = np.argmin(all_shift_diff) - num_shift
else:
final_shift = 0
merge_value = np.nan

# DEBUG = False
DEBUG = True
if DEBUG and normed_diff < 0.2:
# if DEBUG:

import matplotlib.pyplot as plt

fig, ax = plt.subplots()

m0 = template0.flatten()
m1 = template1.flatten()

ax.plot(m0, color="C0", label=f"{label0} {inds0.size}")
ax.plot(m1, color="C1", label=f"{label1} {inds1.size}")

ax.set_title(
f"union{union_chans.size} intersect{target_chans.size} \n {normed_diff:.3f} {final_shift} {is_merge}"
)
ax.legend()
plt.show()

return is_merge, label0, label1, final_shift, merge_value


find_pair_method_list = [
ProjectDistribution,
NormalizedTemplateDiff,
]
find_pair_method_dict = {e.name: e for e in find_pair_method_list}

0 comments on commit 728351a

Please sign in to comment.