Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tools for Generation of Hybrid recordings #2436

Merged
merged 147 commits into from
Jun 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
147 commits
Select commit Hold shift + click to select a range
26ac06f
WIP
yger Jan 22, 2024
3f2bcfa
Merge branch 'SpikeInterface:main' into hybrid_raw_clustering
yger Jan 22, 2024
4216503
WIP
yger Jan 23, 2024
6275aa5
Merge branch 'hybrid_raw_clustering' of github.com:yger/spikeinterfac…
yger Jan 23, 2024
af485ea
Docstrings and cosmetics
yger Jan 23, 2024
ddb1d51
WIP
yger Jan 23, 2024
08ffeef
WIP
yger Jan 23, 2024
0db1597
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 24, 2024
a7b6de2
Merge branch 'estimate_sparsity' of github.com:samuelgarcia/spikeinte…
yger Jan 24, 2024
b89f6a3
WIP
yger Jan 25, 2024
80e6b8d
WIP
yger Jan 25, 2024
84b60ac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 25, 2024
f5f3435
WIP
yger Jan 25, 2024
1c8bb72
Polish
yger Jan 25, 2024
a5b12b0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 25, 2024
1c7cf33
Merge branch 'SpikeInterface:main' into hybrid_raw_clustering
yger Jan 25, 2024
fe6247c
Merge branch 'main' into hybrid_raw_clustering
yger Jan 30, 2024
ae285a4
WIP
yger Jan 30, 2024
495b863
Merge branch 'SpikeInterface:main' into hybrid_raw_clustering
yger Feb 1, 2024
05621bb
Merge branch 'SpikeInterface:main' into hybrid_raw_clustering
yger Feb 2, 2024
7a96d85
WIP
yger Feb 6, 2024
5f0235f
Merge branch 'main' of github.com:yger/spikeinterface into hybrid_raw…
yger Feb 6, 2024
e6595ee
Merge branch 'SpikeInterface:main' into hybrid_raw_clustering
yger Feb 6, 2024
e0251b0
Merge branch 'hybrid_raw_clustering' of github.com:yger/spikeinterfac…
yger Feb 6, 2024
69f0674
Merge branch 'SpikeInterface:main' into hybrid_raw_clustering
yger Feb 16, 2024
e0037b3
WIP
yger Mar 12, 2024
2b8656f
WIP
yger Mar 12, 2024
b2365d4
WIP
yger Mar 12, 2024
6858a2f
Merge branch 'SpikeInterface:main' into hybrid_raw_clustering
yger Mar 27, 2024
1d5b641
WIP
yger Mar 28, 2024
a139762
WIP
yger Mar 28, 2024
6a1673d
WIP
yger Mar 28, 2024
7742882
WIP
yger Mar 28, 2024
36ec32d
WIP
yger Mar 28, 2024
3c59afb
Merge branch 'main' into hybrid_raw_clustering
yger Mar 28, 2024
f6c4c4a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2024
b2e87a0
Docs
yger Mar 28, 2024
5046e53
Merge branch 'hybrid_raw_clustering' of github.com:yger/spikeinterfac…
yger Mar 28, 2024
275f87a
Adding tests
yger Mar 28, 2024
4e97eeb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2024
45653f6
One more test
yger Mar 28, 2024
e05f1d0
WIP
yger Mar 28, 2024
d1ee525
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2024
c12c746
Adding imports
yger Mar 29, 2024
c62f952
Merge branch 'main' into hybrid_raw_clustering
yger Mar 29, 2024
03a41d2
Merge branch 'hybrid_raw_clustering' of github.com:yger/spikeinterfac…
yger Mar 29, 2024
a850950
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 29, 2024
4bcbe5c
Moving functions in localization_tools
yger Mar 29, 2024
d974fe8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 29, 2024
3eedeb4
Imports
yger Mar 29, 2024
37eb54d
Merge branch 'hybrid_raw_clustering' of github.com:yger/spikeinterfac…
yger Mar 29, 2024
5f95fa0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 29, 2024
23be272
Hybrid recordings also with given templates
yger Mar 29, 2024
c40c9bd
Merge branch 'hybrid_raw_clustering' of github.com:yger/spikeinterfac…
yger Mar 29, 2024
8d23525
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Mar 29, 2024
03fd84f
Extension of relocalization to real templates
yger Mar 30, 2024
c9b9215
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 30, 2024
0d2d7f9
Unit_locations could be given
yger Apr 1, 2024
f1b37f4
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Apr 1, 2024
761cef5
Update src/spikeinterface/generation/tests/test_hybrid_tools.py
yger Apr 2, 2024
04c66c6
Update src/spikeinterface/generation/tests/test_hybrid_tools.py
yger Apr 2, 2024
b39bea9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 2, 2024
394ecec
Update src/spikeinterface/generation/hybrid_tools.py
yger Apr 2, 2024
396a0ca
Update src/spikeinterface/generation/tests/test_hybrid_tools.py
yger Apr 2, 2024
bf9cf53
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 2, 2024
d8e67de
Merge branch 'main' of https://github.com/SpikeInterface/spikeinterfa…
yger Apr 2, 2024
a33049a
WIP
yger Apr 4, 2024
4c81d1e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 4, 2024
ebee96a
Merge branch 'SpikeInterface:main' into hybrid_raw_clustering
yger Apr 4, 2024
8127894
Merge branch 'main' into hybrid_raw_clustering
yger Apr 9, 2024
057ad10
Merge branch 'main' into hybrid_raw_clustering
yger Apr 9, 2024
eaba48f
WIP
yger Apr 9, 2024
01ec8ab
Imports
yger Apr 9, 2024
15a9491
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 9, 2024
b081091
Imports
yger Apr 9, 2024
88103d8
Merge branch 'main' into hybrid_raw_clustering
yger Apr 10, 2024
095aa06
WIP
yger Apr 12, 2024
82b5dde
WIP
yger Apr 15, 2024
6ba6d96
WIP
yger Apr 15, 2024
7bc2c6f
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Apr 17, 2024
ef46881
Merge branch 'hybrid_raw_clustering' of github.com:yger/spikeinterfac…
yger Apr 17, 2024
7cfc227
Merge branch 'SpikeInterface:main' into hybrid_raw_clustering
yger Apr 23, 2024
33a426a
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
alejoe91 May 1, 2024
40fd9ff
Add amplitude scaling std
alejoe91 May 1, 2024
c584a81
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
alejoe91 May 1, 2024
76ab76d
Bunch of fixes and enable the use of external Templates objects
alejoe91 May 1, 2024
1043577
Fix scaling for drifting templates
alejoe91 May 1, 2024
836c2d4
fix tests
alejoe91 May 2, 2024
007acf7
Extend template database functionality
alejoe91 May 6, 2024
74bb13c
Final touches
alejoe91 May 6, 2024
a0179a1
pandas local import
alejoe91 May 6, 2024
e0b1ee6
Fix typing and SC2 matching
alejoe91 May 6, 2024
1213137
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
alejoe91 May 6, 2024
38a887e
Fix tests
alejoe91 May 6, 2024
6224d2c
Add hybrid tools and extend plot_unit_templates to Templates
alejoe91 May 10, 2024
1795c9f
Merge branch 'main' into hybrid_raw_clustering
alejoe91 May 10, 2024
5dccdab
Remove unused import
alejoe91 May 10, 2024
289f337
Update src/spikeinterface/widgets/unit_waveforms.py
alejoe91 May 10, 2024
e25bac3
Fix get_unit_colors tests and rename filter_templates to select_templ…
alejoe91 May 13, 2024
ae32e26
Merge branch 'hybrid_raw_clustering' of github.com:yger/spikeinterfac…
alejoe91 May 14, 2024
b3816bd
Fix test zarr path in tests
alejoe91 May 14, 2024
f034b39
Fix scaling and add temlate manipulation tests
alejoe91 May 14, 2024
e6d5cb3
Add widen/narrow button and scale bar to plot_unitwaveforms/templates
alejoe91 May 14, 2024
2835350
cleanup ipywidgets utils
alejoe91 May 14, 2024
ff8b29a
Ramon's suggestions
alejoe91 May 15, 2024
91ebec8
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
alejoe91 May 15, 2024
57af6ac
Revert changes to geT_unit_colors
alejoe91 May 15, 2024
8317ba4
fix tests
alejoe91 May 15, 2024
290701f
Merge branch 'main' into hybrid_raw_clustering
yger May 24, 2024
4bc4f6a
Merge branch 'main' into hybrid_raw_clustering
yger May 27, 2024
3e29d8d
Merge branch 'main' into hybrid_raw_clustering
yger May 31, 2024
bb1b489
Fixing tests
yger May 31, 2024
ec58a8e
Merge branch 'main' into hybrid_raw_clustering
yger Jun 3, 2024
f5e81a0
Fixing tests
yger Jun 3, 2024
8fe4efb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2024
81d42b8
Fix conflicts
alejoe91 Jun 3, 2024
3047e30
Fixing colons
yger Jun 6, 2024
ffd93c7
Merge branch 'main' into hybrid_raw_clustering
h-mayorquin Jun 6, 2024
070cb87
Fix conflicts
alejoe91 Jun 19, 2024
03fe4d9
from_static -> from_static_templates
alejoe91 Jun 19, 2024
d023420
Remove duplicate select_units/channels in Templates
alejoe91 Jun 19, 2024
93eb5e3
Changes form code review
alejoe91 Jun 19, 2024
c2df514
Fix imports and extend docstring
alejoe91 Jun 19, 2024
534202e
Move compute_* localization functions to localization_tools
alejoe91 Jun 19, 2024
c86cad4
Update save_motion_info and add tests
alejoe91 Jun 19, 2024
1ecda4a
Use save_motion_info in correct_motion
alejoe91 Jun 19, 2024
6f7d082
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
alejoe91 Jun 24, 2024
28321d9
User-specified precomputed templates in DriftingTemplates
cwindolf Jun 3, 2024
643ea6a
dataclasses are weird!
cwindolf Jun 3, 2024
866da1e
Incorporate DriftingTemplates.from_precomputed_templates constructor …
alejoe91 Jun 24, 2024
5c3e73b
Start hibrid benchmark docs and add Neuropixel-384 toy probe
alejoe91 Jun 24, 2024
c45f283
wip example
alejoe91 Jun 26, 2024
bc83485
continue how to an lint
alejoe91 Jun 26, 2024
2dd5f8b
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
alejoe91 Jun 26, 2024
3971512
Make Motion JSON serializable and add hybrit How to
alejoe91 Jun 27, 2024
f61a3c2
Merge branch 'main' into hybrid_raw_clustering
alejoe91 Jun 27, 2024
3e319dc
Merge branch 'main' into hybrid_raw_clustering
alejoe91 Jun 27, 2024
e6da491
Remove custom pickle
alejoe91 Jun 27, 2024
7f8f194
Remove duplicated line
alejoe91 Jun 27, 2024
d9c36a1
Use official prot_drift_map
alejoe91 Jun 27, 2024
8ffdd98
Add s3fs to test dependencies
alejoe91 Jun 27, 2024
406ebcf
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
alejoe91 Jun 27, 2024
5186f74
Sam+Charlie's suggestions: estimate one motion vector per unit (plus …
alejoe91 Jun 28, 2024
8b16006
Update how to and fix displacement_vectors
alejoe91 Jun 28, 2024
b2d3d3b
Update jupytext readme
alejoe91 Jun 28, 2024
efd5ecf
Update how to benchmark with real data
alejoe91 Jun 28, 2024
0015f5f
Remove old displacement_unit_factor
alejoe91 Jun 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,552 changes: 2,552 additions & 0 deletions doc/how_to/benchmark_with_hybrid_recordings.rst

