Skip to content

Commit

Permalink
Merge pull request #2811 from samuelgarcia/tridesclous2
Browse files Browse the repository at this point in the history
Benchmark sorting components + Tridesclous2 improvement
  • Loading branch information
samuelgarcia authored May 13, 2024
2 parents 3e9cff3 + c2dbfe2 commit 3864afd
Show file tree
Hide file tree
Showing 23 changed files with 630 additions and 465 deletions.
3 changes: 3 additions & 0 deletions src/spikeinterface/comparison/groundtruthstudy.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ def remove_sorting(self, key):
sorting_folder = self.folder / "sortings" / self.key_to_str(key)
log_file = self.folder / "sortings" / "run_logs" / f"{self.key_to_str(key)}.json"
comparison_file = self.folder / "comparisons" / self.key_to_str(key)
self.sortings[key] = None
self.comparisons[key] = None
if sorting_folder.exists():
shutil.rmtree(sorting_folder)
for f in (log_file, comparison_file):
Expand Down Expand Up @@ -381,6 +383,7 @@ def get_performance_by_unit(self, case_keys=None):

perf_by_unit = pd.concat(perf_by_unit)
perf_by_unit = perf_by_unit.set_index(self.levels)
perf_by_unit = perf_by_unit.sort_index()
return perf_by_unit

def get_count_units(self, case_keys=None, well_detected_score=None, redundant_score=None, overmerged_score=None):
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin):
return (local_peaks,)


def sorting_to_peaks(sorting, extremum_channel_inds, dtype):
def sorting_to_peaks(sorting, extremum_channel_inds, dtype=spike_peak_dtype):
spikes = sorting.to_spike_vector()
peaks = np.zeros(spikes.size, dtype=dtype)
peaks["sample_index"] = spikes["sample_index"]
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,8 @@ def create_binary_folder(cls, folder, sorting, recording, sparsity, return_scale
json.dump(check_json(info), f, indent=4)

# save a copy of the sorting
NumpyFolderSorting.write_sorting(sorting, folder / "sorting")
# NumpyFolderSorting.write_sorting(sorting, folder / "sorting")
sorting.save(folder=folder / "sorting")

# save recording and sorting provenance
if recording.check_serializability("json"):
Expand Down
36 changes: 33 additions & 3 deletions src/spikeinterface/generation/drifting_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def make_one_displacement_vector(
start_drift_index = int(t_start_drift * displacement_sampling_frequency)
end_drift_index = int(t_end_drift * displacement_sampling_frequency)

num_samples = int(displacement_sampling_frequency * duration)
num_samples = int(np.ceil(displacement_sampling_frequency * duration))
displacement_vector = np.zeros(num_samples, dtype="float32")

if drift_mode == "zigzag":
Expand Down Expand Up @@ -286,6 +286,7 @@ def generate_drifting_recording(
),
generate_sorting_kwargs=dict(firing_rates=(2.0, 8.0), refractory_period_ms=4.0),
generate_noise_kwargs=dict(noise_levels=(12.0, 15.0), spatial_decay=25.0),
extra_outputs=False,
seed=None,
):
"""
Expand Down Expand Up @@ -314,6 +315,8 @@ def generate_drifting_recording(
Parameters given to generate_sorting().
generate_noise_kwargs: dict
Parameters given to generate_noise().
extra_outputs: bool, default False
Return optionaly a dict with more variables.
seed: None ot int
A unique seed for all steps.
Expand All @@ -326,7 +329,14 @@ def generate_drifting_recording(
sorting: Sorting
The ground trith soring object.
Same for both recordings.
extra_infos:
If extra_outputs=True, then return also a dict that contain various information like:
* displacement_vectors
* displacement_sampling_frequency
* unit_locations
* displacement_unit_factor
* unit_displacements
This can be helpfull for motion benchmark.
"""

rng = np.random.default_rng(seed=seed)
Expand Down Expand Up @@ -356,6 +366,14 @@ def generate_drifting_recording(
generate_displacement_vector(duration, unit_locations[:, :2], seed=seed, **generate_displacement_vector_kwargs)
)

# unit_displacements is the sum of all discplacements (times, units, direction_x_y)
unit_displacements = np.zeros((displacement_vectors.shape[0], num_units, 2))
for direction in (0, 1):
# x and y
for i in range(displacement_vectors.shape[2]):
m = displacement_vectors[:, direction, i][:, np.newaxis] * displacement_unit_factor[:, i][np.newaxis, :]
unit_displacements[:, :, direction] += m

# unit_params need to be fixed before the displacement steps
generate_templates_kwargs = generate_templates_kwargs.copy()
unit_params = _ensure_unit_params(generate_templates_kwargs.get("unit_params", {}), num_units, seed)
Expand Down Expand Up @@ -400,6 +418,8 @@ def generate_drifting_recording(
seed=seed,
)

sorting.set_property("gt_unit_locations", unit_locations)

## Important precompute displacement do not work on border and so do not work for tetrode
# here we bypass the interpolation and regenrate templates at severals positions.
## drifting_templates.precompute_displacements(displacements_steps)
Expand Down Expand Up @@ -437,4 +457,14 @@ def generate_drifting_recording(
amplitude_factor=None,
)

return static_recording, drifting_recording, sorting
if extra_outputs:
extra_infos = dict(
displacement_vectors=displacement_vectors,
displacement_sampling_frequency=displacement_sampling_frequency,
unit_locations=unit_locations,
displacement_unit_factor=displacement_unit_factor,
unit_displacements=unit_displacements,
)
return static_recording, drifting_recording, sorting, extra_infos
else:
return static_recording, drifting_recording, sorting
178 changes: 27 additions & 151 deletions src/spikeinterface/sorters/internal/tridesclous2.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter):

