Skip to content

Commit

Permalink
Make generalised _format_metric_dataframe function
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Dec 11, 2024
1 parent 7b7323e commit 46681f8
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
15 changes: 10 additions & 5 deletions src/spikeinterface/curation/model_based_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
from pathlib import Path
import json
import warnings
import re

from spikeinterface.core import SortingAnalyzer
from spikeinterface.curation.train_manual_curation import try_to_get_metrics_from_analyzer, _get_computed_metrics
from spikeinterface.curation.train_manual_curation import (
try_to_get_metrics_from_analyzer,
_get_computed_metrics,
_format_metric_dataframe,
)
from copy import deepcopy


Expand Down Expand Up @@ -96,9 +101,7 @@ def predict_labels(
except:
warnings.warn("Could not find `label_conversion` key in `model_info.json` file")

# Prepare input data
input_data = input_data.map(lambda x: np.nan if np.isinf(x) else x)
input_data = input_data.astype("float32")
input_data = _format_metric_dataframe(input_data)

# Apply classifier
predictions = self.pipeline.predict(input_data)
Expand Down Expand Up @@ -388,7 +391,9 @@ def _load_model_from_folder(model_folder=None, model_name=None, trust_model=Fals
exception_msg = str(e)
# the exception message contains the list of untrusted objects. The following
# search assumes it is the only list in the message.
trusted = re.search(r"\[(.*?)\]", exception_msg).group()
string_list = re.search(r"\[(.*?)\]", exception_msg).group()
trusted = [list_item for list_item in string_list.split("'") if len(list_item) > 2]

model = skio.load(skops_file, trusted=trusted)

model_info_path = folder / "model_info.json"
Expand Down
11 changes: 9 additions & 2 deletions src/spikeinterface/curation/train_manual_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,7 @@ def process_test_data_for_classification(self):
raise KeyError(f"{str(e)}, metrics_list contains invalid metric names")

self.X = self.testing_metrics.reindex(columns=self.metric_names)
self.X = self.X.map(lambda x: np.nan if np.isinf(x) else x)
self.X = self.X.astype("float32")
self.X = _format_metric_dataframe(self.testing_metrics)

def apply_scaling_imputation(self, imputation_strategy, scaling_technique, X_train, X_test, y_train, y_test):
"""Impute and scale the data using the specified techniques."""
Expand Down Expand Up @@ -786,3 +785,11 @@ def check_metric_names_are_the_same(metrics_for_each_analyzer):
if metrics_in_2_but_not_1:
error_message += f"#{i} does not contain {metrics_in_2_but_not_1}, which #{j} does."
raise Exception(error_message)


def _format_metric_dataframe(input_data):

input_data = input_data.map(lambda x: np.nan if np.isinf(x) else x)
input_data = input_data.astype("float32")

return input_data

0 comments on commit 46681f8

Please sign in to comment.