Large diffs are not rendered by default.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions doc/how_to/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ Guides on how to solve specific, short problems in SpikeInterface. Learn how to.
combine_recordings
process_by_channel_group
load_your_data_into_sorting
benchmark_with_hybrid_recordings
12 changes: 8 additions & 4 deletions examples/how_to/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,21 @@ with `nbconvert`. Here are the steps (in this example for the `get_started`):

```
>>> jupytext --to notebook get_started.py
>>> jupytext --set-formats ipynb,py get_started.ipynb
```

2. Run the notebook

3. Sync the run notebook to the .py file:

3. Convert the notebook to .rst
```
>>> jupytext --sync get_started.ipynb
```

4. Convert the notebook to .rst

```
>>> jupyter nbconvert get_started.ipynb --to rst
>>> jupyter nbconvert analyse_neuropixels.ipynb --to rst
```


4. Move the .rst and associated folder (e.g. `get_started.rst` and `get_started_files` folder) to the `doc/how_to`.
5. Move the .rst and associated folder (e.g. `get_started.rst` and `get_started_files` folder) to the `doc/how_to`.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# name: python3
# ---

# # Analyse Neuropixels datasets
# # Analyze Neuropixels datasets
#
# This example shows how to perform Neuropixels-specific analysis, including custom pre- and post-processing.

