Skip to content

Commit

Permalink
Update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Dec 11, 2024
1 parent 46681f8 commit 52ac6a0
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 71 deletions.
129 changes: 69 additions & 60 deletions examples/tutorials/curation/plot_1_automated_curation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import spikeinterface.curation as sc
import spikeinterface.widgets as sw


# note: you can use more cores using e.g.
# si.set_global_jobs_kwargs(n_jobs = 8)

Expand All @@ -50,15 +49,16 @@


##############################################################################
# This model was trained on artifically generated tetrode data. The model object has a nice html
# representation, which will appear if you're using a Jupyter notebook.
# This model was trained on artifically generated tetrode data. There are also models trained
# on real data, like the one discussed `below <#A-model-trained-on-real-Neuropixels-data>`_.
# Each model object has a nice html representation, which will appear if you're using a Jupyter notebook.

model

##############################################################################
# This tells us more information about the model. The one we've just downloaded was trained used
# a `RandomForestClassifier`. You can also discover this information by running
# `model.get_params()`. The model object (an `sklearn Pipeline <https://scikit-learn.org/1.5/modules/generated/sklearn.pipeline.Pipeline.html>`_) also contains information
# a ``RandomForestClassifier```. You can also discover this information by running
# ``model.get_params()``. The model object (an `sklearn Pipeline <https://scikit-learn.org/1.5/modules/generated/sklearn.pipeline.Pipeline.html>`_) also contains information
# about which metrics were used to compute the model. We can access it from the model (or from the model_info)

print(model.feature_names_in_)
Expand All @@ -81,10 +81,9 @@
print(set(all_metric_names) == set(model.feature_names_in_))

##############################################################################
# Great! We can now use the model to predict labels. You can either pass a HuggingFace repo, or a
# local folder containing the model file and the ``model_info.json`` file. Here, we'll pass
# a repo. The function returns a dictionary containing a label and a confidence for each unit
# contained in the ``sorting_analyzer``.
# Great! We can now use the model to predict labels. Here, we pass the HF repo id directly
# to the ``auto_label_units`` function. This returns a dictionary containing a label and
# a confidence for each unit contained in the ``sorting_analyzer``.

