diff --git a/examples/tutorials/curation/plot_2_train_a_model.py b/examples/tutorials/curation/plot_2_train_a_model.py index 4b79142864..a5186f7794 100644 --- a/examples/tutorials/curation/plot_2_train_a_model.py +++ b/examples/tutorials/curation/plot_2_train_a_model.py @@ -77,6 +77,7 @@ imputation_strategies = ["median"], # Defaults to all scaling_techniques = ["standard_scaler"], # Defaults to all classifiers = None, # Default to Random Forest only. Other classifiers you can try [ "AdaBoostClassifier","GradientBoostingClassifier","LogisticRegression","MLPClassifier"] + overwrite = True, # Whether or not to overwrite `output_folder` if it already exists. Default is False. ) best_model = trainer.best_pipeline diff --git a/examples/tutorials/curation/plot_3_upload_a_model.py b/examples/tutorials/curation/plot_3_upload_a_model.py index b6afd1c3e4..abccda2c94 100644 --- a/examples/tutorials/curation/plot_3_upload_a_model.py +++ b/examples/tutorials/curation/plot_3_upload_a_model.py @@ -22,60 +22,16 @@ # metadata.json # # SpikeInterface doesn't require you to keep this folder structure, we just advise it as -# best practice. So: let's make these files! +# best practice. # # If you've used SpikeInterface to train your model, you have already created such a folder, -# containing ``my_model_name.skops`` and ``model_info.json``. -# -# We'll now export the training data. If we trained the model using the function... -# -# .. code-block:: -# -# my_model = train_model( -# labels=list_of_labels, -# analyzers=[sorting_analyzer_1, sorting_analyzer_2], -# output_folder = "my_model_folder" -# ) -# -# ...then the training data is contained in your metric extensions. If you've calculated both -# ``quality_metrics`` and ``template_metrics`` for two `sorting_analyzer`s, you can extract -# and export the training data as follows -# -# .. code-block:: -# -# import pandas as pd -# -# training_data_1 = pd.concat([ -# sorting_analyzer_1.get_extension("quality_metrics").get_data(), -# sorting_analyzer_1.get_extension("template_metrics").get_data() -# ],axis=1) -# -# training_data_2 = pd.concat([ -# sorting_analyzer_2.get_extension("quality_metrics").get_data(), -# sorting_analyzer_2.get_extension("template_metrics").get_data() -# ],axis=1) -# -# training_data = pd.concat([ -# training_data_1, -# training_data_2 -# ]) -# -# training_data.to_csv("my_model_folder/training_data.csv") -# -# If you have used different metrics, or use a `.csv` files you will need to modify this code. -# -# Similarly, we can save the curated labels we used: -# -# .. code-block:: -# -# list_of_labels.to_csv("my_model_folder/labels.csv") -# -# Finally, we suggest adding any information which shows when a model is applicable -# (and when it is *not*). Taking a model trained on mouse data and applying it to a primate is -# likely a bad idea. And a model trained in tetrode data will have limited application on a silcone -# high density probe. Hence we suggest the following dictionary as a minimal amount of information -# needed. Note that we format the metadata so that the information in common with the NWB data -# format is consistent with it, +# containing everything except the ``metadata.json`` file. In this file, we suggest saving +# any information which shows when a model is applicable (and when it is *not*). Taking +# a model trained on mouse data and applying it to a primate is likely a bad idea (or a +# great research paper!). And a model trained in tetrode data will have limited application +# on a silconehigh density probe. Hence we suggest the following dictionary as a minimal +# amount of information needed. Note that we format the metadata so that the information +# in common with the NWB data format is consistent with it, # # .. code-block:: # diff --git a/src/spikeinterface/curation/tests/test_train_manual_curation.py b/src/spikeinterface/curation/tests/test_train_manual_curation.py index 47a957ec72..7ccdd2635e 100644 --- a/src/spikeinterface/curation/tests/test_train_manual_curation.py +++ b/src/spikeinterface/curation/tests/test_train_manual_curation.py @@ -114,5 +114,6 @@ def test_train_model(): imputation_strategies=["median"], scaling_techniques=["standard_scaler"], classifiers=["LogisticRegression"], + overwrite=True, ) assert isinstance(trainer, CurationModelTrainer) diff --git a/src/spikeinterface/curation/train_manual_curation.py b/src/spikeinterface/curation/train_manual_curation.py index 6223fd1c4c..a8dc00b070 100644 --- a/src/spikeinterface/curation/train_manual_curation.py +++ b/src/spikeinterface/curation/train_manual_curation.py @@ -480,6 +480,11 @@ def _evaluate(self, imputation_strategies, scaling_techniques, classifiers, X_tr def _save(self): from skops.io import dump import sklearn + import pandas as pd + + # export training data and labels + pd.DataFrame(self.X).to_csv(os.path.join(self.output_folder, f"training_data.csv"), index_label="unit_id") + pd.DataFrame(self.y).to_csv(os.path.join(self.output_folder, f"labels.csv"), index_label="unit_index") self.requirements["scikit-learn"] = sklearn.__version__ @@ -550,6 +555,7 @@ def train_model( imputation_strategies=None, scaling_techniques=None, classifiers=None, + overwrite=False, seed=None, **job_kwargs, ): @@ -580,6 +586,8 @@ def train_model( A list of scaling techniques to apply. If None, default techniques will be used. classifiers : list of str | dict | None, default: None A list of classifiers to evaluate. Optionally, a dictionary of classifiers and their hyperparameter search spaces can be provided. If None, default classifiers will be used. Check the `get_default_classifier_search_spaces` method for the default search spaces & format for custom spaces. + overwrite : bool, default: False + Overwrites the `output_folder` if it already exists seed : int | None, default: None Random seed for reproducibility. If None, a random seed will be generated. @@ -594,6 +602,11 @@ def train_model( and evaluating the models. The evaluation results are saved to the specified output folder. """ + if overwrite is False: + assert not Path( + output_folder + ).exists(), f"folder {output_folder} already exists, choose another name or use overwrite=True" + if labels is None: raise Exception("You must supply a list of curated labels using `labels = ...`")