Expand Down
293 changes: 293 additions & 0 deletions examples/how_to/benchmark_with_hybrid_recordings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
# ---
# jupyter:
# jupytext:
# cell_metadata_filter: -all
# formats: ipynb,py
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.16.2
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---

# # Benchmark spike sorting with hybrid recordings
#
# This example shows how to use the SpikeInterface hybrid recordings framework to benchmark spike sorting results.
#
# Hybrid recordings are built from existing recordings by injecting units with known spiking activity.
# The template (aka average waveforms) of the injected units can be from previous spike sorted data.
# In this example, we will be using an open database of templates that we have constructed from the International Brain Laboratory - Brain Wide Map (available on [DANDI](https://dandiarchive.org/dandiset/000409?search=IBL&page=2&sortOption=0&sortDir=-1&showDrafts=true&showEmpty=false&pos=9)).
#
# Importantly, recordings from long-shank probes, such as Neuropixels, usually experience drifts. Such drifts have to be taken into account in order to smoothly inject spikes into the recording.

# +
import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import spikeinterface.comparison as sc
import spikeinterface.generation as sgen
import spikeinterface.widgets as sw

from spikeinterface.sortingcomponents.motion_estimation import estimate_motion

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
# -

# %matplotlib inline

si.set_global_job_kwargs(n_jobs=16)

