-
Notifications
You must be signed in to change notification settings - Fork 190
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4104d21
commit acb47c8
Showing
3 changed files
with
331 additions
and
256 deletions.
There are no files selected for viewing
217 changes: 217 additions & 0 deletions
217
examples/tutorials/curation/plot_1_automated_curation.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,217 @@ | ||
""" | ||
Model Based Curation Tutorial | ||
============================= | ||
This notebook provides a step-by-step guide on how to use a machine learning classifier for curating spike sorted output. We'll download a toy model from `HuggingFace` and use it to label our sorted data by using Spikeinterface. We start by importing some packages | ||
""" | ||
|
||
import warnings | ||
warnings.filterwarnings("ignore") | ||
from pathlib import Path | ||
import numpy as np | ||
import pandas as pd | ||
import matplotlib.pyplot as plt | ||
import spikeinterface.full as si | ||
|
||
# note: you can use more cores using e.g. | ||
# si.set_global_jobs_kwargs(n_jobs = 8) | ||
|
||
############################################################################## | ||
# Download a pretrained model | ||
# --------------------------- | ||
# | ||
# We can download a pretrained model from `Hugging Face <https://huggingface.co/>`_, and use this | ||
# to label sorted data. The `load_model` function allows us to download a specific model from | ||
# Hugging Face (or a local folder). The function downloads the model and saves it in a temporary | ||
# folder and returns a model and some metadata about the model. | ||
|
||
model, model_info = si.load_model( | ||
repo_id = "SpikeInterface/toy_tetrode_model", | ||
trusted = ['numpy.dtype'] | ||
) | ||
|
||
|
||
############################################################################## | ||
# This model was trained on artifically generated tetrode data. The model object has a nice html | ||
# representation, which will appear if you're using a Jupyter notebook. | ||
|
||
model | ||
|
||
############################################################################## | ||
# The repo contains information about which metrics were used to compute it. We can access it | ||
# from the model, or from the model_info | ||
|
||
print(model.feature_names_in_) | ||
|
||
############################################################################## | ||
# Hence, to use this model we need to create a ``sorting_analyzer`` with all these metrics computed. | ||
# We'll do this by generating a recording and sorting, creating a sorting analyzer and computing a | ||
# bunch of extensions. Follow these links for more info on `recordings <https://spikeinterface.readthedocs.io/en/latest/modules/extractors.html>`_, `sortings <https://spikeinterface.readthedocs.io/en/latest/modules/sorters.html>`_, `sorting analyzers <https://spikeinterface.readthedocs.io/en/latest/tutorials/core/plot_4_sorting_analyzer.html#sphx-glr-tutorials-core-plot-4-sorting-analyzer-py>`_ | ||
# and `extensions <https://spikeinterface.readthedocs.io/en/latest/modules/postprocessing.html>`_. | ||
|
||
recording, sorting = si.generate_ground_truth_recording(num_channels=4, seed=4, num_units=10) | ||
sorting_analyzer = si.create_sorting_analyzer(sorting=sorting, recording=recording) | ||
sorting_analyzer.compute(['noise_levels','random_spikes','waveforms','templates','spike_locations','spike_amplitudes','correlograms','principal_components','quality_metrics','template_metrics']) | ||
|
||
############################################################################## | ||
# This sorting_analyzer now contains the required quality metrics and template metrics computed. | ||
# We can check that this is true by accessing the extension data. | ||
|
||
all_metric_names = list(sorting_analyzer.get_extension('quality_metrics').get_data().keys()) + list(sorting_analyzer.get_extension('template_metrics').get_data().keys()) | ||
print(np.all(all_metric_names == model.feature_names_in_)) | ||
|
||
############################################################################## | ||
# Great! We can now use the model to predict labels. You can either pass a HuggingFace repo, or a | ||
# local folder containing the model file and the ``model_info.json`` file. Here, we'll pass | ||
# a repo. The function returns a dictionary containing a label and a confidence for each unit. | ||
|
||
labels = si.auto_label_units( | ||
sorting_analyzer = sorting_analyzer, | ||
repo_id = "SpikeInterface/toy_tetrode_model", | ||
trusted = ['numpy.dtype'] | ||
) | ||
|
||
labels | ||
|
||
|
||
############################################################################## | ||
# The model has labelled one unit as bad. Let's look at that one and the 'good' unit with the highest | ||
# confidence of being good: | ||
|
||
si.plot_unit_templates(sorting_analyzer, unit_ids=[7,9]) | ||
|
||
############################################################################## | ||
# Nice - we can see that unit 9 does look a lot more spikey than unit 7. You might think that unit | ||
# 7 is a real unit. If so, this model isn't good for you. | ||
# | ||
# Assess the model performance | ||
# ---------------------------- | ||
# | ||
# To assess the performance of the model relative to human labels, we can load or generate some | ||
# human labels, and plot a confusion matrix of predicted vs human labels for all clusters. Here | ||
# we're being a conservative human, who has labelled several units with small amplitudes as 'bad'. | ||
|
||
human_labels = ['bad', 'good', 'good', 'bad', 'good', 'bad', 'good', 'bad', 'good', 'good'] | ||
|
||
# Note: if you labelled using phy, you can load the labels using: | ||
# human_labels = sorting_analyzer.sorting.get_property('quality') | ||
|
||
from sklearn.metrics import confusion_matrix, balanced_accuracy_score | ||
import seaborn as sns | ||
|
||
label_conversion = model_info['label_conversion'] | ||
predictions = [ labels[a][0] for a in range(10) ] | ||
|
||
conf_matrix = confusion_matrix(human_labels, predictions) | ||
|
||
# Calculate balanced accuracy for the confusion matrix | ||
balanced_accuracy = balanced_accuracy_score(human_labels, predictions) | ||
|
||
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='viridis') | ||
plt.xlabel('Predicted Label') | ||
plt.ylabel('Human Label') | ||
plt.xticks(ticks = [0.5, 1.5], labels = list(label_conversion.values())) | ||
plt.yticks(ticks = [0.5, 1.5], labels = list(label_conversion.values())) | ||
plt.title('Predicted vs Human Label') | ||
plt.suptitle(f"Balanced Accuracy: {balanced_accuracy}") | ||
plt.show() | ||
|
||
|
||
############################################################################## | ||
# Here, there are several false positives (if we consider the human labels to be "the truth"). | ||
# | ||
# Next, we can also see how the model's confidence relates to the probability that the model | ||
# label matches the human label | ||
# | ||
# This could be used to set a threshold above which you might accept the model's classification, | ||
# and only manually curate those which it is less sure of | ||
|
||
|
||
def calculate_moving_avg(label_df, confidence_label, window_size): | ||
|
||
label_df[f'{confidence_label}_decile'] = pd.cut(label_df[confidence_label], 10, labels=False, duplicates='drop') | ||
# Group by decile and calculate the proportion of correct labels (agreement) | ||
p_label_grouped = label_df.groupby(f'{confidence_label}_decile')['model_x_human_agreement'].mean() | ||
# Convert decile to range 0-1 | ||
p_label_grouped.index = p_label_grouped.index / 10 | ||
# Sort the DataFrame by confidence scores | ||
label_df_sorted = label_df.sort_values(by=confidence_label) | ||
|
||
p_label_moving_avg = label_df_sorted['model_x_human_agreement'].rolling(window=window_size).mean() | ||
|
||
return label_df_sorted[confidence_label], p_label_moving_avg | ||
|
||
confidences = sorting_analyzer.sorting.get_property('label_confidence') | ||
|
||
# Make dataframe of human label, model label, and confidence | ||
label_df = pd.DataFrame(data = { | ||
'human_label': human_labels, | ||
'decoder_label': predictions, | ||
'confidence': confidences}, | ||
index = sorting_analyzer.sorting.get_unit_ids()) | ||
|
||
# Calculate the proportion of agreed labels by confidence decile | ||
label_df['model_x_human_agreement'] = label_df['human_label'] == label_df['decoder_label'] | ||
|
||
p_agreement_sorted, p_agreement_moving_avg = calculate_moving_avg(label_df, 'confidence', 3) | ||
|
||
# Plot the moving average of agreement | ||
plt.figure(figsize=(6, 6)) | ||
plt.plot(p_agreement_sorted, p_agreement_moving_avg, label = 'Moving Average') | ||
plt.axhline(y=1/len(np.unique(predictions)), color='black', linestyle='--', label='Chance') | ||
plt.xlabel('Confidence'); #plt.xlim(0.5, 1) | ||
plt.ylabel('Proportion Agreement with Human Label'); plt.ylim(0, 1) | ||
plt.title('Agreement vs Confidence (Moving Average)') | ||
plt.legend(); plt.grid(True); plt.show() | ||
|
||
############################################################################## | ||
# In this case, you might decide to only trust labels which had confidence over above 0.86. | ||
# | ||
# A more realistic example | ||
# ------------------------ | ||
# | ||
# Here, we used a toy model trained on generated data. There are also models on HuggingFace | ||
# trained on real data. | ||
# | ||
# For example, the following classifier is trained on Neuropixels data from 11 mice recorded in | ||
# V1,SC and ALM: https://huggingface.co/AnoushkaJain3/curation_machine_learning_models | ||
# | ||
# Go to that page, and take a look at the `Files`. The models are contained in the `skops files <https://skops.readthedocs.io/en/stable/>`_ | ||
# and there are _two_ in this repo. We can choose which to load in the `load_model` function as follows: | ||
# | ||
# .. code-block:: | ||
# | ||
# model, model_info = si.load_model( | ||
# repo_id = "AnoushkaJain3/curation_machine_learning_models", | ||
# model_name= 'noise_neuron_model.skops', | ||
# ) | ||
# | ||
# If you run this locally you will receive an error: | ||
# | ||
# .. code-block:: | ||
# | ||
# UntrustedTypesFoundException: Untrusted types found in the file: ['sklearn.metrics._classification.balanced_accuracy_score', 'sklearn.metrics._scorer._Scorer', 'sklearn.model_selection._search_successive_halving.HalvingGridSearchCV', 'sklearn.model_selection._split.StratifiedKFold'] | ||
# | ||
# This is a security warning. Sharing models, with are Python objects, is a complicated problem. | ||
# We have chosen to use the `skops format <https://skops.readthedocs.io/en/stable/>`_, instead of the common but insecure ``.pkl`` format | ||
# (read about ``pickle`` security issues `here <https://lwn.net/Articles/964392/>`_). While unpacking the ``.skops`` file, each function | ||
# is checked. Ideally, skops should recognise all `sklearn`, `numpy` and `scipy` functions and | ||
# allows the object to be loaded. But when it's not sure, it raises an error. Here, ``skops`` | ||
# doesn't recognise ``['sklearn.metrics._classification.balanced_accuracy_score', | ||
# 'sklearn.metrics._scorer._Scorer', 'sklearn.model_selection._search_successive_halving.HalvingGridSearchCV', | ||
# 'sklearn.model_selection._split.StratifiedKFold']``. Taking a look, these are all functions | ||
# from `sklearn`, and we can happily add them to the ``trusted`` functions. | ||
# | ||
# .. code-block:: | ||
# | ||
# model, model_info = si.load_model( | ||
# model_name = 'noise_neuron_model.skops', | ||
# repo_id = "AnoushkaJain3/curation_machine_learning_models", | ||
# trusted = ['sklearn.metrics._classification.balanced_accuracy_score', 'sklearn.metrics._scorer._Scorer', 'sklearn.model_selection._search_successive_halving.HalvingGridSearchCV', 'sklearn.model_selection._split.StratifiedKFold'] | ||
# ) | ||
# | ||
# As `skops` continues to be developed, we hope more of these functions will be :code:`trusted` | ||
# by default. | ||
# | ||
# In general, you should be cautious when downloading `.skops` files and `.pkl` files from repos, | ||
# especially from unknown sources. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
""" | ||
Training a model for automated curation | ||
============================= | ||
If the pretrained models do not give satisfactory performance on your data, it is easy to train your own classifier using SpikeInterface. | ||
""" | ||
import warnings | ||
warnings.filterwarnings("ignore") | ||
from pathlib import Path | ||
import spikeinterface.full as si | ||
import numpy as np | ||
import pandas as pd | ||
import matplotlib.pyplot as plt | ||
|
||
# Note, you can set the number of cores you use using e.g. | ||
# si.set_global_job_kwargs(n_jobs = 8) | ||
|
||
############################################################################## | ||
# Step 1: Generate and label data | ||
# ------------------------------- | ||
# | ||
# We supply pretrained machine learning classifiers for predicting spike-sorted clusters with | ||
# arbitrary labels, in this example single-unit activity ('good'), or noise. This particular | ||
# approach works as follows: | ||
# | ||
# For the tutorial, we will use simulated data to create :code:`recording` and :code:`sorting` objects. We'll | ||
# create two sorting objects: :code:`sorting_1` is coupled to the real recording, so will contain good | ||
# units, :code:`sorting_2` is uncoupled, so should be pure noise. We'll combine the two into one sorting | ||
# object using :code:`si.aggregate_units`. | ||
# | ||
# You should `load your own recording <https://spikeinterface.readthedocs.io/en/latest/modules/extractors.html>`_ and `do a sorting <https://spikeinterface.readthedocs.io/en/latest/modules/sorters.html>`_ on your data. | ||
|
||
recording, sorting_1 = si.generate_ground_truth_recording(num_channels=4, seed=1, num_units=5) | ||
_, sorting_2 =si.generate_ground_truth_recording(num_channels=4, seed=2, num_units=5) | ||
|
||
both_sortings = si.aggregate_units([sorting_1, sorting_2]) | ||
|
||
############################################################################## | ||
# Our model is based on :code:`quality_metrics`, which are computed using a :code:`sorting_analyzer`. So we'll | ||
# now create a sorting analyzer and compute all the extensions needed to get the quality metrics. | ||
|
||
analyzer = si.create_sorting_analyzer(sorting = both_sortings, recording=recording) | ||
analyzer.compute(['noise_levels','random_spikes','waveforms','templates','spike_locations','spike_amplitudes','correlograms','principal_components','quality_metrics','template_metrics']) | ||
|
||
############################################################################## | ||
# Let's plot the templates for the first and fifth units. The first (unit id 0) belonged to | ||
# :code:`sorting_1` so should look like a real unit; the sixth (unit id 5) belonged to :code:`sorting_2` | ||
# so should look like noise. | ||
|
||
si.plot_unit_templates(analyzer, unit_ids=[0,5]) | ||
|
||
############################################################################## | ||
# This is as expected: great! Find out more about plotting using widgets `here <https://spikeinterface.readthedocs.io/en/latest/modules/widgets.html>`_. The labels | ||
# for our units are then easy to put in a list: | ||
|
||
labels = ['good', 'good', 'good', 'good', 'good', 'bad', 'bad', 'bad', 'bad', 'bad'] | ||
|
||
############################################################################## | ||
# Step 2: Train our model | ||
# ----------------------- | ||
|
||
# With our labelled data in hand, we can train the model using the :code:`train_model` function. | ||
# Here, the idea is that the trainer will try several classifiers, imputation strategies and | ||
# scaling techniques then save the most accurate. To save time, we'll only try one classifier | ||
# (Random Forest), imputation strategy (median) and scaling technique (standard scaler). | ||
|
||
output_folder = "my_model" | ||
|
||
# We will use a list of one analyzer here, we would strongly advise using more than one to | ||
# improve model performance | ||
trainer = si.train_model( | ||
mode = "analyzers", # You can supply a labelled csv file instead of an analyzer | ||
labels = [labels], | ||
analyzers = [analyzer], | ||
output_folder = output_folder, # Where to save the model and model_info.json file | ||
metric_names = None, # Specify which metrics to use for training: by default uses those already calculted | ||
imputation_strategies = ["median"], # Defaults to all | ||
scaling_techniques = ["standard_scaler"], # Defaults to all | ||
classifiers = None, # Default to Random Forest only. Other classifiers you can try [ "AdaBoostClassifier","GradientBoostingClassifier","LogisticRegression","MLPClassifier"] | ||
) | ||
|
||
best_model = trainer.best_pipeline | ||
|
||
############################################################################## | ||
# | ||
# The above code saves the model in :code:`model.skops`, some metadata in :code:`model_info.json` and | ||
# the model accuracies in :code:`model_accuracies.csv` in the specified :code:`output_folder`. | ||
# | ||
# :code:`skops` is a file format; you can think of it as a more-secture pkl file. `Read more <https://skops.readthedocs.io/en/stable/index.html>`_. | ||
# | ||
# The :code:`model_accuracies.csv` file contains the accuracy, precision and recall of the tested models. | ||
# Let's take a look | ||
|
||
accuracies = pd.read_csv(Path(output_folder) / "model_accuracies.csv", index_col = 0) | ||
accuracies.head() | ||
|
||
# Our model is perfect!! This is because the task was _very_ easy. We had 10 units; where | ||
# half were pure noise and half were not. | ||
# | ||
# The model also contains some more information, such as which features are importantly. | ||
# We can plot these as follows: | ||
|
||
# Plot feature importances | ||
importances = best_model.named_steps['classifier'].feature_importances_ | ||
indices = np.argsort(importances)[::-1] | ||
features = best_model.feature_names_in_ | ||
n_features = best_model.n_features_in_ | ||
|
||
plt.figure(figsize=(12, 6)) | ||
plt.title("Feature Importances") | ||
plt.bar(range(n_features), importances[indices], align="center") | ||
plt.xticks(range(n_features), features, rotation=90) | ||
plt.xlim([-1, n_features]) | ||
plt.show() |
Oops, something went wrong.