Skip to content

Commit

Permalink
calculate_pc_metrics -> compute_pc_metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
zm711 committed May 29, 2024
1 parent 790715c commit 0e99c4c
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 7 deletions.
32 changes: 30 additions & 2 deletions src/spikeinterface/qualitymetrics/pca_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,14 @@ def get_quality_pca_metric_list():
return deepcopy(_possible_pc_metric_names)


def calculate_pc_metrics(
sorting_analyzer, metric_names=None, qm_params=None, unit_ids=None, seed=None, n_jobs=1, progress_bar=False
def compute_pc_metrics(
sorting_analyzer,
metric_names=None,
qm_params=None,
unit_ids=None,
seed=None,
n_jobs=1,
progress_bar=False,
):
"""Calculate principal component derived metrics.
Expand Down Expand Up @@ -180,6 +186,28 @@ def calculate_pc_metrics(
return pc_metrics


def calculate_pc_metrics(
sorting_analyzer, metric_names=None, qm_params=None, unit_ids=None, seed=None, n_jobs=1, progress_bar=False
):
warnings.warn(
"The `calculate_pc_metrics` function is deprecated and will be removed in 0.103.0. Please use compute_pc_metrics instead",
category=DeprecationWarning,
stacklevel=2,
)

pc_metrics = compute_pc_metrics(
sorting_analyzer,
metric_names=metric_names,
qm_params=qm_params,
unit_ids=unit_ids,
seed=seed,
n_jobs=n_jobs,
progress_bar=progress_bar,
)

return pc_metrics


#################################################################
# Code from spikemetrics

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension


from .quality_metric_list import calculate_pc_metrics, _misc_metric_name_to_func, _possible_pc_metric_names
from .quality_metric_list import compute_pc_metrics, _misc_metric_name_to_func, _possible_pc_metric_names
from .misc_metrics import _default_params as misc_metrics_params
from .pca_metrics import _default_params as pca_metrics_params

Expand Down Expand Up @@ -143,7 +143,7 @@ def _run(self, verbose=False, **job_kwargs):
if len(pc_metric_names) > 0 and not self.params["skip_pc_metrics"]:
if not self.sorting_analyzer.has_extension("principal_components"):
raise ValueError("waveform_principal_component must be provied")
pc_metrics = calculate_pc_metrics(
pc_metrics = compute_pc_metrics(
self.sorting_analyzer,
unit_ids=non_empty_unit_ids,
metric_names=pc_metric_names,
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/qualitymetrics/tests/test_pca_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from spikeinterface.qualitymetrics.utils import create_ground_truth_pc_distributions

from spikeinterface.qualitymetrics import (
calculate_pc_metrics,
compute_pc_metrics,
nearest_neighbors_isolation,
nearest_neighbors_noise_overlap,
)
Expand Down Expand Up @@ -56,10 +56,10 @@ def sorting_analyzer_simple():

def test_calculate_pc_metrics(sorting_analyzer_simple):
sorting_analyzer = sorting_analyzer_simple
res1 = calculate_pc_metrics(sorting_analyzer, n_jobs=1, progress_bar=True)
res1 = compute_pc_metrics(sorting_analyzer, n_jobs=1, progress_bar=True)
res1 = pd.DataFrame(res1)

res2 = calculate_pc_metrics(sorting_analyzer, n_jobs=2, progress_bar=True)
res2 = compute_pc_metrics(sorting_analyzer, n_jobs=2, progress_bar=True)
res2 = pd.DataFrame(res2)

for k in res1.columns:
Expand Down

0 comments on commit 0e99c4c

Please sign in to comment.