# For this notebook, we will use a drifting recording similar to the one acquired by Nick Steinmetz and available [here](https://doi.org/10.6084/m9.figshare.14024495.v1), where an triangular motion was imposed to the recording by moving the probe up and down with a micro-manipulator.

workdir = Path("/ssd980/working/hybrid/steinmetz_imposed_motion")
workdir.mkdir(exist_ok=True)

recording_np1_imposed = se.read_spikeglx("/hdd1/data/spikeglx/nick-steinmetz/dataset1/p1_g0_t0/")
recording_preproc = spre.highpass_filter(recording_np1_imposed)
recording_preproc = spre.common_reference(recording_preproc)

# To visualize the drift, we can estimate the motion and plot it:

# to correct for drift, we need a float dtype
recording_preproc = spre.astype(recording_preproc, "float")
_, motion_info = spre.correct_motion(
recording_preproc, preset="nonrigid_fast_and_accurate", n_jobs=4, progress_bar=True, output_motion_info=True
)

ax = sw.plot_drift_raster_map(
peaks=motion_info["peaks"],
peak_locations=motion_info["peak_locations"],
recording=recording_preproc,
cmap="Greys_r",
scatter_decimate=10,
depth_lim=(-10, 3000)
)

# ## Retrieve templates from database

# +
templates_info = sgen.fetch_templates_database_info()

print(f"Number of templates in database: {len(templates_info)}")
print(f"Template database columns: {templates_info.columns}")
# -

available_brain_areas = np.unique(templates_info.brain_area)
print(f"Available brain areas: {available_brain_areas}")

# Let's perform a query: templates from visual brain regions and at the "top" of the probe

target_area = ["VISa5", "VISa6a", "VISp5", "VISp6a", "VISrl6b"]
minimum_depth = 1500
templates_selected_info = templates_info.query(f"brain_area in {target_area} and depth_along_probe > {minimum_depth}")
len(templates_selected_info)

# We can now retrieve the selected templates as a `Templates` object:

templates_selected = sgen.query_templates_from_database(templates_selected_info, verbose=True)
print(templates_selected)

# While we selected templates from a target aread and at certain depths, we can see that the template amplitudes are quite large. This will make spike sorting easy... we can further manipulate the `Templates` by rescaling, relocating, or further selections with the `sgen.scale_template_to_range`, `sgen.relocate_templates`, and `sgen.select_templates` functions.
#
# In our case, let's rescale the amplitudes between 50 and 150 $\mu$V and relocate them towards the bottom half of the probe, where the activity looks interesting!

# +
min_amplitude = 50
max_amplitude = 150
templates_scaled = sgen.scale_template_to_range(
templates=templates_selected,
min_amplitude=min_amplitude,
max_amplitude=max_amplitude
)

min_displacement = 1000
max_displacement = 3000
templates_relocated = sgen.relocate_templates(
templates=templates_scaled,
min_displacement=min_displacement,
max_displacement=max_displacement
)
# -

# Let's plot the selected templates:

sparsity_plot = si.compute_sparsity(templates_relocated)
fig = plt.figure(figsize=(10, 10))
w = sw.plot_unit_templates(templates_relocated, sparsity=sparsity_plot, ncols=4, figure=fig)
w.figure.subplots_adjust(wspace=0.5, hspace=0.7)

# ## Constructing hybrid recordings
#
# We can construct now hybrid recordings with the selected templates.
#
# We will do this in two ways to show how important it is to account for drifts when injecting hybrid spikes.
#
# - For the first recording we will not pass the estimated motion (`recording_hybrid_ignore_drift`).
# - For the second recording, we will pass and account for the estimated motion (`recording_hybrid_with_drift`).

recording_hybrid_ignore_drift, sorting_hybrid = sgen.generate_hybrid_recording(
recording=recording_preproc, templates=templates_relocated, seed=2308
)
recording_hybrid_ignore_drift

# Note that the `generate_hybrid_recording` is warning us that we might want to account for drift!

# by passing the `sorting_hybrid` object, we make sure that injected spikes are the same
# this will take a bit more time because it's interpolating the templates to account for drifts
recording_hybrid_with_drift, sorting_hybrid = sgen.generate_hybrid_recording(
recording=recording_preproc,
templates=templates_relocated,
motion=motion_info["motion"],
sorting=sorting_hybrid,
seed=2308,
)
recording_hybrid_with_drift

# We can use the `SortingAnalyzer` to estimate spike locations and plot them:

# +
# construct analyzers and compute spike locations
analyzer_hybrid_ignore_drift = si.create_sorting_analyzer(sorting_hybrid, recording_hybrid_ignore_drift)
analyzer_hybrid_ignore_drift.compute(["random_spikes", "templates"])
analyzer_hybrid_ignore_drift.compute("spike_locations", method="grid_convolution")

analyzer_hybrid_with_drift = si.create_sorting_analyzer(sorting_hybrid, recording_hybrid_with_drift)
analyzer_hybrid_with_drift.compute(["random_spikes", "templates"])
analyzer_hybrid_with_drift.compute("spike_locations", method="grid_convolution")
# -

# Let's plot the added hybrid spikes using the drift maps:

fig, axs = plt.subplots(ncols=2, figsize=(10, 7), sharex=True, sharey=True)
_ = sw.plot_drift_raster_map(
peaks=motion_info["peaks"],
peak_locations=motion_info["peak_locations"],
recording=recording_preproc,
cmap="Greys_r",
scatter_decimate=10,
ax=axs[0],
)
_ = sw.plot_drift_raster_map(
sorting_analyzer=analyzer_hybrid_ignore_drift,
color_amplitude=False,
color="r",
scatter_decimate=10,
ax=axs[0]
)
_ = sw.plot_drift_raster_map(
peaks=motion_info["peaks"],
peak_locations=motion_info["peak_locations"],
recording=recording_preproc,
cmap="Greys_r",
scatter_decimate=10,
ax=axs[1],
)
_ = sw.plot_drift_raster_map(
sorting_analyzer=analyzer_hybrid_with_drift,
color_amplitude=False,
color="b",
scatter_decimate=10,
ax=axs[1]
)
axs[0].set_title("Hybrid spikes\nIgnoring drift")
axs[1].set_title("Hybrid spikes\nAccounting for drift")
axs[0].set_xlim(1000, 1500)
axs[0].set_ylim(500, 2500)

# We can see that clearly following drift is essential in order to properly blend the hybrid spikes into the recording!

# ## Ground-truth study
#
# In this section we will use the hybrid recording to benchmark a few spike sorters:
#
# - `Kilosort2.5`
# - `Kilosort3`
# - `Kilosort4`
# - `Spyking-CIRCUS 2`

# to speed up computations, let's first dump the recording to binary
recording_hybrid_bin = recording_hybrid_with_drift.save(
folder=workdir / "hybrid_bin",
overwrite=True
)

# +
datasets = {
"hybrid": (recording_hybrid_bin, sorting_hybrid),
}

cases = {
("kilosort2.5", "hybrid"): {
"label": "KS2.5",
"dataset": "hybrid",
"run_sorter_params": {
"sorter_name": "kilosort2_5",
},
},
("kilosort3", "hybrid"): {
"label": "KS3",
"dataset": "hybrid",
"run_sorter_params": {
"sorter_name": "kilosort3",
},
},
("kilosort4", "hybrid"): {
"label": "KS4",
"dataset": "hybrid",
"run_sorter_params": {"sorter_name": "kilosort4", "nblocks": 5},
},
("sc2", "hybrid"): {
"label": "spykingcircus2",
"dataset": "hybrid",
"run_sorter_params": {
"sorter_name": "spykingcircus2",
},
},
}

