Skip to content

Commit

Permalink
Merge pull request #2158 from DradeAW/wvf_extension_typo
Browse files Browse the repository at this point in the history
`WaveformExtractor.is_extension` --> `has_extension`
  • Loading branch information
alejoe91 authored Nov 15, 2023
2 parents e172cad + 4cbf2f8 commit e4c5d11
Show file tree
Hide file tree
Showing 11 changed files with 41 additions and 33 deletions.
20 changes: 14 additions & 6 deletions src/spikeinterface/core/waveform_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ def get_recording_property(self, key) -> np.ndarray:
def get_sorting_property(self, key) -> np.ndarray:
return self.sorting.get_property(key)

def get_extension_class(self, extension_name):
def get_extension_class(self, extension_name: str):
"""
Get extension class from name and check if registered.
Expand All @@ -525,7 +525,7 @@ def get_extension_class(self, extension_name):
ext_class = extensions_dict[extension_name]
return ext_class

def is_extension(self, extension_name) -> bool:
def has_extension(self, extension_name: str) -> bool:
"""
Check if the extension exists in memory or in the folder.
Expand Down Expand Up @@ -556,7 +556,15 @@ def is_extension(self, extension_name) -> bool:
and "params" in self._waveforms_root[extension_name].attrs.keys()
)

def load_extension(self, extension_name):
def is_extension(self, extension_name) -> bool:
warn(
"WaveformExtractor.is_extension is deprecated and will be removed in version 0.102.0! Use `has_extension` instead.",
DeprecationWarning,
stacklevel=2,
)
return self.has_extension(extension_name)

def load_extension(self, extension_name: str):
"""
Load an extension from its name.
The module of the extension must be loaded and registered.
Expand All @@ -572,7 +580,7 @@ def load_extension(self, extension_name):
The loaded instance of the extension
"""
if self.folder is not None and extension_name not in self._loaded_extensions:
if self.is_extension(extension_name):
if self.has_extension(extension_name):
ext_class = self.get_extension_class(extension_name)
ext = ext_class.load(self.folder, self)
if extension_name not in self._loaded_extensions:
Expand All @@ -588,7 +596,7 @@ def delete_extension(self, extension_name) -> None:
extension_name: str
The extension name.
"""
assert self.is_extension(extension_name), f"The extension {extension_name} is not available"
assert self.has_extension(extension_name), f"The extension {extension_name} is not available"
del self._loaded_extensions[extension_name]
if self.folder is not None and (self.folder / extension_name).is_dir():
shutil.rmtree(self.folder / extension_name)
Expand All @@ -610,7 +618,7 @@ def get_available_extension_names(self):
"""
extension_names_in_folder = []
for extension_class in self.extensions:
if self.is_extension(extension_class.extension_name):
if self.has_extension(extension_class.extension_name):
extension_names_in_folder.append(extension_class.extension_name)
return extension_names_in_folder

