Skip to content

Commit

Permalink
Adding SharedMemoryTemplates
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed Dec 13, 2024
1 parent 6fde997 commit 5ed4f03
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 1 deletion.
72 changes: 72 additions & 0 deletions src/spikeinterface/core/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from probeinterface import Probe
from pathlib import Path
from .sparsity import ChannelSparsity
from multiprocessing.shared_memory import SharedMemory
from .core_tools import make_shared_array


@dataclass
Expand Down Expand Up @@ -453,3 +455,73 @@ def get_channel_locations(self) -> np.ndarray:
assert self.probe is not None, "Templates.get_channel_locations() needs a probe to be set"
channel_locations = self.probe.contact_positions
return channel_locations


class SharedMemoryTemplates(Templates):

def __init__(self, shm_name, shape, dtype, sampling_frequency, nbefore, sparsity_mask,
channel_ids, unit_ids, probe, is_scaled, main_shm_owner=True):

assert len(shape) == 3
assert shape[0] > 0, "SharedMemoryTemplates only supported with no empty templates"

self.shm = SharedMemory(shm_name, create=False)
templates_array = np.ndarray(shape=shape, dtype=dtype, buffer=self.shm.buf)

Templates.__init__(self, templates_array=templates_array,
sampling_frequency=sampling_frequency,
nbefore=nbefore,
sparsity_mask=sparsity_mask,
channel_ids=channel_ids,
unit_ids=unit_ids,
probe=probe,
is_scaled=is_scaled)

# self._serializability["memory"] = True
# self._serializability["json"] = False
# self._serializability["pickle"] = False

# this is very important for the shm.unlink()
# only the main instance need to call it
# all other instances that are loaded from dict are not the main owner
self.main_shm_owner = main_shm_owner

self._kwargs = dict(
shm_name=shm_name,
shape=shape,
sampling_frequency=sampling_frequency,
nbefore=self.nbefore,
sparsity_mask=self.sparsity_mask,
channel_ids=self.channel_ids,
unit_ids=self.unit_ids,
probe=self.probe,
is_scaled=self.is_scaled,
# this ensure that all dump/load will not be main shm owner
main_shm_owner=False,
)

def __del__(self):
self.shm.close()
if self.main_shm_owner:
self.shm.unlink()

@staticmethod
def from_templates(templates):
data = templates.get_dense_templates()
shm_templates, shm = make_shared_array(data.shape, data.dtype)
shm_templates[:] = data
shared_templates = SharedMemoryTemplates(
shm.name,
data.shape,
data.dtype,
templates.sampling_frequency,
templates.nbefore,
templates.sparsity_mask,
templates.channel_ids,
templates.unit_ids,
templates.probe,
templates.is_scaled,
main_shm_owner=True,
)
shm.close()
return shared_templates
17 changes: 16 additions & 1 deletion src/spikeinterface/core/tests/test_template_class.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import numpy as np
import pickle
from spikeinterface.core.template import Templates
from spikeinterface.core.template import Templates, SharedMemoryTemplates
from spikeinterface.core.sparsity import ChannelSparsity

from probeinterface import generate_multi_columns_probe
Expand Down Expand Up @@ -170,6 +170,21 @@ def test_select_channels(template_type, is_scaled):
if template.sparsity_mask is not None:
assert np.array_equal(selected_template.sparsity_mask, template.sparsity_mask[:, selected_channel_ids_indices])

@pytest.mark.parametrize("is_scaled", [True, False])
@pytest.mark.parametrize("template_type", ["dense"])
def test_shm_templates(template_type, is_scaled):
template = generate_test_template(template_type, is_scaled)
shm_templates = SharedMemoryTemplates.from_templates(template)

# Verify that the channel ids match
assert np.array_equal(shm_templates.channel_ids, template.channel_ids)
# Verify that the templates data matches
assert np.array_equal(
shm_templates.templates_array, template.templates_array
)

if template.sparsity_mask is not None:
assert np.array_equal(shm_templates.sparsity_mask, template.sparsity_mask)

if __name__ == "__main__":
# test_json_serialization("sparse")
Expand Down

0 comments on commit 5ed4f03

Please sign in to comment.