Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tools for Generation of Hybrid recordings #2436

Merged
merged 147 commits into from
Jun 29, 2024
Merged
Show file tree
Hide file tree
Changes from 133 commits
Commits
Show all changes
147 commits
Select commit Hold shift + click to select a range
26ac06f
WIP
yger Jan 22, 2024
3f2bcfa
Merge branch 'SpikeInterface:main' into hybrid_raw_clustering
yger Jan 22, 2024
4216503
WIP
yger Jan 23, 2024
6275aa5
Merge branch 'hybrid_raw_clustering' of github.com:yger/spikeinterfac…
yger Jan 23, 2024
af485ea
Docstrings and cosmetics
yger Jan 23, 2024
ddb1d51
WIP
yger Jan 23, 2024
08ffeef
WIP
yger Jan 23, 2024
0db1597
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 24, 2024
a7b6de2
Merge branch 'estimate_sparsity' of github.com:samuelgarcia/spikeinte…
yger Jan 24, 2024
b89f6a3
WIP
yger Jan 25, 2024
80e6b8d
WIP
yger Jan 25, 2024
84b60ac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 25, 2024
f5f3435
WIP
yger Jan 25, 2024
1c8bb72
Polish
yger Jan 25, 2024
a5b12b0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 25, 2024
1c7cf33
Merge branch 'SpikeInterface:main' into hybrid_raw_clustering
yger Jan 25, 2024
fe6247c
Merge branch 'main' into hybrid_raw_clustering
yger Jan 30, 2024
ae285a4
WIP
yger Jan 30, 2024
495b863
Merge branch 'SpikeInterface:main' into hybrid_raw_clustering
yger Feb 1, 2024
05621bb
Merge branch 'SpikeInterface:main' into hybrid_raw_clustering
yger Feb 2, 2024
7a96d85
WIP
yger Feb 6, 2024
5f0235f
Merge branch 'main' of github.com:yger/spikeinterface into hybrid_raw…
yger Feb 6, 2024
e6595ee
Merge branch 'SpikeInterface:main' into hybrid_raw_clustering
yger Feb 6, 2024
e0251b0
Merge branch 'hybrid_raw_clustering' of github.com:yger/spikeinterfac…
yger Feb 6, 2024
69f0674
Merge branch 'SpikeInterface:main' into hybrid_raw_clustering
yger Feb 16, 2024
e0037b3
WIP
yger Mar 12, 2024
2b8656f
WIP
yger Mar 12, 2024
b2365d4
WIP
yger Mar 12, 2024
6858a2f
Merge branch 'SpikeInterface:main' into hybrid_raw_clustering
yger Mar 27, 2024
1d5b641
WIP
yger Mar 28, 2024
a139762
WIP
yger Mar 28, 2024
6a1673d
WIP
yger Mar 28, 2024
7742882
WIP
yger Mar 28, 2024
36ec32d
WIP
yger Mar 28, 2024
3c59afb
Merge branch 'main' into hybrid_raw_clustering
yger Mar 28, 2024
f6c4c4a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2024
b2e87a0
Docs
yger Mar 28, 2024
5046e53
Merge branch 'hybrid_raw_clustering' of github.com:yger/spikeinterfac…
yger Mar 28, 2024
275f87a
Adding tests
yger Mar 28, 2024
4e97eeb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2024
45653f6
One more test
yger Mar 28, 2024
e05f1d0
WIP
yger Mar 28, 2024
d1ee525
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2024
c12c746
Adding imports
yger Mar 29, 2024
c62f952
Merge branch 'main' into hybrid_raw_clustering
yger Mar 29, 2024
03a41d2
Merge branch 'hybrid_raw_clustering' of github.com:yger/spikeinterfac…
yger Mar 29, 2024
a850950
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 29, 2024
4bcbe5c
Moving functions in localization_tools
yger Mar 29, 2024
d974fe8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 29, 2024
3eedeb4
Imports
yger Mar 29, 2024
37eb54d
Merge branch 'hybrid_raw_clustering' of github.com:yger/spikeinterfac…
yger Mar 29, 2024
5f95fa0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 29, 2024
23be272
Hybrid recordings also with given templates
yger Mar 29, 2024
c40c9bd
Merge branch 'hybrid_raw_clustering' of github.com:yger/spikeinterfac…
yger Mar 29, 2024
8d23525
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Mar 29, 2024
03fd84f
Extension of relocalization to real templates
yger Mar 30, 2024
c9b9215
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 30, 2024
0d2d7f9
Unit_locations could be given
yger Apr 1, 2024
f1b37f4
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Apr 1, 2024
761cef5
Update src/spikeinterface/generation/tests/test_hybrid_tools.py
yger Apr 2, 2024
04c66c6
Update src/spikeinterface/generation/tests/test_hybrid_tools.py
yger Apr 2, 2024
b39bea9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 2, 2024
394ecec
Update src/spikeinterface/generation/hybrid_tools.py
yger Apr 2, 2024
396a0ca
Update src/spikeinterface/generation/tests/test_hybrid_tools.py
yger Apr 2, 2024
bf9cf53
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 2, 2024
d8e67de
Merge branch 'main' of https://github.com/SpikeInterface/spikeinterfa…
yger Apr 2, 2024
a33049a
WIP
yger Apr 4, 2024
4c81d1e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 4, 2024
ebee96a
Merge branch 'SpikeInterface:main' into hybrid_raw_clustering
yger Apr 4, 2024
8127894
Merge branch 'main' into hybrid_raw_clustering
yger Apr 9, 2024
057ad10
Merge branch 'main' into hybrid_raw_clustering
yger Apr 9, 2024
eaba48f
WIP
yger Apr 9, 2024
01ec8ab
Imports
yger Apr 9, 2024
15a9491
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 9, 2024
b081091
Imports
yger Apr 9, 2024
88103d8
Merge branch 'main' into hybrid_raw_clustering
yger Apr 10, 2024
095aa06
WIP
yger Apr 12, 2024
82b5dde
WIP
yger Apr 15, 2024
6ba6d96
WIP
yger Apr 15, 2024
7bc2c6f
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Apr 17, 2024
ef46881
Merge branch 'hybrid_raw_clustering' of github.com:yger/spikeinterfac…
yger Apr 17, 2024
7cfc227
Merge branch 'SpikeInterface:main' into hybrid_raw_clustering
yger Apr 23, 2024
33a426a
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
alejoe91 May 1, 2024
40fd9ff
Add amplitude scaling std
alejoe91 May 1, 2024
c584a81
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
alejoe91 May 1, 2024
76ab76d
Bunch of fixes and enable the use of external Templates objects
alejoe91 May 1, 2024
1043577
Fix scaling for drifting templates
alejoe91 May 1, 2024
836c2d4
fix tests
alejoe91 May 2, 2024
007acf7
Extend template database functionality
alejoe91 May 6, 2024
74bb13c
Final touches
alejoe91 May 6, 2024
a0179a1
pandas local import
alejoe91 May 6, 2024
e0b1ee6
Fix typing and SC2 matching
alejoe91 May 6, 2024
1213137
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
alejoe91 May 6, 2024
38a887e
Fix tests
alejoe91 May 6, 2024
6224d2c
Add hybrid tools and extend plot_unit_templates to Templates
alejoe91 May 10, 2024
1795c9f
Merge branch 'main' into hybrid_raw_clustering
alejoe91 May 10, 2024
5dccdab
Remove unused import
alejoe91 May 10, 2024
289f337
Update src/spikeinterface/widgets/unit_waveforms.py
alejoe91 May 10, 2024
e25bac3
Fix get_unit_colors tests and rename filter_templates to select_templ…
alejoe91 May 13, 2024
ae32e26
Merge branch 'hybrid_raw_clustering' of github.com:yger/spikeinterfac…
alejoe91 May 14, 2024
b3816bd
Fix test zarr path in tests
alejoe91 May 14, 2024
f034b39
Fix scaling and add temlate manipulation tests
alejoe91 May 14, 2024
e6d5cb3
Add widen/narrow button and scale bar to plot_unitwaveforms/templates
alejoe91 May 14, 2024
2835350
cleanup ipywidgets utils
alejoe91 May 14, 2024
ff8b29a
Ramon's suggestions
alejoe91 May 15, 2024
91ebec8
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
alejoe91 May 15, 2024
57af6ac
Revert changes to geT_unit_colors
alejoe91 May 15, 2024
8317ba4
fix tests
alejoe91 May 15, 2024
290701f
Merge branch 'main' into hybrid_raw_clustering
yger May 24, 2024
4bc4f6a
Merge branch 'main' into hybrid_raw_clustering
yger May 27, 2024
3e29d8d
Merge branch 'main' into hybrid_raw_clustering
yger May 31, 2024
bb1b489
Fixing tests
yger May 31, 2024
ec58a8e
Merge branch 'main' into hybrid_raw_clustering
yger Jun 3, 2024
f5e81a0
Fixing tests
yger Jun 3, 2024
8fe4efb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2024
81d42b8
Fix conflicts
alejoe91 Jun 3, 2024
3047e30
Fixing colons
yger Jun 6, 2024
ffd93c7
Merge branch 'main' into hybrid_raw_clustering
h-mayorquin Jun 6, 2024
070cb87
Fix conflicts
alejoe91 Jun 19, 2024
03fe4d9
from_static -> from_static_templates
alejoe91 Jun 19, 2024
d023420
Remove duplicate select_units/channels in Templates
alejoe91 Jun 19, 2024
93eb5e3
Changes form code review
alejoe91 Jun 19, 2024
c2df514
Fix imports and extend docstring
alejoe91 Jun 19, 2024
534202e
Move compute_* localization functions to localization_tools
alejoe91 Jun 19, 2024
c86cad4
Update save_motion_info and add tests
alejoe91 Jun 19, 2024
1ecda4a
Use save_motion_info in correct_motion
alejoe91 Jun 19, 2024
6f7d082
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
alejoe91 Jun 24, 2024
28321d9
User-specified precomputed templates in DriftingTemplates
cwindolf Jun 3, 2024
643ea6a
dataclasses are weird!
cwindolf Jun 3, 2024
866da1e
Incorporate DriftingTemplates.from_precomputed_templates constructor …
alejoe91 Jun 24, 2024
5c3e73b
Start hibrid benchmark docs and add Neuropixel-384 toy probe
alejoe91 Jun 24, 2024
c45f283
wip example
alejoe91 Jun 26, 2024
bc83485
continue how to an lint
alejoe91 Jun 26, 2024
2dd5f8b
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
alejoe91 Jun 26, 2024
3971512
Make Motion JSON serializable and add hybrit How to
alejoe91 Jun 27, 2024
f61a3c2
Merge branch 'main' into hybrid_raw_clustering
alejoe91 Jun 27, 2024
3e319dc
Merge branch 'main' into hybrid_raw_clustering
alejoe91 Jun 27, 2024
e6da491
Remove custom pickle
alejoe91 Jun 27, 2024
7f8f194
Remove duplicated line
alejoe91 Jun 27, 2024
d9c36a1
Use official prot_drift_map
alejoe91 Jun 27, 2024
8ffdd98
Add s3fs to test dependencies
alejoe91 Jun 27, 2024
406ebcf
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
alejoe91 Jun 27, 2024
5186f74
Sam+Charlie's suggestions: estimate one motion vector per unit (plus …
alejoe91 Jun 28, 2024
8b16006
Update how to and fix displacement_vectors
alejoe91 Jun 28, 2024
b2d3d3b
Update jupytext readme
alejoe91 Jun 28, 2024
efd5ecf
Update how to benchmark with real data
alejoe91 Jun 28, 2024
0015f5f
Remove old displacement_unit_factor
alejoe91 Jun 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# name: python3
# ---

