diff --git a/examples/tutorials/curation/plot_1_automated_curation.py b/examples/tutorials/curation/plot_1_automated_curation.py index 8e447dd12a..f355a48032 100644 --- a/examples/tutorials/curation/plot_1_automated_curation.py +++ b/examples/tutorials/curation/plot_1_automated_curation.py @@ -7,7 +7,6 @@ import warnings warnings.filterwarnings("ignore") -from pathlib import Path import numpy as np import pandas as pd import matplotlib.pyplot as plt @@ -97,7 +96,6 @@ # human_labels = sorting_analyzer.sorting.get_property('quality') from sklearn.metrics import confusion_matrix, balanced_accuracy_score -import seaborn as sns label_conversion = model_info['label_conversion'] predictions = [ labels[a][0] for a in range(10) ] @@ -107,7 +105,9 @@ # Calculate balanced accuracy for the confusion matrix balanced_accuracy = balanced_accuracy_score(human_labels, predictions) -sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='viridis') +plt.imshow(conf_matrix) +for (index, value) in np.ndenumerate(conf_matrix): + plt.annotate( str(value), xy=index, color="white", fontsize="15") plt.xlabel('Predicted Label') plt.ylabel('Human Label') plt.xticks(ticks = [0.5, 1.5], labels = list(label_conversion.values())) diff --git a/pyproject.toml b/pyproject.toml index e51fcb2a4d..6d724a32c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -200,6 +200,8 @@ docs = [ "numba", # For many postprocessing functions "xarray", # For use of SortingAnalyzer zarr format "networkx", + "skops", # For auotmated curation + "sklearn", # For auotmated curation # Download data "pooch>=1.8.2", "datalad>=1.0.2",