Skip to content

Commit

Permalink
Merge pull request #2119 from zm711/asserts-sortingcomp
Browse files Browse the repository at this point in the history
Improve assert messaging (sortingcomponents)
  • Loading branch information
samuelgarcia authored Oct 23, 2023
2 parents 8b31b33 + 3751aec commit c66733b
Show file tree
Hide file tree
Showing 10 changed files with 55 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class BenchmarkClustering:
def __init__(self, recording, gt_sorting, method, exhaustive_gt=True, tmp_folder=None, job_kwargs={}, verbose=True):
self.method = method

assert method in clustering_methods, "Clustering method should be in %s" % clustering_methods.keys()
assert method in clustering_methods, f"Clustering method should be in {clustering_methods.keys()}"

self.verbose = verbose
self.recording = recording
Expand Down
12 changes: 6 additions & 6 deletions src/spikeinterface/sortingcomponents/clustering/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ def _check_params(cls, recording, peaks, params):
elif params["waveform_mode"] == "shared_memory":
assert tmp_folder is None, "tmp_folder must be None for shared_memory"
else:
raise ValueError("shared_memory")
raise ValueError("'waveform_mode' must be 'memmap' or 'shared_memory'")

return params2

@classmethod
def main_function(cls, recording, peaks, params):
assert HAVE_HDBSCAN, "twisted clustering need hdbscan to be installed"
assert HAVE_HDBSCAN, "twisted clustering needs hdbscan to be installed"

params = cls._check_params(recording, peaks, params)
d = params
Expand Down Expand Up @@ -110,7 +110,7 @@ def main_function(cls, recording, peaks, params):
if params["waveform_mode"] == "shared_memory":
wf_folder = None
else:
assert params["tmp_folder"] is not None
assert params["tmp_folder"] is not None, "tmp_folder must be supplied"
wf_folder = params["tmp_folder"] / "sparse_snippets"
wf_folder.mkdir()

Expand Down Expand Up @@ -225,13 +225,13 @@ def main_function(cls, recording, peaks, params):
if params["waveform_mode"] == "shared_memory":
wf_folder = None
else:
assert params["tmp_folder"] is not None
assert params["tmp_folder"] is not None, "tmp_folder must be supplied"
wf_folder = params["tmp_folder"] / "dense_snippets"
wf_folder.mkdir()

cleaning_method = params["cleaning_method"]

print("We found %d raw clusters, starting to clean with %s..." % (len(labels), cleaning_method))
print(f"We found {len(labels)} raw clusters, starting to clean with {cleaning_method}...")

