Skip to content

Commit

Permalink
Merge pull request #2205 from yger/legacy_mode
Browse files Browse the repository at this point in the history
Legacy mode for circus 2 to ease comparison with circus 1
  • Loading branch information
samuelgarcia authored Nov 21, 2023
2 parents efc042e + 20cc70e commit 38d82d4
Show file tree
Hide file tree
Showing 8 changed files with 283 additions and 337 deletions.
2 changes: 1 addition & 1 deletion src/spikeinterface/comparison/groundtruthstudy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


# This is to separate names when the key are tuples when saving folders
_key_separator = " ## "
_key_separator = "_##_"


class GroundTruthStudy:
Expand Down
35 changes: 28 additions & 7 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,17 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):

_default_params = {
"general": {"ms_before": 2, "ms_after": 2, "radius_um": 100},
"waveforms": {"max_spikes_per_unit": 200, "overwrite": True, "sparse": True, "method": "ptp", "threshold": 1},
"waveforms": {
"max_spikes_per_unit": 200,
"overwrite": True,
"sparse": True,
"method": "energy",
"threshold": 0.25,
},
"filtering": {"freq_min": 150, "dtype": "float32"},
"detection": {"peak_sign": "neg", "detect_threshold": 5},
"detection": {"peak_sign": "neg", "detect_threshold": 4},
"selection": {"n_peaks_per_channel": 5000, "min_n_peaks": 20000},
"localization": {},
"clustering": {},
"clustering": {"legacy": False},
"matching": {},
"apply_preprocessing": True,
"shared_memory": True,
Expand Down Expand Up @@ -66,6 +71,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
recording_f = common_reference(recording_f)
else:
recording_f = recording
recording_f.annotate(is_filtered=True)

# recording_f = whiten(recording_f, dtype="float32")
recording_f = zscore(recording_f, dtype="float32")
Expand Down Expand Up @@ -111,8 +117,18 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
clustering_params["tmp_folder"] = sorter_output_folder / "clustering"
clustering_params.update({"noise_levels": noise_levels})

if "legacy" in clustering_params:
legacy = clustering_params.pop("legacy")
else:
legacy = False

if legacy:
clustering_method = "circus"
else:
clustering_method = "random_projections"

labels, peak_labels = find_cluster_from_peaks(
recording_f, selected_peaks, method="random_projections", method_kwargs=clustering_params
recording_f, selected_peaks, method=clustering_method, method_kwargs=clustering_params
)

## We get the labels for our peaks
Expand Down Expand Up @@ -140,13 +156,18 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
waveforms_folder = sorter_output_folder / "waveforms"

we = extract_waveforms(
recording_f, sorting, waveforms_folder, mode=mode, **waveforms_params, return_scaled=False
recording_f,
sorting,
waveforms_folder,
return_scaled=False,
precompute_template=["median"],
mode=mode,
**waveforms_params,
)

## We launch a OMP matching pursuit by full convolution of the templates and the raw traces
matching_params = params["matching"].copy()
matching_params["waveform_extractor"] = we
matching_params.update({"noise_levels": noise_levels})

matching_job_params = job_kwargs.copy()
for value in ["chunk_size", "chunk_memory", "total_memory", "chunk_duration"]:
Expand Down
49 changes: 9 additions & 40 deletions src/spikeinterface/sorters/internal/tridesclous2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
NumpySorting,
get_channel_distances,
)
from spikeinterface.core.waveform_tools import extract_waveforms_to_single_buffer

from spikeinterface.core.job_tools import fix_job_kwargs

from spikeinterface.preprocessing import bandpass_filter, common_reference, zscore
from spikeinterface.core.basesorting import minimum_spike_dtype

from spikeinterface.sortingcomponents.tools import extract_waveform_at_max_channel

import numpy as np

import pickle
Expand Down Expand Up @@ -115,9 +117,14 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
if verbose:
print("We kept %d peaks for clustering" % len(peaks))

ms_before = params["waveforms"]["ms_before"]
ms_after = params["waveforms"]["ms_after"]

# SVD for time compression
few_peaks = select_peaks(peaks, method="uniform", n_peaks=5000)
few_wfs = extract_waveform_at_max_channel(recording, few_peaks, **job_kwargs)
few_wfs = extract_waveform_at_max_channel(
recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs
)

wfs = few_wfs[:, :, 0]
tsvd = TruncatedSVD(params["svd"]["n_components"])
Expand All @@ -129,8 +136,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
with open(model_folder / "pca_model.pkl", "wb") as f:
pickle.dump(tsvd, f)

ms_before = params["waveforms"]["ms_before"]
ms_after = params["waveforms"]["ms_after"]
model_params = {
"ms_before": ms_before,
"ms_after": ms_after,
Expand Down Expand Up @@ -319,39 +324,3 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
sorting = sorting.save(folder=sorter_output_folder / "sorting")

return sorting


def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **job_kwargs):
"""
Helper function to extractor waveforms at max channel from a peak list
"""
n = rec.get_num_channels()
unit_ids = np.arange(n, dtype="int64")
sparsity_mask = np.eye(n, dtype="bool")

spikes = np.zeros(
peaks.size, dtype=[("sample_index", "int64"), ("unit_index", "int64"), ("segment_index", "int64")]
)
spikes["sample_index"] = peaks["sample_index"]
spikes["unit_index"] = peaks["channel_index"]
spikes["segment_index"] = peaks["segment_index"]

nbefore = int(ms_before * rec.sampling_frequency / 1000.0)
nafter = int(ms_after * rec.sampling_frequency / 1000.0)

all_wfs = extract_waveforms_to_single_buffer(
rec,
spikes,
unit_ids,
nbefore,
nafter,
mode="shared_memory",
return_scaled=False,
sparsity_mask=sparsity_mask,
copy=True,
**job_kwargs,
)

return all_wfs
Loading

0 comments on commit 38d82d4

Please sign in to comment.