Skip to content

Commit

Permalink
Fixes for the clustering
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed Nov 15, 2024
1 parent 472022d commit d3bee9a
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 20 deletions.
6 changes: 3 additions & 3 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
"general": {"ms_before": 2, "ms_after": 2, "radius_um": 50},
"sparsity": {"method": "snr", "amplitude_mode": "peak_to_peak", "threshold": 0.25},
"filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2},
"whitening": {"mode": "local", "regularize": False, "radius_um": 150},
"whitening": {"mode": "local", "regularize": False, "radius_um": 100},
"detection": {"peak_sign": "neg", "detect_threshold": 4},
"selection": {
"method": "uniform",
Expand All @@ -46,7 +46,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
"corr_diff_thresh": 0.25,
},
},
"clustering": {"legacy": True},
"clustering": {"legacy": False},
"matching": {"method": "circus-omp-svd"},
"apply_preprocessing": True,
"matched_filtering": True,
Expand Down Expand Up @@ -286,7 +286,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
sparsity = compute_sparsity(templates, noise_levels, **params["sparsity"])
templates = templates.to_sparse(sparsity)
templates = remove_empty_templates(templates)

if params["debug"]:
templates.to_zarr(folder_path=clustering_folder / "templates")
sorting = sorting.save(folder=clustering_folder / "sorting")
Expand Down
28 changes: 19 additions & 9 deletions src/spikeinterface/sortingcomponents/clustering/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import random, string
from spikeinterface.core import get_global_tmp_folder
from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.core.waveform_tools import estimate_templates
from spikeinterface.core.waveform_tools import estimate_templates, estimate_templates_with_accumulator
from .clustering_tools import remove_duplicates_via_matching
from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances
from spikeinterface.core.job_tools import fix_job_kwargs
Expand Down Expand Up @@ -60,6 +60,7 @@ class CircusClustering:
"n_svd": [5, 2],
"ms_before": 0.5,
"ms_after": 0.5,
"noise_threshold" : 1,
"rank": 5,
"noise_levels": None,
"tmp_folder": None,
Expand Down Expand Up @@ -226,36 +227,45 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
nbefore = int(params["waveforms"]["ms_before"] * fs / 1000.0)
nafter = int(params["waveforms"]["ms_after"] * fs / 1000.0)

templates_array = estimate_templates(
recording, spikes, unit_ids, nbefore, nafter, return_scaled=False, job_name=None, **job_kwargs
if params["noise_levels"] is None:
params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs)

templates_array, templates_array_std = estimate_templates_with_accumulator(
recording, spikes, unit_ids, nbefore, nafter, return_scaled=False, return_std=True, job_name=None, **job_kwargs
)

peak_snrs = np.abs(templates_array[:, nbefore, :])/templates_array_std[:, nbefore, :]
valid_templates = np.linalg.norm(peak_snrs, axis=1)/np.linalg.norm(params["noise_levels"])
valid_templates = valid_templates > params["noise_threshold"]


if d["rank"] is not None:
from spikeinterface.sortingcomponents.matching.circus import compress_templates

_, _, _, templates_array = compress_templates(templates_array, d["rank"])

templates = Templates(
templates_array=templates_array,
templates_array=templates_array[valid_templates],
sampling_frequency=fs,
nbefore=nbefore,
sparsity_mask=None,
channel_ids=recording.channel_ids,
unit_ids=unit_ids,
unit_ids=unit_ids[valid_templates],
probe=recording.get_probe(),
is_scaled=False,
)

if params["noise_levels"] is None:
params["noise_levels"] = get_noise_levels(recording, return_scaled=False)


sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"])
templates = templates.to_sparse(sparsity)
empty_templates = templates.sparsity_mask.sum(axis=1) == 0
templates = remove_empty_templates(templates)

mask = np.isin(peak_labels, np.where(empty_templates)[0])
peak_labels[mask] = -1

mask = np.isin(peak_labels, np.where(~valid_templates)[0])
peak_labels[mask] = -1

if verbose:
print("We found %d raw clusters, starting to clean with matching..." % (len(templates.unit_ids)))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
HAVE_HDBSCAN = False

from spikeinterface.core.basesorting import minimum_spike_dtype
from spikeinterface.core.waveform_tools import estimate_templates
from spikeinterface.core.waveform_tools import estimate_templates, estimate_templates_with_accumulator
from .clustering_tools import remove_duplicates_via_matching
from spikeinterface.core.recording_tools import get_noise_levels, get_channel_distances
from spikeinterface.sortingcomponents.waveforms.savgol_denoiser import SavGolDenoiser
Expand Down Expand Up @@ -53,6 +53,7 @@ class RandomProjectionClustering:
"random_seed": 42,
"noise_levels": None,
"smoothing_kwargs": {"window_length_ms": 0.25},
"noise_threshold" : 1,
"tmp_folder": None,
"verbose": True,
}
Expand Down Expand Up @@ -129,25 +130,38 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
nbefore = int(params["waveforms"]["ms_before"] * fs / 1000.0)
nafter = int(params["waveforms"]["ms_after"] * fs / 1000.0)

templates_array = estimate_templates(
recording, spikes, unit_ids, nbefore, nafter, return_scaled=False, job_name=None, **job_kwargs
if params["noise_levels"] is None:
params["noise_levels"] = get_noise_levels(recording, return_scaled=False, **job_kwargs)

templates_array, templates_array_std = estimate_templates_with_accumulator(
recording, spikes, unit_ids, nbefore, nafter, return_scaled=False, return_std=True, job_name=None, **job_kwargs
)

peak_snrs = np.abs(templates_array[:, nbefore, :])/templates_array_std[:, nbefore, :]
valid_templates = np.linalg.norm(peak_snrs, axis=1)/np.linalg.norm(params["noise_levels"])
valid_templates = valid_templates > params["noise_threshold"]

templates = Templates(
templates_array=templates_array,
templates_array=templates_array[valid_templates],
sampling_frequency=fs,
nbefore=nbefore,
sparsity_mask=None,
channel_ids=recording.channel_ids,
unit_ids=unit_ids,
unit_ids=unit_ids[valid_templates],
probe=recording.get_probe(),
is_scaled=False,
)
if params["noise_levels"] is None:
params["noise_levels"] = get_noise_levels(recording, return_scaled=False)
sparsity = compute_sparsity(templates, params["noise_levels"], **params["sparsity"])

sparsity = compute_sparsity(templates, noise_levels=params["noise_levels"], **params["sparsity"])
templates = templates.to_sparse(sparsity)
empty_templates = templates.sparsity_mask.sum(axis=1) == 0
templates = remove_empty_templates(templates)

mask = np.isin(peak_labels, np.where(empty_templates)[0])
peak_labels[mask] = -1

mask = np.isin(peak_labels, np.where(~valid_templates)[0])
peak_labels[mask] = -1

if verbose:
print("We found %d raw clusters, starting to clean with matching..." % (len(templates.unit_ids)))
Expand Down

0 comments on commit d3bee9a

Please sign in to comment.