if cleaning_method == "cosine":
wfs_arrays = extract_waveforms_to_buffers(
Expand Down Expand Up @@ -288,6 +288,6 @@ def main_function(cls, recording, peaks, params):
labels, peak_labels = remove_duplicates_via_matching(we, peak_labels, job_kwargs=params["job_kwargs"])
shutil.rmtree(tmp_folder)

print("We kept %d non-duplicated clusters..." % len(labels))
print(f"We kept {len(labels)} non-duplicated clusters...")

return labels, peak_labels
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class PositionAndFeaturesClustering:

@classmethod
def main_function(cls, recording, peaks, params):
assert HAVE_HDBSCAN, "twisted clustering need hdbscan to be installed"
assert HAVE_HDBSCAN, "twisted clustering needs hdbscan to be installed"

if "n_jobs" in params["job_kwargs"]:
if params["job_kwargs"]["n_jobs"] == -1:
Expand Down
9 changes: 8 additions & 1 deletion src/spikeinterface/sortingcomponents/features_from_peaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@ def compute_features_from_peaks(
feature_list: List of features to be computed.
- amplitude
- ptp
- com
- center_of_mass
- energy
- std_ptp
- ptp_lag
- random_projections_ptp
- random_projections_energy
ms_before: float
The duration in ms before the peak for extracting the features (default 1 ms)
ms_after: float
Expand All @@ -61,6 +65,9 @@ def compute_features_from_peaks(
extract_dense_waveforms,
]
for feature_name in feature_list:
assert (
feature_name in _features_class.keys()
), f"Feature {feature_name} in 'feature_list' is not possible. Possible features are {list(_features_class.keys())}"
Class = _features_class[feature_name]
params = feature_params.get(feature_name, {}).copy()
node = Class(recording, parents=[peak_retriever, extract_dense_waveforms], **params)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sortingcomponents/matching/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def find_spikes_from_templates(recording, method="naive", method_kwargs={}, extr
"""
from .method_list import matching_methods

assert method in matching_methods, "The method %s is not a valid one" % method
assert method in matching_methods, f"The 'method' {method} is not valid. Use a method from {matching_methods}"

job_kwargs = fix_job_kwargs(job_kwargs)

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/sortingcomponents/matching/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def initialize_and_check_kwargs(cls, recording, kwargs):
d = cls.default_params.copy()
d.update(kwargs)

assert d["waveform_extractor"] is not None
assert d["waveform_extractor"] is not None, "'waveform_extractor' must be supplied"

we = d["waveform_extractor"]

Expand Down
7 changes: 5 additions & 2 deletions src/spikeinterface/sortingcomponents/matching/tdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,15 @@ class TridesclousPeeler(BaseTemplateMatchingEngine):

@classmethod
def initialize_and_check_kwargs(cls, recording, kwargs):
assert HAVE_NUMBA, "TridesclousPeeler need numba to be installed"
assert HAVE_NUMBA, "TridesclousPeeler needs numba to be installed"

d = cls.default_params.copy()
d.update(kwargs)

assert isinstance(d["waveform_extractor"], WaveformExtractor)
assert isinstance(d["waveform_extractor"], WaveformExtractor), (
f"The waveform_extractor supplied is of type {type(d['waveform_extractor'])} "
f"and must be a WaveformExtractor"
)

we = d["waveform_extractor"]
unit_ids = we.unit_ids
Expand Down
23 changes: 16 additions & 7 deletions src/spikeinterface/sortingcomponents/motion_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ class IterativeTemplateRegistration:
See https://www.science.org/doi/abs/10.1126/science.abf4588?cookieSet=1
Ported by Alessio Buccino in SpikeInterface
Ported by Alessio Buccino into SpikeInterface
"""

name = "iterative_template"
Expand Down Expand Up @@ -824,7 +824,7 @@ def compute_pairwise_displacement(
from scipy import sparse
from scipy import linalg

assert conv_engine in ("torch", "numpy")
assert conv_engine in ("torch", "numpy"), f"'conv_engine' must be 'torch' or 'numpy'"
size = motion_hist.shape[0]
pairwise_displacement = np.zeros((size, size), dtype="float32")

Expand Down Expand Up @@ -890,7 +890,7 @@ def compute_pairwise_displacement(
try:
import skimage.registration
except ImportError:
raise ImportError("To use 'phase_cross_correlation' method install scikit-image")
raise ImportError("To use the 'phase_cross_correlation' method install scikit-image")

errors = np.zeros((size, size), dtype="float32")
loop = range(size)
Expand All @@ -906,7 +906,10 @@ def compute_pairwise_displacement(
correlation = 1 - errors

else:
raise ValueError(f"method does not exist for compute_pairwise_displacement {method}")
raise ValueError(
f"method {method} does not exist for compute_pairwise_displacement. Current possible methods are"
f" 'conv' or 'phase_cross_correlation'"
)

if weight_scale == "linear":
# between 0 and 1
Expand All @@ -925,6 +928,9 @@ def compute_pairwise_displacement(
return pairwise_displacement, pairwise_displacement_weight


_possible_convergence_method = ("lsmr", "gradient_descent", "lsqr_robust")


def compute_global_displacement(
pairwise_displacement,
pairwise_displacement_weight=None,
Expand Down Expand Up @@ -1166,7 +1172,10 @@ def jac(p):

displacement = displacement.reshape(B, T).T
else:
raise ValueError(f"Method {convergence_method} doesn't exist for compute_global_displacement")
raise ValueError(
f"Method {convergence_method} doesn't exist for compute_global_displacement"
f" possible values for 'convergence_method' are {_possible_convergence_method}"
)

return np.squeeze(displacement)

Expand Down Expand Up @@ -1371,7 +1380,7 @@ def normxcorr1d(
conv1d = scipy_conv1d
npx = np
else:
raise ValueError(f"Unknown conv_engine {conv_engine}")
raise ValueError(f"Unknown conv_engine {conv_engine}. 'conv_engine' must be 'torch' or 'numpy'")

x = npx.atleast_2d(x)
num_templates, length = template.shape
Expand Down Expand Up @@ -1452,7 +1461,7 @@ def scipy_conv1d(input, weights, padding="valid"):
input = np.pad(input, [*[(0, 0)] * (input.ndim - 1), (padding, padding)])
length_out = length - (kernel_size - 1) + 2 * padding
else:
raise ValueError(f"Unknown padding {padding}")
raise ValueError(f"Unknown 'padding' value of {padding}, 'padding' must be 'same', 'valid' or an integer")

output = np.zeros((n, c_out, length_out), dtype=input.dtype)
for m in range(n):
Expand Down
15 changes: 13 additions & 2 deletions src/spikeinterface/sortingcomponents/peak_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def select_peak_indices(peaks, method, seed, **method_kwargs):
params.update(method_kwargs)

assert params["n_peaks"] is not None, "n_peaks should be defined!"
assert params["peaks_locations"] is not None, "peaks_locations should be d96efined!"
assert params["peaks_locations"] is not None, "peaks_locations should be defined!"

nb_spikes = len(params["peaks_locations"]["x"])

Expand Down Expand Up @@ -252,8 +252,19 @@ def select_peak_indices(peaks, method, seed, **method_kwargs):
selected_indices = [rng.permutation(my_selection)[: params["n_peaks"]]]

else:
raise NotImplementedError(f"No method {method} for peaks selection")
raise NotImplementedError(
f"The 'method' {method} does not exist for peaks selection." f" possible methods are {_possible_methods}"
)

selected_indices = np.concatenate(selected_indices)
selected_indices = selected_indices[np.argsort(peaks[selected_indices]["sample_index"])]
return selected_indices


_possible_methods = (
"uniform",
"uniform_locations",
"smart_sampling_amplitudes",
"smart_sampling_locations",
"smart_sampling_locations_and_time",
)
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def __init__(
return_output=return_output,
parents=parents,
)
assert feature in ["ptp", "mean", "energy", "peak_voltage"], f"{feature} is not a valid feature"
assert feature in ["ptp", "mean", "energy", "peak_voltage"], (
f"{feature} is not a valid feature" " must be one of 'ptp', 'mean', 'energy'," " or 'peak_voltage'"
)

self.threshold = threshold
self.feature = feature
Expand Down

0 comments on commit c66733b

Please sign in to comment.