From fe76d80b58c090bb36980a717727e65e9b353fdc Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 19 Sep 2024 13:24:45 +0100 Subject: [PATCH] Update how-tos and change property name --- doc/how_to/auto_curation_prediction.rst | 50 +++++++++---------- doc/how_to/auto_curation_training.rst | 47 ++++++++--------- .../curation/model_based_curation.py | 4 +- 3 files changed, 48 insertions(+), 53 deletions(-) diff --git a/doc/how_to/auto_curation_prediction.rst b/doc/how_to/auto_curation_prediction.rst index 8873a3a794..3ec5134df6 100644 --- a/doc/how_to/auto_curation_prediction.rst +++ b/doc/how_to/auto_curation_prediction.rst @@ -1,40 +1,40 @@ How to use a trained model to predict the curation labels ========================================================= -This assumes you alrady have your data loaded as a SortingAnalyzer, and have calculated some quality metrics - `tutorial here `_ -Full tutorial for model-based curation can be found `here `_ - - Pre-trained models can be downloaded from `Hugging Face `_, or opened from `skops `_ files - - The model (``.skops``) file AND the ``pipeline_info.json`` (both produced when training the model) are required for full prediction +There is a Collection of models for automated curation available on the +`SpikeInterface HuggingFace page `_. +We'll apply the model ``toy_tetrode_model`` from ``SpikeInterface`` on a SortingAnalyzer +called ``sorting_analyzer``. We assume that the quality and template metrics have +already been computed (`full tutorial here `_). + +We need to pass the ``sorting_analyzer``, the ``repo_id`` (which is just the part of the +repo's URL after huggingface.co/) and the trusted types list (see more `here `_): .. code:: - # Load a model - from huggingface_hub import hf_hub_download - import skops.io - import json + from spikeinterface.curation import auto_label_units - # Download the model and json file from Hugging Face (can also load from local paths) - repo = "chrishalcrow/test_automated_curation_3" - model_path = hf_hub_download(repo_id=repo, filename="best_model_label.skops") - json_path = hf_hub_download(repo_id=repo, filename="model_pipeline.json") - model = skops.io.load(model_path, trusted='numpy.dtype') - pipeline_info = json.load(open(json_path)) + labels_and_probabilities = auto_label_units( + sorting_analyzer = sorting_analyzer, + repo_id = "SpikeInterface/toy_tetrode_model", + trusted = ['numpy.dtype'] + ) -Use the model to predict labels on your SortingAnalyzer +If you have a local directory containing the model in a ``skops`` file you can use this to +create the labels: .. code:: - from spikeinterface.curation.model_based_curation import auto_label_units + labels_and_probabilities = si.auto_label_units( + sorting_analyzer = sorting_analyzer, + model_folder = "my_folder_with_a_model_in_it", + ) - analyzer # Your SortingAnalyzer +The returned labels are a dictionary of model's predictions and it's confidence. These +are also saved as a property of your ``sorting_analyzer`` and can be accessed like so: - label_conversion = pipeline_info['label_conversion'] - label_dict = auto_label_units(sorting_analyzer=analyzer, - pipeline=model, - label_conversion=label_conversion, - export_to_phy=False, - pipeline_info_path=None) +.. code:: - # The labels are stored in the sorting "label_predictions" and "label_confidence" property - analyzer.sorting + labels = sorting_analyzer.get_property("classifier_label") + probabilities = sorting_analyzer.get_property("classifier_probability") diff --git a/doc/how_to/auto_curation_training.rst b/doc/how_to/auto_curation_training.rst index 359cb7c232..7c9bbf0edf 100644 --- a/doc/how_to/auto_curation_training.rst +++ b/doc/how_to/auto_curation_training.rst @@ -1,21 +1,21 @@ How to train a model to predict curation labels =============================================== -- This assumes you alrady have your data loaded as one or several - SortingAnalyzers, and have calculated some quality metrics - - `tutorial - here ` -- You also need a list of cluster labels for each SortingAnalyzer, - which can be extracted from a SortingAnalyzer, from phy, or loaded - from elsewhere -- Full tutorial for model-based curation can be found - `here` +A full tutorial for model-based curation can be found `here `_. + +Here, we assume that you have: + +* Two SortingAnalyzers called ``analyzer_1`` and + ``analyzer_2``, and have calculated some template and quality metrics for both +* Manutally curated labels for the units in each analyzer, in lists called + ``analyzer_1_labels`` and ``analyzer_2_labels``. If you have used phy, the lists can + be accessed using ``curated_labels = analyzer.sorting.get_property("quality")``. + +With these objects calculated, you can train a model as follows .. code:: - from pathlib import Path - import pandas as pd - from spikeinterface.curation.train_manual_curation import train_model + from spikeinterface.curation import train_model analyzer_list = [analyzer_1, analyzer_2] labels_list = [analyzer_1_labels, analyzer_2_labels] @@ -34,18 +34,18 @@ How to train a model to predict curation labels ) - best_model = trainer.best_pipeline - best_model - -Load and disply top 5 pipelines and accuracies +The trainer tries several models and chooses the most accurate one. This model and +some metadata are stored in the ``output_folder``, which can later be loaded using the +``load_model`` function (`more details `_). +We can also access the model, which is an sklearn ``Pipeline``, from the trainer object .. code:: - accuracies = pd.read_csv(Path(output_folder) / Path("model_label_accuracies.csv"), index_col = 0) - accuracies.head() + best_model = trainer.best_pipeline + -This training function can also be run in “csv” mode if you want to -store metrics in a single .csv file. If the target labels are stored in +The training function can also be run in “csv” mode, if you prefer to +store metrics in a single .csv file. If the target labels are stored as a column in the file, you can point to these with the ``target_label`` parameter .. code:: @@ -53,11 +53,6 @@ the file, you can point to these with the ``target_label`` parameter trainer = train_model( mode="csv", metrics_path = "/path/to/csv", - target_label = "label", + target_label = "my_label", output_folder=output_folder, - metric_names=None, # Set if you want to use a subset of metrics, defaults to all calculated quality and template metrics - imputation_strategies=None, # Default is all available imputation strategies - scaling_techniques=None, # Default is all available scaling techniques - classifiers=None, # Defaults to Random Forest classifier only - we usually find this gives the best results, but a range of classifiers is available - seed=None, # Set a seed for reproducibility ) diff --git a/src/spikeinterface/curation/model_based_curation.py b/src/spikeinterface/curation/model_based_curation.py index ebc12458cc..b0090a5d36 100644 --- a/src/spikeinterface/curation/model_based_curation.py +++ b/src/spikeinterface/curation/model_based_curation.py @@ -115,8 +115,8 @@ def predict_labels(self, label_conversion=None, input_data=None, export_to_phy=F } # Set predictions and probability as sorting properties - self.sorting_analyzer.sorting.set_property("label_prediction", predictions) - self.sorting_analyzer.sorting.set_property("label_confidence", probabilities) + self.sorting_analyzer.sorting.set_property("classifier_label", predictions) + self.sorting_analyzer.sorting.set_property("classifier_probability", probabilities) if export_to_phy: self._export_to_phy(classified_units)