Expand Down
8 changes: 4 additions & 4 deletions src/spikeinterface/exporters/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def export_report(
unit_ids = sorting.unit_ids

# load or compute spike_amplitudes
if we.is_extension("spike_amplitudes"):
if we.has_extension("spike_amplitudes"):
spike_amplitudes = we.load_extension("spike_amplitudes").get_data(outputs="by_unit")
elif force_computation:
spike_amplitudes = compute_spike_amplitudes(we, peak_sign=peak_sign, outputs="by_unit", **job_kwargs)
Expand All @@ -62,7 +62,7 @@ def export_report(
)

# load or compute quality_metrics
if we.is_extension("quality_metrics"):
if we.has_extension("quality_metrics"):
metrics = we.load_extension("quality_metrics").get_data()
elif force_computation:
metrics = compute_quality_metrics(we)
Expand All @@ -73,7 +73,7 @@ def export_report(
)

# load or compute correlograms
if we.is_extension("correlograms"):
if we.has_extension("correlograms"):
correlograms, bins = we.load_extension("correlograms").get_data()
elif force_computation:
correlograms, bins = compute_correlograms(we, window_ms=100.0, bin_ms=1.0)
Expand All @@ -84,7 +84,7 @@ def export_report(
)

# pre-compute unit locations if not done
if not we.is_extension("unit_locations"):
if not we.has_extension("unit_locations"):
unit_locations = compute_unit_locations(we)

output_folder = Path(output_folder).absolute()
Expand Down
8 changes: 4 additions & 4 deletions src/spikeinterface/exporters/to_phy.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def export_to_phy(
templates[unit_ind, :, :][:, : len(chan_inds)] = template
templates_ind[unit_ind, : len(chan_inds)] = chan_inds

if waveform_extractor.is_extension("similarity"):
if waveform_extractor.has_extension("similarity"):
tmc = waveform_extractor.load_extension("similarity")
template_similarity = tmc.get_data()
else:
Expand All @@ -219,7 +219,7 @@ def export_to_phy(
np.save(str(output_folder / "channel_groups.npy"), channel_groups)

if compute_amplitudes:
if waveform_extractor.is_extension("spike_amplitudes"):
if waveform_extractor.has_extension("spike_amplitudes"):
sac = waveform_extractor.load_extension("spike_amplitudes")
amplitudes = sac.get_data(outputs="concatenated")
else:
Expand All @@ -231,7 +231,7 @@ def export_to_phy(
np.save(str(output_folder / "amplitudes.npy"), amplitudes)

if compute_pc_features:
if waveform_extractor.is_extension("principal_components"):
if waveform_extractor.has_extension("principal_components"):
pc = waveform_extractor.load_extension("principal_components")
else:
pc = compute_principal_components(
Expand Down Expand Up @@ -264,7 +264,7 @@ def export_to_phy(
channel_group = pd.DataFrame({"cluster_id": [i for i in range(len(unit_ids))], "channel_group": unit_groups})
channel_group.to_csv(output_folder / "cluster_channel_group.tsv", sep="\t", index=False)

if waveform_extractor.is_extension("quality_metrics"):
if waveform_extractor.has_extension("quality_metrics"):
qm = waveform_extractor.load_extension("quality_metrics")
qm_data = qm.get_data()
for column_name in qm_data.columns:
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/postprocessing/principal_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@ def compute_principal_components(
>>> pc.run_for_all_spikes(file_path="all_pca_projections.npy")
"""

if load_if_exists and waveform_extractor.is_extension(WaveformPrincipalComponent.extension_name):
if load_if_exists and waveform_extractor.has_extension(WaveformPrincipalComponent.extension_name):
pc = waveform_extractor.load_extension(WaveformPrincipalComponent.extension_name)
else:
pc = WaveformPrincipalComponent.create(waveform_extractor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _test_extension_folder(self, we, in_memory=False):

# reload as an extension from we
assert self.extension_class.extension_name in we.get_available_extension_names()
assert we.is_extension(self.extension_class.extension_name)
assert we.has_extension(self.extension_class.extension_name)
ext = we.load_extension(self.extension_class.extension_name)
assert isinstance(ext, self.extension_class)
for ext_name in self.extension_data_names:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test_project_new(self):
from sklearn.decomposition import IncrementalPCA

we = self.we1
if we.is_extension("principal_components"):
if we.has_extension("principal_components"):
we.delete_extension("principal_components")
we_cp = we.select_units(we.unit_ids, self.cache_folder / "toy_waveforms_1seg_cp")

Expand Down
10 changes: 5 additions & 5 deletions src/spikeinterface/qualitymetrics/misc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def compute_snrs(
snrs : dict
Computed signal to noise ratio for each unit.
"""
if waveform_extractor.is_extension("noise_levels"):
if waveform_extractor.has_extension("noise_levels"):
noise_levels = waveform_extractor.load_extension("noise_levels").get_data()
else:
if random_chunk_kwargs_dict is None:
Expand Down Expand Up @@ -687,7 +687,7 @@ def compute_amplitude_cv_metrics(
if unit_ids is None:
unit_ids = sorting.unit_ids

if waveform_extractor.is_extension(amplitude_extension):
if waveform_extractor.has_extension(amplitude_extension):
sac = waveform_extractor.load_extension(amplitude_extension)
amps = sac.get_data(outputs="concatenated")
if amplitude_extension == "spike_amplitudes":
Expand Down Expand Up @@ -803,7 +803,7 @@ def compute_amplitude_cutoffs(

spike_amplitudes = None
invert_amplitudes = False
if waveform_extractor.is_extension("spike_amplitudes"):
if waveform_extractor.has_extension("spike_amplitudes"):
amp_calculator = waveform_extractor.load_extension("spike_amplitudes")
spike_amplitudes = amp_calculator.get_data(outputs="by_unit")
if amp_calculator._params["peak_sign"] == "pos":
Expand Down Expand Up @@ -881,7 +881,7 @@ def compute_amplitude_medians(waveform_extractor, peak_sign="neg", unit_ids=None
extremum_channels_ids = get_template_extremum_channel(waveform_extractor, peak_sign=peak_sign)

spike_amplitudes = None
if waveform_extractor.is_extension("spike_amplitudes"):
if waveform_extractor.has_extension("spike_amplitudes"):
amp_calculator = waveform_extractor.load_extension("spike_amplitudes")
spike_amplitudes = amp_calculator.get_data(outputs="by_unit")

Expand Down Expand Up @@ -974,7 +974,7 @@ def compute_drift_metrics(
if unit_ids is None:
unit_ids = sorting.unit_ids

if waveform_extractor.is_extension("spike_locations"):
if waveform_extractor.has_extension("spike_locations"):
locs_calculator = waveform_extractor.load_extension("spike_locations")
spike_locations = locs_calculator.get_data(outputs="concatenated")
spike_locations_by_unit = locs_calculator.get_data(outputs="by_unit")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ def _set_params(
if metric_names is None:
metric_names = list(_misc_metric_name_to_func.keys())
# if PC is available, PC metrics are automatically added to the list
if self.waveform_extractor.is_extension("principal_components"):
if self.waveform_extractor.has_extension("principal_components"):
# by default 'nearest_neightbor' is removed because too slow
pc_metrics = _possible_pc_metric_names.copy()
pc_metrics.remove("nn_isolation")
pc_metrics.remove("nn_noise_overlap")
metric_names += pc_metrics
# if spike_locations are not available, drift is removed from the list
if not self.waveform_extractor.is_extension("spike_locations"):
if not self.waveform_extractor.has_extension("spike_locations"):
if "drift" in metric_names:
metric_names.remove("drift")

Expand Down Expand Up @@ -130,7 +130,7 @@ def _run(self, verbose, **job_kwargs):
# metrics based on PCs
pc_metric_names = [k for k in metric_names if k in _possible_pc_metric_names]
if len(pc_metric_names) > 0 and not self._params["skip_pc_metrics"]:
if not self.waveform_extractor.is_extension("principal_components"):
if not self.waveform_extractor.has_extension("principal_components"):
raise ValueError("waveform_principal_component must be provied")
pc_extension = self.waveform_extractor.load_extension("principal_components")
pc_metrics = calculate_pc_metrics(
Expand Down Expand Up @@ -216,7 +216,7 @@ def compute_quality_metrics(
metrics: pandas.DataFrame
Data frame with the computed metrics
"""
if load_if_exists and waveform_extractor.is_extension(QualityMetricCalculator.extension_name):
if load_if_exists and waveform_extractor.has_extension(QualityMetricCalculator.extension_name):
qmc = waveform_extractor.load_extension(QualityMetricCalculator.extension_name)
else:
qmc = QualityMetricCalculator(waveform_extractor)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_metrics(self):
we = self.we_long

# avoid NaNs
if we.is_extension("spike_amplitudes"):
if we.has_extension("spike_amplitudes"):
we.delete_extension("spike_amplitudes")

# without PC
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def check_extensions(waveform_extractor, extensions):
error_msg = ""
raise_error = False
for extension in extensions:
if not waveform_extractor.is_extension(extension):
if not waveform_extractor.has_extension(extension):
raise_error = True
error_msg += (
f"The {extension} waveform extension is required for this widget. "
Expand Down
10 changes: 5 additions & 5 deletions src/spikeinterface/widgets/unit_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
fig = self.figure
nrows = 2
ncols = 3
if we.is_extension("correlograms") or we.is_extension("spike_amplitudes"):
if we.has_extension("correlograms") or we.has_extension("spike_amplitudes"):
ncols += 1
if we.is_extension("spike_amplitudes"):
if we.has_extension("spike_amplitudes"):
nrows += 1
gs = fig.add_gridspec(nrows, ncols)

if we.is_extension("unit_locations"):
if we.has_extension("unit_locations"):
ax1 = fig.add_subplot(gs[:2, 0])
# UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1)
w = UnitLocationsWidget(
Expand Down Expand Up @@ -129,7 +129,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
)
ax3.set_ylabel(None)

if we.is_extension("correlograms"):
if we.has_extension("correlograms"):
ax4 = fig.add_subplot(gs[:2, 3])
AutoCorrelogramsWidget(
we,
Expand All @@ -142,7 +142,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
ax4.set_title(None)
ax4.set_yticks([])

if we.is_extension("spike_amplitudes"):
if we.has_extension("spike_amplitudes"):
ax5 = fig.add_subplot(gs[2, :3])
ax6 = fig.add_subplot(gs[2, 3])
axes = np.array([ax5, ax6])
Expand Down

0 comments on commit e4c5d11

Please sign in to comment.