_default_params = {
"apply_preprocessing": True,
"apply_motion_correction": False,
"cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True},
"waveforms": {
"ms_before": 0.5,
Expand All @@ -52,10 +53,12 @@ class Tridesclous2Sorter(ComponentsBasedSorter):
"ms_before": 2.0,
"ms_after": 3.0,
"max_spikes_per_unit": 400,
"sparsity_threshold": 2.0,
# "peak_shift_ms": 0.2,
},
# "matching": {"method": "tridesclous", "method_kwargs": {"peak_shift_ms": 0.2, "radius_um": 100.0}},
"matching": {"method": "circus-omp-svd", "method_kwargs": {}},
# "matching": {"method": "circus-omp-svd", "method_kwargs": {}},
"matching": {"method": "wobble", "method_kwargs": {}},
"job_kwargs": {"n_jobs": -1},
"save_array": True,
}
Expand Down Expand Up @@ -102,6 +105,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
from spikeinterface.sortingcomponents.clustering.split import split_clusters
from spikeinterface.sortingcomponents.clustering.merge import merge_clusters
from spikeinterface.sortingcomponents.clustering.tools import compute_template_from_sparse
from spikeinterface.sortingcomponents.clustering.main import find_cluster_from_peaks
from spikeinterface.sortingcomponents.tools import remove_empty_templates

from sklearn.decomposition import TruncatedSVD

Expand All @@ -115,10 +120,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
# preprocessing
if params["apply_preprocessing"]:
recording = bandpass_filter(recording_raw, **params["filtering"])
# TODO what is the best about zscore>common_reference or the reverse
recording = common_reference(recording)
recording = zscore(recording, dtype="float32")
# recording = whiten(recording, dtype="float32")
recording = whiten(recording, dtype="float32")

# used only if "folder" or "zarr"
cache_folder = sorter_output_folder / "cache_preprocessing"
Expand Down Expand Up @@ -148,152 +152,22 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
if verbose:
print("We kept %d peaks for clustering" % len(peaks))

