diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index c97a727340..a81d36139d 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -506,7 +506,7 @@ def get_recording_property(self, key) -> np.ndarray: def get_sorting_property(self, key) -> np.ndarray: return self.sorting.get_property(key) - def get_extension_class(self, extension_name): + def get_extension_class(self, extension_name: str): """ Get extension class from name and check if registered. @@ -525,7 +525,7 @@ def get_extension_class(self, extension_name): ext_class = extensions_dict[extension_name] return ext_class - def is_extension(self, extension_name) -> bool: + def has_extension(self, extension_name: str) -> bool: """ Check if the extension exists in memory or in the folder. @@ -556,7 +556,15 @@ def is_extension(self, extension_name) -> bool: and "params" in self._waveforms_root[extension_name].attrs.keys() ) - def load_extension(self, extension_name): + def is_extension(self, extension_name) -> bool: + warn( + "WaveformExtractor.is_extension is deprecated and will be removed in version 0.102.0! Use `has_extension` instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.has_extension(extension_name) + + def load_extension(self, extension_name: str): """ Load an extension from its name. The module of the extension must be loaded and registered. @@ -572,7 +580,7 @@ def load_extension(self, extension_name): The loaded instance of the extension """ if self.folder is not None and extension_name not in self._loaded_extensions: - if self.is_extension(extension_name): + if self.has_extension(extension_name): ext_class = self.get_extension_class(extension_name) ext = ext_class.load(self.folder, self) if extension_name not in self._loaded_extensions: @@ -588,7 +596,7 @@ def delete_extension(self, extension_name) -> None: extension_name: str The extension name. """ - assert self.is_extension(extension_name), f"The extension {extension_name} is not available" + assert self.has_extension(extension_name), f"The extension {extension_name} is not available" del self._loaded_extensions[extension_name] if self.folder is not None and (self.folder / extension_name).is_dir(): shutil.rmtree(self.folder / extension_name) @@ -610,7 +618,7 @@ def get_available_extension_names(self): """ extension_names_in_folder = [] for extension_class in self.extensions: - if self.is_extension(extension_class.extension_name): + if self.has_extension(extension_class.extension_name): extension_names_in_folder.append(extension_class.extension_name) return extension_names_in_folder diff --git a/src/spikeinterface/exporters/report.py b/src/spikeinterface/exporters/report.py index 57a5ab0166..8b14930859 100644 --- a/src/spikeinterface/exporters/report.py +++ b/src/spikeinterface/exporters/report.py @@ -51,7 +51,7 @@ def export_report( unit_ids = sorting.unit_ids # load or compute spike_amplitudes - if we.is_extension("spike_amplitudes"): + if we.has_extension("spike_amplitudes"): spike_amplitudes = we.load_extension("spike_amplitudes").get_data(outputs="by_unit") elif force_computation: spike_amplitudes = compute_spike_amplitudes(we, peak_sign=peak_sign, outputs="by_unit", **job_kwargs) @@ -62,7 +62,7 @@ def export_report( ) # load or compute quality_metrics - if we.is_extension("quality_metrics"): + if we.has_extension("quality_metrics"): metrics = we.load_extension("quality_metrics").get_data() elif force_computation: metrics = compute_quality_metrics(we) @@ -73,7 +73,7 @@ def export_report( ) # load or compute correlograms - if we.is_extension("correlograms"): + if we.has_extension("correlograms"): correlograms, bins = we.load_extension("correlograms").get_data() elif force_computation: correlograms, bins = compute_correlograms(we, window_ms=100.0, bin_ms=1.0) @@ -84,7 +84,7 @@ def export_report( ) # pre-compute unit locations if not done - if not we.is_extension("unit_locations"): + if not we.has_extension("unit_locations"): unit_locations = compute_unit_locations(we) output_folder = Path(output_folder).absolute() diff --git a/src/spikeinterface/exporters/to_phy.py b/src/spikeinterface/exporters/to_phy.py index ecc5b316ec..607aa3e846 100644 --- a/src/spikeinterface/exporters/to_phy.py +++ b/src/spikeinterface/exporters/to_phy.py @@ -196,7 +196,7 @@ def export_to_phy( templates[unit_ind, :, :][:, : len(chan_inds)] = template templates_ind[unit_ind, : len(chan_inds)] = chan_inds - if waveform_extractor.is_extension("similarity"): + if waveform_extractor.has_extension("similarity"): tmc = waveform_extractor.load_extension("similarity") template_similarity = tmc.get_data() else: @@ -219,7 +219,7 @@ def export_to_phy( np.save(str(output_folder / "channel_groups.npy"), channel_groups) if compute_amplitudes: - if waveform_extractor.is_extension("spike_amplitudes"): + if waveform_extractor.has_extension("spike_amplitudes"): sac = waveform_extractor.load_extension("spike_amplitudes") amplitudes = sac.get_data(outputs="concatenated") else: @@ -231,7 +231,7 @@ def export_to_phy( np.save(str(output_folder / "amplitudes.npy"), amplitudes) if compute_pc_features: - if waveform_extractor.is_extension("principal_components"): + if waveform_extractor.has_extension("principal_components"): pc = waveform_extractor.load_extension("principal_components") else: pc = compute_principal_components( @@ -264,7 +264,7 @@ def export_to_phy( channel_group = pd.DataFrame({"cluster_id": [i for i in range(len(unit_ids))], "channel_group": unit_groups}) channel_group.to_csv(output_folder / "cluster_channel_group.tsv", sep="\t", index=False) - if waveform_extractor.is_extension("quality_metrics"): + if waveform_extractor.has_extension("quality_metrics"): qm = waveform_extractor.load_extension("quality_metrics") qm_data = qm.get_data() for column_name in qm_data.columns: diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index cf32e79b25..effd87007f 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -750,7 +750,7 @@ def compute_principal_components( >>> pc.run_for_all_spikes(file_path="all_pca_projections.npy") """ - if load_if_exists and waveform_extractor.is_extension(WaveformPrincipalComponent.extension_name): + if load_if_exists and waveform_extractor.has_extension(WaveformPrincipalComponent.extension_name): pc = waveform_extractor.load_extension(WaveformPrincipalComponent.extension_name) else: pc = WaveformPrincipalComponent.create(waveform_extractor) diff --git a/src/spikeinterface/postprocessing/tests/common_extension_tests.py b/src/spikeinterface/postprocessing/tests/common_extension_tests.py index b539bbd5d4..2bef246bc2 100644 --- a/src/spikeinterface/postprocessing/tests/common_extension_tests.py +++ b/src/spikeinterface/postprocessing/tests/common_extension_tests.py @@ -143,7 +143,7 @@ def _test_extension_folder(self, we, in_memory=False): # reload as an extension from we assert self.extension_class.extension_name in we.get_available_extension_names() - assert we.is_extension(self.extension_class.extension_name) + assert we.has_extension(self.extension_class.extension_name) ext = we.load_extension(self.extension_class.extension_name) assert isinstance(ext, self.extension_class) for ext_name in self.extension_data_names: diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 49591d9b89..f5e315b18f 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -135,7 +135,7 @@ def test_project_new(self): from sklearn.decomposition import IncrementalPCA we = self.we1 - if we.is_extension("principal_components"): + if we.has_extension("principal_components"): we.delete_extension("principal_components") we_cp = we.select_units(we.unit_ids, self.cache_folder / "toy_waveforms_1seg_cp") diff --git a/src/spikeinterface/qualitymetrics/misc_metrics.py b/src/spikeinterface/qualitymetrics/misc_metrics.py index 5c734b9100..9dab06124b 100644 --- a/src/spikeinterface/qualitymetrics/misc_metrics.py +++ b/src/spikeinterface/qualitymetrics/misc_metrics.py @@ -201,7 +201,7 @@ def compute_snrs( snrs : dict Computed signal to noise ratio for each unit. """ - if waveform_extractor.is_extension("noise_levels"): + if waveform_extractor.has_extension("noise_levels"): noise_levels = waveform_extractor.load_extension("noise_levels").get_data() else: if random_chunk_kwargs_dict is None: @@ -687,7 +687,7 @@ def compute_amplitude_cv_metrics( if unit_ids is None: unit_ids = sorting.unit_ids - if waveform_extractor.is_extension(amplitude_extension): + if waveform_extractor.has_extension(amplitude_extension): sac = waveform_extractor.load_extension(amplitude_extension) amps = sac.get_data(outputs="concatenated") if amplitude_extension == "spike_amplitudes": @@ -803,7 +803,7 @@ def compute_amplitude_cutoffs( spike_amplitudes = None invert_amplitudes = False - if waveform_extractor.is_extension("spike_amplitudes"): + if waveform_extractor.has_extension("spike_amplitudes"): amp_calculator = waveform_extractor.load_extension("spike_amplitudes") spike_amplitudes = amp_calculator.get_data(outputs="by_unit") if amp_calculator._params["peak_sign"] == "pos": @@ -881,7 +881,7 @@ def compute_amplitude_medians(waveform_extractor, peak_sign="neg", unit_ids=None extremum_channels_ids = get_template_extremum_channel(waveform_extractor, peak_sign=peak_sign) spike_amplitudes = None - if waveform_extractor.is_extension("spike_amplitudes"): + if waveform_extractor.has_extension("spike_amplitudes"): amp_calculator = waveform_extractor.load_extension("spike_amplitudes") spike_amplitudes = amp_calculator.get_data(outputs="by_unit") @@ -974,7 +974,7 @@ def compute_drift_metrics( if unit_ids is None: unit_ids = sorting.unit_ids - if waveform_extractor.is_extension("spike_locations"): + if waveform_extractor.has_extension("spike_locations"): locs_calculator = waveform_extractor.load_extension("spike_locations") spike_locations = locs_calculator.get_data(outputs="concatenated") spike_locations_by_unit = locs_calculator.get_data(outputs="by_unit") diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 53309db282..90a1f7206e 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -42,14 +42,14 @@ def _set_params( if metric_names is None: metric_names = list(_misc_metric_name_to_func.keys()) # if PC is available, PC metrics are automatically added to the list - if self.waveform_extractor.is_extension("principal_components"): + if self.waveform_extractor.has_extension("principal_components"): # by default 'nearest_neightbor' is removed because too slow pc_metrics = _possible_pc_metric_names.copy() pc_metrics.remove("nn_isolation") pc_metrics.remove("nn_noise_overlap") metric_names += pc_metrics # if spike_locations are not available, drift is removed from the list - if not self.waveform_extractor.is_extension("spike_locations"): + if not self.waveform_extractor.has_extension("spike_locations"): if "drift" in metric_names: metric_names.remove("drift") @@ -130,7 +130,7 @@ def _run(self, verbose, **job_kwargs): # metrics based on PCs pc_metric_names = [k for k in metric_names if k in _possible_pc_metric_names] if len(pc_metric_names) > 0 and not self._params["skip_pc_metrics"]: - if not self.waveform_extractor.is_extension("principal_components"): + if not self.waveform_extractor.has_extension("principal_components"): raise ValueError("waveform_principal_component must be provied") pc_extension = self.waveform_extractor.load_extension("principal_components") pc_metrics = calculate_pc_metrics( @@ -216,7 +216,7 @@ def compute_quality_metrics( metrics: pandas.DataFrame Data frame with the computed metrics """ - if load_if_exists and waveform_extractor.is_extension(QualityMetricCalculator.extension_name): + if load_if_exists and waveform_extractor.has_extension(QualityMetricCalculator.extension_name): qmc = waveform_extractor.load_extension(QualityMetricCalculator.extension_name) else: qmc = QualityMetricCalculator(waveform_extractor) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index eb8317e4df..b601e5d6d8 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -88,7 +88,7 @@ def test_metrics(self): we = self.we_long # avoid NaNs - if we.is_extension("spike_amplitudes"): + if we.has_extension("spike_amplitudes"): we.delete_extension("spike_amplitudes") # without PC diff --git a/src/spikeinterface/widgets/base.py b/src/spikeinterface/widgets/base.py index a5d3cb2429..6ff837065b 100644 --- a/src/spikeinterface/widgets/base.py +++ b/src/spikeinterface/widgets/base.py @@ -107,7 +107,7 @@ def check_extensions(waveform_extractor, extensions): error_msg = "" raise_error = False for extension in extensions: - if not waveform_extractor.is_extension(extension): + if not waveform_extractor.has_extension(extension): raise_error = True error_msg += ( f"The {extension} waveform extension is required for this widget. " diff --git a/src/spikeinterface/widgets/unit_summary.py b/src/spikeinterface/widgets/unit_summary.py index 35fde07326..aa280ad658 100644 --- a/src/spikeinterface/widgets/unit_summary.py +++ b/src/spikeinterface/widgets/unit_summary.py @@ -80,13 +80,13 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): fig = self.figure nrows = 2 ncols = 3 - if we.is_extension("correlograms") or we.is_extension("spike_amplitudes"): + if we.has_extension("correlograms") or we.has_extension("spike_amplitudes"): ncols += 1 - if we.is_extension("spike_amplitudes"): + if we.has_extension("spike_amplitudes"): nrows += 1 gs = fig.add_gridspec(nrows, ncols) - if we.is_extension("unit_locations"): + if we.has_extension("unit_locations"): ax1 = fig.add_subplot(gs[:2, 0]) # UnitLocationsPlotter().do_plot(dp.plot_data_unit_locations, ax=ax1) w = UnitLocationsWidget( @@ -129,7 +129,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ) ax3.set_ylabel(None) - if we.is_extension("correlograms"): + if we.has_extension("correlograms"): ax4 = fig.add_subplot(gs[:2, 3]) AutoCorrelogramsWidget( we, @@ -142,7 +142,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): ax4.set_title(None) ax4.set_yticks([]) - if we.is_extension("spike_amplitudes"): + if we.has_extension("spike_amplitudes"): ax5 = fig.add_subplot(gs[2, :3]) ax6 = fig.add_subplot(gs[2, 3]) axes = np.array([ax5, ax6])