# # Analyse Neuropixels datasets
# # Analyze Neuropixels datasets
#
# This example shows how to perform Neuropixels-specific analysis, including custom pre- and post-processing.

Expand Down
318 changes: 318 additions & 0 deletions examples/how_to/benchmark_with_hybrid_recordings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
# # Benchmark spike sorting with hybrid recordings
#
# This example shows how to use the SpikeInterface hybrid recordings framework to benchmark spike sorting results.
#
# Hybrid recordings are built from existing recordings by injecting units with known spiking activity.
# The template (aka average waveforms) of the injected units can be from previous spike sorted data.
# In this example, we will be using an open database of templates that we have constructed from the International Brain Laboratory - Brain Wide Map (available on [DANDI](https://dandiarchive.org/dandiset/000409?search=IBL&page=2&sortOption=0&sortDir=-1&showDrafts=true&showEmpty=false&pos=9)).
#
# Importantly, recordings from long-shank probes, such as Neuropixels, usually experience drifts. Such drifts have to be taken into account in order to smoothly inject spikes into the recording.

# +
import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import spikeinterface.comparison as sc
import spikeinterface.generation as sgen
import spikeinterface.widgets as sw

from spikeinterface.sortingcomponents.motion_estimation import estimate_motion

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# -

# %matplotlib inline