labels = sc.auto_label_units(
sorting_analyzer = sorting_analyzer,
Expand All @@ -96,8 +95,8 @@


##############################################################################
# The model has labelled one unit as bad. Let's look at that one, and the 'good' unit with the highest
# confidence of being 'good'.
# The model has labelled one unit as bad. Let's look at that one, and also the 'good' unit
# with the highest confidence of being 'good'.

sw.plot_unit_templates(sorting_analyzer, unit_ids=[7,9])

Expand Down Expand Up @@ -202,73 +201,83 @@ def calculate_moving_avg(label_df, confidence_label, window_size):
# In this case, you might decide to only trust labels which had confidence over above 0.88,
# and manually labels the ones the model isn't so confident about.
#
# Using other models, e.g. one trained on Neuropixels data
# ------------------------------------------
# A model trained on real Neuropixels data
# ----------------------------------------
#
# Above, we used a toy model trained on generated data. There are also models on HuggingFace
# trained on real data.
#
# For example, the following classifier is trained on Neuropixels data from 11 mice recorded in
# V1,SC and ALM: https://huggingface.co/AnoushkaJain3/curation_machine_learning_models
#
# Go to that page, and take a look at the ``Files``. The models are contained in the
# `skops files <https://skops.readthedocs.io/en/stable/>`_ and there are *two* in this repo.
# We can choose which to load in the ``load_model`` function as follows:
# For example, the following classifiers are trained on Neuropixels data from 11 mice recorded in
# V1,SC and ALM: https://huggingface.co/AnoushkaJain3/noise_neural_classifier/ and
# https://huggingface.co/AnoushkaJain3/sua_mua_classifier/ . One will classify units into
# `noise` or `not-noise` and the other will classify the `not-noise` units into single
# unit activity (sua) units and multi-unit activity (mua) units.
#
# .. code-block::
# import spikeinterface.curation as sc
# model, model_info = sc.load_model(
# repo_id = "AnoushkaJain3/curation_machine_learning_models",
# model_name= 'noise_neuron_model.skops',
# )
# There is more information about the model on the model's HuggingFace page. Take a look!
# The idea here is to first apply the noise/not-noise classifier, then the sua/mua one.
# We can do so as follows:
#
# If you run this locally you will receive an error:

# Apply the noise/not-noise model
noise_neuron_labels = si.auto_label_units(
analyzer = sorting_analyzer,
repo_id = "AnoushkaJain3/noise_neural_classifier",
trust_model=True,
)

noise_units = noise_neuron_labels[noise_neuron_labels['prediction']=='noise']
analyzer_neural = analyzer.remove_units(noise_units.index)

# Apply the sua/mua model
sua_mua_labels = si.auto_label_units(
sorting_analyzer,
repo_id = "AnoushkaJain3/sua_mua_classifier",
trust_model=True,
)

all_labels = pd.concat([sua_mua_labels, noise_units]).sort_index()
print(all_labels)

# If you run this without the ``trust_model=True`` parameter, you will receive an error:
#
# .. code-block::
#
# UntrustedTypesFoundException: Untrusted types found in the file: ['sklearn.metrics._classification.balanced_accuracy_score', 'sklearn.metrics._scorer._Scorer', 'sklearn.model_selection._search_successive_halving.HalvingGridSearchCV', 'sklearn.model_selection._split.StratifiedKFold']
#
# This is a security warning. Sharing models, with are Python objects, is complicated.
# We have chosen to use the `skops format <https://skops.readthedocs.io/en/stable/>`_, instead
# of the common but insecure ``.pkl`` format (read about ``pickle`` security issues
# `here <https://lwn.net/Articles/964392/>`_). While unpacking the ``.skops`` file, each function
# is checked. Ideally, skops should recognise all `sklearn`, `numpy` and `scipy` functions and
# allow the object to be loaded if it only contains these (and no unkown malicious code). But
# when ``skops`` it's not sure, it raises an error. Here, it doesn't recognise
# ``['sklearn.metrics._classification.balanced_accuracy_score', 'sklearn.metrics._scorer._Scorer',
# 'sklearn.model_selection._search_successive_halving.HalvingGridSearchCV',
# 'sklearn.model_selection._split.StratifiedKFold']``. Taking a look, these are all functions
# from `sklearn`, and we can happily add them to the ``trusted`` functions to load:
# This is a security warning, which can be overcome by passing the trusted types list
# ``trusted = ['sklearn.metrics._classification.balanced_accuracy_score', 'sklearn.metrics._scorer._Scorer', 'sklearn.model_selection._search_successive_halving.HalvingGridSearchCV', 'sklearn.model_selection._split.StratifiedKFold']``
# or by passing the ``trust_model=True``` keyword.
#
# .. code-block::
#
# model, model_info = sc.load_model(
# model_name = 'noise_neuron_model.skops',
# repo_id = "AnoushkaJain3/curation_machine_learning_models",
# trusted = ['sklearn.metrics._classification.balanced_accuracy_score', 'sklearn.metrics._scorer._Scorer', 'sklearn.model_selection._search_successive_halving.HalvingGridSearchCV', 'sklearn.model_selection._split.StratifiedKFold']
# )
#
# As ``skops`` continues to be developed, we hope more of these functions will be :code:`trusted`
# by default.
# .. dropdown:: More about security
#
# If you unequivocally trust the model (e.g. if you have created it), you can bypass this security
# step by passing ``trust_model = True`` to the ``load_model`` function.
#
# In general, you should be cautious when downloading ``.skops`` files and ``.pkl`` files from repos,
# especially from unknown sources.
# Sharing models, with are Python objects, is complicated.
# We have chosen to use the `skops format <https://skops.readthedocs.io/en/stable/>`_, instead
# of the common but insecure ``.pkl`` format (read about ``pickle`` security issues
# `here <https://lwn.net/Articles/964392/>`_). While unpacking the ``.skops`` file, each function
# is checked. Ideally, skops should recognise all `sklearn`, `numpy` and `scipy` functions and
# allow the object to be loaded if it only contains these (and no unkown malicious code). But
# when ``skops`` it's not sure, it raises an error. Here, it doesn't recognise
# ``['sklearn.metrics._classification.balanced_accuracy_score', 'sklearn.metrics._scorer._Scorer',
# 'sklearn.model_selection._search_successive_halving.HalvingGridSearchCV',
# 'sklearn.model_selection._split.StratifiedKFold']``. Taking a look, these are all functions
# from `sklearn`, and we can happily add them to the ``trusted`` functions to load.
#
# In general, you should be cautious when downloading ``.skops`` files and ``.pkl`` files from repos,
# especially from unknown sources.
#
# Directly applying a sklearn Pipeline
# ------------------------------------
#
# Instead of using ``HuggingFace`` and ``skops``, you might have another way of receiving a sklearn
# pipeline, and want to apply it to your sorted data.

from spikeinterface.curation.model_based_curation import ModelBasedClassification

model_based_classification = ModelBasedClassification(sorting_analyzer, model)
labels = model_based_classification.predict_labels()
labels
# Instead of using ``HuggingFace`` and ``skops``, someone might have given you a model
# in differet way: perhaps by e-mail or a download. If you have the model in a
# folder, you can apply it in a very similar way:
#
# .. code-block::
#
# labels = sc.auto_label_units(
# sorting_analyzer = sorting_analyzer,
# model_folder = "path/to/model/folder",
# )

##############################################################################
# Using this, you lose the advantages of the model metadata: the quality metric parameters
Expand Down
14 changes: 7 additions & 7 deletions examples/tutorials/curation/plot_2_train_a_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@

##############################################################################
# This is as expected: great! (Find out more about plotting using widgets `here <https://spikeinterface.readthedocs.io/en/latest/modules/widgets.html>`_.)
# We've set out system up so that the first five units are 'good' and the next five are 'bad'.
# We've set up our system so that the first five units are 'good' and the next five are 'bad'.
# So we can make a list of labels which contain this information. For real data, you could
# use a manual curation tool to make your own list.

Expand All @@ -72,13 +72,13 @@
# of the units, and then be applied to units from other sortings. The properties we use are the
# `quality metrics <https://spikeinterface.readthedocs.io/en/latest/modules/qualitymetrics.html>`_
# and `template metrics <https://spikeinterface.readthedocs.io/en/latest/modules/postprocessing.html#template-metrics>`_.
# Hence we need to compute these, using some `sorting_analyzer` extensions.
# Hence we need to compute these, using some ``sorting_analyzer``` extensions.

analyzer.compute(['spike_locations','spike_amplitudes','correlograms','principal_components','quality_metrics','template_metrics'])

##############################################################################
# Now that we have metrics and labels, we're ready to train the model using the
# `train_model` function. The trainer will try several classifiers, imputation strategies and
# ``train_model``` function. The trainer will try several classifiers, imputation strategies and
# scaling techniques then save the most accurate. To save time in this tutorial,
# we'll only try one classifier (Random Forest), imputation strategy (median) and scaling
# technique (standard scaler).
Expand Down Expand Up @@ -109,7 +109,7 @@
# `imputation strategies <https://scikit-learn.org/1.5/api/sklearn.impute.html>`_ and
# `scalers <https://scikit-learn.org/1.5/api/sklearn.preprocessing.html>`_, although the
# documentation is quite overwhelming. You can find the classifiers we've tried out
# using the `sc.get_default_classifier_search_spaces` function.
# using the ``sc.get_default_classifier_search_spaces`` function.
#
# The above code saves the model in ``model.skops``, some metadata in
# ``model_info.json`` and the model accuracies in ``model_accuracies.csv``
Expand Down Expand Up @@ -156,10 +156,10 @@

##############################################################################
# Roughly, this means the model is using metrics such as "nn_hit_rate" and "l_ratio"
# but is using "sync_spike_4" and "rp_contanimation". This is a toy model, so don't
# take these results seriously! But using this information, you could retrain another,
# but is not using "sync_spike_4" and "rp_contanimation". This is a toy model, so don't
# take these results seriously. But using this information, you could retrain another,
# simpler model using a subset of the metrics, by passing, e.g.,
# `metric_names = ['nn_hit_rate', 'l_ratio',...]` to the `train_model` function.
# ``metric_names = ['nn_hit_rate', 'l_ratio',...]`` to the ``train_model`` function.
#
# Now that you have a model, you can `apply it to another sorting
# <https://spikeinterface.readthedocs.io/en/latest/tutorials/curation/plot_1_automated_curation.html>`_
Expand Down
7 changes: 3 additions & 4 deletions examples/tutorials/curation/plot_3_upload_a_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
# SpikeInterface and HFH don't require you to keep this folder structure, we just advise it as
# best practice.
#
# If you've used SpikeInterface to train your model, the `train_model` function auto-generates
# If you've used SpikeInterface to train your model, the ``train_model`` function auto-generates
# most of this data. The only thing missing is the the ``metadata.json`` file. The purpose of this
# file is to detail how the model was trained, which can help prospective users decide if it
# is relevant for them. For example, taking
Expand Down Expand Up @@ -65,8 +65,7 @@
# Upload to HuggingFaceHub
# ------------------------
#
# We'll now upload this folder to HFH using the web interface. (If you don't want to
# use HFH, you could just share this folder with a colleague.)
# We'll now upload this folder to HFH using the web interface.
#
# First, go to https://huggingface.co/ and make an account. Once you've logged in, press
# ``+`` then ``New model`` or find ``+ New Model`` in the user menu. You will be asked
Expand Down Expand Up @@ -128,7 +127,7 @@
#
# labels = auto_label_units(
# sorting_analyzer = sorting_analyzer,
# model_folder = "SpikeInterface/a_folder_for_a_model",
# model_folder = "path/to/a_folder_for_a_model",
# trusted = ['numpy.dtype']
# )
# ` ` `
Expand Down

0 comments on commit 52ac6a0

Please sign in to comment.