# +
study_folder = workdir / "gt_study"

gtstudy = sc.GroundTruthStudy(study_folder)

# -

# run the spike sorting jobs
gtstudy.run_sorters(verbose=False, keep=True)

# run the comparisons
gtstudy.run_comparisons(exhaustive_gt=False)

# ## Plot performances
#
# Given that we know the exactly where we injected the hybrid spikes, we can now compute and plot performance metrics: accuracy, precision, and recall.
#
# In the following plot, the x axis is the unit index, while the y axis is the performance metric. The units are sorted by performance.

w_perf = sw.plot_study_performances(gtstudy, figsize=(12, 7))
w_perf.axes[0, 0].legend(loc=4)

# From the performance plots, we can see that there is no clear "winner", but `Kilosort3` definitely performs worse than the other options.
#
# Although non of the sorters find all units perfectly, `Kilosort2.5`, `Kilosort4`, and `SpyKING CIRCUS 2` all find around 10-12 hybrid units with accuracy greater than 80%.
# `Kilosort4` has a better overall curve, being able to find almost all units with an accuracy above 50%. `Kilosort2.5` performs well when looking at precision (finding all spikes in a hybrid unit), at the cost of lower recall (finding spikes when it shouldn't).
#
#
# In this example, we showed how to:
#
# - Access and fetch templates from the SpikeInterface template database
# - Manipulate templates (scaling/relocating)
# - Construct hybrid recordings accounting for drifts
# - Use the `GroundTruthStudy` to benchmark different sorters
#
# The hybrid framework can be extended to target multiple recordings from different brain regions and species and creating recordings of increasing complexity to challenge the existing sorters!
#
# In addition, hybrid studies can also be used to fine-tune spike sorting parameters on specific datasets.
#
# **Are you ready to try it on your data?**
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ test = [
# preprocessing
"ibllib>=2.36.0", # for IBL

# streaming templates
"s3fs",

# tridesclous
"numba",
"hdbscan>=0.8.33", # Previous version had a broken wheel
Expand Down
4 changes: 4 additions & 0 deletions src/spikeinterface/core/core_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class SIJsonEncoder(json.JSONEncoder):

def default(self, obj):
from spikeinterface.core.base import BaseExtractor
from spikeinterface.sortingcomponents.motion_utils import Motion

# Over-write behaviors for datetime object
if isinstance(obj, datetime.datetime):
Expand All @@ -97,6 +98,9 @@ def default(self, obj):
if isinstance(obj, BaseExtractor):
return obj.to_dict()

if isinstance(obj, Motion):
return obj.to_dict()

# The base-class handles the assertion
return super().default(obj)

Expand Down
3 changes: 1 addition & 2 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import warnings
import numpy as np
from typing import Union, Optional, List, Literal
import warnings
from math import ceil

from .basesorting import SpikeVectorSortingSegment
Expand Down Expand Up @@ -1858,7 +1857,7 @@ def get_traces(
wf = template[start_template:end_template]
if self.amplitude_vector is not None:
wf = wf * self.amplitude_vector[i]
traces[start_traces:end_traces] += wf
traces[start_traces:end_traces] += wf.astype(traces.dtype, copy=False)
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved

return traces.astype(self.dtype, copy=False)

Expand Down
1 change: 0 additions & 1 deletion src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,6 @@ def _init_peak_pipeline(recording, nodes):
worker_ctx["recording"] = recording
worker_ctx["nodes"] = nodes
worker_ctx["max_margin"] = max(node.get_trace_margin() for node in nodes)

return worker_ctx


Expand Down
4 changes: 3 additions & 1 deletion src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,9 @@ def compute_one_extension(self, extension_name, save=True, verbose=False, **kwar
extension_class = get_extension_class(extension_name)

for child in _get_children_dependencies(extension_name):
self.delete_extension(child)
if self.has_extension(child):
print(f"Deleting {child}")
self.delete_extension(child)

if extension_class.need_job_kwargs:
params, job_kwargs = split_job_kwargs(kwargs)
Expand Down
Loading
Loading