si.set_global_job_kwargs(n_jobs=4)

# To make this notebook self-contained, we will simulate a drifting recording similar to the one acquired by Nick Steinmetz and available [here](https://doi.org/10.6084/m9.figshare.14024495.v1), where an triangular motion was imposed to the recording by moving the probe up and down with a micro-manipulator.
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved

# +
# sgen.generate_displacement_vector?

# +
generate_displacement_vector_kwargs = {
"motion_list": [
{
"drift_mode": "zigzag",
"non_rigid_gradient": None,
"t_start_drift": 30.0,
"t_end_drift": 210,
"period_s": 60,
}
],
"drift_start_um": [0, 100],
"drift_stop_um": [0, -100],
}

# this generates a "static" and "drifting" recording version
static_recording, drifting_recording, sorting = sgen.generate_drifting_recording(
probe_name="Neuropixel-384", # ,'Neuropixel-384',
seed=23,
duration=240,
num_units=100,
generate_displacement_vector_kwargs=generate_displacement_vector_kwargs,
)

# we sort the channels by depth, to match the hybrid templates
drifting_recording = spre.depth_order(drifting_recording)
# -

# To visualize the drift, we can estimate the motion and plot it:

_, motion_info = spre.correct_motion(
drifting_recording, preset="nonrigid_fast_and_accurate", n_jobs=4, progress_bar=True, output_motion_info=True
)