ms_before = params["waveforms"]["ms_before"]
ms_after = params["waveforms"]["ms_after"]
clustering_kwargs = {}
clustering_kwargs["folder"] = sorter_output_folder
clustering_kwargs["waveforms"] = params["waveforms"].copy()
clustering_kwargs["clustering"] = params["clustering"].copy()

# SVD for time compression
few_peaks = select_peaks(peaks, method="uniform", n_peaks=5000)
few_wfs = extract_waveform_at_max_channel(
recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs
labels_set, post_clean_label, extra_out = find_cluster_from_peaks(
recording, peaks, method="tdc_clustering", method_kwargs=clustering_kwargs, extra_outputs=True, **job_kwargs
)

wfs = few_wfs[:, :, 0]
tsvd = TruncatedSVD(params["svd"]["n_components"])
tsvd.fit(wfs)

model_folder = sorter_output_folder / "tsvd_model"

model_folder.mkdir(exist_ok=True)
with open(model_folder / "pca_model.pkl", "wb") as f:
pickle.dump(tsvd, f)

model_params = {
"ms_before": ms_before,
"ms_after": ms_after,
"sampling_frequency": float(sampling_frequency),
}
with open(model_folder / "params.json", "w") as f:
json.dump(model_params, f)

# features

features_folder = sorter_output_folder / "features"
node0 = PeakRetriever(recording, peaks)

radius_um = params["waveforms"]["radius_um"]
node1 = ExtractSparseWaveforms(
recording,
parents=[node0],
return_output=True,
ms_before=ms_before,
ms_after=ms_after,
radius_um=radius_um,
)

model_folder_path = sorter_output_folder / "tsvd_model"

node2 = TemporalPCAProjection(
recording, parents=[node0, node1], return_output=True, model_folder_path=model_folder_path
)

pipeline_nodes = [node0, node1, node2]

output = run_node_pipeline(
recording,
pipeline_nodes,
job_kwargs,
gather_mode="npy",
gather_kwargs=dict(exist_ok=True),
folder=features_folder,
names=["sparse_wfs", "sparse_tsvd"],
)

# TODO make this generic in GatherNPY ???
sparse_mask = node1.neighbours_mask
np.save(features_folder / "sparse_mask.npy", sparse_mask)
np.save(features_folder / "peaks.npy", peaks)

# Clustering: channel index > split > merge
split_radius_um = params["clustering"]["split_radius_um"]
neighbours_mask = get_channel_distances(recording) < split_radius_um

original_labels = peaks["channel_index"]

min_cluster_size = 50

post_split_label, split_count = split_clusters(
original_labels,
recording,
features_folder,
method="local_feature_clustering",
method_kwargs=dict(
# clusterer="hdbscan",
clusterer="isocut5",
feature_name="sparse_tsvd",
# feature_name="sparse_wfs",
neighbours_mask=neighbours_mask,
waveforms_sparse_mask=sparse_mask,
min_size_split=min_cluster_size,
clusterer_kwargs={"min_cluster_size": min_cluster_size},
n_pca_features=3,
),
recursive=True,
recursive_depth=3,
returns_split_count=True,
**job_kwargs,
)

merge_radius_um = params["clustering"]["merge_radius_um"]
threshold_diff = params["clustering"]["threshold_diff"]

post_merge_label, peak_shifts = merge_clusters(
peaks,
post_split_label,
recording,
features_folder,
radius_um=merge_radius_um,
# method="project_distribution",
# method_kwargs=dict(
# waveforms_sparse_mask=sparse_mask,
# feature_name="sparse_wfs",
# projection="centroid",
# criteria="distrib_overlap",
# threshold_overlap=0.3,
# min_cluster_size=min_cluster_size + 1,
# num_shift=5,
# ),
method="normalized_template_diff",
method_kwargs=dict(
waveforms_sparse_mask=sparse_mask,
threshold_diff=threshold_diff,
min_cluster_size=min_cluster_size + 1,
num_shift=5,
),
**job_kwargs,
)

# sparse_wfs = np.load(features_folder / "sparse_wfs.npy", mmap_mode="r")

peak_shifts = extra_out["peak_shifts"]
new_peaks = peaks.copy()
new_peaks["sample_index"] -= peak_shifts

# clean very small cluster before peeler
post_clean_label = post_merge_label.copy()

minimum_cluster_size = 25
labels_set, count = np.unique(post_clean_label, return_counts=True)
to_remove = labels_set[count < minimum_cluster_size]
mask = np.isin(post_clean_label, to_remove)
post_clean_label[mask] = -1

# final label sets
labels_set = np.unique(post_clean_label)
labels_set = labels_set[labels_set >= 0]

mask = post_clean_label >= 0
sorting_pre_peeler = NumpySorting.from_times_labels(
new_peaks["sample_index"][mask],
post_merge_label[mask],
post_clean_label[mask],
sampling_frequency,
unit_ids=labels_set,
)
Expand All @@ -303,6 +177,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

nbefore = int(params["templates"]["ms_before"] * sampling_frequency / 1000.0)
nafter = int(params["templates"]["ms_after"] * sampling_frequency / 1000.0)
sparsity_threshold = params["templates"]["sparsity_threshold"]
templates_array = estimate_templates_with_accumulator(
recording_w,
sorting_pre_peeler.to_spike_vector(),
Expand All @@ -320,8 +195,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
)
# TODO : try other methods for sparsity
# sparsity = compute_sparsity(templates_dense, method="radius", radius_um=120.)
sparsity = compute_sparsity(templates_dense, noise_levels=noise_levels, threshold=1.0)
sparsity = compute_sparsity(templates_dense, noise_levels=noise_levels, threshold=sparsity_threshold)
templates = templates_dense.to_sparse(sparsity)
templates = remove_empty_templates(templates)

# snrs = compute_snrs(we, peak_sign=params["detection"]["peak_sign"], peak_mode="extremum")
# print(snrs)
Expand Down Expand Up @@ -350,12 +226,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
# )
# )

if matching_method == "circus-omp-svd":
job_kwargs = job_kwargs.copy()
for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]:
if value in job_kwargs:
job_kwargs.pop(value)
job_kwargs["chunk_duration"] = "100ms"
# if matching_method == "circus-omp-svd":
# job_kwargs = job_kwargs.copy()
# for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]:
# if value in job_kwargs:
# job_kwargs.pop(value)
# job_kwargs["chunk_duration"] = "100ms"

