Skip to content

Commit

Permalink
Add autosave of training_data and labels, and add overwrite
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Sep 20, 2024
1 parent 2d68bba commit a4f3795
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 52 deletions.
1 change: 1 addition & 0 deletions examples/tutorials/curation/plot_2_train_a_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
imputation_strategies = ["median"], # Defaults to all
scaling_techniques = ["standard_scaler"], # Defaults to all
classifiers = None, # Default to Random Forest only. Other classifiers you can try [ "AdaBoostClassifier","GradientBoostingClassifier","LogisticRegression","MLPClassifier"]
overwrite = True, # Whether or not to overwrite `output_folder` if it already exists. Default is False.
)

best_model = trainer.best_pipeline
Expand Down
60 changes: 8 additions & 52 deletions examples/tutorials/curation/plot_3_upload_a_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,60 +22,16 @@
# metadata.json
#
# SpikeInterface doesn't require you to keep this folder structure, we just advise it as
# best practice. So: let's make these files!
# best practice.
#
# If you've used SpikeInterface to train your model, you have already created such a folder,
# containing ``my_model_name.skops`` and ``model_info.json``.
#
# We'll now export the training data. If we trained the model using the function...
#
# .. code-block::
#
# my_model = train_model(
# labels=list_of_labels,
# analyzers=[sorting_analyzer_1, sorting_analyzer_2],
# output_folder = "my_model_folder"
# )
#
# ...then the training data is contained in your metric extensions. If you've calculated both
# ``quality_metrics`` and ``template_metrics`` for two `sorting_analyzer`s, you can extract
# and export the training data as follows
#
# .. code-block::
#
# import pandas as pd
#
# training_data_1 = pd.concat([
# sorting_analyzer_1.get_extension("quality_metrics").get_data(),
# sorting_analyzer_1.get_extension("template_metrics").get_data()
# ],axis=1)
#
# training_data_2 = pd.concat([
# sorting_analyzer_2.get_extension("quality_metrics").get_data(),
# sorting_analyzer_2.get_extension("template_metrics").get_data()
# ],axis=1)
#
# training_data = pd.concat([
# training_data_1,
# training_data_2
# ])
#
# training_data.to_csv("my_model_folder/training_data.csv")
#
# If you have used different metrics, or use a `.csv` files you will need to modify this code.
#
# Similarly, we can save the curated labels we used:
#
# .. code-block::
#
# list_of_labels.to_csv("my_model_folder/labels.csv")
#
# Finally, we suggest adding any information which shows when a model is applicable
# (and when it is *not*). Taking a model trained on mouse data and applying it to a primate is
# likely a bad idea. And a model trained in tetrode data will have limited application on a silcone
# high density probe. Hence we suggest the following dictionary as a minimal amount of information
# needed. Note that we format the metadata so that the information in common with the NWB data
# format is consistent with it,
# containing everything except the ``metadata.json`` file. In this file, we suggest saving
# any information which shows when a model is applicable (and when it is *not*). Taking
# a model trained on mouse data and applying it to a primate is likely a bad idea (or a
# great research paper!). And a model trained in tetrode data will have limited application
# on a silconehigh density probe. Hence we suggest the following dictionary as a minimal
# amount of information needed. Note that we format the metadata so that the information
# in common with the NWB data format is consistent with it,
#
# .. code-block::
#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,5 +114,6 @@ def test_train_model():
imputation_strategies=["median"],
scaling_techniques=["standard_scaler"],
classifiers=["LogisticRegression"],
overwrite=True,
)
assert isinstance(trainer, CurationModelTrainer)
13 changes: 13 additions & 0 deletions src/spikeinterface/curation/train_manual_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,11 @@ def _evaluate(self, imputation_strategies, scaling_techniques, classifiers, X_tr
def _save(self):
from skops.io import dump
import sklearn
import pandas as pd

# export training data and labels
pd.DataFrame(self.X).to_csv(os.path.join(self.output_folder, f"training_data.csv"), index_label="unit_id")
pd.DataFrame(self.y).to_csv(os.path.join(self.output_folder, f"labels.csv"), index_label="unit_index")

self.requirements["scikit-learn"] = sklearn.__version__

Expand Down Expand Up @@ -550,6 +555,7 @@ def train_model(
imputation_strategies=None,
scaling_techniques=None,
classifiers=None,
overwrite=False,
seed=None,
**job_kwargs,
):
Expand Down Expand Up @@ -580,6 +586,8 @@ def train_model(
A list of scaling techniques to apply. If None, default techniques will be used.
classifiers : list of str | dict | None, default: None
A list of classifiers to evaluate. Optionally, a dictionary of classifiers and their hyperparameter search spaces can be provided. If None, default classifiers will be used. Check the `get_default_classifier_search_spaces` method for the default search spaces & format for custom spaces.
overwrite : bool, default: False
Overwrites the `output_folder` if it already exists
seed : int | None, default: None
Random seed for reproducibility. If None, a random seed will be generated.
Expand All @@ -594,6 +602,11 @@ def train_model(
and evaluating the models. The evaluation results are saved to the specified output folder.
"""

if overwrite is False:
assert not Path(
output_folder
).exists(), f"folder {output_folder} already exists, choose another name or use overwrite=True"

if labels is None:
raise Exception("You must supply a list of curated labels using `labels = ...`")

Expand Down

0 comments on commit a4f3795

Please sign in to comment.