sw.plot_motion_info(motion_info, recording=drifting_recording)


def plot_drift_map(
peaks=None,
peak_locations=None,
recording=None,
analyzer=None,
direction="y",
sampling_frequency=None,
segment_index=0,
depth_lim=None,
color_amplitude=True,
scatter_decimate=None,
color="k",
cmap="inferno",
clim=None,
alpha=1,
ax=None,
):
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize

if ax is None:
fig, ax = plt.subplots()

assert peaks is not None or analyzer is not None
if peaks is not None:
assert peak_locations is not None
if recording is None:
assert sampling_frequency is not None
else:
sampling_frequency = recording.sampling_frequency
peak_amplitudes = peaks["amplitude"]
if analyzer is not None:
if analyzer.has_recording():
recording = analyzer.recording
else:
recording = None
sampling_frequency = analyzer.sampling_frequency
peaks = analyzer.sorting.to_spike_vector()
assert analyzer.has_extension("spike_locations")
peak_locations = analyzer.get_extension("spike_locations").get_data()
if analyzer.has_extension("spike_amplitudes"):
peak_amplitudes = analyzer.get_extension("spike_amplitudes").get_data()
else:
peak_amplitudes = None
times = recording.get_times(segment_index=segment_index) if recording is not None else None

if times is None:
x = peaks["sample_index"] / sampling_frequency
else:
# use real times and adjust temporal bins with t_start
x = times[peaks["sample_index"]]

y = peak_locations[direction]
if scatter_decimate is not None:
x = x[::scatter_decimate]
y = y[::scatter_decimate]
y2 = y2[::scatter_decimate]

if color_amplitude:
assert peak_amplitudes is not None, "To color by amplitudes the 'spike_amplitude' extension is needed"
amps = peak_amplitudes
amps_abs = np.abs(amps)
q_95 = np.quantile(amps_abs, 0.95)
if scatter_decimate is not None:
amps = amps[::scatter_decimate]
amps_abs = amps_abs[::scatter_decimate]
cmap = plt.colormaps[cmap]
if clim is None:
amps = amps_abs
amps /= q_95
c = cmap(amps)
else:
norm_function = Normalize(vmin=clim[0], vmax=clim[1], clip=True)
c = cmap(norm_function(amps))
color_kwargs = dict(
color=None,
c=c,
alpha=alpha,
)
else:
color_kwargs = dict(color=color, c=None, alpha=alpha)

