Skip to content

Commit

Permalink
Merge pull request #2720 from zm711/ks-ms
Browse files Browse the repository at this point in the history
Port bug fix PRs from Main over to 0.100-bug-fixes
  • Loading branch information
alejoe91 authored Apr 25, 2024
2 parents 7c2656c + 7ad9bbb commit 449e219
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 27 deletions.
2 changes: 2 additions & 0 deletions src/spikeinterface/core/waveform_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,8 @@ def has_exceeding_spikes(recording, sorting):
if len(spike_vector_seg) > 0:
if spike_vector_seg["sample_index"][-1] > recording.get_num_samples(segment_index=segment_index) - 1:
return True
if spike_vector_seg["sample_index"][0] < 0:
return True
return False


Expand Down
6 changes: 4 additions & 2 deletions src/spikeinterface/curation/remove_excess_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def _custom_cache_spike_vector(self) -> None:
for segment_index in range(num_segments):
spike_vector = parent_spike_vector[segments_bounds[segment_index] : segments_bounds[segment_index + 1]]
end = np.searchsorted(spike_vector["sample_index"], self._num_samples[segment_index])
list_spike_vectors.append(spike_vector[:end])
start = np.searchsorted(spike_vector["sample_index"], 0)
list_spike_vectors.append(spike_vector[start:end])

spike_vector = np.concatenate(list_spike_vectors)
self._cached_spike_vector = spike_vector
Expand All @@ -78,8 +79,9 @@ def get_unit_spike_train(
) -> np.ndarray:
spike_train = self._parent_segment.get_unit_spike_train(unit_id, start_frame=start_frame, end_frame=end_frame)
max_spike = np.searchsorted(spike_train, self._num_samples, side="left")
min_spike = np.searchsorted(spike_train, 0, side="left")

return spike_train[:max_spike]
return spike_train[min_spike:max_spike]


def remove_excess_spikes(sorting, recording):
Expand Down
11 changes: 9 additions & 2 deletions src/spikeinterface/curation/tests/test_remove_excess_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,23 @@ def test_remove_excess_spikes():
num_spikes = 100
num_num_samples_spikes_per_segment = 5
num_excess_spikes_per_segment = 5
num_neg_spike_times_per_segment = 2
times = []
labels = []
for segment_index in range(recording.get_num_segments()):
num_samples = recording.get_num_samples(segment_index=segment_index)
times_segment = np.array([], dtype=int)
labels_segment = np.array([], dtype=int)
for unit in range(num_units):
neg_spike_times = np.random.randint(-50, 0, num_neg_spike_times_per_segment)
spike_times = np.random.randint(0, num_samples, num_spikes)
last_samples_spikes = (num_samples - 1) * np.ones(num_num_samples_spikes_per_segment, dtype=int)
num_samples_spike_times = num_samples * np.ones(num_num_samples_spikes_per_segment, dtype=int)
excess_spikes = np.random.randint(num_samples, num_samples + 100, num_excess_spikes_per_segment)
spike_times = np.sort(
np.concatenate((spike_times, last_samples_spikes, num_samples_spike_times, excess_spikes))
np.concatenate(
(neg_spike_times, spike_times, last_samples_spikes, num_samples_spike_times, excess_spikes)
)
)
spike_labels = unit * np.ones_like(spike_times)
times_segment = np.concatenate((times_segment, spike_times))
Expand All @@ -47,7 +51,10 @@ def test_remove_excess_spikes():

assert (
len(spike_train_corrected)
== len(spike_train_excess) - num_num_samples_spikes_per_segment - num_excess_spikes_per_segment
== len(spike_train_excess)
- num_num_samples_spikes_per_segment
- num_excess_spikes_per_segment
- num_neg_spike_times_per_segment
)


Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/sorters/external/kilosort4.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class Kilosort4Sorter(BaseSorter):
"sig_interp": 20,
"nt0min": None,
"dmin": None,
"dminx": None,
"dminx": 32,
"min_template_size": 10,
"template_sizes": 5,
"nearest_chans": 10,
Expand Down Expand Up @@ -68,7 +68,7 @@ class Kilosort4Sorter(BaseSorter):
"sig_interp": "For drift correction, sigma for interpolation (spatial standard deviation). Approximate smoothness scale in units of microns. Default value: 20.",
"nt0min": "Sample index for aligning waveforms, so that their minimum or maximum value happens here. Default of 20. Default value: None.",
"dmin": "Vertical spacing of template centers used for spike detection, in microns. Determined automatically by default. Default value: None.",
"dminx": "Horizontal spacing of template centers used for spike detection, in microns. Determined automatically by default. Default value: None.",
"dminx": "Horizontal spacing of template centers used for spike detection, in microns. Default value: 32.",
"min_template_size": "Standard deviation of the smallest, spatial envelope Gaussian used for universal templates. Default value: 10.",
"template_sizes": "Number of sizes for universal spike templates (multiples of the min_template_size). Default value: 5.",
"nearest_chans": "Number of nearest channels to consider when finding local maxima during spike detection. Default value: 10.",
Expand Down
46 changes: 25 additions & 21 deletions src/spikeinterface/sorters/external/mountainsort5.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from pathlib import Path
from packaging.version import parse

