Skip to content

Commit

Permalink
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
Browse files Browse the repository at this point in the history
…tdc_2
  • Loading branch information
samuelgarcia committed Oct 6, 2023
2 parents f5a42e7 + 4cc4777 commit 3b681bf
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 4 deletions.
5 changes: 4 additions & 1 deletion src/spikeinterface/comparison/groundtruthstudy.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True
sorting_exists = sorting_folder.exists()

sorter_folder = self.folder / "sorters" / self.key_to_str(key)
sorter_folder_exists = sorting_folder.exists()
sorter_folder_exists = sorter_folder.exists()

if keep:
if sorting_exists:
Expand All @@ -185,6 +185,9 @@ def run_sorters(self, case_keys=None, engine="loop", engine_kwargs={}, keep=True
if log_file.exists():
log_file.unlink()

if sorter_folder_exists:
shutil.rmtree(sorter_folder)

params = self.cases[key]["run_sorter_params"].copy()
# this ensure that sorter_name is given
recording, _ = self.datasets[self.cases[key]["dataset"]]
Expand Down
4 changes: 3 additions & 1 deletion src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

## We get the labels for our peaks
mask = peak_labels > -1
sorting = NumpySorting.from_times_labels(selected_peaks["sample_index"][mask], peak_labels[mask], sampling_rate)
sorting = NumpySorting.from_times_labels(
selected_peaks["sample_index"][mask], peak_labels[mask].astype(int), sampling_rate
)
clustering_folder = sorter_output_folder / "clustering"
if clustering_folder.exists():
shutil.rmtree(clustering_folder)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def plot_statistics(self, metric="cosine", annotations=True, detect_threshold=5)
template_real = template_real.reshape(template_real.size, 1).T

if metric == "cosine":
dist = sklearn.metrics.pairwise.cosine_similarity(template, template_real, metric).flatten().tolist()
dist = sklearn.metrics.pairwise.cosine_similarity(template, template_real).flatten().tolist()
else:
dist = sklearn.metrics.pairwise_distances(template, template_real, metric).flatten().tolist()
res += dist
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/widgets/widget_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@
plot_study_run_times = StudyRunTimesWidget
plot_study_unit_counts = StudyUnitCountsWidget
plot_study_performances = StudyPerformances
plot_stufy_performances_vs_metrics = StudyPerformancesVsMetrics
plot_study_performances_vs_metrics = StudyPerformancesVsMetrics


def plot_timeseries(*args, **kwargs):
Expand Down

0 comments on commit 3b681bf

Please sign in to comment.