Skip to content

Commit

Permalink
Update doc tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Sep 19, 2024
1 parent 4104d21 commit acb47c8
Show file tree
Hide file tree
Showing 3 changed files with 331 additions and 256 deletions.
217 changes: 217 additions & 0 deletions examples/tutorials/curation/plot_1_automated_curation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
"""
Model Based Curation Tutorial
=============================
This notebook provides a step-by-step guide on how to use a machine learning classifier for curating spike sorted output. We'll download a toy model from `HuggingFace` and use it to label our sorted data by using Spikeinterface. We start by importing some packages
"""

import warnings
warnings.filterwarnings("ignore")
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import spikeinterface.full as si

# note: you can use more cores using e.g.
# si.set_global_jobs_kwargs(n_jobs = 8)

##############################################################################
# Download a pretrained model
# ---------------------------
#
# We can download a pretrained model from `Hugging Face <https://huggingface.co/>`_, and use this
# to label sorted data. The `load_model` function allows us to download a specific model from
# Hugging Face (or a local folder). The function downloads the model and saves it in a temporary
# folder and returns a model and some metadata about the model.

model, model_info = si.load_model(
repo_id = "SpikeInterface/toy_tetrode_model",
trusted = ['numpy.dtype']
)


##############################################################################
# This model was trained on artifically generated tetrode data. The model object has a nice html
# representation, which will appear if you're using a Jupyter notebook.

model

##############################################################################
# The repo contains information about which metrics were used to compute it. We can access it
# from the model, or from the model_info

print(model.feature_names_in_)

##############################################################################
# Hence, to use this model we need to create a ``sorting_analyzer`` with all these metrics computed.
# We'll do this by generating a recording and sorting, creating a sorting analyzer and computing a
# bunch of extensions. Follow these links for more info on `recordings <https://spikeinterface.readthedocs.io/en/latest/modules/extractors.html>`_, `sortings <https://spikeinterface.readthedocs.io/en/latest/modules/sorters.html>`_, `sorting analyzers <https://spikeinterface.readthedocs.io/en/latest/tutorials/core/plot_4_sorting_analyzer.html#sphx-glr-tutorials-core-plot-4-sorting-analyzer-py>`_
# and `extensions <https://spikeinterface.readthedocs.io/en/latest/modules/postprocessing.html>`_.

recording, sorting = si.generate_ground_truth_recording(num_channels=4, seed=4, num_units=10)
sorting_analyzer = si.create_sorting_analyzer(sorting=sorting, recording=recording)
sorting_analyzer.compute(['noise_levels','random_spikes','waveforms','templates','spike_locations','spike_amplitudes','correlograms','principal_components','quality_metrics','template_metrics'])

##############################################################################
# This sorting_analyzer now contains the required quality metrics and template metrics computed.
# We can check that this is true by accessing the extension data.

all_metric_names = list(sorting_analyzer.get_extension('quality_metrics').get_data().keys()) + list(sorting_analyzer.get_extension('template_metrics').get_data().keys())
print(np.all(all_metric_names == model.feature_names_in_))

##############################################################################
# Great! We can now use the model to predict labels. You can either pass a HuggingFace repo, or a
# local folder containing the model file and the ``model_info.json`` file. Here, we'll pass
# a repo. The function returns a dictionary containing a label and a confidence for each unit.

labels = si.auto_label_units(
sorting_analyzer = sorting_analyzer,
repo_id = "SpikeInterface/toy_tetrode_model",
trusted = ['numpy.dtype']
)

labels


##############################################################################
# The model has labelled one unit as bad. Let's look at that one and the 'good' unit with the highest
# confidence of being good:

si.plot_unit_templates(sorting_analyzer, unit_ids=[7,9])

##############################################################################
# Nice - we can see that unit 9 does look a lot more spikey than unit 7. You might think that unit
# 7 is a real unit. If so, this model isn't good for you.
#
# Assess the model performance
# ----------------------------
#
# To assess the performance of the model relative to human labels, we can load or generate some
# human labels, and plot a confusion matrix of predicted vs human labels for all clusters. Here
# we're being a conservative human, who has labelled several units with small amplitudes as 'bad'.

human_labels = ['bad', 'good', 'good', 'bad', 'good', 'bad', 'good', 'bad', 'good', 'good']

# Note: if you labelled using phy, you can load the labels using:
# human_labels = sorting_analyzer.sorting.get_property('quality')

from sklearn.metrics import confusion_matrix, balanced_accuracy_score
import seaborn as sns

label_conversion = model_info['label_conversion']
predictions = [ labels[a][0] for a in range(10) ]

conf_matrix = confusion_matrix(human_labels, predictions)

# Calculate balanced accuracy for the confusion matrix
balanced_accuracy = balanced_accuracy_score(human_labels, predictions)

sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='viridis')
plt.xlabel('Predicted Label')
plt.ylabel('Human Label')
plt.xticks(ticks = [0.5, 1.5], labels = list(label_conversion.values()))
plt.yticks(ticks = [0.5, 1.5], labels = list(label_conversion.values()))
plt.title('Predicted vs Human Label')
plt.suptitle(f"Balanced Accuracy: {balanced_accuracy}")
plt.show()