ax.scatter(x, y, s=1, **color_kwargs)
if depth_lim is not None:
ax.set_ylim(*depth_lim)
ax.set_title("Peak depth")
ax.set_xlabel("Times [s]")
ax.set_ylabel("Depth [$\\mu$m]")
return ax


ax = plot_drift_map(
peaks=motion_info["peaks"],
peak_locations=motion_info["peak_locations"],
recording=drifting_recording,
cmap="Greys_r",
)

# +
#
# sw.plot_drift_map(...)
#
# -

# ## Retrieve templates from database

templates_info = sgen.fetch_templates_database_info()

print(len(templates_info))

templates_info.head()

available_brain_areas = np.unique(templates_info.brain_area)
print(f"Available brain areas: {available_brain_areas}")

# let's perform a query: templates from brain region VISp5 and at the "top" of the probe
target_area = ["VISa5", "VISa6a", "VISp5", "VISp6a", "VISrl6b"]
minimum_depth = 1000
templates_selected_info = templates_info.query(f"brain_area in {target_area} and depth_along_probe > {minimum_depth}")
display(templates_selected_info)

# We can now retrieve the selected templates as a `Templates` object
#

templates_selected = sgen.query_templates_from_database(templates_selected_info, verbose=True)
print(templates_selected)

# While we selected templates from a target aread and at certain depths, we can see that the template amplitudes are quite large. This will make spike sorting easy... we can further manipulate the `Templates` by rescaling, relocating, or further selections with the `sgen.scale_template_to_range`, `sgen.relocate_templates`, and `sgen.select_templates` functions.
#
# In our case, let's rescale the amplitudes between 30 and 50 $\mu$V.

templates_scaled = sgen.scale_template_to_range(templates=templates_selected, min_amplitude=30, max_amplitude=50)

# +
# sgen.relocate_templates?
# -

templates_relocated = sgen.relocate_templates(templates=templates_scaled, min_displacement=200, max_displacement=2000)

# Let's plot the selected templates:

sparsity_plot = si.compute_sparsity(templates_relocated)
w = sw.plot_unit_templates(templates_relocated, sparsity=sparsity_plot, ncols=4)
w.figure.subplots_adjust(wspace=0.5, hspace=0.7)

# ## Constructing hybrid recordings

recording_hybrid_no_drift, sorting_hybrid = sgen.generate_hybrid_recording(
recording=drifting_recording, templates=templates_relocated, seed=2308
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
)

recording_hybrid_no_drift

recording_hybrid, sorting_hybrid = sgen.generate_hybrid_recording(
recording=drifting_recording,
templates=templates_relocated,
motion=motion_info["motion"],
sorting=sorting_hybrid,
seed=2308,
)

recording_hybrid

# show spike locations on top of original rastermap
# construct hybrid analyzer for spike locations
analyzer_hybrid = si.create_sorting_analyzer(sorting_hybrid, recording_hybrid)
analyzer_hybrid.compute(["random_spikes", "templates"])
analyzer_hybrid.compute("spike_locations", method="grid_convolution")

# construct hybrid analyzer for spike locations
analyzer_hybrid_no_drift = si.create_sorting_analyzer(sorting_hybrid, recording_hybrid_no_drift)
analyzer_hybrid_no_drift.compute(["random_spikes", "templates"])
analyzer_hybrid_no_drift.compute("spike_locations", method="grid_convolution")

fig, axs = plt.subplots(ncols=2, figsize=(10, 7))
_ = plot_drift_map(
peaks=motion_info["peaks"],
peak_locations=motion_info["peak_locations"],
recording=drifting_recording,
cmap="Greys_r",
ax=axs[0],
)
_ = plot_drift_map(analyzer=analyzer_hybrid_no_drift, color_amplitude=False, color="r", ax=axs[0])
_ = plot_drift_map(
peaks=motion_info["peaks"],
peak_locations=motion_info["peak_locations"],
recording=drifting_recording,
cmap="Greys_r",
ax=axs[1],
)
_ = plot_drift_map(analyzer=analyzer_hybrid, color_amplitude=False, color="b", ax=axs[1])
axs[0].set_title("Ignoring drift")
axs[1].set_title("Accounting for drift")

# ## Ground-truth study
#
# In this section ...

