diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index b64f0610ea..7aa4bb5b38 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -6,6 +6,8 @@ from probeinterface import Probe from pathlib import Path from .sparsity import ChannelSparsity +from multiprocessing.shared_memory import SharedMemory +from .core_tools import make_shared_array @dataclass @@ -453,3 +455,73 @@ def get_channel_locations(self) -> np.ndarray: assert self.probe is not None, "Templates.get_channel_locations() needs a probe to be set" channel_locations = self.probe.contact_positions return channel_locations + + +class SharedMemoryTemplates(Templates): + + def __init__(self, shm_name, shape, dtype, sampling_frequency, nbefore, sparsity_mask, + channel_ids, unit_ids, probe, is_scaled, main_shm_owner=True): + + assert len(shape) == 3 + assert shape[0] > 0, "SharedMemoryTemplates only supported with no empty templates" + + self.shm = SharedMemory(shm_name, create=False) + templates_array = np.ndarray(shape=shape, dtype=dtype, buffer=self.shm.buf) + + Templates.__init__(self, templates_array=templates_array, + sampling_frequency=sampling_frequency, + nbefore=nbefore, + sparsity_mask=sparsity_mask, + channel_ids=channel_ids, + unit_ids=unit_ids, + probe=probe, + is_scaled=is_scaled) + + # self._serializability["memory"] = True + # self._serializability["json"] = False + # self._serializability["pickle"] = False + + # this is very important for the shm.unlink() + # only the main instance need to call it + # all other instances that are loaded from dict are not the main owner + self.main_shm_owner = main_shm_owner + + self._kwargs = dict( + shm_name=shm_name, + shape=shape, + sampling_frequency=sampling_frequency, + nbefore=self.nbefore, + sparsity_mask=self.sparsity_mask, + channel_ids=self.channel_ids, + unit_ids=self.unit_ids, + probe=self.probe, + is_scaled=self.is_scaled, + # this ensure that all dump/load will not be main shm owner + main_shm_owner=False, + ) + + def __del__(self): + self.shm.close() + if self.main_shm_owner: + self.shm.unlink() + + @staticmethod + def from_templates(templates): + data = templates.get_dense_templates() + shm_templates, shm = make_shared_array(data.shape, data.dtype) + shm_templates[:] = data + shared_templates = SharedMemoryTemplates( + shm.name, + data.shape, + data.dtype, + templates.sampling_frequency, + templates.nbefore, + templates.sparsity_mask, + templates.channel_ids, + templates.unit_ids, + templates.probe, + templates.is_scaled, + main_shm_owner=True, + ) + shm.close() + return shared_templates \ No newline at end of file diff --git a/src/spikeinterface/core/tests/test_template_class.py b/src/spikeinterface/core/tests/test_template_class.py index 4e0a0c8567..ee694892a1 100644 --- a/src/spikeinterface/core/tests/test_template_class.py +++ b/src/spikeinterface/core/tests/test_template_class.py @@ -1,7 +1,7 @@ import pytest import numpy as np import pickle -from spikeinterface.core.template import Templates +from spikeinterface.core.template import Templates, SharedMemoryTemplates from spikeinterface.core.sparsity import ChannelSparsity from probeinterface import generate_multi_columns_probe @@ -170,6 +170,21 @@ def test_select_channels(template_type, is_scaled): if template.sparsity_mask is not None: assert np.array_equal(selected_template.sparsity_mask, template.sparsity_mask[:, selected_channel_ids_indices]) +@pytest.mark.parametrize("is_scaled", [True, False]) +@pytest.mark.parametrize("template_type", ["dense"]) +def test_shm_templates(template_type, is_scaled): + template = generate_test_template(template_type, is_scaled) + shm_templates = SharedMemoryTemplates.from_templates(template) + + # Verify that the channel ids match + assert np.array_equal(shm_templates.channel_ids, template.channel_ids) + # Verify that the templates data matches + assert np.array_equal( + shm_templates.templates_array, template.templates_array + ) + + if template.sparsity_mask is not None: + assert np.array_equal(shm_templates.sparsity_mask, template.sparsity_mask) if __name__ == "__main__": # test_json_serialization("sparse")