spikes = find_spikes_from_templates(
recording_w, method=matching_method, method_kwargs=matching_params, **job_kwargs
Expand All @@ -366,9 +242,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

np.save(sorter_output_folder / "noise_levels.npy", noise_levels)
np.save(sorter_output_folder / "all_peaks.npy", all_peaks)
np.save(sorter_output_folder / "post_split_label.npy", post_split_label)
np.save(sorter_output_folder / "split_count.npy", split_count)
np.save(sorter_output_folder / "post_merge_label.npy", post_merge_label)
# np.save(sorter_output_folder / "post_split_label.npy", post_split_label)
# np.save(sorter_output_folder / "split_count.npy", split_count)
# np.save(sorter_output_folder / "post_merge_label.npy", post_merge_label)
np.save(sorter_output_folder / "spikes.npy", spikes)

final_spikes = np.zeros(spikes.size, dtype=minimum_spike_dtype)
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/sorters/tests/test_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,11 @@ def test_run_sorter_by_property():
setup_module()
job_list = get_job_list()

test_run_sorter_jobs_loop(job_list)
# test_run_sorter_jobs_loop(job_list)
# test_run_sorter_jobs_joblib(job_list)
# test_run_sorter_jobs_processpoolexecutor(job_list)
# test_run_sorter_jobs_multiprocessing(job_list)
# test_run_sorter_jobs_dask(job_list)
# test_run_sorter_jobs_slurm(job_list)

# test_run_sorter_by_property()
test_run_sorter_by_property()
Loading

0 comments on commit 3864afd

Please sign in to comment.