Skip to content

Commit

Permalink
Merge pull request #2078 from zm711/clean-asserts
Browse files Browse the repository at this point in the history
Improve assert messages (preprocessing & core)
  • Loading branch information
alejoe91 authored Oct 6, 2023
2 parents cdc1ccb + 9db087d commit 0bf1b89
Show file tree
Hide file tree
Showing 19 changed files with 68 additions and 57 deletions.
6 changes: 3 additions & 3 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, main_ids: Sequence) -> None:
self._kwargs = {}

# 'main_ids' will either be channel_ids or units_ids
# They is used for properties
# They are used for properties
self._main_ids = np.array(main_ids)

# dict at object level
Expand Down Expand Up @@ -984,7 +984,7 @@ def _load_extractor_from_dict(dic) -> BaseExtractor:
class_name = None

if "kwargs" not in dic:
raise Exception(f"This dict cannot be load into extractor {dic}")
raise Exception(f"This dict cannot be loaded into extractor {dic}")

# Create new kwargs to avoid modifying the original dict["kwargs"]
new_kwargs = dict()
Expand All @@ -1005,7 +1005,7 @@ def _load_extractor_from_dict(dic) -> BaseExtractor:
assert extractor_class is not None and class_name is not None, "Could not load spikeinterface class"
if not _check_same_version(class_name, dic["version"]):
warnings.warn(
f"Versions are not the same. This might lead compatibility errors. "
f"Versions are not the same. This might lead to compatibility errors. "
f"Using {class_name.split('.')[0]}=={dic['version']} is recommended"
)

Expand Down
7 changes: 4 additions & 3 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ def get_traces(

if not self.has_scaled():
raise ValueError(
"This recording do not support return_scaled=True (need gain_to_uV and offset_" "to_uV properties)"
"This recording does not support return_scaled=True (need gain_to_uV and offset_"
"to_uV properties)"
)
else:
gains = self.get_property("gain_to_uV")
Expand Down Expand Up @@ -416,8 +417,8 @@ def set_times(self, times, segment_index=None, with_warning=True):
if with_warning:
warn(
"Setting times with Recording.set_times() is not recommended because "
"times are not always propagated to across preprocessing"
"Use use this carefully!"
"times are not always propagated across preprocessing"
"Use this carefully!"
)

def sample_index_to_time(self, sample_ind, segment_index=None):
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def register_recording(self, recording, check_spike_frames=True):
if check_spike_frames:
if has_exceeding_spikes(recording, self):
warnings.warn(
"Some spikes are exceeding the recording's duration! "
"Some spikes exceed the recording's duration! "
"Removing these excess spikes with `spikeinterface.curation.remove_excess_spikes()` "
"Might be necessary for further postprocessing."
)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/binaryrecordingextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
file_path_list = [Path(file_paths)]

if t_starts is not None:
assert len(t_starts) == len(file_path_list), "t_starts must be a list of same size than file_paths"
assert len(t_starts) == len(file_path_list), "t_starts must be a list of the same size as file_paths"
t_starts = [float(t_start) for t_start in t_starts]

dtype = np.dtype(dtype)
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/channelsaggregationrecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,11 @@ def __init__(self, channel_map, parent_segments):
times_kargs0 = parent_segment0.get_times_kwargs()
if times_kargs0["time_vector"] is None:
for ps in parent_segments:
assert ps.get_times_kwargs()["time_vector"] is None, "All segment should not have times set"
assert ps.get_times_kwargs()["time_vector"] is None, "All segments should not have times set"
else:
for ps in parent_segments:
assert ps.get_times_kwargs()["t_start"] == times_kargs0["t_start"], (
"All segment should have the same " "t_start"
"All segments should have the same " "t_start"
)

BaseRecordingSegment.__init__(self, **times_kargs0)
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/channelslice.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None)
), "ChannelSliceRecording: renamed channel_ids must be the same size"
assert (
self._channel_ids.size == np.unique(self._channel_ids).size
), "ChannelSliceRecording : channel_ids not unique"
), "ChannelSliceRecording : channel_ids are not unique"

sampling_frequency = parent_recording.get_sampling_frequency()

