diff --git a/examples/tutorials/curation/plot_1_automated_curation.py b/examples/tutorials/curation/plot_1_automated_curation.py
new file mode 100644
index 0000000000..8e447dd12a
--- /dev/null
+++ b/examples/tutorials/curation/plot_1_automated_curation.py
@@ -0,0 +1,217 @@
+"""
+Model Based Curation Tutorial
+=============================
+
+This notebook provides a step-by-step guide on how to use a machine learning classifier for curating spike sorted output. We'll download a toy model from `HuggingFace` and use it to label our sorted data by using Spikeinterface. We start by importing some packages
+"""
+
+import warnings
+warnings.filterwarnings("ignore")
+from pathlib import Path
+import numpy as np
+import pandas as pd
+import matplotlib.pyplot as plt
+import spikeinterface.full as si
+
+# note: you can use more cores using e.g.
+# si.set_global_jobs_kwargs(n_jobs = 8)
+
+##############################################################################
+# Download a pretrained model
+# ---------------------------
+#
+# We can download a pretrained model from `Hugging Face `_, and use this
+# to label sorted data. The `load_model` function allows us to download a specific model from
+# Hugging Face (or a local folder). The function downloads the model and saves it in a temporary
+# folder and returns a model and some metadata about the model.
+
+model, model_info = si.load_model(
+ repo_id = "SpikeInterface/toy_tetrode_model",
+ trusted = ['numpy.dtype']
+)
+
+
+##############################################################################
+# This model was trained on artifically generated tetrode data. The model object has a nice html
+# representation, which will appear if you're using a Jupyter notebook.
+
+model
+
+##############################################################################
+# The repo contains information about which metrics were used to compute it. We can access it
+# from the model, or from the model_info
+
+print(model.feature_names_in_)
+
+##############################################################################
+# Hence, to use this model we need to create a ``sorting_analyzer`` with all these metrics computed.
+# We'll do this by generating a recording and sorting, creating a sorting analyzer and computing a
+# bunch of extensions. Follow these links for more info on `recordings `_, `sortings `_, `sorting analyzers `_
+# and `extensions `_.
+
+recording, sorting = si.generate_ground_truth_recording(num_channels=4, seed=4, num_units=10)
+sorting_analyzer = si.create_sorting_analyzer(sorting=sorting, recording=recording)
+sorting_analyzer.compute(['noise_levels','random_spikes','waveforms','templates','spike_locations','spike_amplitudes','correlograms','principal_components','quality_metrics','template_metrics'])
+
+##############################################################################
+# This sorting_analyzer now contains the required quality metrics and template metrics computed.
+# We can check that this is true by accessing the extension data.
+
+all_metric_names = list(sorting_analyzer.get_extension('quality_metrics').get_data().keys()) + list(sorting_analyzer.get_extension('template_metrics').get_data().keys())
+print(np.all(all_metric_names == model.feature_names_in_))
+
+##############################################################################
+# Great! We can now use the model to predict labels. You can either pass a HuggingFace repo, or a
+# local folder containing the model file and the ``model_info.json`` file. Here, we'll pass
+# a repo. The function returns a dictionary containing a label and a confidence for each unit.
+
+labels = si.auto_label_units(
+ sorting_analyzer = sorting_analyzer,
+ repo_id = "SpikeInterface/toy_tetrode_model",
+ trusted = ['numpy.dtype']
+)
+
+labels
+
+
+##############################################################################
+# The model has labelled one unit as bad. Let's look at that one and the 'good' unit with the highest
+# confidence of being good:
+
+si.plot_unit_templates(sorting_analyzer, unit_ids=[7,9])
+
+##############################################################################
+# Nice - we can see that unit 9 does look a lot more spikey than unit 7. You might think that unit
+# 7 is a real unit. If so, this model isn't good for you.
+#
+# Assess the model performance
+# ----------------------------
+#
+# To assess the performance of the model relative to human labels, we can load or generate some
+# human labels, and plot a confusion matrix of predicted vs human labels for all clusters. Here
+# we're being a conservative human, who has labelled several units with small amplitudes as 'bad'.
+
+human_labels = ['bad', 'good', 'good', 'bad', 'good', 'bad', 'good', 'bad', 'good', 'good']
+
+# Note: if you labelled using phy, you can load the labels using:
+# human_labels = sorting_analyzer.sorting.get_property('quality')
+
+from sklearn.metrics import confusion_matrix, balanced_accuracy_score
+import seaborn as sns
+
+label_conversion = model_info['label_conversion']
+predictions = [ labels[a][0] for a in range(10) ]
+
+conf_matrix = confusion_matrix(human_labels, predictions)
+
+# Calculate balanced accuracy for the confusion matrix
+balanced_accuracy = balanced_accuracy_score(human_labels, predictions)
+
+sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='viridis')
+plt.xlabel('Predicted Label')
+plt.ylabel('Human Label')
+plt.xticks(ticks = [0.5, 1.5], labels = list(label_conversion.values()))
+plt.yticks(ticks = [0.5, 1.5], labels = list(label_conversion.values()))
+plt.title('Predicted vs Human Label')
+plt.suptitle(f"Balanced Accuracy: {balanced_accuracy}")
+plt.show()
+
+
+##############################################################################
+# Here, there are several false positives (if we consider the human labels to be "the truth").
+#
+# Next, we can also see how the model's confidence relates to the probability that the model
+# label matches the human label
+#
+# This could be used to set a threshold above which you might accept the model's classification,
+# and only manually curate those which it is less sure of
+
+
+def calculate_moving_avg(label_df, confidence_label, window_size):
+
+ label_df[f'{confidence_label}_decile'] = pd.cut(label_df[confidence_label], 10, labels=False, duplicates='drop')
+ # Group by decile and calculate the proportion of correct labels (agreement)
+ p_label_grouped = label_df.groupby(f'{confidence_label}_decile')['model_x_human_agreement'].mean()
+ # Convert decile to range 0-1
+ p_label_grouped.index = p_label_grouped.index / 10
+ # Sort the DataFrame by confidence scores
+ label_df_sorted = label_df.sort_values(by=confidence_label)
+
+ p_label_moving_avg = label_df_sorted['model_x_human_agreement'].rolling(window=window_size).mean()
+
+ return label_df_sorted[confidence_label], p_label_moving_avg
+
+confidences = sorting_analyzer.sorting.get_property('label_confidence')
+
+# Make dataframe of human label, model label, and confidence
+label_df = pd.DataFrame(data = {
+ 'human_label': human_labels,
+ 'decoder_label': predictions,
+ 'confidence': confidences},
+ index = sorting_analyzer.sorting.get_unit_ids())
+
+# Calculate the proportion of agreed labels by confidence decile
+label_df['model_x_human_agreement'] = label_df['human_label'] == label_df['decoder_label']
+
+p_agreement_sorted, p_agreement_moving_avg = calculate_moving_avg(label_df, 'confidence', 3)
+
+# Plot the moving average of agreement
+plt.figure(figsize=(6, 6))
+plt.plot(p_agreement_sorted, p_agreement_moving_avg, label = 'Moving Average')
+plt.axhline(y=1/len(np.unique(predictions)), color='black', linestyle='--', label='Chance')
+plt.xlabel('Confidence'); #plt.xlim(0.5, 1)
+plt.ylabel('Proportion Agreement with Human Label'); plt.ylim(0, 1)
+plt.title('Agreement vs Confidence (Moving Average)')
+plt.legend(); plt.grid(True); plt.show()
+
+##############################################################################
+# In this case, you might decide to only trust labels which had confidence over above 0.86.
+#
+# A more realistic example
+# ------------------------
+#
+# Here, we used a toy model trained on generated data. There are also models on HuggingFace
+# trained on real data.
+#
+# For example, the following classifier is trained on Neuropixels data from 11 mice recorded in
+# V1,SC and ALM: https://huggingface.co/AnoushkaJain3/curation_machine_learning_models
+#
+# Go to that page, and take a look at the `Files`. The models are contained in the `skops files `_
+# and there are _two_ in this repo. We can choose which to load in the `load_model` function as follows:
+#
+# .. code-block::
+#
+# model, model_info = si.load_model(
+# repo_id = "AnoushkaJain3/curation_machine_learning_models",
+# model_name= 'noise_neuron_model.skops',
+# )
+#
+# If you run this locally you will receive an error:
+#
+# .. code-block::
+#
+# UntrustedTypesFoundException: Untrusted types found in the file: ['sklearn.metrics._classification.balanced_accuracy_score', 'sklearn.metrics._scorer._Scorer', 'sklearn.model_selection._search_successive_halving.HalvingGridSearchCV', 'sklearn.model_selection._split.StratifiedKFold']
+#
+# This is a security warning. Sharing models, with are Python objects, is a complicated problem.
+# We have chosen to use the `skops format `_, instead of the common but insecure ``.pkl`` format
+# (read about ``pickle`` security issues `here `_). While unpacking the ``.skops`` file, each function
+# is checked. Ideally, skops should recognise all `sklearn`, `numpy` and `scipy` functions and
+# allows the object to be loaded. But when it's not sure, it raises an error. Here, ``skops``
+# doesn't recognise ``['sklearn.metrics._classification.balanced_accuracy_score',
+# 'sklearn.metrics._scorer._Scorer', 'sklearn.model_selection._search_successive_halving.HalvingGridSearchCV',
+# 'sklearn.model_selection._split.StratifiedKFold']``. Taking a look, these are all functions
+# from `sklearn`, and we can happily add them to the ``trusted`` functions.
+#
+# .. code-block::
+#
+# model, model_info = si.load_model(
+# model_name = 'noise_neuron_model.skops',
+# repo_id = "AnoushkaJain3/curation_machine_learning_models",
+# trusted = ['sklearn.metrics._classification.balanced_accuracy_score', 'sklearn.metrics._scorer._Scorer', 'sklearn.model_selection._search_successive_halving.HalvingGridSearchCV', 'sklearn.model_selection._split.StratifiedKFold']
+# )
+#
+# As `skops` continues to be developed, we hope more of these functions will be :code:`trusted`
+# by default.
+#
+# In general, you should be cautious when downloading `.skops` files and `.pkl` files from repos,
+# especially from unknown sources.
diff --git a/examples/tutorials/curation/plot_2_train_a_model.py b/examples/tutorials/curation/plot_2_train_a_model.py
new file mode 100644
index 0000000000..1cc90224f4
--- /dev/null
+++ b/examples/tutorials/curation/plot_2_train_a_model.py
@@ -0,0 +1,114 @@
+"""
+Training a model for automated curation
+=============================
+
+If the pretrained models do not give satisfactory performance on your data, it is easy to train your own classifier using SpikeInterface.
+"""
+import warnings
+warnings.filterwarnings("ignore")
+from pathlib import Path
+import spikeinterface.full as si
+import numpy as np
+import pandas as pd
+import matplotlib.pyplot as plt
+
+# Note, you can set the number of cores you use using e.g.
+# si.set_global_job_kwargs(n_jobs = 8)
+
+##############################################################################
+# Step 1: Generate and label data
+# -------------------------------
+#
+# We supply pretrained machine learning classifiers for predicting spike-sorted clusters with
+# arbitrary labels, in this example single-unit activity ('good'), or noise. This particular
+# approach works as follows:
+#
+# For the tutorial, we will use simulated data to create :code:`recording` and :code:`sorting` objects. We'll
+# create two sorting objects: :code:`sorting_1` is coupled to the real recording, so will contain good
+# units, :code:`sorting_2` is uncoupled, so should be pure noise. We'll combine the two into one sorting
+# object using :code:`si.aggregate_units`.
+#
+# You should `load your own recording `_ and `do a sorting `_ on your data.
+
+recording, sorting_1 = si.generate_ground_truth_recording(num_channels=4, seed=1, num_units=5)
+_, sorting_2 =si.generate_ground_truth_recording(num_channels=4, seed=2, num_units=5)
+
+both_sortings = si.aggregate_units([sorting_1, sorting_2])
+
+##############################################################################
+# Our model is based on :code:`quality_metrics`, which are computed using a :code:`sorting_analyzer`. So we'll
+# now create a sorting analyzer and compute all the extensions needed to get the quality metrics.
+
+analyzer = si.create_sorting_analyzer(sorting = both_sortings, recording=recording)
+analyzer.compute(['noise_levels','random_spikes','waveforms','templates','spike_locations','spike_amplitudes','correlograms','principal_components','quality_metrics','template_metrics'])
+
+##############################################################################
+# Let's plot the templates for the first and fifth units. The first (unit id 0) belonged to
+# :code:`sorting_1` so should look like a real unit; the sixth (unit id 5) belonged to :code:`sorting_2`
+# so should look like noise.
+
+si.plot_unit_templates(analyzer, unit_ids=[0,5])
+
+##############################################################################
+# This is as expected: great! Find out more about plotting using widgets `here `_. The labels
+# for our units are then easy to put in a list:
+
+labels = ['good', 'good', 'good', 'good', 'good', 'bad', 'bad', 'bad', 'bad', 'bad']
+
+##############################################################################
+# Step 2: Train our model
+# -----------------------
+
+# With our labelled data in hand, we can train the model using the :code:`train_model` function.
+# Here, the idea is that the trainer will try several classifiers, imputation strategies and
+# scaling techniques then save the most accurate. To save time, we'll only try one classifier
+# (Random Forest), imputation strategy (median) and scaling technique (standard scaler).
+
+output_folder = "my_model"
+
+# We will use a list of one analyzer here, we would strongly advise using more than one to
+# improve model performance
+trainer = si.train_model(
+ mode = "analyzers", # You can supply a labelled csv file instead of an analyzer
+ labels = [labels],
+ analyzers = [analyzer],
+ output_folder = output_folder, # Where to save the model and model_info.json file
+ metric_names = None, # Specify which metrics to use for training: by default uses those already calculted
+ imputation_strategies = ["median"], # Defaults to all
+ scaling_techniques = ["standard_scaler"], # Defaults to all
+ classifiers = None, # Default to Random Forest only. Other classifiers you can try [ "AdaBoostClassifier","GradientBoostingClassifier","LogisticRegression","MLPClassifier"]
+)
+
+best_model = trainer.best_pipeline
+
+##############################################################################
+#
+# The above code saves the model in :code:`model.skops`, some metadata in :code:`model_info.json` and
+# the model accuracies in :code:`model_accuracies.csv` in the specified :code:`output_folder`.
+#
+# :code:`skops` is a file format; you can think of it as a more-secture pkl file. `Read more `_.
+#
+# The :code:`model_accuracies.csv` file contains the accuracy, precision and recall of the tested models.
+# Let's take a look
+
+accuracies = pd.read_csv(Path(output_folder) / "model_accuracies.csv", index_col = 0)
+accuracies.head()
+
+# Our model is perfect!! This is because the task was _very_ easy. We had 10 units; where
+# half were pure noise and half were not.
+#
+# The model also contains some more information, such as which features are importantly.
+# We can plot these as follows:
+
+# Plot feature importances
+importances = best_model.named_steps['classifier'].feature_importances_
+indices = np.argsort(importances)[::-1]
+features = best_model.feature_names_in_
+n_features = best_model.n_features_in_
+
+plt.figure(figsize=(12, 6))
+plt.title("Feature Importances")
+plt.bar(range(n_features), importances[indices], align="center")
+plt.xticks(range(n_features), features, rotation=90)
+plt.xlim([-1, n_features])
+plt.show()
diff --git a/examples/tutorials/qualitymetrics/plot_5_automated_curation.py b/examples/tutorials/qualitymetrics/plot_5_automated_curation.py
deleted file mode 100644
index a572b59757..0000000000
--- a/examples/tutorials/qualitymetrics/plot_5_automated_curation.py
+++ /dev/null
@@ -1,256 +0,0 @@
-"""
-Model Based Curation Tutorial
-=============================
-
-This notebook outlines approaches to training a machine learning classifier on Spikeinterface-computed quality metrics, and using such models to predict curation labels for previously uncurated electrophysiology data
- - Code for predicting labels using a trained model can be found in the first section, then code for training your own bespoke model
- - Plots can be generated to assess the performance of the model, both on the training and unseen data
- - Pre-trained models can be downloaded from `Hugging Face `_, or opened from `skops `_ files
-"""
-
-from pathlib import Path
-import numpy as np
-import pandas as pd
-import matplotlib.pyplot as plt
-import spikeinterface as si
-import spikeinterface.extractors as se
-from spikeinterface.curation.model_based_curation import compute_all_metrics
-
-from os import cpu_count
-
-# Set the number of CPU cores to be used globally - defaults to all cores -1
-n_cores = cpu_count() -1
-si.set_global_job_kwargs(n_jobs = n_cores)
-print(f"Number of cores set to: {n_cores}")
-
-# SET OUTPUT FOLDER
-output_folder = "/home/jake/Documents/ephys_analysis/code/ephys_analysis/auto_curation/models"
-
-##############################################################################
-# Applying a pretrained model to predict curation labels
-# ------------------------------
-#
-# We supply pretrained machine learning classifiers for predicting spike-sorted clusters with arbitrary labels, in this example single-unit activity ('good'), or noise. This particular approach works as follows:
-#
-# 1. Create a Spikeinterface 'sorting analyzer `_ and compute `quality metrics `_
-#
-# 2. Load a pretrained classification model from Hugging Face & skops
-#
-# 3. Compare with human-applied curation labels to assess performance, optionally export labels to phy to allow manual checking
-#
-# Load data and compute quality metrics - test data are used here, but replace with your own recordings!
-# We would recommend to start with some previously labelled data, to allow comparison of the model performance against your labelling
-# First, let's simulate some data and compute quality metrics
-
-unlabelled_recording, unlabelled_sorting = si.generate_ground_truth_recording(durations=[60], num_units=30)
-
-unlabelled_analyzer = si.create_sorting_analyzer(sorting = unlabelled_sorting, recording = unlabelled_recording, sparse = True)
-
-# Compute all quality metrics
-quality_metrics, template_metrics = compute_all_metrics(unlabelled_analyzer)
-
-##############################################################################
-# Load pretrained models and predict curation labels for unlabelled data
-# ------------------------------
-#
-# We can download a pretrained model from Hugging Face, and use this to test out prediction. This particular model assumes only two labels ('noise', and 'good') have been used for classification
-# Predictions and prediction confidence are then stored in "label_prediction" and "label_confidence" unit properties
-
-##############################################################################
-# Load pretrained noise/neural activity model and predict on unlabelled data
-from spikeinterface.curation.model_based_curation import auto_label_units
-
-from huggingface_hub import hf_hub_download
-import skops.io
-model_path = hf_hub_download(repo_id="chrishalcrow/test_automated_curation_3", filename="skops-_xvuw15v.skops")
-model = skops.io.load(model_path, trusted='numpy.dtype')
-
-label_conversion = {0: 'noise', 1: 'good'}
-
-label_dict = auto_label_units(sorting_analyzer=unlabelled_analyzer,
- pipeline=model,
- label_conversion=label_conversion,
- export_to_phy=False,
- pipeline_info_path=None)
-unlabelled_analyzer.sorting
-
-##############################################################################
-# Assess model performance by comparing with human labels
-
-# To assess the performance of the model relative to human labels, we can load (or here generate randomly) some labels, and plot a confusion matrix of predicted vs human labels for all clusters
-
-from sklearn.metrics import confusion_matrix, balanced_accuracy_score
-import seaborn as sns
-
-# Use 'ground-truth' labels to check prediction accuracy
-# These are assigned randomly here but you could load these from phy 'cluster_group.tsv', from the 'quality' property of the sorting, or similar
-human_labels = np.random.choice(list(label_conversion.values()), unlabelled_analyzer.get_num_units())
-
-# Get labels from phy sorting (if loaded) using:
-# human_labels = unlabelled_analyzer.sorting.get_property('quality')
-
-predictions = unlabelled_analyzer.sorting.get_property('label_prediction')
-
-conf_matrix = confusion_matrix(human_labels, predictions)
-
-# Calculate balanced accuracy for the confusion matrix
-balanced_accuracy = balanced_accuracy_score(human_labels, predictions)
-
-sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='viridis')
-plt.xlabel('Predicted Label')
-plt.ylabel('Human Label')
-plt.xticks(ticks = [0.5, 1.5], labels = list(label_conversion.values()))
-plt.yticks(ticks = [0.5, 1.5], labels = list(label_conversion.values()))
-plt.title('Predicted vs Human Label')
-plt.suptitle(f"Balanced Accuracy: {balanced_accuracy}")
-plt.show()
-
-##############################################################################
-# We can also see how the model's confidence relates to the probability that the model label matches the human label
-#
-# This could be used to set a threshold above which you might accept the model's classification, and only manually curate those which it is less sure of
-
-confidences = unlabelled_analyzer.sorting.get_property('label_confidence')
-
-# Make dataframe of human label, model label, and confidence
-label_df = pd.DataFrame(data = {
- 'phy_label': human_labels,
- 'decoder_label': predictions,
- 'confidence': confidences},
- index = unlabelled_analyzer.sorting.get_unit_ids())
-
-# Calculate the proportion of agreed labels by confidence decile
-label_df['model_x_human_agreement'] = label_df['phy_label'] == label_df['decoder_label']
-
-def calculate_moving_avg(label_df, confidence_label, window_size):
-
- label_df[f'{confidence_label}_decile'] = pd.cut(label_df[confidence_label], 10, labels=False, duplicates='drop')
- # Group by decile and calculate the proportion of correct labels (agreement)
- p_label_grouped = label_df.groupby(f'{confidence_label}_decile')['model_x_human_agreement'].mean()
- # Convert decile to range 0-1
- p_label_grouped.index = p_label_grouped.index / 10
- # Sort the DataFrame by confidence scores
- label_df_sorted = label_df.sort_values(by=confidence_label)
-
- p_label_moving_avg = label_df_sorted['model_x_human_agreement'].rolling(window=window_size).mean()
-
- return label_df_sorted[confidence_label], p_label_moving_avg
-
-p_agreement_sorted, p_agreement_moving_avg = calculate_moving_avg(label_df, 'confidence', 20)
-
-# Plot the moving average of agreement
-plt.figure(figsize=(6, 6))
-plt.plot(p_agreement_sorted, p_agreement_moving_avg, label = 'Moving Average')
-plt.axhline(y=1/len(np.unique(predictions)), color='black', linestyle='--', label='Chance')
-plt.xlabel('Confidence'); plt.xlim(0.5, 1)
-plt.ylabel('Proportion Agreement with Human Label'); plt.ylim(0, 1)
-plt.title('Agreement vs Confidence (Moving Average)')
-plt.legend(); plt.grid(True); plt.show()
-
-##############################################################################
-# ------------------------------
-#
-##############################################################################
-# Training a new model
-#
-# **If the pretrained models do not give satisfactory performance on your data, it is easy to train your own classifier through SpikeInterface!**
-#
-# First we make a Spikeinterface SortingAnalyzer object, in this case using simulated data, and generate some labels for the units to use as our target
-# Note that these labels are random as written here, so the example model will only perform at chance
-#
-# Load some of your data and curation labels here and see how it performs!
-# **Note that it is likely that for useful generalisability, you will need to use multiple labelled recordings for training.** To do this, compute metrics as described for multiple SortingAnalyzers, then pass them as a list to the model training function, and pass the labels as a single list in the same order
-#
-# The set of unique labels used is arbitrary, so this could just as easily be used for any cluster categorisation task as for curation into the standard 'good', 'mua' and 'noise' categories
-
-# Make a simulated SortingAnalyzer with 100 units
-labelled_recording, labelled_sorting = si.generate_ground_truth_recording(durations=[60], num_units=30)
-
-labelled_analyzer = si.create_sorting_analyzer(sorting = labelled_sorting, recording = labelled_recording, sparse = True)
-
-# Compute all quality metrics
-compute_all_metrics(labelled_analyzer)
-
-label_conversion = {'noise': 0, 'mua': 1, 'good': 2}
-
-# These are assigned randomly here but you could load these from phy 'cluster_group.tsv', from the 'quality' property of the sorting, or similar
-human_labels = np.random.choice(list(label_conversion.values()), labelled_analyzer.get_num_units())
-labelled_analyzer.sorting.set_property('quality', human_labels)
-
-# Get labels from phy sorting (if loaded) using:
-# human_labels = unlabelled_analyzer.sorting.get_property('quality')
-
-##############################################################################
-# Now we train the machine learning classifier
-#
-# By default, this searches a range of possible imputing and scaling strategies, and uses a Random Forest classifier. It then selects the model which most accurately predicts the supplied 'ground-truth' labels
-#
-# As output, this function saves the best model (as a skops file, similar to a pickle), a csv file containing information about the performance of all tested models (`model_label_accuracies.csv`), and a `model_info.json` file containing the parameters used to compute quality metrics, and the SpikeInterface version, for reproducibility
-
-# Load labelled metrics and train model
-from spikeinterface.curation.train_manual_curation import train_model
-
-# We will use a list of two (identical) analyzers here, we would advise using more than one to improve model performance
-trainer = train_model(mode = "analyzers",
- labels = np.append(human_labels, human_labels),
- analyzers = [labelled_analyzer, labelled_analyzer],
- output_folder = output_folder, # Optional, can be set to save the model and model_info.json file
- metric_names = None, # Can be set to specify which metrics to use for training
- imputation_strategies = None, # Default to all
- scaling_techniques = None, # Default to all
- classifiers = None, # Default to Random Forest only
- seed = None)
-
-best_model = trainer.best_pipeline
-best_model
-
-# OR load model from file
-# import skops.io
-# pipeline_path = Path(output_folder) / Path("best_model_label.skops")
-# unknown_types = skops.io.get_untrusted_types(file=pipeline_path)
-# best_model = skops.io.load(pipeline_path, trusted=unknown_types)
-
-##############################################################################
-# We can see the performance of each model in this `model_label_accuracies.csv` output file
-
-# Load and disply top 5 pipelines and accuracies
-accuracies = pd.read_csv(Path(output_folder) / Path("model_label_accuracies.csv"), index_col = 0)
-accuracies.head()
-
-##############################################################################
-# We can also see which metrics are most important to our model:
-
-# Plot feature importances
-importances = best_model.named_steps['classifier'].feature_importances_
-indices = np.argsort(importances)[::-1]
-features = best_model.feature_names_in_
-n_features = best_model.n_features_in_
-
-plt.figure(figsize=(12, 6))
-plt.title("Feature Importances")
-plt.bar(range(n_features), importances[indices], align="center")
-plt.xticks(range(n_features), features, rotation=90)
-plt.xlim([-1, n_features])
-plt.show()
-
-##############################################################################
-# Apply trained model to unlabelled data
-#
-# This approach is the same as in the first section, but without the need to combine the output of two separate classifiers
-
-unlabelled_recording, unlabelled_sorting = si.generate_ground_truth_recording(durations=[60], num_units=30)
-unlabelled_analyzer = si.create_sorting_analyzer(sorting = unlabelled_sorting, recording = unlabelled_recording, sparse = True)
-
-compute_all_metrics(unlabelled_analyzer)
-
-##############################################################################
-# Load best model and predict on unlabelled data
-
-from spikeinterface.curation.model_based_curation import auto_label_units
-label_conversion = {0: 'noise', 1: 'mua', 2: 'good'}
-label_dict = auto_label_units(sorting_analyzer = unlabelled_analyzer,
- pipeline = best_model,
- label_conversion = label_conversion,
- export_to_phy = False,
- pipeline_info_path = Path(output_folder) / Path("model_info.json"))
-unlabelled_analyzer.sorting