diff --git a/doc/how_to/auto_curation_prediction.rst b/doc/how_to/auto_curation_prediction.rst index 25d54687a7..8873a3a794 100644 --- a/doc/how_to/auto_curation_prediction.rst +++ b/doc/how_to/auto_curation_prediction.rst @@ -7,7 +7,7 @@ Full tutorial for model-based curation can be found `here ` -.. code:: +.. code:: from pathlib import Path import pandas as pd from spikeinterface.curation.train_manual_curation import train_model - + analyzer_list = [analyzer_1, analyzer_2] labels_list = [analyzer_1_labels, analyzer_2_labels] output_folder = "/path/to/output_folder" - + trainer = train_model( mode="analyzers", labels=labels_list, @@ -32,14 +32,14 @@ How to train a model to predict curation labels classifiers=None, # Defaults to Random Forest classifier only - we usually find this gives the best results, but a range of classifiers is available seed=None, # Set a seed for reproducibility ) - - + + best_model = trainer.best_pipeline best_model Load and disply top 5 pipelines and accuracies -.. code:: +.. code:: accuracies = pd.read_csv(Path(output_folder) / Path("model_label_accuracies.csv"), index_col = 0) accuracies.head() @@ -48,7 +48,7 @@ This training function can also be run in “csv” mode if you want to store metrics in a single .csv file. If the target labels are stored in the file, you can point to these with the ``target_label`` parameter -.. code:: +.. code:: trainer = train_model( mode="csv", diff --git a/examples/tutorials/qualitymetrics/plot_5_automated_curation.py b/examples/tutorials/qualitymetrics/plot_5_automated_curation.py index be50c94e4c..a572b59757 100644 --- a/examples/tutorials/qualitymetrics/plot_5_automated_curation.py +++ b/examples/tutorials/qualitymetrics/plot_5_automated_curation.py @@ -29,15 +29,15 @@ ############################################################################## # Applying a pretrained model to predict curation labels # ------------------------------ -# +# # We supply pretrained machine learning classifiers for predicting spike-sorted clusters with arbitrary labels, in this example single-unit activity ('good'), or noise. This particular approach works as follows: -# +# # 1. Create a Spikeinterface 'sorting analyzer `_ and compute `quality metrics `_ # # 2. Load a pretrained classification model from Hugging Face & skops # # 3. Compare with human-applied curation labels to assess performance, optionally export labels to phy to allow manual checking -# +# # Load data and compute quality metrics - test data are used here, but replace with your own recordings! # We would recommend to start with some previously labelled data, to allow comparison of the model performance against your labelling # First, let's simulate some data and compute quality metrics @@ -52,7 +52,7 @@ ############################################################################## # Load pretrained models and predict curation labels for unlabelled data # ------------------------------ -# +# # We can download a pretrained model from Hugging Face, and use this to test out prediction. This particular model assumes only two labels ('noise', and 'good') have been used for classification # Predictions and prediction confidence are then stored in "label_prediction" and "label_confidence" unit properties @@ -107,7 +107,7 @@ ############################################################################## # We can also see how the model's confidence relates to the probability that the model label matches the human label -# +# # This could be used to set a threshold above which you might accept the model's classification, and only manually curate those which it is less sure of confidences = unlabelled_analyzer.sorting.get_property('label_confidence') @@ -149,18 +149,18 @@ def calculate_moving_avg(label_df, confidence_label, window_size): ############################################################################## # ------------------------------ -# +# ############################################################################## # Training a new model -# +# # **If the pretrained models do not give satisfactory performance on your data, it is easy to train your own classifier through SpikeInterface!** -# +# # First we make a Spikeinterface SortingAnalyzer object, in this case using simulated data, and generate some labels for the units to use as our target # Note that these labels are random as written here, so the example model will only perform at chance -# +# # Load some of your data and curation labels here and see how it performs! # **Note that it is likely that for useful generalisability, you will need to use multiple labelled recordings for training.** To do this, compute metrics as described for multiple SortingAnalyzers, then pass them as a list to the model training function, and pass the labels as a single list in the same order -# +# # The set of unique labels used is arbitrary, so this could just as easily be used for any cluster categorisation task as for curation into the standard 'good', 'mua' and 'noise' categories # Make a simulated SortingAnalyzer with 100 units @@ -182,9 +182,9 @@ def calculate_moving_avg(label_df, confidence_label, window_size): ############################################################################## # Now we train the machine learning classifier -# +# # By default, this searches a range of possible imputing and scaling strategies, and uses a Random Forest classifier. It then selects the model which most accurately predicts the supplied 'ground-truth' labels -# +# # As output, this function saves the best model (as a skops file, similar to a pickle), a csv file containing information about the performance of all tested models (`model_label_accuracies.csv`), and a `model_info.json` file containing the parameters used to compute quality metrics, and the SpikeInterface version, for reproducibility # Load labelled metrics and train model @@ -234,8 +234,8 @@ def calculate_moving_avg(label_df, confidence_label, window_size): plt.show() ############################################################################## -# Apply trained model to unlabelled data -# +# Apply trained model to unlabelled data +# # This approach is the same as in the first section, but without the need to combine the output of two separate classifiers unlabelled_recording, unlabelled_sorting = si.generate_ground_truth_recording(durations=[60], num_units=30) @@ -253,4 +253,4 @@ def calculate_moving_avg(label_df, confidence_label, window_size): label_conversion = label_conversion, export_to_phy = False, pipeline_info_path = Path(output_folder) / Path("model_info.json")) -unlabelled_analyzer.sorting \ No newline at end of file +unlabelled_analyzer.sorting