Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tools for Generation of Hybrid recordings #2436

Merged
merged 147 commits into from
Jun 29, 2024
Merged
Show file tree
Hide file tree
Changes from 116 commits
Commits
Show all changes
147 commits
Select commit Hold shift + click to select a range
26ac06f
WIP
yger Jan 22, 2024
3f2bcfa
Merge branch 'SpikeInterface:main' into hybrid_raw_clustering
yger Jan 22, 2024
4216503
WIP
yger Jan 23, 2024
6275aa5
Merge branch 'hybrid_raw_clustering' of github.com:yger/spikeinterfac…
yger Jan 23, 2024
af485ea
Docstrings and cosmetics
yger Jan 23, 2024
ddb1d51
WIP
yger Jan 23, 2024
08ffeef
WIP
yger Jan 23, 2024
0db1597
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 24, 2024
a7b6de2
Merge branch 'estimate_sparsity' of github.com:samuelgarcia/spikeinte…
yger Jan 24, 2024
b89f6a3
WIP
yger Jan 25, 2024
80e6b8d
WIP
yger Jan 25, 2024
84b60ac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 25, 2024
f5f3435
WIP
yger Jan 25, 2024
1c8bb72
Polish
yger Jan 25, 2024
a5b12b0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 25, 2024
1c7cf33
Merge branch 'SpikeInterface:main' into hybrid_raw_clustering
yger Jan 25, 2024
fe6247c
Merge branch 'main' into hybrid_raw_clustering
yger Jan 30, 2024
ae285a4
WIP
yger Jan 30, 2024
495b863
Merge branch 'SpikeInterface:main' into hybrid_raw_clustering
yger Feb 1, 2024
05621bb
Merge branch 'SpikeInterface:main' into hybrid_raw_clustering
yger Feb 2, 2024
7a96d85
WIP
yger Feb 6, 2024
5f0235f
Merge branch 'main' of github.com:yger/spikeinterface into hybrid_raw…
yger Feb 6, 2024
e6595ee
Merge branch 'SpikeInterface:main' into hybrid_raw_clustering
yger Feb 6, 2024
e0251b0
Merge branch 'hybrid_raw_clustering' of github.com:yger/spikeinterfac…
yger Feb 6, 2024
69f0674
Merge branch 'SpikeInterface:main' into hybrid_raw_clustering
yger Feb 16, 2024
e0037b3
WIP
yger Mar 12, 2024
2b8656f
WIP
yger Mar 12, 2024
b2365d4
WIP
yger Mar 12, 2024
6858a2f
Merge branch 'SpikeInterface:main' into hybrid_raw_clustering
yger Mar 27, 2024
1d5b641
WIP
yger Mar 28, 2024
a139762
WIP
yger Mar 28, 2024
6a1673d
WIP
yger Mar 28, 2024
7742882
WIP
yger Mar 28, 2024
36ec32d
WIP
yger Mar 28, 2024
3c59afb
Merge branch 'main' into hybrid_raw_clustering
yger Mar 28, 2024
f6c4c4a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2024
b2e87a0
Docs
yger Mar 28, 2024
5046e53
Merge branch 'hybrid_raw_clustering' of github.com:yger/spikeinterfac…
yger Mar 28, 2024
275f87a
Adding tests
yger Mar 28, 2024
4e97eeb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2024
45653f6
One more test
yger Mar 28, 2024
e05f1d0
WIP
yger Mar 28, 2024
d1ee525
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2024
c12c746
Adding imports
yger Mar 29, 2024
c62f952
Merge branch 'main' into hybrid_raw_clustering
yger Mar 29, 2024
03a41d2
Merge branch 'hybrid_raw_clustering' of github.com:yger/spikeinterfac…
yger Mar 29, 2024
a850950
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 29, 2024
4bcbe5c
Moving functions in localization_tools
yger Mar 29, 2024
d974fe8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 29, 2024
3eedeb4
Imports
yger Mar 29, 2024
37eb54d
Merge branch 'hybrid_raw_clustering' of github.com:yger/spikeinterfac…
yger Mar 29, 2024
5f95fa0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 29, 2024
23be272
Hybrid recordings also with given templates
yger Mar 29, 2024
c40c9bd
Merge branch 'hybrid_raw_clustering' of github.com:yger/spikeinterfac…
yger Mar 29, 2024
8d23525
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Mar 29, 2024
03fd84f
Extension of relocalization to real templates
yger Mar 30, 2024
c9b9215
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 30, 2024
0d2d7f9
Unit_locations could be given
yger Apr 1, 2024
f1b37f4
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Apr 1, 2024
761cef5
Update src/spikeinterface/generation/tests/test_hybrid_tools.py
yger Apr 2, 2024
04c66c6
Update src/spikeinterface/generation/tests/test_hybrid_tools.py
yger Apr 2, 2024
b39bea9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 2, 2024
394ecec
Update src/spikeinterface/generation/hybrid_tools.py
yger Apr 2, 2024
396a0ca
Update src/spikeinterface/generation/tests/test_hybrid_tools.py
yger Apr 2, 2024
bf9cf53
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 2, 2024
d8e67de
Merge branch 'main' of https://github.com/SpikeInterface/spikeinterfa…
yger Apr 2, 2024
a33049a
WIP
yger Apr 4, 2024
4c81d1e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 4, 2024
ebee96a
Merge branch 'SpikeInterface:main' into hybrid_raw_clustering
yger Apr 4, 2024
8127894
Merge branch 'main' into hybrid_raw_clustering
yger Apr 9, 2024
057ad10
Merge branch 'main' into hybrid_raw_clustering
yger Apr 9, 2024
eaba48f
WIP
yger Apr 9, 2024
01ec8ab
Imports
yger Apr 9, 2024
15a9491
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 9, 2024
b081091
Imports
yger Apr 9, 2024
88103d8
Merge branch 'main' into hybrid_raw_clustering
yger Apr 10, 2024
095aa06
WIP
yger Apr 12, 2024
82b5dde
WIP
yger Apr 15, 2024
6ba6d96
WIP
yger Apr 15, 2024
7bc2c6f
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
yger Apr 17, 2024
ef46881
Merge branch 'hybrid_raw_clustering' of github.com:yger/spikeinterfac…
yger Apr 17, 2024
7cfc227
Merge branch 'SpikeInterface:main' into hybrid_raw_clustering
yger Apr 23, 2024
33a426a
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
alejoe91 May 1, 2024
40fd9ff
Add amplitude scaling std
alejoe91 May 1, 2024
c584a81
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
alejoe91 May 1, 2024
76ab76d
Bunch of fixes and enable the use of external Templates objects
alejoe91 May 1, 2024
1043577
Fix scaling for drifting templates
alejoe91 May 1, 2024
836c2d4
fix tests
alejoe91 May 2, 2024
007acf7
Extend template database functionality
alejoe91 May 6, 2024
74bb13c
Final touches
alejoe91 May 6, 2024
a0179a1
pandas local import
alejoe91 May 6, 2024
e0b1ee6
Fix typing and SC2 matching
alejoe91 May 6, 2024
1213137
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
alejoe91 May 6, 2024
38a887e
Fix tests
alejoe91 May 6, 2024
6224d2c
Add hybrid tools and extend plot_unit_templates to Templates
alejoe91 May 10, 2024
1795c9f
Merge branch 'main' into hybrid_raw_clustering
alejoe91 May 10, 2024
5dccdab
Remove unused import
alejoe91 May 10, 2024
289f337
Update src/spikeinterface/widgets/unit_waveforms.py
alejoe91 May 10, 2024
e25bac3
Fix get_unit_colors tests and rename filter_templates to select_templ…
alejoe91 May 13, 2024
ae32e26
Merge branch 'hybrid_raw_clustering' of github.com:yger/spikeinterfac…
alejoe91 May 14, 2024
b3816bd
Fix test zarr path in tests
alejoe91 May 14, 2024
f034b39
Fix scaling and add temlate manipulation tests
alejoe91 May 14, 2024
e6d5cb3
Add widen/narrow button and scale bar to plot_unitwaveforms/templates
alejoe91 May 14, 2024
2835350
cleanup ipywidgets utils
alejoe91 May 14, 2024
ff8b29a
Ramon's suggestions
alejoe91 May 15, 2024
91ebec8
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
alejoe91 May 15, 2024
57af6ac
Revert changes to geT_unit_colors
alejoe91 May 15, 2024
8317ba4
fix tests
alejoe91 May 15, 2024
290701f
Merge branch 'main' into hybrid_raw_clustering
yger May 24, 2024
4bc4f6a
Merge branch 'main' into hybrid_raw_clustering
yger May 27, 2024
3e29d8d
Merge branch 'main' into hybrid_raw_clustering
yger May 31, 2024
bb1b489
Fixing tests
yger May 31, 2024
ec58a8e
Merge branch 'main' into hybrid_raw_clustering
yger Jun 3, 2024
f5e81a0
Fixing tests
yger Jun 3, 2024
8fe4efb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 3, 2024
81d42b8
Fix conflicts
alejoe91 Jun 3, 2024
3047e30
Fixing colons
yger Jun 6, 2024
ffd93c7
Merge branch 'main' into hybrid_raw_clustering
h-mayorquin Jun 6, 2024
070cb87
Fix conflicts
alejoe91 Jun 19, 2024
03fe4d9
from_static -> from_static_templates
alejoe91 Jun 19, 2024
d023420
Remove duplicate select_units/channels in Templates
alejoe91 Jun 19, 2024
93eb5e3
Changes form code review
alejoe91 Jun 19, 2024
c2df514
Fix imports and extend docstring
alejoe91 Jun 19, 2024
534202e
Move compute_* localization functions to localization_tools
alejoe91 Jun 19, 2024
c86cad4
Update save_motion_info and add tests
alejoe91 Jun 19, 2024
1ecda4a
Use save_motion_info in correct_motion
alejoe91 Jun 19, 2024
6f7d082
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
alejoe91 Jun 24, 2024
28321d9
User-specified precomputed templates in DriftingTemplates
cwindolf Jun 3, 2024
643ea6a
dataclasses are weird!
cwindolf Jun 3, 2024
866da1e
Incorporate DriftingTemplates.from_precomputed_templates constructor …
alejoe91 Jun 24, 2024
5c3e73b
Start hibrid benchmark docs and add Neuropixel-384 toy probe
alejoe91 Jun 24, 2024
c45f283
wip example
alejoe91 Jun 26, 2024
bc83485
continue how to an lint
alejoe91 Jun 26, 2024
2dd5f8b
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
alejoe91 Jun 26, 2024
3971512
Make Motion JSON serializable and add hybrit How to
alejoe91 Jun 27, 2024
f61a3c2
Merge branch 'main' into hybrid_raw_clustering
alejoe91 Jun 27, 2024
3e319dc
Merge branch 'main' into hybrid_raw_clustering
alejoe91 Jun 27, 2024
e6da491
Remove custom pickle
alejoe91 Jun 27, 2024
7f8f194
Remove duplicated line
alejoe91 Jun 27, 2024
d9c36a1
Use official prot_drift_map
alejoe91 Jun 27, 2024
8ffdd98
Add s3fs to test dependencies
alejoe91 Jun 27, 2024
406ebcf
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
alejoe91 Jun 27, 2024
5186f74
Sam+Charlie's suggestions: estimate one motion vector per unit (plus …
alejoe91 Jun 28, 2024
8b16006
Update how to and fix displacement_vectors
alejoe91 Jun 28, 2024
b2d3d3b
Update jupytext readme
alejoe91 Jun 28, 2024
efd5ecf
Update how to benchmark with real data
alejoe91 Jun 28, 2024
0015f5f
Remove old displacement_unit_factor
alejoe91 Jun 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import warnings
import numpy as np
from typing import Union, Optional, List, Literal
import warnings
from math import ceil