##############################################################################
# Here, there are several false positives (if we consider the human labels to be "the truth").
#
# Next, we can also see how the model's confidence relates to the probability that the model
# label matches the human label
#
# This could be used to set a threshold above which you might accept the model's classification,
# and only manually curate those which it is less sure of


def calculate_moving_avg(label_df, confidence_label, window_size):

label_df[f'{confidence_label}_decile'] = pd.cut(label_df[confidence_label], 10, labels=False, duplicates='drop')
# Group by decile and calculate the proportion of correct labels (agreement)
p_label_grouped = label_df.groupby(f'{confidence_label}_decile')['model_x_human_agreement'].mean()
# Convert decile to range 0-1
p_label_grouped.index = p_label_grouped.index / 10
# Sort the DataFrame by confidence scores
label_df_sorted = label_df.sort_values(by=confidence_label)

p_label_moving_avg = label_df_sorted['model_x_human_agreement'].rolling(window=window_size).mean()

return label_df_sorted[confidence_label], p_label_moving_avg

confidences = sorting_analyzer.sorting.get_property('label_confidence')

# Make dataframe of human label, model label, and confidence
label_df = pd.DataFrame(data = {
'human_label': human_labels,
'decoder_label': predictions,
'confidence': confidences},
index = sorting_analyzer.sorting.get_unit_ids())

# Calculate the proportion of agreed labels by confidence decile
label_df['model_x_human_agreement'] = label_df['human_label'] == label_df['decoder_label']

p_agreement_sorted, p_agreement_moving_avg = calculate_moving_avg(label_df, 'confidence', 3)

# Plot the moving average of agreement
plt.figure(figsize=(6, 6))
plt.plot(p_agreement_sorted, p_agreement_moving_avg, label = 'Moving Average')
plt.axhline(y=1/len(np.unique(predictions)), color='black', linestyle='--', label='Chance')
plt.xlabel('Confidence'); #plt.xlim(0.5, 1)
plt.ylabel('Proportion Agreement with Human Label'); plt.ylim(0, 1)
plt.title('Agreement vs Confidence (Moving Average)')
plt.legend(); plt.grid(True); plt.show()

