Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 3, 2024
1 parent 45fc116 commit b90fda9
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 26 deletions.
7 changes: 4 additions & 3 deletions src/spikeinterface/comparison/groundtruthstudy.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,12 +176,13 @@ def remove_sorting(self, key):
for f in (log_file, comparison_file):
if f.exists():
f.unlink()

def set_colors(self, colors=None, map_name="tab20"):
if colors is None:
case_keys = list(self.cases.keys())
self.colors = get_some_colors(case_keys, map_name=map_name,
color_engine = "matplotlib", shuffle=False, margin=0)
self.colors = get_some_colors(
case_keys, map_name=map_name, color_engine="matplotlib", shuffle=False, margin=0
)
else:
self.colors = colors

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)):
fig.colorbar(im, ax=axes[0, count])
label = self.cases[key]["label"]
axes[0, count].set_title(label)

return fig

def plot_metrics_vs_snr(self, metric="agreement", case_keys=None, figsize=(15, 5), axes=None):
Expand Down Expand Up @@ -298,7 +298,7 @@ def plot_metrics_vs_snr(self, metric="agreement", case_keys=None, figsize=(15, 5
label = self.cases[key]["label"]
axes[count].set_title(label)
axes[count].legend()

if fig is not None:
return fig

Expand Down Expand Up @@ -355,7 +355,7 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs
elif metric == "agreement":
for found, real in zip(matched_ids2[mask], unit_ids1[mask]):
to_plot += [scores.at[real, found]]
elif metric in ['recall', 'precision', 'accuracy']:
elif metric in ["recall", "precision", "accuracy"]:
to_plot = result["gt_comparison"].get_performance()[metric].values
depth_matched = depth
snr_matched = metrics["snr"]
Expand All @@ -368,14 +368,14 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs
label = self.cases[key]["label"]
axes[0, count].set_title(label)
if count > 0:
axes[0, count].set_ylabel('')
axes[0, count].set_ylabel("")
axes[0, count].set_yticks([], [])
# axs[0, count].legend()

fig.subplots_adjust(right=0.85)
cbar_ax = fig.add_axes([0.9, 0.1, 0.025, 0.75])
fig.colorbar(im, cax=cbar_ax, label=metric)

return fig

def plot_unit_losses(self, cases_before, cases_after, metric="agreement", figsize=None):
Expand All @@ -398,19 +398,17 @@ def plot_unit_losses(self, cases_before, cases_after, metric="agreement", figsiz
ax.set_ylabel("depth (um)")
ax.set_ylabel("snr")
if count > 0:
ax.set_ylabel('')
ax.set_ylabel("")
ax.set_yticks([], [])
im = ax.scatter(positions[:, 1], x, c=(y_after - y_before), cmap="coolwarm")
im.set_clim(-1, 1)
#fig.colorbar(im, ax=ax)
#ax.set_title(k)

# fig.colorbar(im, ax=ax)
# ax.set_title(k)

fig.subplots_adjust(right=0.85)
cbar_ax = fig.add_axes([0.9, 0.1, 0.025, 0.75])
cbar = fig.colorbar(im, cax=cbar_ax, label=metric)
#cbar.set_clim(-1, 1)

# cbar.set_clim(-1, 1)

return fig

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from spikeinterface.core.recording_tools import get_noise_levels
from spikeinterface.core.sparsity import compute_sparsity


class MatchingBenchmark(Benchmark):

def __init__(self, recording, gt_sorting, params):
Expand Down Expand Up @@ -241,19 +242,19 @@ def plot_unit_losses(self, before, after, metric=["precision"], figsize=None):

y_before = self.get_result(before)["gt_comparison"].get_performance()[k].values
y_after = self.get_result(after)["gt_comparison"].get_performance()[k].values
#if count < 2:
#ax.set_xticks([], [])
#elif count == 2:
# if count < 2:
# ax.set_xticks([], [])
# elif count == 2:
ax.set_xlabel("depth (um)")
im = ax.scatter(positions[:, 1], x, c=(y_after - y_before), cmap="coolwarm")
fig.colorbar(im, ax=ax, label=k)
im.set_clim(-1, 1)
ax.set_title(k)
ax.set_ylabel("snr")
#fig.subplots_adjust(right=0.85)
#cbar_ax = fig.add_axes([0.9, 0.1, 0.025, 0.75])
#cbar = fig.colorbar(im, cax=cbar_ax, label=metric)

# fig.subplots_adjust(right=0.85)
# cbar_ax = fig.add_axes([0.9, 0.1, 0.025, 0.75])
# cbar = fig.colorbar(im, cax=cbar_ax, label=metric)

# if count == 2:
# ax.legend()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def plot_detected_amplitudes(self, case_keys=None, figsize=(15, 5), detect_thres
ymin, ymax = ax.get_ylim()
abs_threshold = -detect_threshold * noise_levels
ax.plot([abs_threshold, abs_threshold], [ymin, ymax], "k--")

return fig

def plot_deltas_per_cells(self, case_keys=None, figsize=(15, 5)):
Expand Down
7 changes: 5 additions & 2 deletions src/spikeinterface/sortingcomponents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,13 @@ def remove_empty_templates(templates):
is_scaled=templates.is_scaled,
)


def sigmoid(x, x0, k, b):
return (1 / (1+np.exp(-k*(x-x0)))) + b
return (1 / (1 + np.exp(-k * (x - x0)))) + b


def fit_sigmoid(xdata, ydata, p0=None):
from scipy.optimize import curve_fit

popt, pcov = curve_fit(sigmoid, xdata, ydata, p0)
return popt
return popt
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/gtstudy.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
label = dp.study.cases[key]["label"]
rt = dp.run_times.loc[key]
self.ax.bar(i, rt, width=0.8, label=label, facecolor=self.colors[key])
self.ax.set_ylabel('run time (s)')
self.ax.set_ylabel("run time (s)")
self.ax.legend()


Expand Down

0 comments on commit b90fda9

Please sign in to comment.