Skip to content

Commit

Permalink
Update how-tos and change property name
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Sep 19, 2024
1 parent 2304b80 commit fe76d80
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 53 deletions.
50 changes: 25 additions & 25 deletions doc/how_to/auto_curation_prediction.rst
Original file line number Diff line number Diff line change
@@ -1,40 +1,40 @@
How to use a trained model to predict the curation labels
=========================================================

This assumes you alrady have your data loaded as a SortingAnalyzer, and have calculated some quality metrics - `tutorial here <https://spikeinterface.readthedocs.io/en/latest/tutorials/qualitymetrics/plot_3_quality_metrics.html>`_
Full tutorial for model-based curation can be found `here <https://spikeinterface.readthedocs.io/en/latest/tutorials/qualitymetrics/plot_5_automated_curation.html>`_
- Pre-trained models can be downloaded from `Hugging Face <https://huggingface.co/>`_, or opened from `skops <https://skops.readthedocs.io/en/stable/>`_ files
- The model (``.skops``) file AND the ``pipeline_info.json`` (both produced when training the model) are required for full prediction
There is a Collection of models for automated curation available on the
`SpikeInterface HuggingFace page <https://huggingface.co/SpikeInterface>`_.

We'll apply the model ``toy_tetrode_model`` from ``SpikeInterface`` on a SortingAnalyzer
called ``sorting_analyzer``. We assume that the quality and template metrics have
already been computed (`full tutorial here <https://spikeinterface.readthedocs.io/en/latest/tutorials/curation/plot_1_automated_curation.html>`_).

We need to pass the ``sorting_analyzer``, the ``repo_id`` (which is just the part of the
repo's URL after huggingface.co/) and the trusted types list (see more `here <https://spikeinterface.readthedocs.io/en/latest/tutorials/curation/plot_1_automated_curation.html#a-more-realistic-example>`_):

.. code::
# Load a model
from huggingface_hub import hf_hub_download
import skops.io
import json
from spikeinterface.curation import auto_label_units
# Download the model and json file from Hugging Face (can also load from local paths)
repo = "chrishalcrow/test_automated_curation_3"
model_path = hf_hub_download(repo_id=repo, filename="best_model_label.skops")
json_path = hf_hub_download(repo_id=repo, filename="model_pipeline.json")
model = skops.io.load(model_path, trusted='numpy.dtype')
pipeline_info = json.load(open(json_path))
labels_and_probabilities = auto_label_units(
sorting_analyzer = sorting_analyzer,
repo_id = "SpikeInterface/toy_tetrode_model",
trusted = ['numpy.dtype']
)
Use the model to predict labels on your SortingAnalyzer
If you have a local directory containing the model in a ``skops`` file you can use this to
create the labels:

.. code::
from spikeinterface.curation.model_based_curation import auto_label_units
labels_and_probabilities = si.auto_label_units(
sorting_analyzer = sorting_analyzer,
model_folder = "my_folder_with_a_model_in_it",
)
analyzer # Your SortingAnalyzer
The returned labels are a dictionary of model's predictions and it's confidence. These
are also saved as a property of your ``sorting_analyzer`` and can be accessed like so:

label_conversion = pipeline_info['label_conversion']
label_dict = auto_label_units(sorting_analyzer=analyzer,
pipeline=model,
label_conversion=label_conversion,
export_to_phy=False,
pipeline_info_path=None)
.. code::
# The labels are stored in the sorting "label_predictions" and "label_confidence" property
analyzer.sorting
labels = sorting_analyzer.get_property("classifier_label")
probabilities = sorting_analyzer.get_property("classifier_probability")
47 changes: 21 additions & 26 deletions doc/how_to/auto_curation_training.rst
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
How to train a model to predict curation labels
===============================================

- This assumes you alrady have your data loaded as one or several
SortingAnalyzers, and have calculated some quality metrics -
`tutorial
here <https://spikeinterface.readthedocs.io/en/latest/tutorials/qualitymetrics/plot_3_quality_mertics.html>`
- You also need a list of cluster labels for each SortingAnalyzer,
which can be extracted from a SortingAnalyzer, from phy, or loaded
from elsewhere
- Full tutorial for model-based curation can be found
`here<https://spikeinterface.readthedocs.io/en/latest/tutorials/qualitymetrics/plot_5_automated_curation.html>`
A full tutorial for model-based curation can be found `here <https://spikeinterface.readthedocs.io/en/latest/tutorials/curation/plot_2_train_a_model.html>`_.

Here, we assume that you have:

* Two SortingAnalyzers called ``analyzer_1`` and
``analyzer_2``, and have calculated some template and quality metrics for both
* Manutally curated labels for the units in each analyzer, in lists called
``analyzer_1_labels`` and ``analyzer_2_labels``. If you have used phy, the lists can
be accessed using ``curated_labels = analyzer.sorting.get_property("quality")``.

With these objects calculated, you can train a model as follows

.. code::
from pathlib import Path
import pandas as pd
from spikeinterface.curation.train_manual_curation import train_model
from spikeinterface.curation import train_model
analyzer_list = [analyzer_1, analyzer_2]
labels_list = [analyzer_1_labels, analyzer_2_labels]
Expand All @@ -34,30 +34,25 @@ How to train a model to predict curation labels
)
best_model = trainer.best_pipeline
best_model
Load and disply top 5 pipelines and accuracies
The trainer tries several models and chooses the most accurate one. This model and
some metadata are stored in the ``output_folder``, which can later be loaded using the
``load_model`` function (`more details <https://spikeinterface.readthedocs.io/en/latest/tutorials/curation/plot_1_automated_curation.html#download-a-pretrained-model>`_).
We can also access the model, which is an sklearn ``Pipeline``, from the trainer object

.. code::
accuracies = pd.read_csv(Path(output_folder) / Path("model_label_accuracies.csv"), index_col = 0)
accuracies.head()
best_model = trainer.best_pipeline
This training function can also be run in “csv” mode if you want to
store metrics in a single .csv file. If the target labels are stored in
The training function can also be run in “csv” mode, if you prefer to
store metrics in a single .csv file. If the target labels are stored as a column in
the file, you can point to these with the ``target_label`` parameter

.. code::
trainer = train_model(
mode="csv",
metrics_path = "/path/to/csv",
target_label = "label",
target_label = "my_label",
output_folder=output_folder,
metric_names=None, # Set if you want to use a subset of metrics, defaults to all calculated quality and template metrics
imputation_strategies=None, # Default is all available imputation strategies
scaling_techniques=None, # Default is all available scaling techniques
classifiers=None, # Defaults to Random Forest classifier only - we usually find this gives the best results, but a range of classifiers is available
seed=None, # Set a seed for reproducibility
)
4 changes: 2 additions & 2 deletions src/spikeinterface/curation/model_based_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ def predict_labels(self, label_conversion=None, input_data=None, export_to_phy=F
}

# Set predictions and probability as sorting properties
self.sorting_analyzer.sorting.set_property("label_prediction", predictions)
self.sorting_analyzer.sorting.set_property("label_confidence", probabilities)
self.sorting_analyzer.sorting.set_property("classifier_label", predictions)
self.sorting_analyzer.sorting.set_property("classifier_probability", probabilities)

if export_to_phy:
self._export_to_phy(classified_units)
Expand Down

0 comments on commit fe76d80

Please sign in to comment.