##############################################################################
# In this case, you might decide to only trust labels which had confidence over above 0.86.
#
# A more realistic example
# ------------------------
#
# Here, we used a toy model trained on generated data. There are also models on HuggingFace
# trained on real data.
#
# For example, the following classifier is trained on Neuropixels data from 11 mice recorded in
# V1,SC and ALM: https://huggingface.co/AnoushkaJain3/curation_machine_learning_models
#
# Go to that page, and take a look at the `Files`. The models are contained in the `skops files <https://skops.readthedocs.io/en/stable/>`_
# and there are _two_ in this repo. We can choose which to load in the `load_model` function as follows:
#
# .. code-block::
#
# model, model_info = si.load_model(
# repo_id = "AnoushkaJain3/curation_machine_learning_models",
# model_name= 'noise_neuron_model.skops',
# )
#
# If you run this locally you will receive an error:
#
# .. code-block::
#
# UntrustedTypesFoundException: Untrusted types found in the file: ['sklearn.metrics._classification.balanced_accuracy_score', 'sklearn.metrics._scorer._Scorer', 'sklearn.model_selection._search_successive_halving.HalvingGridSearchCV', 'sklearn.model_selection._split.StratifiedKFold']
#
# This is a security warning. Sharing models, with are Python objects, is a complicated problem.
# We have chosen to use the `skops format <https://skops.readthedocs.io/en/stable/>`_, instead of the common but insecure ``.pkl`` format
# (read about ``pickle`` security issues `here <https://lwn.net/Articles/964392/>`_). While unpacking the ``.skops`` file, each function
# is checked. Ideally, skops should recognise all `sklearn`, `numpy` and `scipy` functions and
# allows the object to be loaded. But when it's not sure, it raises an error. Here, ``skops``
# doesn't recognise ``['sklearn.metrics._classification.balanced_accuracy_score',
# 'sklearn.metrics._scorer._Scorer', 'sklearn.model_selection._search_successive_halving.HalvingGridSearchCV',
# 'sklearn.model_selection._split.StratifiedKFold']``. Taking a look, these are all functions
# from `sklearn`, and we can happily add them to the ``trusted`` functions.
#
# .. code-block::
#
# model, model_info = si.load_model(
# model_name = 'noise_neuron_model.skops',
# repo_id = "AnoushkaJain3/curation_machine_learning_models",
# trusted = ['sklearn.metrics._classification.balanced_accuracy_score', 'sklearn.metrics._scorer._Scorer', 'sklearn.model_selection._search_successive_halving.HalvingGridSearchCV', 'sklearn.model_selection._split.StratifiedKFold']
# )
#
# As `skops` continues to be developed, we hope more of these functions will be :code:`trusted`
# by default.
#
# In general, you should be cautious when downloading `.skops` files and `.pkl` files from repos,
# especially from unknown sources.
114 changes: 114 additions & 0 deletions examples/tutorials/curation/plot_2_train_a_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""
Training a model for automated curation
=============================
If the pretrained models do not give satisfactory performance on your data, it is easy to train your own classifier using SpikeInterface.
"""
import warnings
warnings.filterwarnings("ignore")
from pathlib import Path
import spikeinterface.full as si
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Note, you can set the number of cores you use using e.g.
# si.set_global_job_kwargs(n_jobs = 8)

##############################################################################
# Step 1: Generate and label data
# -------------------------------
#
# We supply pretrained machine learning classifiers for predicting spike-sorted clusters with
# arbitrary labels, in this example single-unit activity ('good'), or noise. This particular
# approach works as follows:
#
# For the tutorial, we will use simulated data to create :code:`recording` and :code:`sorting` objects. We'll
# create two sorting objects: :code:`sorting_1` is coupled to the real recording, so will contain good
# units, :code:`sorting_2` is uncoupled, so should be pure noise. We'll combine the two into one sorting
# object using :code:`si.aggregate_units`.
#
# You should `load your own recording <https://spikeinterface.readthedocs.io/en/latest/modules/extractors.html>`_ and `do a sorting <https://spikeinterface.readthedocs.io/en/latest/modules/sorters.html>`_ on your data.

recording, sorting_1 = si.generate_ground_truth_recording(num_channels=4, seed=1, num_units=5)
_, sorting_2 =si.generate_ground_truth_recording(num_channels=4, seed=2, num_units=5)

both_sortings = si.aggregate_units([sorting_1, sorting_2])

##############################################################################
# Our model is based on :code:`quality_metrics`, which are computed using a :code:`sorting_analyzer`. So we'll
# now create a sorting analyzer and compute all the extensions needed to get the quality metrics.

analyzer = si.create_sorting_analyzer(sorting = both_sortings, recording=recording)
analyzer.compute(['noise_levels','random_spikes','waveforms','templates','spike_locations','spike_amplitudes','correlograms','principal_components','quality_metrics','template_metrics'])

##############################################################################
# Let's plot the templates for the first and fifth units. The first (unit id 0) belonged to
# :code:`sorting_1` so should look like a real unit; the sixth (unit id 5) belonged to :code:`sorting_2`
# so should look like noise.

si.plot_unit_templates(analyzer, unit_ids=[0,5])

##############################################################################
# This is as expected: great! Find out more about plotting using widgets `here <https://spikeinterface.readthedocs.io/en/latest/modules/widgets.html>`_. The labels
# for our units are then easy to put in a list:

labels = ['good', 'good', 'good', 'good', 'good', 'bad', 'bad', 'bad', 'bad', 'bad']

##############################################################################
# Step 2: Train our model
# -----------------------

# With our labelled data in hand, we can train the model using the :code:`train_model` function.
# Here, the idea is that the trainer will try several classifiers, imputation strategies and
# scaling techniques then save the most accurate. To save time, we'll only try one classifier
# (Random Forest), imputation strategy (median) and scaling technique (standard scaler).

output_folder = "my_model"

# We will use a list of one analyzer here, we would strongly advise using more than one to
# improve model performance
trainer = si.train_model(
mode = "analyzers", # You can supply a labelled csv file instead of an analyzer
labels = [labels],
analyzers = [analyzer],
output_folder = output_folder, # Where to save the model and model_info.json file
metric_names = None, # Specify which metrics to use for training: by default uses those already calculted
imputation_strategies = ["median"], # Defaults to all
scaling_techniques = ["standard_scaler"], # Defaults to all
classifiers = None, # Default to Random Forest only. Other classifiers you can try [ "AdaBoostClassifier","GradientBoostingClassifier","LogisticRegression","MLPClassifier"]
)

best_model = trainer.best_pipeline

##############################################################################
#
# The above code saves the model in :code:`model.skops`, some metadata in :code:`model_info.json` and
# the model accuracies in :code:`model_accuracies.csv` in the specified :code:`output_folder`.
#
# :code:`skops` is a file format; you can think of it as a more-secture pkl file. `Read more <https://skops.readthedocs.io/en/stable/index.html>`_.
#
# The :code:`model_accuracies.csv` file contains the accuracy, precision and recall of the tested models.
# Let's take a look

accuracies = pd.read_csv(Path(output_folder) / "model_accuracies.csv", index_col = 0)
accuracies.head()

# Our model is perfect!! This is because the task was _very_ easy. We had 10 units; where
# half were pure noise and half were not.
#
# The model also contains some more information, such as which features are importantly.
# We can plot these as follows:

# Plot feature importances
importances = best_model.named_steps['classifier'].feature_importances_
indices = np.argsort(importances)[::-1]
features = best_model.feature_names_in_
n_features = best_model.n_features_in_

plt.figure(figsize=(12, 6))
plt.title("Feature Importances")
plt.bar(range(n_features), importances[indices], align="center")
plt.xticks(range(n_features), features, rotation=90)
plt.xlim([-1, n_features])
plt.show()
Loading

0 comments on commit acb47c8

Please sign in to comment.