Skip to content

Commit

Permalink
Removed required_metrics arg from prediction class & func
Browse files Browse the repository at this point in the history
  • Loading branch information
jakeswann1 committed Jun 21, 2024
1 parent b763922 commit 108f09a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 20 deletions.
19 changes: 7 additions & 12 deletions src/spikeinterface/curation/model_based_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,15 @@ class ModelBasedClassification:
The sorting analyzer object containing the spike sorting data.
pipeline : Pipeline
The pipeline object representing the trained classification model.
required_metrics : Sequence[str]
The list of required metrics for classification.
Attributes
----------
sorting_analyzer : SortingAnalyzer
The sorting analyzer object containing the spike sorting data.
required_metrics : Sequence[str]
The list of required metrics for classification.
pipeline : Pipeline
The pipeline object representing the trained classification model.
required_metrics : Sequence[str]
The list of required metrics for classification, extracted from the pipeline.
Methods
-------
Expand All @@ -38,15 +36,16 @@ class ModelBasedClassification:
Checks if the parameters for classification match the training parameters.
"""

def __init__(self, sorting_analyzer: SortingAnalyzer, pipeline, required_metrics: Sequence[str]):
def __init__(self, sorting_analyzer: SortingAnalyzer, pipeline):
from sklearn.pipeline import Pipeline

if not isinstance(pipeline, Pipeline):
raise ValueError("The pipeline must be an instance of sklearn.pipeline.Pipeline")

self.sorting_analyzer = sorting_analyzer
self.required_metrics = required_metrics
self.pipeline = pipeline
self.required_metrics = pipeline.feature_names_in_


def predict_labels(self):
"""
Expand All @@ -73,8 +72,6 @@ def predict_labels(self):
predictions = self.pipeline.predict(input_data)
probabilities = self.pipeline.predict_proba(input_data)

# TODO: add feature importance?

# Make output dict with {unit_id: (prediction, probability)}
classified_units = {
unit_id: (prediction, probability)
Expand Down Expand Up @@ -134,7 +131,7 @@ def _check_params_for_classification(self):
# This would need to account for the fact that these extensions may no longer exist


def auto_label_units(sorting_analyzer: SortingAnalyzer, pipeline, required_metrics: Sequence[str]):
def auto_label_units(sorting_analyzer: SortingAnalyzer, pipeline):
"""
Automatically labels units based on a model-based classification.
Expand All @@ -144,8 +141,6 @@ def auto_label_units(sorting_analyzer: SortingAnalyzer, pipeline, required_metri
The sorting analyzer object containing the spike sorting results.
pipeline : Pipeline
The pipeline object containing the model-based classification pipeline.
required_metrics : Sequence[str]
The list of required metrics used for classification.
Returns
-------
Expand All @@ -159,7 +154,7 @@ def auto_label_units(sorting_analyzer: SortingAnalyzer, pipeline, required_metri
if not isinstance(pipeline, Pipeline):
raise ValueError("The pipeline must be an instance of sklearn.pipeline.Pipeline")

model_based_classification = ModelBasedClassification(sorting_analyzer, pipeline, required_metrics)
model_based_classification = ModelBasedClassification(sorting_analyzer, pipeline)

classified_units = model_based_classification.predict_labels()

Expand Down
14 changes: 6 additions & 8 deletions src/spikeinterface/curation/tests/test_model_based_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,17 @@ def required_metrics():
return ["num_spikes", "half_width"]


def test_model_based_classification_init(sorting_analyzer_for_curation, pipeline, required_metrics):
def test_model_based_classification_init(sorting_analyzer_for_curation, pipeline):
# Test the initialization of ModelBasedClassification
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, pipeline, required_metrics)
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, pipeline)
assert model_based_classification.sorting_analyzer == sorting_analyzer_for_curation
assert model_based_classification.pipeline == pipeline
assert model_based_classification.required_metrics == required_metrics


def test_model_based_classification_get_metrics_for_classification(
sorting_analyzer_for_curation, pipeline, required_metrics
):
# Test the _get_metrics_for_classification() method of ModelBasedClassification
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, pipeline, required_metrics)
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, pipeline)

# Check that ValueError is returned when quality_metrics are not present in sorting_analyzer
with pytest.raises(ValueError):
Expand All @@ -74,7 +72,7 @@ def test_model_based_classification_check_params_for_classification(
sorting_analyzer_for_curation = make_sorting_analyzer()

# Test the _check_params_for_classification() method of ModelBasedClassification
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, pipeline, required_metrics)
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, pipeline)
# Check that ValueError is raised when required_metrics are not computed
with pytest.raises(ValueError):
model_based_classification._check_params_for_classification()
Expand All @@ -87,9 +85,9 @@ def test_model_based_classification_check_params_for_classification(


# TODO: fix this test
def test_model_based_classification_predict_labels(sorting_analyzer_for_curation, pipeline, required_metrics):
def test_model_based_classification_predict_labels(sorting_analyzer_for_curation, pipeline):
# Test the predict_labels() method of ModelBasedClassification
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, pipeline, required_metrics)
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, pipeline)
classified_units = model_based_classification.predict_labels()
# TODO: check that classifications match some known set of outputs
predictions = [classified_units[i][0] for i in classified_units]
Expand Down

0 comments on commit 108f09a

Please sign in to comment.