Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 21, 2024
1 parent 044983d commit b763922
Showing 1 changed file with 25 additions and 40 deletions.
65 changes: 25 additions & 40 deletions src/spikeinterface/curation/train_manual_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
seed = 42
warnings.filterwarnings("ignore")


class CurationModelTrainer:
def __init__(self, target_column, output_folder, imputation_strategies=None, scaling_techniques=None, metrics_list=None):
def __init__(
self, target_column, output_folder, imputation_strategies=None, scaling_techniques=None, metrics_list=None
):
from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler

if imputation_strategies is None:
Expand Down Expand Up @@ -49,12 +52,12 @@ def load_data_file(self, path):

def process_test_data_for_classification(self):
import pandas as pd

if self.target_column in self.testing_metrics[0].columns:
self.y = self.testing_metrics[0][self.target_column]

self.X = self.testing_metrics[0].dropna(subset=[self.target_column])
self.X.drop(columns = self.target_column, inplace=True)
self.X.drop(columns=self.target_column, inplace=True)

self.metrics_list = self.X.columns

Expand All @@ -63,7 +66,9 @@ def process_test_data_for_classification(self):
self.X = self.X[self.metrics_list]
except KeyError:
# Work out and print which metrics are missing
missing_metrics = [metric for metric in self.metrics_list if metric not in self.testing_metrics[0].columns]
missing_metrics = [
metric for metric in self.metrics_list if metric not in self.testing_metrics[0].columns
]
raise ValueError(f"Metrics list contains uncomputed metrics: {missing_metrics}")
else:
raise ValueError(f"Target column {self.target_column} not found in testing metrics file")
Expand Down Expand Up @@ -153,7 +158,7 @@ def get_classifier_search_space(self, classifier):
"solver": ["adam"],
"alpha": [1e-7, 1e-1],
"learning_rate": ["constant", "adaptive"],
"n_iter_no_change": [32]
"n_iter_no_change": [32],
}
model = MLPClassifier(random_state=seed)
else:
Expand All @@ -162,33 +167,16 @@ def get_classifier_search_space(self, classifier):

# TODO: sort out function naming so things are actually clear
# E.g. evaluate_model_config, _evaluate, _train_and_evaluate - what do they all actually do?
def evaluate_model_config(
self, imputation_strategies, scaling_techniques, classifiers):
def evaluate_model_config(self, imputation_strategies, scaling_techniques, classifiers):

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(self.X, self.y, test_size=0.2, random_state=seed, stratify=self.y)
self._evaluate(
imputation_strategies,
scaling_techniques,
classifiers,
X_train,
X_test,
y_train,
y_test

X_train, X_test, y_train, y_test = train_test_split(
self.X, self.y, test_size=0.2, random_state=seed, stratify=self.y
)
self._evaluate(imputation_strategies, scaling_techniques, classifiers, X_train, X_test, y_train, y_test)


def _evaluate(
self,
imputation_strategies,
scaling_techniques,
classifiers,
X_train,
X_test,
y_train,
y_test
):
def _evaluate(self, imputation_strategies, scaling_techniques, classifiers, X_train, X_test, y_train, y_test):

from joblib import Parallel, delayed
from sklearn.pipeline import Pipeline
Expand Down Expand Up @@ -291,18 +279,15 @@ def train_model(metrics_path, output_folder, target_label):
MLPClassifier,
]

trainer = CurationModelTrainer(target_label,
output_folder,
imputation_strategies=imputation_strategies,
scaling_techniques=scaling_techniques,
metrics_list=None
)

trainer = CurationModelTrainer(
target_label,
output_folder,
imputation_strategies=imputation_strategies,
scaling_techniques=scaling_techniques,
metrics_list=None,
)

trainer.load_and_preprocess_full(metrics_path)

trainer.evaluate_model_config(
imputation_strategies,
scaling_techniques,
classifiers
)
trainer.evaluate_model_config(imputation_strategies, scaling_techniques, classifiers)
return trainer

0 comments on commit b763922

Please sign in to comment.