Expand Down Expand Up @@ -123,7 +123,7 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None):
), "ChannelSliceSnippets: renamed channel_ids must be the same size"
assert (
self._channel_ids.size == np.unique(self._channel_ids).size
), "ChannelSliceSnippets : channel_ids not unique"
), "ChannelSliceSnippets : channel_ids are not unique"

sampling_frequency = parent_snippets.get_sampling_frequency()

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/frameslicerecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class FrameSliceRecording(BaseRecording):
def __init__(self, parent_recording, start_frame=None, end_frame=None):
channel_ids = parent_recording.get_channel_ids()

assert parent_recording.get_num_segments() == 1, "FrameSliceRecording work only with one segment"
assert parent_recording.get_num_segments() == 1, "FrameSliceRecording only works with one segment"

parent_size = parent_recording.get_num_samples(0)
if start_frame is None:
Expand Down
8 changes: 4 additions & 4 deletions src/spikeinterface/core/frameslicesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class FrameSliceSorting(BaseSorting):
def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike_frames=True):
unit_ids = parent_sorting.get_unit_ids()

assert parent_sorting.get_num_segments() == 1, "FrameSliceSorting work only with one segment"
assert parent_sorting.get_num_segments() == 1, "FrameSliceSorting only works with one segment"

if start_frame is None:
start_frame = 0
Expand All @@ -49,10 +49,10 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike
end_frame = parent_n_samples
assert (
end_frame <= parent_n_samples
), "`end_frame` should be smaller than the sortings total number of samples."
), "`end_frame` should be smaller than the sortings' total number of samples."
assert (
start_frame <= parent_n_samples
), "`start_frame` should be smaller than the sortings total number of samples."
), "`start_frame` should be smaller than the sortings' total number of samples."
if check_spike_frames and has_exceeding_spikes(parent_sorting._recording, parent_sorting):
raise ValueError(
"The sorting object has spikes exceeding the recording duration. You have to remove those spikes "
Expand All @@ -67,7 +67,7 @@ def __init__(self, parent_sorting, start_frame=None, end_frame=None, check_spike
end_frame = max_spike_time + 1

assert start_frame < end_frame, (
"`start_frame` should be greater than `end_frame`. "
"`start_frame` should be less than `end_frame`. "
"This may be due to start_frame >= max_spike_time, if the end frame "
"was not specified explicitly."
)
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,11 +1101,11 @@ def __init__(
# handle also upsampling and jitter
upsample_factor = templates.shape[3]
elif templates.ndim == 5:
# handle also dirft
# handle also drift
raise NotImplementedError("Drift will be implented soon...")
# upsample_factor = templates.shape[3]
else:
raise ValueError("templates have wring dim should 3 or 4")
raise ValueError("templates have wrong dim should 3 or 4")

if upsample_factor is not None:
assert upsample_vector is not None
Expand Down
48 changes: 28 additions & 20 deletions src/spikeinterface/core/template_tools.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
from __future__ import annotations
import numpy as np
import warnings

from .sparsity import compute_sparsity, _sparsity_doc
from .recording_tools import get_channel_distances, get_noise_levels


def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: str = "extremum"):
def get_template_amplitudes(
waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "extremum"
):
"""
Get amplitude per channel for each unit.
Parameters
----------
waveform_extractor: WaveformExtractor
The waveform extractor
peak_sign: str
Sign of the template to compute best channels ('neg', 'pos', 'both')
mode: str
peak_sign: "neg" | "pos" | "both", default: "neg"
Sign of the template to compute best channels
mode: "extremum" | "at_index", default: "extremum"
'extremum': max or min
'at_index': take value at spike index
Expand All @@ -24,8 +27,8 @@ def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: st
peak_values: dict
Dictionary with unit ids as keys and template amplitudes as values
"""
assert peak_sign in ("both", "neg", "pos")
assert mode in ("extremum", "at_index")
assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'both', 'neg', or 'pos'"
assert mode in ("extremum", "at_index"), "'mode' must be 'extremum' or 'at_index'"
unit_ids = waveform_extractor.sorting.unit_ids

before = waveform_extractor.nbefore
Expand Down Expand Up @@ -57,7 +60,10 @@ def get_template_amplitudes(waveform_extractor, peak_sign: str = "neg", mode: st


def get_template_extremum_channel(
waveform_extractor, peak_sign: str = "neg", mode: str = "extremum", outputs: str = "id"
waveform_extractor,
peak_sign: "neg" | "pos" | "both" = "neg",
mode: "extremum" | "at_index" = "extremum",
outputs: "id" | "index" = "id",
):
"""
Compute the channel with the extremum peak for each unit.
Expand All @@ -66,12 +72,12 @@ def get_template_extremum_channel(
----------
waveform_extractor: WaveformExtractor
The waveform extractor
peak_sign: str
Sign of the template to compute best channels ('neg', 'pos', 'both')
mode: str
peak_sign: "neg" | "pos" | "both", default: "neg"
Sign of the template to compute best channels
mode: "extremum" | "at_index", default: "extremum"
'extremum': max or min
'at_index': take value at spike index
outputs: str
outputs: "id" | "index", default: "id"
* 'id': channel id
* 'index': channel index
Expand Down Expand Up @@ -159,7 +165,7 @@ def get_template_channel_sparsity(
get_template_channel_sparsity.__doc__ = get_template_channel_sparsity.__doc__.format(_sparsity_doc)


def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: str = "neg"):
def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg"):
"""
In some situations spike sorters could return a spike index with a small shift related to the waveform peak.
This function estimates and return these alignment shifts for the mean template.
Expand All @@ -169,8 +175,8 @@ def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: str
----------
waveform_extractor: WaveformExtractor
The waveform extractor
peak_sign: str
Sign of the template to compute best channels ('neg', 'pos', 'both')
peak_sign: "neg" | "pos" | "both", default: "neg"
Sign of the template to compute best channels
Returns
-------
Expand Down Expand Up @@ -203,17 +209,19 @@ def get_template_extremum_channel_peak_shift(waveform_extractor, peak_sign: str
return shifts


def get_template_extremum_amplitude(waveform_extractor, peak_sign: str = "neg", mode: str = "at_index"):
def get_template_extremum_amplitude(
waveform_extractor, peak_sign: "neg" | "pos" | "both" = "neg", mode: "extremum" | "at_index" = "at_index"
):
"""
Computes amplitudes on the best channel.
Parameters
----------
waveform_extractor: WaveformExtractor
The waveform extractor
peak_sign: str
Sign of the template to compute best channels ('neg', 'pos', 'both')
mode: str
peak_sign: "neg" | "pos" | "both"
Sign of the template to compute best channels
mode: "extremum" | "at_index", default: "at_index"
Where the amplitude is computed
'extremum': max or min
'at_index': take value at spike index
Expand All @@ -223,8 +231,8 @@ def get_template_extremum_amplitude(waveform_extractor, peak_sign: str = "neg",
amplitudes: dict
Dictionary with unit ids as keys and amplitudes as values
"""
assert peak_sign in ("both", "neg", "pos")
assert mode in ("extremum", "at_index")
assert peak_sign in ("both", "neg", "pos"), "'peak_sign' must be 'neg' or 'pos' or 'both'"
assert mode in ("extremum", "at_index"), "'mode' must be 'extremum' or 'at_index'"
unit_ids = waveform_extractor.sorting.unit_ids

before = waveform_extractor.nbefore
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/unitsaggregationsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(self, sorting_list, renamed_unit_ids=None):
try:
property_dict[prop_name] = np.concatenate((property_dict[prop_name], values))
except Exception as e:
print(f"Skipping property '{prop_name}' for shape inconsistency")
print(f"Skipping property '{prop_name}' due to shape inconsistency")
del property_dict[prop_name]
break
for prop_name, prop_values in property_dict.items():
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(
chunk_size=500,
seed=0,
):
assert direction in ("upper", "lower", "both")
assert direction in ("upper", "lower", "both"), "'direction' must be 'upper', 'lower', or 'both'"

if fill_value is None or quantile_threshold is not None:
random_data = get_random_data_chunks(
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/preprocessing/common_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(
ref_channel_ids = np.asarray(ref_channel_ids)
assert np.all(
[ch in recording.get_channel_ids() for ch in ref_channel_ids]
), "Some wrong 'ref_channel_ids'!"
), "Some 'ref_channel_ids' are wrong!"
elif reference == "local":
assert groups is None, "With 'local' CAR, the group option should not be used."
closest_inds, dist = get_closest_channels(recording)
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/preprocessing/detect_bad_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,9 @@ def detect_bad_channels(

if bad_channel_ids.size > recording.get_num_channels() / 3:
warnings.warn(
"Over 1/3 of channels are detected as bad. In the precense of a high"
"Over 1/3 of channels are detected as bad. In the presence of a high"
"number of dead / noisy channels, bad channel detection may fail "
"(erroneously label good channels as dead)."
"(good channels may be erroneously labeled as dead)."
)

elif method == "neighborhood_r2":
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/preprocessing/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ def __init__(
):
import scipy.signal

assert filter_mode in ("sos", "ba")
assert filter_mode in ("sos", "ba"), "'filter' mode must be 'sos' or 'ba'"
fs = recording.get_sampling_frequency()
if coeff is None:
assert btype in ("bandpass", "highpass")
assert btype in ("bandpass", "highpass"), "'bytpe' must be 'bandpass' or 'highpass'"
# coefficient
# self.coeff is 'sos' or 'ab' style
filter_coeff = scipy.signal.iirfilter(
Expand Down Expand Up @@ -258,7 +258,7 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None):
if dtype.kind == "u":
raise TypeError(
"The notch filter only supports signed types. Use the 'dtype' argument"
"to specify a signed type (e.g. 'int16', 'float32'"
"to specify a signed type (e.g. 'int16', 'float32')"
)

BasePreprocessor.__init__(self, recording, dtype=dtype)
Expand Down
12 changes: 6 additions & 6 deletions src/spikeinterface/preprocessing/filter_opencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def __init__(
margin_ms=5.0,
):
assert HAVE_PYOPENCL, "You need to install pyopencl (and GPU driver!!)"

assert btype in ("bandpass", "lowpass", "highpass", "bandstop")
assert filter_mode in ("sos",)
btype_modes = ("bandpass", "lowpass", "highpass", "bandstop")
assert btype in btype_modes, f"'btype' must be in {btype_modes}"
assert filter_mode in ("sos",), "'filter_mode' must be 'sos'"

# coefficient
sf = recording.get_sampling_frequency()
Expand Down Expand Up @@ -96,8 +96,8 @@ def __init__(self, parent_recording_segment, executor, margin):
self.margin = margin

def get_traces(self, start_frame, end_frame, channel_indices):
assert start_frame is not None, "FilterOpenCLRecording work with fixed chunk_size"
assert end_frame is not None, "FilterOpenCLRecording work with fixed chunk_size"
assert start_frame is not None, "FilterOpenCLRecording only works with fixed chunk_size"
assert end_frame is not None, "FilterOpenCLRecording only works with fixed chunk_size"

chunk_size = end_frame - start_frame
if chunk_size != self.executor.chunk_size:
Expand Down Expand Up @@ -157,7 +157,7 @@ def process(self, traces):

if traces.shape[0] != self.full_size:
if self.full_size is not None:
print(f"Warning : chunk_size have change {self.chunk_size} {traces.shape[0]}, need recompile CL!!!")
print(f"Warning : chunk_size has changed {self.chunk_size} {traces.shape[0]}, need to recompile CL!!!")
self.create_buffers_and_compile()

event = pyopencl.enqueue_copy(self.queue, self.input_cl, traces)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def get_traces(self, start_frame, end_frame, channel_indices):
traces = traces * self.taper[np.newaxis, :]

# apply actual HP filter
import scipy
import scipy.signal

traces = scipy.signal.sosfiltfilt(self.sos_filter, traces, axis=1)

Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/preprocessing/normalize_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(
dtype="float32",
**random_chunk_kwargs,
):
assert mode in ("pool_channel", "by_channel")
assert mode in ("pool_channel", "by_channel"), "'mode' must be 'pool_channel' or 'by_channel'"

random_data = get_random_data_chunks(recording, **random_chunk_kwargs)

Expand Down Expand Up @@ -260,7 +260,7 @@ def __init__(
dtype="float32",
**random_chunk_kwargs,
):
assert mode in ("median+mad", "mean+std")
assert mode in ("median+mad", "mean+std"), "'mode' must be 'median+mad' or 'mean+std'"

# fix dtype
dtype_ = fix_dtype(recording, dtype)
Expand Down
Loading

0 comments on commit 0bf1b89

Please sign in to comment.