from .basesorting import SpikeVectorSortingSegment
Expand Down Expand Up @@ -1862,7 +1861,7 @@ def get_traces(
wf = template[start_template:end_template]
if self.amplitude_vector is not None:
wf = wf * self.amplitude_vector[i]
traces[start_traces:end_traces] += wf
traces[start_traces:end_traces] += wf.astype(traces.dtype, copy=False)
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved

return traces.astype(self.dtype, copy=False)

Expand Down
1 change: 0 additions & 1 deletion src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,6 @@ def _init_peak_pipeline(recording, nodes):
worker_ctx["recording"] = recording
worker_ctx["nodes"] = nodes
worker_ctx["max_margin"] = max(node.get_trace_margin() for node in nodes)

return worker_ctx


Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/recording_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ def get_chunk_with_margin(
case zero padding is used, in the second case np.pad is called
with mod="reflect".
"""
length = rec_segment.get_num_samples()
length = int(rec_segment.get_num_samples())
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved

if channel_indices is None:
channel_indices = slice(None)
Expand Down
59 changes: 54 additions & 5 deletions src/spikeinterface/core/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,9 @@ def from_zarr_group(cls, zarr_group: "zarr.Group") -> "Templates":
the `add_templates_to_zarr_group` method.

"""
templates_array = zarr_group["templates_array"]
channel_ids = zarr_group["channel_ids"]
unit_ids = zarr_group["unit_ids"]
templates_array = zarr_group["templates_array"][:]
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
channel_ids = zarr_group["channel_ids"][:]
unit_ids = zarr_group["unit_ids"][:]
sampling_frequency = zarr_group.attrs["sampling_frequency"]
nbefore = zarr_group.attrs["nbefore"]

Expand All @@ -305,7 +305,7 @@ def from_zarr_group(cls, zarr_group: "zarr.Group") -> "Templates":

sparsity_mask = None
if "sparsity_mask" in zarr_group:
sparsity_mask = zarr_group["sparsity_mask"]
sparsity_mask = zarr_group["sparsity_mask"][:]

probe = None
if "probe" in zarr_group:
Expand Down Expand Up @@ -352,6 +352,55 @@ def to_json(self):
def from_json(cls, json_str):
return cls.from_dict(json.loads(json_str))

def select_units(self, unit_ids):
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
"""
Return a new Templates object with only the selected units.

Parameters
----------
unit_ids : list
List of unit IDs to select.
"""
unit_ids_list = list(self.unit_ids)
unit_indices = np.array([unit_ids_list.index(unit_id) for unit_id in unit_ids], dtype=int)
sliced_sparsity_mask = None if self.sparsity_mask is None else self.sparsity_mask[unit_indices]
return Templates(
templates_array=self.templates_array[unit_indices],
sampling_frequency=self.sampling_frequency,
nbefore=self.nbefore,
sparsity_mask=sliced_sparsity_mask,
channel_ids=self.channel_ids,
unit_ids=unit_ids,
probe=self.probe,
check_for_consistent_sparsity=False,
)

def select_channels(self, channel_ids):
"""
Return a new Templates object with only the selected channels.
This operation can be useful to remove bad channels for hybrid recording
generation.

Parameters
----------
channel_ids : list
List of channel IDs to select.
"""
assert not self.are_templates_sparse(), "Cannot select channels on sparse templates"
channel_ids_list = list(self.channel_ids)
channel_indices = np.array([channel_ids_list.index(channel_id) for channel_id in channel_ids])
sliced_sparsity_mask = None if self.sparsity_mask is None else self.sparsity_mask[:, channel_indices]
return Templates(
templates_array=self.templates_array[:, :, channel_indices],
sampling_frequency=self.sampling_frequency,
nbefore=self.nbefore,
sparsity_mask=sliced_sparsity_mask,
channel_ids=channel_ids,
unit_ids=self.unit_ids,
probe=self.probe,
check_for_consistent_sparsity=False,
)

def __eq__(self, other):
"""
Necessary to compare templates because they naturally compare objects by equality of their fields
Expand Down Expand Up @@ -390,7 +439,7 @@ def __eq__(self, other):

return True

def get_channel_locations(self):
def get_channel_locations(self) -> np.ndarray:
assert self.probe is not None, "Templates.get_channel_locations() needs a probe to be set"
channel_locations = self.probe.contact_positions
return channel_locations
2 changes: 1 addition & 1 deletion src/spikeinterface/core/template_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _get_nbefore(one_object):
raise ValueError("SortingAnalyzer need extension 'templates' to be computed")
return ext.nbefore
else:
raise ValueError("Input should be Templates or SortingAnalyzer or SortingAnalyzer")
raise ValueError("Input should be Templates or SortingAnalyzer")


def get_template_amplitudes(
Expand Down
8 changes: 8 additions & 0 deletions src/spikeinterface/generation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
InjectDriftingTemplatesRecording,
make_linear_displacement,
)

from .hybrid_tools import (
generate_hybrid_recording,
estimate_templates_from_recording,
select_templates,
scale_templates,
shift_templates,
)
from .noise_tools import generate_noise
from .drifting_generator import (
make_one_displacement_vector,
Expand Down
15 changes: 8 additions & 7 deletions src/spikeinterface/generation/drift_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,13 @@ def __init__(self, **kwargs):

@classmethod
def from_static(cls, templates):
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
drifting_teplates = cls(
drifting_templates = cls(
templates_array=templates.templates_array,
sampling_frequency=templates.sampling_frequency,
nbefore=templates.nbefore,
probe=templates.probe,
)
return drifting_teplates
return drifting_templates

def move_one_template(self, unit_index, displacement, **interpolation_kwargs):
"""
Expand Down Expand Up @@ -264,9 +264,9 @@ def __init__(
):
import scipy.spatial

assert isinstance(
drifting_templates, DriftingTemplates
), "drifting_templates must be a DriftingTemplates object"
# assert isinstance(
# drifting_templates, DriftingTemplates
# ), "drifting_templates must be a DriftingTemplates object"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why ?

self.drifting_templates = drifting_templates

if parent_recording is None:
Expand Down Expand Up @@ -442,7 +442,8 @@ def __init__(
# TODO: self.upsample_vector = upsample_vector
self.upsample_vector = None
self.parent_recording = parent_recording_segment
self.num_samples = parent_recording_segment.get_num_frames() if num_samples is None else num_samples
self.num_samples = parent_recording_segment.get_num_samples() if num_samples is None else num_samples
self.num_samples = int(num_samples)

self.displacement_indices = displacement_indices
self.templates_array_moved = templates_array_moved
Expand Down Expand Up @@ -507,7 +508,7 @@ def get_traces(
wf = template[start_template:end_template]
if self.amplitude_vector is not None:
wf *= self.amplitude_vector[i]
traces[start_traces:end_traces] += wf
traces[start_traces:end_traces] += wf.astype(self.dtype, copy=False)

return traces.astype(self.dtype)

Expand Down
Loading
Loading