diff --git a/src/spikeinterface/curation/train_manual_curation.py b/src/spikeinterface/curation/train_manual_curation.py index b0daa39f36..68832be768 100644 --- a/src/spikeinterface/curation/train_manual_curation.py +++ b/src/spikeinterface/curation/train_manual_curation.py @@ -13,7 +13,7 @@ class CurationModelTrainer: def __init__( - self, target_column, output_folder, imputation_strategies=None, scaling_techniques=None, metrics_list=None + self, target_column, output_folder, imputation_strategies=None, scaling_techniques=None, metrics_to_use=None ): from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler @@ -30,9 +30,11 @@ def __init__( self.imputation_strategies = imputation_strategies self.scaling_techniques = scaling_techniques - if metrics_list is None: + if metrics_to_use is None: self.metrics_list = self.get_default_metrics_list() print("No metrics list provided, using default metrics list (all)") + else: + self.metrics_list = metrics_to_use # Check if the output folder exists, and create it if it doesn't if not os.path.exists(output_folder): @@ -67,17 +69,13 @@ def process_test_data_for_classification(self): if self.target_column in self.testing_metrics[0].columns: # Extract the target variable and features self.y = self.testing_metrics[0][self.target_column] - self.X = self.testing_metrics[0].drop(columns=self.target_column) - - # Store the initial list of metrics - self.metrics_list = list(self.testing_metrics[0].columns) - self.metrics_list.remove(self.target_column) - # Reorder columns to match the initial metrics list - self.X = self.X.reindex(columns=self.metrics_list) + # Reorder columns to match the initial metrics list, + # Drops any columns not in the metrics list, fills any missing columns with NaN + self.X = self.testing_metrics[0].reindex(columns=self.metrics_list) - # Ensure that no features are dropped by filling missing values with a placeholder - self.X.fillna(0, inplace=True) # or use a different strategy if appropriate + # Fill any NaN values with 0 + self.X.fillna(0, inplace=True) else: raise ValueError(f"Target column {self.target_column} not found in testing metrics file")