-
Notifications
You must be signed in to change notification settings - Fork 191
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update how-tos and change property name
- Loading branch information
1 parent
2304b80
commit fe76d80
Showing
3 changed files
with
48 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,40 +1,40 @@ | ||
How to use a trained model to predict the curation labels | ||
========================================================= | ||
|
||
This assumes you alrady have your data loaded as a SortingAnalyzer, and have calculated some quality metrics - `tutorial here <https://spikeinterface.readthedocs.io/en/latest/tutorials/qualitymetrics/plot_3_quality_metrics.html>`_ | ||
Full tutorial for model-based curation can be found `here <https://spikeinterface.readthedocs.io/en/latest/tutorials/qualitymetrics/plot_5_automated_curation.html>`_ | ||
- Pre-trained models can be downloaded from `Hugging Face <https://huggingface.co/>`_, or opened from `skops <https://skops.readthedocs.io/en/stable/>`_ files | ||
- The model (``.skops``) file AND the ``pipeline_info.json`` (both produced when training the model) are required for full prediction | ||
There is a Collection of models for automated curation available on the | ||
`SpikeInterface HuggingFace page <https://huggingface.co/SpikeInterface>`_. | ||
|
||
We'll apply the model ``toy_tetrode_model`` from ``SpikeInterface`` on a SortingAnalyzer | ||
called ``sorting_analyzer``. We assume that the quality and template metrics have | ||
already been computed (`full tutorial here <https://spikeinterface.readthedocs.io/en/latest/tutorials/curation/plot_1_automated_curation.html>`_). | ||
|
||
We need to pass the ``sorting_analyzer``, the ``repo_id`` (which is just the part of the | ||
repo's URL after huggingface.co/) and the trusted types list (see more `here <https://spikeinterface.readthedocs.io/en/latest/tutorials/curation/plot_1_automated_curation.html#a-more-realistic-example>`_): | ||
|
||
.. code:: | ||
# Load a model | ||
from huggingface_hub import hf_hub_download | ||
import skops.io | ||
import json | ||
from spikeinterface.curation import auto_label_units | ||
# Download the model and json file from Hugging Face (can also load from local paths) | ||
repo = "chrishalcrow/test_automated_curation_3" | ||
model_path = hf_hub_download(repo_id=repo, filename="best_model_label.skops") | ||
json_path = hf_hub_download(repo_id=repo, filename="model_pipeline.json") | ||
model = skops.io.load(model_path, trusted='numpy.dtype') | ||
pipeline_info = json.load(open(json_path)) | ||
labels_and_probabilities = auto_label_units( | ||
sorting_analyzer = sorting_analyzer, | ||
repo_id = "SpikeInterface/toy_tetrode_model", | ||
trusted = ['numpy.dtype'] | ||
) | ||
Use the model to predict labels on your SortingAnalyzer | ||
If you have a local directory containing the model in a ``skops`` file you can use this to | ||
create the labels: | ||
|
||
.. code:: | ||
from spikeinterface.curation.model_based_curation import auto_label_units | ||
labels_and_probabilities = si.auto_label_units( | ||
sorting_analyzer = sorting_analyzer, | ||
model_folder = "my_folder_with_a_model_in_it", | ||
) | ||
analyzer # Your SortingAnalyzer | ||
The returned labels are a dictionary of model's predictions and it's confidence. These | ||
are also saved as a property of your ``sorting_analyzer`` and can be accessed like so: | ||
|
||
label_conversion = pipeline_info['label_conversion'] | ||
label_dict = auto_label_units(sorting_analyzer=analyzer, | ||
pipeline=model, | ||
label_conversion=label_conversion, | ||
export_to_phy=False, | ||
pipeline_info_path=None) | ||
.. code:: | ||
# The labels are stored in the sorting "label_predictions" and "label_confidence" property | ||
analyzer.sorting | ||
labels = sorting_analyzer.get_property("classifier_label") | ||
probabilities = sorting_analyzer.get_property("classifier_probability") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters