From 285343af338c0337fcc17312632c94c8179d8a14 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 17 Oct 2023 08:55:27 -0400 Subject: [PATCH 1/4] add additional assert info --- src/spikeinterface/curation/curation_tools.py | 22 ++++++++++++++----- .../curation/curationsorting.py | 12 +++++----- .../curation/remove_redundant.py | 12 ++++++---- .../curation/splitunitsorting.py | 11 +++++----- 4 files changed, 36 insertions(+), 21 deletions(-) diff --git a/src/spikeinterface/curation/curation_tools.py b/src/spikeinterface/curation/curation_tools.py index 38ff1f62c5..ddf7d4dc9d 100644 --- a/src/spikeinterface/curation/curation_tools.py +++ b/src/spikeinterface/curation/curation_tools.py @@ -1,3 +1,4 @@ +from __future__ import annotations from typing import Optional import numpy as np @@ -9,9 +10,15 @@ except ModuleNotFoundError as err: HAVE_NUMBA = False +_methods = ("keep_first", "random", "keep_last", "keep_first_iterative", "keep_last_iterative") +_methods_numpy = ("keep_first", "random", "keep_last") + def _find_duplicated_spikes_numpy( - spike_train: np.ndarray, censored_period: int, seed: Optional[int] = None, method: str = "keep_first" + spike_train: np.ndarray, + censored_period: int, + seed: Optional[int] = None, + method: "keep_first" | "random" | "keep_last" = "keep_first", ) -> np.ndarray: (indices_of_duplicates,) = np.where(np.diff(spike_train) <= censored_period) @@ -29,7 +36,9 @@ def _find_duplicated_spikes_numpy( (indices_of_duplicates,) = np.where(~mask) elif method != "keep_last": - raise ValueError(f"Method '{method}' isn't a valid method for _find_duplicated_spikes_numpy.") + raise ValueError( + f"Method '{method}' isn't a valid method for _find_duplicated_spikes_numpy use one of {_methods_numpy}." + ) return indices_of_duplicates @@ -84,7 +93,10 @@ def _find_duplicated_spikes_keep_last_iterative(spike_train, censored_period): def find_duplicated_spikes( - spike_train, censored_period: int, method: str = "random", seed: Optional[int] = None + spike_train, + censored_period: int, + method: "keep_first" | "keep_last" | "keep_first_iterative" | "keep_last_iterative" | "random" = "random", + seed: Optional[int] = None, ) -> np.ndarray: """ Finds the indices where spikes should be considered duplicates. @@ -97,7 +109,7 @@ def find_duplicated_spikes( The spike train on which to look for duplicated spikes. censored_period: int The censored period for duplicates (in sample time). - method: str in ("keep_first", "keep_last", "keep_first_iterative', 'keep_last_iterative", random") + method: "keep_first" |"keep_last" | "keep_first_iterative' | 'keep_last_iterative" |random" Method used to remove the duplicated spikes. seed: int | None The seed to use if method="random". @@ -120,4 +132,4 @@ def find_duplicated_spikes( assert HAVE_NUMBA, "'keep_last' method requires numba. Install it with >>> pip install numba" return _find_duplicated_spikes_keep_last_iterative(spike_train.astype(np.int64), censored_period) else: - raise ValueError(f"Method '{method}' isn't a valid method for find_duplicated_spikes.") + raise ValueError(f"Method '{method}' isn't a valid method for find_duplicated_spikes. Use one of {_methods}") diff --git a/src/spikeinterface/curation/curationsorting.py b/src/spikeinterface/curation/curationsorting.py index f2776bafe6..bdb33e9eb1 100644 --- a/src/spikeinterface/curation/curationsorting.py +++ b/src/spikeinterface/curation/curationsorting.py @@ -148,24 +148,24 @@ def remove_empty_units(self): edges = None self._add_new_stage(new_sorting, edges) - def redo_avaiable(self): + def redo_available(self): # useful function for a gui return self._sorting_stages_i < len(self._sorting_stages) - def undo_avaiable(self): + def undo_available(self): # useful function for a gui return self._sorting_stages_i > 0 def undo(self): - if self.undo_avaiable(): + if self.undo_available(): self._sorting_stages_i -= 1 def redo(self): - if self.redo_avaiable(): + if self.redo_available(): self._sorting_stages_i += 1 def draw_graph(self, **kwargs): - assert self._make_graph, "to make a graph make_graph=True" + assert self._make_graph, "to make a graph use make_graph=True" graph = self.graph ids = [c.unit_id for c in graph.nodes] pos = {n: (n.stage_id, -ids.index(n.unit_id)) for n in graph.nodes} @@ -174,7 +174,7 @@ def draw_graph(self, **kwargs): @property def graph(self): - assert self._make_graph, "to have a graph make_graph=True" + assert self._make_graph, "to have a graph use make_graph=True" return self._graphs[self._sorting_stages_i] @property diff --git a/src/spikeinterface/curation/remove_redundant.py b/src/spikeinterface/curation/remove_redundant.py index c2617d5b52..e13f83550a 100644 --- a/src/spikeinterface/curation/remove_redundant.py +++ b/src/spikeinterface/curation/remove_redundant.py @@ -1,3 +1,4 @@ +from __future__ import annotations import numpy as np from spikeinterface import WaveformExtractor @@ -6,6 +7,9 @@ from ..postprocessing import align_sorting +_remove_strategies = ("minimum_shift", "highest_amplitude", "max_spikes") + + def remove_redundant_units( sorting_or_waveform_extractor, align=True, @@ -42,7 +46,7 @@ def remove_redundant_units( duplicate_threshold : float, optional Final threshold on the portion of coincident events over the number of spikes above which the unit is removed, by default 0.8 - remove_strategy: str + remove_strategy: 'minimum_shift' | 'highest_amplitude' | 'max_spikes', default: 'minimum_shift' Which strategy to remove one of the two duplicated units: * 'minimum_shift': keep the unit with best peak alignment (minimum shift) @@ -50,7 +54,7 @@ def remove_redundant_units( * 'highest_amplitude': keep the unit with the best amplitude on unshifted max. * 'max_spikes': keep the unit with more spikes - peak_sign: str ('neg', 'pos', 'both') + peak_sign: 'neg' |'pos' | 'both', default: 'neg' Used when remove_strategy='highest_amplitude' extra_outputs: bool If True, will return the redundant pairs. @@ -93,7 +97,7 @@ def remove_redundant_units( peak_values = {unit_id: np.max(np.abs(values)) for unit_id, values in peak_values.items()} if remove_strategy == "minimum_shift": - assert align, "remove_strategy with minimum_shift need align=True" + assert align, "remove_strategy with minimum_shift needs align=True" for u1, u2 in redundant_unit_pairs: if np.abs(unit_peak_shifts[u1]) > np.abs(unit_peak_shifts[u2]): remove_unit_ids.append(u1) @@ -125,7 +129,7 @@ def remove_redundant_units( # this will be implemented in a futur PR by the first who need it! raise NotImplementedError() else: - raise ValueError(f"remove_strategy : {remove_strategy} is not implemented!") + raise ValueError(f"remove_strategy : {remove_strategy} is not implemented! Options are {_remove_strategies}") sorting_clean = sorting.remove_units(remove_unit_ids) diff --git a/src/spikeinterface/curation/splitunitsorting.py b/src/spikeinterface/curation/splitunitsorting.py index 816d62cf9f..23863a85e5 100644 --- a/src/spikeinterface/curation/splitunitsorting.py +++ b/src/spikeinterface/curation/splitunitsorting.py @@ -21,11 +21,10 @@ class SplitUnitSorting(BaseSorting): be the same length as the spike train (for each segment) new_unit_ids: int Unit ids of the new units to be created. - properties_policy: str + properties_policy: 'keep' | 'remove', default: 'keep' Policy used to propagate properties. If 'keep' the properties will be passed to the new units (if the units_to_merge have the same value). If 'remove' the new units will have an empty value for all the properties of the new unit. - Default: 'keep' Returns ------- sorting: Sorting @@ -48,19 +47,19 @@ def __init__(self, parent_sorting, split_unit_id, indices_list, new_unit_ids=Non new_unit_ids = np.array([u + new_unit_ids for u in range(tot_splits)], dtype=parents_unit_ids.dtype) else: new_unit_ids = np.array(new_unit_ids, dtype=parents_unit_ids.dtype) - assert len(np.unique(new_unit_ids)) == len(new_unit_ids), "Each element in new_unit_ids should be unique" - assert len(new_unit_ids) <= tot_splits, "indices_list have more ids indices than the length of new_unit_ids" + assert len(np.unique(new_unit_ids)) == len(new_unit_ids), "Each element in new_unit_ids must be unique" + assert len(new_unit_ids) <= tot_splits, "indices_list has more id indices than the length of new_unit_ids" assert parent_sorting.get_num_segments() == len( indices_list ), "The length of indices_list must be the same as parent_sorting.get_num_segments" - assert split_unit_id in parents_unit_ids, "Unit to split should be in parent sorting" + assert split_unit_id in parents_unit_ids, "Unit to split must be in parent sorting" assert properties_policy == "keep" or properties_policy == "remove", ( "properties_policy must be " "keep" " or " "remove" "" ) assert not any( np.isin(new_unit_ids, unchanged_units) - ), "new_unit_ids should be new units or one could be equal to split_unit_id" + ), "new_unit_ids should be new unit ids or no more than one unit id can be found in split_unit_id" sampling_frequency = parent_sorting.get_sampling_frequency() units_ids = np.concatenate([unchanged_units, new_unit_ids]) From 34854aa0cd115618ec3a99ce6503d3f28569cfe9 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 17 Oct 2023 09:00:28 -0400 Subject: [PATCH 2/4] add segment number to assert message --- src/spikeinterface/exporters/to_phy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 31a452f389..2c916d33b5 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -78,7 +78,7 @@ def export_to_phy( ), "waveform_extractor must be a WaveformExtractor object" sorting = waveform_extractor.sorting - assert waveform_extractor.get_num_segments() == 1, "Export to phy only works with one segment" + assert waveform_extractor.get_num_segments() == 1, f"Export to phy only works with one segment, your extractor has {waveform_extractor.get_num_segments()} segments" num_chans = waveform_extractor.get_num_channels() fs = waveform_extractor.sampling_frequency From 11677223154481d275ff3029af74d00c7d723f1c Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 17 Oct 2023 09:07:27 -0400 Subject: [PATCH 3/4] working on assert messaging --- src/spikeinterface/sorters/basesorter.py | 6 +++--- src/spikeinterface/sorters/launcher.py | 4 ++-- src/spikeinterface/sorters/sorterlist.py | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index a956f8c811..139f15bf12 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -103,7 +103,7 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo ) if not isinstance(recording, BaseRecordingSnippets): - raise ValueError("recording must be a Recording or Snippets!!") + raise ValueError("recording must be a Recording or a Snippets!!") if cls.requires_locations: locations = recording.get_channel_locations() @@ -133,7 +133,7 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo if recording.get_num_segments() > 1: if not cls.handle_multi_segment: raise ValueError( - f"This sorter {cls.sorter_name} do not handle multi segment, use si.concatenate_recordings(...)" + f"This sorter {cls.sorter_name} does not handle multi-segment recordings, use si.concatenate_recordings(...)" ) rec_file = output_folder / "spikeinterface_recording.json" @@ -299,7 +299,7 @@ def get_result_from_folder(cls, output_folder, register_recording=True, sorting_ # check errors in log file log_file = output_folder / "spikeinterface_log.json" if not log_file.is_file(): - raise SpikeSortingError("get result error: the folder does not contain the `spikeinterface_log.json` file") + raise SpikeSortingError("Get result error: the folder does not contain the `spikeinterface_log.json` file") with log_file.open("r", encoding="utf8") as f: log = json.load(f) diff --git a/src/spikeinterface/sorters/launcher.py b/src/spikeinterface/sorters/launcher.py index 704f6843f2..e7fdedcfe7 100644 --- a/src/spikeinterface/sorters/launcher.py +++ b/src/spikeinterface/sorters/launcher.py @@ -374,7 +374,7 @@ def run_sorters( mode_if_folder_exists in ("raise", "keep", "overwrite") if mode_if_folder_exists == "raise" and working_folder.is_dir(): - raise Exception("working_folder already exists, please remove it") + raise Exception(f"working_folder {working_folder} already exists, please remove it") assert engine in _implemented_engine, f"engine must be in {_implemented_engine}" @@ -390,7 +390,7 @@ def run_sorters( elif isinstance(recording_dict_or_list, dict): recording_dict = recording_dict_or_list else: - raise ValueError("bad recording dict") + raise ValueError("Wrong format for recording_dict_or_list") dtype_rec_name = np.dtype(type(list(recording_dict.keys())[0])) assert dtype_rec_name.kind in ("i", "u", "S", "U"), "Dict keys can only be integers or strings!" diff --git a/src/spikeinterface/sorters/sorterlist.py b/src/spikeinterface/sorters/sorterlist.py index 40b5cdebaa..761bb6d716 100644 --- a/src/spikeinterface/sorters/sorterlist.py +++ b/src/spikeinterface/sorters/sorterlist.py @@ -89,7 +89,7 @@ def get_default_sorter_params(sorter_name_or_class): elif sorter_name_or_class in sorter_full_list: SorterClass = sorter_name_or_class else: - raise (ValueError("Unknown sorter")) + raise (ValueError(f"Unknown sorter {sorter_name_or_class} has been given")) return SorterClass.default_params() @@ -113,7 +113,7 @@ def get_sorter_params_description(sorter_name_or_class): elif sorter_name_or_class in sorter_full_list: SorterClass = sorter_name_or_class else: - raise (ValueError("Unknown sorter")) + raise (ValueError(f"Unknown sorter {sorter_name_or_class} has been given")) return SorterClass.params_description() @@ -137,6 +137,6 @@ def get_sorter_description(sorter_name_or_class): elif sorter_name_or_class in sorter_full_list: SorterClass = sorter_name_or_class else: - raise (ValueError("Unknown sorter")) + raise (ValueError(f"Unknown sorter {sorter_name_or_class} has been given")) return SorterClass.sorter_description From 44ad0ef0f0e29973e6e6c05fc2b992ee755db89b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Oct 2023 13:11:14 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/exporters/to_phy.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index 2c916d33b5..0529c99d12 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -78,7 +78,9 @@ def export_to_phy( ), "waveform_extractor must be a WaveformExtractor object" sorting = waveform_extractor.sorting - assert waveform_extractor.get_num_segments() == 1, f"Export to phy only works with one segment, your extractor has {waveform_extractor.get_num_segments()} segments" + assert ( + waveform_extractor.get_num_segments() == 1 + ), f"Export to phy only works with one segment, your extractor has {waveform_extractor.get_num_segments()} segments" num_chans = waveform_extractor.get_num_channels() fs = waveform_extractor.sampling_frequency