Skip to content

Commit

Permalink
Merge branch 'main' into deepinterp
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 authored Oct 19, 2023
2 parents d4e0824 + a733fd0 commit 3fe5d07
Show file tree
Hide file tree
Showing 26 changed files with 244 additions and 276 deletions.
22 changes: 17 additions & 5 deletions src/spikeinterface/curation/curation_tools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
from typing import Optional
import numpy as np

Expand All @@ -9,9 +10,15 @@
except ModuleNotFoundError as err:
HAVE_NUMBA = False

_methods = ("keep_first", "random", "keep_last", "keep_first_iterative", "keep_last_iterative")
_methods_numpy = ("keep_first", "random", "keep_last")


def _find_duplicated_spikes_numpy(
spike_train: np.ndarray, censored_period: int, seed: Optional[int] = None, method: str = "keep_first"
spike_train: np.ndarray,
censored_period: int,
seed: Optional[int] = None,
method: "keep_first" | "random" | "keep_last" = "keep_first",
) -> np.ndarray:
(indices_of_duplicates,) = np.where(np.diff(spike_train) <= censored_period)

Expand All @@ -29,7 +36,9 @@ def _find_duplicated_spikes_numpy(

(indices_of_duplicates,) = np.where(~mask)
elif method != "keep_last":
raise ValueError(f"Method '{method}' isn't a valid method for _find_duplicated_spikes_numpy.")
raise ValueError(
f"Method '{method}' isn't a valid method for _find_duplicated_spikes_numpy use one of {_methods_numpy}."
)

return indices_of_duplicates

Expand Down Expand Up @@ -84,7 +93,10 @@ def _find_duplicated_spikes_keep_last_iterative(spike_train, censored_period):


def find_duplicated_spikes(
spike_train, censored_period: int, method: str = "random", seed: Optional[int] = None
spike_train,
censored_period: int,
method: "keep_first" | "keep_last" | "keep_first_iterative" | "keep_last_iterative" | "random" = "random",
seed: Optional[int] = None,
) -> np.ndarray:
"""
Finds the indices where spikes should be considered duplicates.
Expand All @@ -97,7 +109,7 @@ def find_duplicated_spikes(
The spike train on which to look for duplicated spikes.
censored_period: int
The censored period for duplicates (in sample time).
method: str in ("keep_first", "keep_last", "keep_first_iterative', 'keep_last_iterative", random")
method: "keep_first" |"keep_last" | "keep_first_iterative' | 'keep_last_iterative" |random"
Method used to remove the duplicated spikes.
seed: int | None
The seed to use if method="random".
Expand All @@ -120,4 +132,4 @@ def find_duplicated_spikes(
assert HAVE_NUMBA, "'keep_last' method requires numba. Install it with >>> pip install numba"
return _find_duplicated_spikes_keep_last_iterative(spike_train.astype(np.int64), censored_period)
else:
raise ValueError(f"Method '{method}' isn't a valid method for find_duplicated_spikes.")
raise ValueError(f"Method '{method}' isn't a valid method for find_duplicated_spikes. Use one of {_methods}")
12 changes: 6 additions & 6 deletions src/spikeinterface/curation/curationsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,24 +148,24 @@ def remove_empty_units(self):
edges = None
self._add_new_stage(new_sorting, edges)

def redo_avaiable(self):
def redo_available(self):
# useful function for a gui
return self._sorting_stages_i < len(self._sorting_stages)

def undo_avaiable(self):
def undo_available(self):
# useful function for a gui
return self._sorting_stages_i > 0

def undo(self):
if self.undo_avaiable():
if self.undo_available():
self._sorting_stages_i -= 1

def redo(self):
if self.redo_avaiable():
if self.redo_available():
self._sorting_stages_i += 1

def draw_graph(self, **kwargs):
assert self._make_graph, "to make a graph make_graph=True"
assert self._make_graph, "to make a graph use make_graph=True"
graph = self.graph
ids = [c.unit_id for c in graph.nodes]
pos = {n: (n.stage_id, -ids.index(n.unit_id)) for n in graph.nodes}
Expand All @@ -174,7 +174,7 @@ def draw_graph(self, **kwargs):

@property
def graph(self):
assert self._make_graph, "to have a graph make_graph=True"
assert self._make_graph, "to have a graph use make_graph=True"
return self._graphs[self._sorting_stages_i]

@property
Expand Down
12 changes: 8 additions & 4 deletions src/spikeinterface/curation/remove_redundant.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import numpy as np

from spikeinterface import WaveformExtractor
Expand All @@ -6,6 +7,9 @@
from ..postprocessing import align_sorting


_remove_strategies = ("minimum_shift", "highest_amplitude", "max_spikes")


def remove_redundant_units(
sorting_or_waveform_extractor,
align=True,
Expand Down Expand Up @@ -42,15 +46,15 @@ def remove_redundant_units(
duplicate_threshold : float, optional
Final threshold on the portion of coincident events over the number of spikes above which the
unit is removed, by default 0.8
remove_strategy: str
remove_strategy: 'minimum_shift' | 'highest_amplitude' | 'max_spikes', default: 'minimum_shift'
Which strategy to remove one of the two duplicated units:
* 'minimum_shift': keep the unit with best peak alignment (minimum shift)
If shifts are equal then the 'highest_amplitude' is used
* 'highest_amplitude': keep the unit with the best amplitude on unshifted max.
* 'max_spikes': keep the unit with more spikes
peak_sign: str ('neg', 'pos', 'both')
peak_sign: 'neg' |'pos' | 'both', default: 'neg'
Used when remove_strategy='highest_amplitude'
extra_outputs: bool
If True, will return the redundant pairs.
Expand Down Expand Up @@ -93,7 +97,7 @@ def remove_redundant_units(
peak_values = {unit_id: np.max(np.abs(values)) for unit_id, values in peak_values.items()}

if remove_strategy == "minimum_shift":
assert align, "remove_strategy with minimum_shift need align=True"
assert align, "remove_strategy with minimum_shift needs align=True"
for u1, u2 in redundant_unit_pairs:
if np.abs(unit_peak_shifts[u1]) > np.abs(unit_peak_shifts[u2]):
remove_unit_ids.append(u1)
Expand Down Expand Up @@ -125,7 +129,7 @@ def remove_redundant_units(
# this will be implemented in a futur PR by the first who need it!
raise NotImplementedError()
else:
raise ValueError(f"remove_strategy : {remove_strategy} is not implemented!")
raise ValueError(f"remove_strategy : {remove_strategy} is not implemented! Options are {_remove_strategies}")

sorting_clean = sorting.remove_units(remove_unit_ids)

Expand Down
11 changes: 5 additions & 6 deletions src/spikeinterface/curation/splitunitsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@ class SplitUnitSorting(BaseSorting):
be the same length as the spike train (for each segment)
new_unit_ids: int
Unit ids of the new units to be created.
properties_policy: str
properties_policy: 'keep' | 'remove', default: 'keep'
Policy used to propagate properties. If 'keep' the properties will be passed to the new units
(if the units_to_merge have the same value). If 'remove' the new units will have an empty
value for all the properties of the new unit.
Default: 'keep'
Returns
-------
sorting: Sorting
Expand All @@ -48,19 +47,19 @@ def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=Non
new_unit_ids = np.array([u + new_unit_ids for u in range(tot_splits)], dtype=parents_unit_ids.dtype)
else:
new_unit_ids = np.array(new_unit_ids, dtype=parents_unit_ids.dtype)
assert len(np.unique(new_unit_ids)) == len(new_unit_ids), "Each element in new_unit_ids should be unique"
assert len(new_unit_ids) <= tot_splits, "indices_list have more ids indices than the length of new_unit_ids"
assert len(np.unique(new_unit_ids)) == len(new_unit_ids), "Each element in new_unit_ids must be unique"
assert len(new_unit_ids) <= tot_splits, "indices_list has more id indices than the length of new_unit_ids"

assert parent_sorting.get_num_segments() == len(
indices_list
), "The length of indices_list must be the same as parent_sorting.get_num_segments"
assert split_unit_id in parents_unit_ids, "Unit to split should be in parent sorting"
assert split_unit_id in parents_unit_ids, "Unit to split must be in parent sorting"
assert properties_policy == "keep" or properties_policy == "remove", (
"properties_policy must be " "keep" " or " "remove" ""
)
assert not any(
np.isin(new_unit_ids, unchanged_units)
), "new_unit_ids should be new units or one could be equal to split_unit_id"
), "new_unit_ids should be new unit ids or no more than one unit id can be found in split_unit_id"

sampling_frequency = parent_sorting.get_sampling_frequency()
units_ids = np.concatenate([unchanged_units, new_unit_ids])
Expand Down
4 changes: 3 additions & 1 deletion src/spikeinterface/exporters/to_phy.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def export_to_phy(
), "waveform_extractor must be a WaveformExtractor object"
sorting = waveform_extractor.sorting

assert waveform_extractor.get_num_segments() == 1, "Export to phy only works with one segment"
assert (
waveform_extractor.get_num_segments() == 1
), f"Export to phy only works with one segment, your extractor has {waveform_extractor.get_num_segments()} segments"
num_chans = waveform_extractor.get_num_channels()
fs = waveform_extractor.sampling_frequency

Expand Down
10 changes: 1 addition & 9 deletions src/spikeinterface/extractors/neoextractors/neuroscope.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,6 @@

from .neobaseextractor import NeoBaseRecordingExtractor

try:
from lxml import etree as et

HAVE_LXML = True
except ImportError:
HAVE_LXML = False

PathType = Union[str, Path]
OptionalPathType = Optional[PathType]
Expand Down Expand Up @@ -108,8 +102,6 @@ class NeuroScopeSortingExtractor(BaseSorting):
"""

extractor_name = "NeuroscopeSortingExtractor"
installed = HAVE_LXML
installation_mesg = "Please install lxml to use this extractor!"
name = "neuroscope"

def __init__(
Expand All @@ -121,7 +113,7 @@ def __init__(
exclude_shanks: Optional[list] = None,
xml_file_path: OptionalPathType = None,
):
assert self.installed, self.installation_mesg
from lxml import etree as et

assert not (
folder_path is None and resfile_path is None and clufile_path is None
Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/postprocessing/template_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
https://github.com/AllenInstitute/ecephys_spike_sorting/blob/master/ecephys_spike_sorting/modules/mean_waveforms/waveform_metrics.py
22/04/2020
"""
from __future__ import annotations

import numpy as np
import warnings
from typing import Optional
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/sorters/basesorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo
)

if not isinstance(recording, BaseRecordingSnippets):
raise ValueError("recording must be a Recording or Snippets!!")
raise ValueError("recording must be a Recording or a Snippets!!")

if cls.requires_locations:
locations = recording.get_channel_locations()
Expand Down Expand Up @@ -133,7 +133,7 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo
if recording.get_num_segments() > 1:
if not cls.handle_multi_segment:
raise ValueError(
f"This sorter {cls.sorter_name} do not handle multi segment, use si.concatenate_recordings(...)"
f"This sorter {cls.sorter_name} does not handle multi-segment recordings, use si.concatenate_recordings(...)"
)

rec_file = output_folder / "spikeinterface_recording.json"
Expand Down Expand Up @@ -299,7 +299,7 @@ def get_result_from_folder(cls, output_folder, register_recording=True, sorting_
# check errors in log file
log_file = output_folder / "spikeinterface_log.json"
if not log_file.is_file():
raise SpikeSortingError("get result error: the folder does not contain the `spikeinterface_log.json` file")
raise SpikeSortingError("Get result error: the folder does not contain the `spikeinterface_log.json` file")

with log_file.open("r", encoding="utf8") as f:
log = json.load(f)
Expand Down
6 changes: 4 additions & 2 deletions src/spikeinterface/sorters/external/kilosort.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class KilosortSorter(KilosortBase, BaseSorter):
"Nfilt": None,
"NT": None,
"wave_length": 61,
"delete_tmp_files": True,
"delete_tmp_files": ("matlab_files",),
"delete_recording_dat": False,
}

Expand All @@ -56,7 +56,9 @@ class KilosortSorter(KilosortBase, BaseSorter):
"Nfilt": "Number of clusters to use (if None it is automatically computed)",
"NT": "Batch size (if None it is automatically computed)",
"wave_length": "size of the waveform extracted around each detected peak, (Default 61, maximum 81)",
"delete_tmp_files": "Whether to delete all temporary files after a successful run",
"delete_tmp_files": "Delete temporary files created during sorting (matlab files and the `temp_wh.dat` file that "
"contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deletes all files) "
"or a Tuple containing the files to delete. Options are: ('temp_wh.dat', 'matlab_files')",
"delete_recording_dat": "Whether to delete the 'recording.dat' file after a successful run",
}

Expand Down
6 changes: 4 additions & 2 deletions src/spikeinterface/sorters/external/kilosort2.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class Kilosort2Sorter(KilosortBase, BaseSorter):
"skip_kilosort_preprocessing": False,
"scaleproc": None,
"save_rez_to_mat": False,
"delete_tmp_files": True,
"delete_tmp_files": ("matlab_files",),
"delete_recording_dat": False,
}

Expand All @@ -73,7 +73,9 @@ class Kilosort2Sorter(KilosortBase, BaseSorter):
"skip_kilosort_preprocessing": "Can optionaly skip the internal kilosort preprocessing",
"scaleproc": "int16 scaling of whitened data, if None set to 200.",
"save_rez_to_mat": "Save the full rez internal struc to mat file",
"delete_tmp_files": "Whether to delete all temporary files after a successful run",
"delete_tmp_files": "Delete temporary files created during sorting (matlab files and the `temp_wh.dat` file that "
"contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deletes all files) "
"or a Tuple containing the files to delete. Options are: ('temp_wh.dat', 'matlab_files')",
"delete_recording_dat": "Whether to delete the 'recording.dat' file after a successful run",
}

Expand Down
6 changes: 4 additions & 2 deletions src/spikeinterface/sorters/external/kilosort2_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Kilosort2_5Sorter(KilosortBase, BaseSorter):
"skip_kilosort_preprocessing": False,
"scaleproc": None,
"save_rez_to_mat": False,
"delete_tmp_files": True,
"delete_tmp_files": ("matlab_files",),
"delete_recording_dat": False,
}

Expand All @@ -83,7 +83,9 @@ class Kilosort2_5Sorter(KilosortBase, BaseSorter):
"skip_kilosort_preprocessing": "Can optionaly skip the internal kilosort preprocessing",
"scaleproc": "int16 scaling of whitened data, if None set to 200.",
"save_rez_to_mat": "Save the full rez internal struc to mat file",
"delete_tmp_files": "Whether to delete all temporary files after a successful run",
"delete_tmp_files": "Delete temporary files created during sorting (matlab files and the `temp_wh.dat` file that "
"contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deletes all files) "
"or a Tuple containing the files to delete. Options are: ('temp_wh.dat', 'matlab_files') ",
"delete_recording_dat": "Whether to delete the 'recording.dat' file after a successful run",
}

Expand Down
6 changes: 4 additions & 2 deletions src/spikeinterface/sorters/external/kilosort3.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class Kilosort3Sorter(KilosortBase, BaseSorter):
"skip_kilosort_preprocessing": False,
"scaleproc": None,
"save_rez_to_mat": False,
"delete_tmp_files": True,
"delete_tmp_files": ("matlab_files",),
"delete_recording_dat": False,
}

Expand All @@ -80,7 +80,9 @@ class Kilosort3Sorter(KilosortBase, BaseSorter):
"skip_kilosort_preprocessing": "Can optionaly skip the internal kilosort preprocessing",
"scaleproc": "int16 scaling of whitened data, if None set to 200.",
"save_rez_to_mat": "Save the full rez internal struc to mat file",
"delete_tmp_files": "Whether to delete all temporary files after a successful run",
"delete_tmp_files": "Delete temporary files created during sorting (matlab files and the `temp_wh.dat` file that "
"contains kilosort-preprocessed data). Accepts `False` (deletes no files), `True` (deletes all files) "
"or a Tuple containing the files to delete. Options are: ('temp_wh.dat', 'matlab_files')",
"delete_recording_dat": "Whether to delete the 'recording.dat' file after a successful run",
}

Expand Down
33 changes: 26 additions & 7 deletions src/spikeinterface/sorters/external/kilosortbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,16 +215,35 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
raise Exception(f"{cls.sorter_name} returned a non-zero exit code")

# Clean-up temporary files
if params["delete_tmp_files"]:
for temp_file in sorter_output_folder.glob("*.m"):
temp_file.unlink()
for temp_file in sorter_output_folder.glob("*.mat"):
temp_file.unlink()
if (sorter_output_folder / "temp_wh.dat").exists():
(sorter_output_folder / "temp_wh.dat").unlink()
if params["delete_recording_dat"] and (recording_file := sorter_output_folder / "recording.dat").exists():
recording_file.unlink()

all_tmp_files = ("matlab_files", "temp_wh.dat")

if isinstance(params["delete_tmp_files"], bool):
if params["delete_tmp_files"]:
tmp_files_to_remove = all_tmp_files
else:
tmp_files_to_remove = ()
else:
assert isinstance(
params["delete_tmp_files"], (tuple, list)
), "`delete_tmp_files` must be a `Bool`, `Tuple` or `List`."

for name in params["delete_tmp_files"]:
assert name in all_tmp_files, f"{name} is not a valid option, must be one of: {all_tmp_files}"

tmp_files_to_remove = params["delete_tmp_files"]

if "temp_wh.dat" in tmp_files_to_remove:
if (temp_wh_file := sorter_output_folder / "temp_wh.dat").exists():
temp_wh_file.unlink()

if "matlab_files" in tmp_files_to_remove:
for ext in ["*.m", "*.mat"]:
for temp_file in sorter_output_folder.glob(ext):
temp_file.unlink()

@classmethod
def _get_result_from_folder(cls, sorter_output_folder):
sorter_output_folder = Path(sorter_output_folder)
Expand Down
Loading

0 comments on commit 3fe5d07

Please sign in to comment.