Skip to content

Commit

Permalink
Merge pull request #2989 from h-mayorquin/add_select_methods_for_temp…
Browse files Browse the repository at this point in the history
…lates

Add select chanel and select unit method for template objects
  • Loading branch information
alejoe91 authored Jun 10, 2024
2 parents d962631 + d3cf9ac commit 17ee785
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 10 deletions.
61 changes: 60 additions & 1 deletion src/spikeinterface/core/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np
import json
from dataclasses import dataclass, field, astuple
from dataclasses import dataclass, field, astuple, replace
from probeinterface import Probe
from pathlib import Path
from .sparsity import ChannelSparsity
Expand Down Expand Up @@ -97,8 +97,20 @@ def __post_init__(self):
# Initialize sparsity object
if self.channel_ids is None:
self.channel_ids = np.arange(self.num_channels)
else:
self.channel_ids = np.asarray(self.channel_ids)
assert (
len(self.channel_ids) == self.num_channels
), f"length of channel ids {len(self.channel_ids)} must be equal to the number of channels {self.num_channels}"

if self.unit_ids is None:
self.unit_ids = np.arange(self.num_units)
else:
self.unit_ids = np.asarray(self.unit_ids)
assert (
self.unit_ids.size == self.num_units
), f"length of units ids {self.unit_ids.size} must be equal to the number of units {self.num_units}"

if self.sparsity_mask is not None:
self.sparsity = ChannelSparsity(
mask=self.sparsity_mask,
Expand Down Expand Up @@ -128,6 +140,53 @@ def __repr__(self):

return repr_str

def select_units(self, unit_ids) -> Templates:
"""
Return a new Templates object with only the selected units.
Parameters
----------
unit_ids : list
List of unit IDs to select.
"""
unit_ids_list = list(self.unit_ids)
unit_indices = np.array([unit_ids_list.index(unit_id) for unit_id in unit_ids], dtype=int)
sliced_sparsity_mask = None if self.sparsity_mask is None else self.sparsity_mask[unit_indices]

# Data class method to only change selected fields
return replace(
self,
templates_array=self.templates_array[unit_indices],
sparsity_mask=sliced_sparsity_mask,
unit_ids=unit_ids,
check_for_consistent_sparsity=False,
)

def select_channels(self, channel_ids) -> Templates:
"""
Return a new Templates object with only the selected channels.
This operation can be useful to remove bad channels for hybrid recording
generation.
Parameters
----------
channel_ids : list
List of channel IDs to select.
"""
assert not self.are_templates_sparse(), "Cannot select channels on sparse templates"
channel_ids_list = list(self.channel_ids)
channel_indices = np.array([channel_ids_list.index(channel_id) for channel_id in channel_ids])
sliced_sparsity_mask = None if self.sparsity_mask is None else self.sparsity_mask[:, channel_indices]

# Data class method to only change selected fields
return replace(
self,
templates_array=self.templates_array[:, :, channel_indices],
sparsity_mask=sliced_sparsity_mask,
channel_ids=channel_ids,
check_for_consistent_sparsity=False,
)

def to_sparse(self, sparsity):
# Turn a dense representation of templates into a sparse one, given some sparsity.
# Note that nothing prevent Templates tobe empty after sparsification if the sparse mask have no channels for some units
Expand Down
72 changes: 63 additions & 9 deletions src/spikeinterface/core/tests/test_template_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@


def generate_test_template(template_type, is_scaled=True) -> Templates:
num_units = 2
num_units = 3
num_samples = 5
num_channels = 3
num_channels = 4
templates_shape = (num_units, num_samples, num_channels)
templates_array = np.arange(num_units * num_samples * num_channels).reshape(templates_shape)

unit_ids = ["unit_a", "unit_b", "unit_c"]
channel_ids = ["channel1", "channel2", "channel3", "channel4"]
sampling_frequency = 30_000
nbefore = 2

Expand All @@ -25,19 +26,25 @@ def generate_test_template(template_type, is_scaled=True) -> Templates:
sampling_frequency=sampling_frequency,
nbefore=nbefore,
probe=probe,
unit_ids=unit_ids,
channel_ids=channel_ids,
is_scaled=is_scaled,
)
elif template_type == "sparse": # sparse with sparse templates
sparsity_mask = np.array([[True, False, True], [False, True, False]])
sparsity_mask = np.array(
[[True, False, True, True], [False, True, False, False], [True, False, True, False]],
)
sparsity = ChannelSparsity(
mask=sparsity_mask, unit_ids=np.arange(num_units), channel_ids=np.arange(num_channels)
mask=sparsity_mask,
unit_ids=unit_ids,
channel_ids=channel_ids,
)