from tempfile import TemporaryDirectory

import shutil
import numpy as np
from warnings import warn

from spikeinterface.preprocessing import bandpass_filter, whiten

from spikeinterface.core.baserecording import BaseRecording
from ..basesorter import BaseSorter
from ..basesorter import BaseSorter, get_job_kwargs

from spikeinterface.extractors import NpzSortingExtractor

Expand Down Expand Up @@ -43,8 +44,7 @@ class Mountainsort5Sorter(BaseSorter):
"freq_max": 6000,
"filter": True,
"whiten": True, # Important to do whitening
"temporary_base_dir": None,
"n_jobs_for_preprocessing": -1,
"delete_temporary_recording": True,
}

_params_description = {
Expand All @@ -68,8 +68,7 @@ class Mountainsort5Sorter(BaseSorter):
"freq_max": "Low-pass filter cutoff frequency",
"filter": "Enable or disable filter",
"whiten": "Enable or disable whitening",
"temporary_base_dir": "Temporary directory base directory for storing cached recording",
"n_jobs_for_preprocessing": "Number of parallel jobs for creating the cached recording",
"delete_temporary_recording": "If True, the temporary recording file is deleted after sorting (this may fail on Windows requiring the end-user to delete the file themselves later)",
}

sorter_description = "MountainSort5 uses Isosplit clustering. It is an updated version of MountainSort4. See https://doi.org/10.1016/j.neuron.2017.08.030"
Expand Down Expand Up @@ -186,21 +185,26 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
block_sorting_parameters=scheme2_sorting_parameters, block_duration_sec=p["scheme3_block_duration_sec"]
)

with TemporaryDirectory(dir=p["temporary_base_dir"]) as tmpdir:
# cache the recording to a temporary directory for efficient reading (so we don't have to re-filter)
recording_cached = create_cached_recording(
recording=recording, folder=tmpdir, n_jobs=p["n_jobs_for_preprocessing"]
)

scheme = p["scheme"]
if scheme == "1":
sorting = ms5.sorting_scheme1(recording=recording_cached, sorting_parameters=scheme1_sorting_parameters)
elif p["scheme"] == "2":
sorting = ms5.sorting_scheme2(recording=recording_cached, sorting_parameters=scheme2_sorting_parameters)
elif p["scheme"] == "3":
sorting = ms5.sorting_scheme3(recording=recording_cached, sorting_parameters=scheme3_sorting_parameters)
else:
raise ValueError(f"Invalid scheme: {scheme} given. scheme must be one of '1', '2' or '3'")
if not recording.is_binary_compatible():
recording_cached = recording.save(folder=sorter_output_folder / "recording", **get_job_kwargs(p, verbose))
else:
recording_cached = recording

if p["scheme"] == "1":
sorting = ms5.sorting_scheme1(recording=recording_cached, sorting_parameters=scheme1_sorting_parameters)
elif p["scheme"] == "2":
sorting = ms5.sorting_scheme2(recording=recording_cached, sorting_parameters=scheme2_sorting_parameters)
elif p["scheme"] == "3":
sorting = ms5.sorting_scheme3(recording=recording_cached, sorting_parameters=scheme3_sorting_parameters)
else:
raise ValueError(f"Invalid scheme: {p['scheme']} given. scheme must be one of '1', '2' or '3'")

if p["delete_temporary_recording"]:
if not recording.is_binary_compatible():
del recording_cached
shutil.rmtree(sorter_output_folder / "recording", ignore_errors=True)
if Path(sorter_output_folder / "recording").is_dir():
warn("cleanup failed, please remove file yourself if desired")

NpzSortingExtractor.write_sorting(sorting, str(sorter_output_folder / "firings.npz"))

Expand Down

0 comments on commit 449e219

Please sign in to comment.