diff --git a/src/spikeinterface/comparison/groundtruthstudy.py b/src/spikeinterface/comparison/groundtruthstudy.py index cccf269f86..72ab83d455 100644 --- a/src/spikeinterface/comparison/groundtruthstudy.py +++ b/src/spikeinterface/comparison/groundtruthstudy.py @@ -176,12 +176,13 @@ def remove_sorting(self, key): for f in (log_file, comparison_file): if f.exists(): f.unlink() - + def set_colors(self, colors=None, map_name="tab20"): if colors is None: case_keys = list(self.cases.keys()) - self.colors = get_some_colors(case_keys, map_name=map_name, - color_engine = "matplotlib", shuffle=False, margin=0) + self.colors = get_some_colors( + case_keys, map_name=map_name, color_engine="matplotlib", shuffle=False, margin=0 + ) else: self.colors = colors diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py index 614ae6ba23..b77d9ce1d0 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_clustering.py @@ -239,7 +239,7 @@ def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)): fig.colorbar(im, ax=axes[0, count]) label = self.cases[key]["label"] axes[0, count].set_title(label) - + return fig def plot_metrics_vs_snr(self, metric="agreement", case_keys=None, figsize=(15, 5), axes=None): @@ -298,7 +298,7 @@ def plot_metrics_vs_snr(self, metric="agreement", case_keys=None, figsize=(15, 5 label = self.cases[key]["label"] axes[count].set_title(label) axes[count].legend() - + if fig is not None: return fig @@ -355,7 +355,7 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs elif metric == "agreement": for found, real in zip(matched_ids2[mask], unit_ids1[mask]): to_plot += [scores.at[real, found]] - elif metric in ['recall', 'precision', 'accuracy']: + elif metric in ["recall", "precision", "accuracy"]: to_plot = result["gt_comparison"].get_performance()[metric].values depth_matched = depth snr_matched = metrics["snr"] @@ -368,14 +368,14 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs label = self.cases[key]["label"] axes[0, count].set_title(label) if count > 0: - axes[0, count].set_ylabel('') + axes[0, count].set_ylabel("") axes[0, count].set_yticks([], []) # axs[0, count].legend() - + fig.subplots_adjust(right=0.85) cbar_ax = fig.add_axes([0.9, 0.1, 0.025, 0.75]) fig.colorbar(im, cax=cbar_ax, label=metric) - + return fig def plot_unit_losses(self, cases_before, cases_after, metric="agreement", figsize=None): @@ -398,19 +398,17 @@ def plot_unit_losses(self, cases_before, cases_after, metric="agreement", figsiz ax.set_ylabel("depth (um)") ax.set_ylabel("snr") if count > 0: - ax.set_ylabel('') + ax.set_ylabel("") ax.set_yticks([], []) im = ax.scatter(positions[:, 1], x, c=(y_after - y_before), cmap="coolwarm") im.set_clim(-1, 1) - #fig.colorbar(im, ax=ax) - #ax.set_title(k) - + # fig.colorbar(im, ax=ax) + # ax.set_title(k) fig.subplots_adjust(right=0.85) cbar_ax = fig.add_axes([0.9, 0.1, 0.025, 0.75]) cbar = fig.colorbar(im, cax=cbar_ax, label=metric) - #cbar.set_clim(-1, 1) - + # cbar.set_clim(-1, 1) return fig diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py index e53ec232fb..ab1523d13a 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_matching.py @@ -15,6 +15,7 @@ from spikeinterface.core.recording_tools import get_noise_levels from spikeinterface.core.sparsity import compute_sparsity + class MatchingBenchmark(Benchmark): def __init__(self, recording, gt_sorting, params): @@ -241,19 +242,19 @@ def plot_unit_losses(self, before, after, metric=["precision"], figsize=None): y_before = self.get_result(before)["gt_comparison"].get_performance()[k].values y_after = self.get_result(after)["gt_comparison"].get_performance()[k].values - #if count < 2: - #ax.set_xticks([], []) - #elif count == 2: + # if count < 2: + # ax.set_xticks([], []) + # elif count == 2: ax.set_xlabel("depth (um)") im = ax.scatter(positions[:, 1], x, c=(y_after - y_before), cmap="coolwarm") fig.colorbar(im, ax=ax, label=k) im.set_clim(-1, 1) ax.set_title(k) ax.set_ylabel("snr") - - #fig.subplots_adjust(right=0.85) - #cbar_ax = fig.add_axes([0.9, 0.1, 0.025, 0.75]) - #cbar = fig.colorbar(im, cax=cbar_ax, label=metric) + + # fig.subplots_adjust(right=0.85) + # cbar_ax = fig.add_axes([0.9, 0.1, 0.025, 0.75]) + # cbar = fig.colorbar(im, cax=cbar_ax, label=metric) # if count == 2: # ax.legend() diff --git a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py index b2825f1ab6..7d862343d2 100644 --- a/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/benchmark/benchmark_peak_detection.py @@ -195,7 +195,7 @@ def plot_detected_amplitudes(self, case_keys=None, figsize=(15, 5), detect_thres ymin, ymax = ax.get_ylim() abs_threshold = -detect_threshold * noise_levels ax.plot([abs_threshold, abs_threshold], [ymin, ymax], "k--") - + return fig def plot_deltas_per_cells(self, case_keys=None, figsize=(15, 5)): diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index cb8e0d301d..928979e447 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -140,10 +140,13 @@ def remove_empty_templates(templates): is_scaled=templates.is_scaled, ) + def sigmoid(x, x0, k, b): - return (1 / (1+np.exp(-k*(x-x0)))) + b + return (1 / (1 + np.exp(-k * (x - x0)))) + b + def fit_sigmoid(xdata, ydata, p0=None): from scipy.optimize import curve_fit + popt, pcov = curve_fit(sigmoid, xdata, ydata, p0) - return popt \ No newline at end of file + return popt diff --git a/src/spikeinterface/widgets/gtstudy.py b/src/spikeinterface/widgets/gtstudy.py index eccc6808b3..19d6c0105d 100644 --- a/src/spikeinterface/widgets/gtstudy.py +++ b/src/spikeinterface/widgets/gtstudy.py @@ -50,7 +50,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): label = dp.study.cases[key]["label"] rt = dp.run_times.loc[key] self.ax.bar(i, rt, width=0.8, label=label, facecolor=self.colors[key]) - self.ax.set_ylabel('run time (s)') + self.ax.set_ylabel("run time (s)") self.ax.legend()