# Create sparse templates
sparse_templates_array = np.zeros(shape=(num_units, num_samples, sparsity.max_num_active_channels))
for unit_index in range(num_units):
for unit_index, unit_id in enumerate(unit_ids):
template = templates_array[unit_index, ...]
sparse_template = sparsity.sparsify_waveforms(waveforms=template, unit_id=unit_index)
sparse_template = sparsity.sparsify_waveforms(waveforms=template, unit_id=unit_id)
sparse_templates_array[unit_index, :, : sparse_template.shape[1]] = sparse_template

return Templates(
Expand All @@ -47,18 +54,23 @@ def generate_test_template(template_type, is_scaled=True) -> Templates:
nbefore=nbefore,
probe=probe,
is_scaled=is_scaled,
unit_ids=unit_ids,
channel_ids=channel_ids,
)

elif template_type == "sparse_with_dense_templates": # sparse with dense templates
sparsity_mask = np.array([[True, False, True], [False, True, False]])

sparsity_mask = np.array(
[[True, False, True, True], [False, True, False, False], [True, False, True, False]],
)
return Templates(
templates_array=templates_array,
sparsity_mask=sparsity_mask,
sampling_frequency=sampling_frequency,
nbefore=nbefore,
probe=probe,
is_scaled=is_scaled,
unit_ids=unit_ids,
channel_ids=channel_ids,
)


Expand Down Expand Up @@ -117,6 +129,48 @@ def test_save_and_load_zarr(template_type, is_scaled, tmp_path):
assert original_template == loaded_template


@pytest.mark.parametrize("is_scaled", [True, False])
@pytest.mark.parametrize("template_type", ["dense", "sparse"])
def test_select_units(template_type, is_scaled):
template = generate_test_template(template_type, is_scaled)
selected_unit_ids = ["unit_a", "unit_c"]
selected_unit_ids_indices = [0, 2]

selected_template = template.select_units(selected_unit_ids)

# Verify that the selected template has the correct number of units
assert selected_template.num_units == len(selected_unit_ids)
# Verify that the unit ids match
assert np.array_equal(selected_template.unit_ids, selected_unit_ids)
# Verify that the templates data matches
assert np.array_equal(selected_template.templates_array, template.templates_array[selected_unit_ids_indices])

if template.sparsity_mask is not None:
assert np.array_equal(selected_template.sparsity_mask, template.sparsity_mask[selected_unit_ids_indices])


@pytest.mark.parametrize("is_scaled", [True, False])
@pytest.mark.parametrize("template_type", ["dense"])
def test_select_channels(template_type, is_scaled):
template = generate_test_template(template_type, is_scaled)
selected_channel_ids = ["channel1", "channel3"]
selected_channel_ids_indices = [0, 2]

selected_template = template.select_channels(selected_channel_ids)

# Verify that the selected template has the correct number of channels
assert selected_template.num_channels == len(selected_channel_ids)
# Verify that the channel ids match
assert np.array_equal(selected_template.channel_ids, selected_channel_ids)
# Verify that the templates data matches
assert np.array_equal(
selected_template.templates_array, template.templates_array[:, :, selected_channel_ids_indices]
)

if template.sparsity_mask is not None:
assert np.array_equal(selected_template.sparsity_mask, template.sparsity_mask[:, selected_channel_ids_indices])


if __name__ == "__main__":
# test_json_serialization("sparse")
test_json_serialization("dense")

0 comments on commit 17ee785

Please sign in to comment.