# +
datasets = {
"hybrid": (recording_hybrid, sorting_hybrid),
}

cases = {
("kilosort2.5", "hybrid"): {
"label": "KS2.5",
"dataset": "hybrid",
"run_sorter_params": {
"sorter_name": "kilosort2_5",
},
},
("kilosort3", "hybrid"): {
"label": "KS3",
"dataset": "hybrid",
"run_sorter_params": {
"sorter_name": "kilosort3",
},
},
("kilosort4", "hybrid"): {
"label": "KS4",
"dataset": "hybrid",
"run_sorter_params": {"sorter_name": "kilosort4", "nblocks": 5},
},
("sc2", "hybrid"): {
"label": "spykingcircus2",
"dataset": "hybrid",
"run_sorter_params": {
"sorter_name": "spykingcircus2",
},
},
}
# -

study_folder = workdir / "gt_study"
if (workdir / "gt_study").is_dir():
gtstudy = sc.GroundTruthStudy(study_folder)
else:
gtstudy = sc.GroundTruthStudy.create(study_folder=study_folder, datasets=datasets, cases=cases)

gtstudy.run_sorters(verbose=True, keep=True)

gtstudy.run_comparisons(exhaustive_gt=False)

w_run_times = sw.plot_study_run_times(gtstudy)
w_perf = sw.plot_study_performances(gtstudy, figsize=(12, 7))
w_perf.axes[0, 0].legend(loc=4)
3 changes: 1 addition & 2 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import warnings
import numpy as np
from typing import Union, Optional, List, Literal
import warnings
from math import ceil

from .basesorting import SpikeVectorSortingSegment
Expand Down Expand Up @@ -1868,7 +1867,7 @@ def get_traces(
wf = template[start_template:end_template]
if self.amplitude_vector is not None:
wf = wf * self.amplitude_vector[i]
traces[start_traces:end_traces] += wf
traces[start_traces:end_traces] += wf.astype(traces.dtype, copy=False)
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved

return traces.astype(self.dtype, copy=False)

Expand Down
1 change: 0 additions & 1 deletion src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,6 @@ def _init_peak_pipeline(recording, nodes):
worker_ctx["recording"] = recording
worker_ctx["nodes"] = nodes
worker_ctx["max_margin"] = max(node.get_trace_margin() for node in nodes)

return worker_ctx


Expand Down
5 changes: 4 additions & 1 deletion src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,10 +967,13 @@ def compute_one_extension(self, extension_name, save=True, verbose=False, **kwar
>>> wfs = compute_waveforms(sorting_analyzer, **some_params)

"""
print(extension_name)
extension_class = get_extension_class(extension_name)

for child in _get_children_dependencies(extension_name):
self.delete_extension(child)
if self.has_extension(child):
print(f"Deleting {child}")
self.delete_extension(child)

if extension_class.need_job_kwargs:
params, job_kwargs = split_job_kwargs(kwargs)
Expand Down
10 changes: 5 additions & 5 deletions src/spikeinterface/core/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,9 @@ def from_zarr_group(cls, zarr_group: "zarr.Group") -> "Templates":
the `add_templates_to_zarr_group` method.

"""
templates_array = zarr_group["templates_array"]
channel_ids = zarr_group["channel_ids"]
unit_ids = zarr_group["unit_ids"]
templates_array = zarr_group["templates_array"][:]
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
channel_ids = zarr_group["channel_ids"][:]
unit_ids = zarr_group["unit_ids"][:]
sampling_frequency = zarr_group.attrs["sampling_frequency"]
nbefore = zarr_group.attrs["nbefore"]

Expand All @@ -364,7 +364,7 @@ def from_zarr_group(cls, zarr_group: "zarr.Group") -> "Templates":

sparsity_mask = None
if "sparsity_mask" in zarr_group:
sparsity_mask = zarr_group["sparsity_mask"]
sparsity_mask = zarr_group["sparsity_mask"][:]

probe = None
if "probe" in zarr_group:
Expand Down Expand Up @@ -449,7 +449,7 @@ def __eq__(self, other):

return True

def get_channel_locations(self):
def get_channel_locations(self) -> np.ndarray:
assert self.probe is not None, "Templates.get_channel_locations() needs a probe to be set"
channel_locations = self.probe.contact_positions
return channel_locations
Loading
Loading