Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jan 16, 2025
1 parent 2d087f2 commit 69c8782
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
24 changes: 14 additions & 10 deletions src/spikeinterface/generation/drift_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from multiprocessing.shared_memory import SharedMemory
from spikeinterface.core.core_tools import make_shared_array


def interpolate_templates(templates_array, source_locations, dest_locations, interpolation_method="cubic"):
"""
Interpolate templates_array to new positions.
Expand Down Expand Up @@ -266,10 +267,10 @@ def __init__(
shm_name,
shape,
dtype,
templates_array_moved=None,
templates_array_moved=None,
displacements=None,
main_shm_owner=True,
**static_kwargs
**static_kwargs,
):

assert len(shape) == 4
Expand All @@ -279,10 +280,7 @@ def __init__(
templates_array_moved = np.ndarray(shape=shape, dtype=dtype, buffer=self.shm.buf)
self.static_kwargs = static_kwargs
DriftingTemplates.__init__(
self,
templates_array_moved=templates_array_moved,
displacements=displacements,
**self.static_kwargs
self, templates_array_moved=templates_array_moved, displacements=displacements, **self.static_kwargs
)

# this is very important for the shm.unlink()
Expand All @@ -307,20 +305,26 @@ def __del__(self):

@staticmethod
def from_drifting_templates(drifting_templates):
assert drifting_templates.templates_array_moved is not None, "drifting_templates must have precomputed displacements"
assert (
drifting_templates.templates_array_moved is not None
), "drifting_templates must have precomputed displacements"
data = drifting_templates.templates_array_moved
shm_templates, shm = make_shared_array(data.shape, data.dtype)
shm_templates[:] = data
static_kwargs = drifting_templates.to_dict()
init_kwargs = {
"templates_array": np.asarray(static_kwargs["templates_array"]),
"sparsity_mask": None if static_kwargs["sparsity_mask"] is None else np.asarray(static_kwargs["sparsity_mask"]),
"sparsity_mask": (
None if static_kwargs["sparsity_mask"] is None else np.asarray(static_kwargs["sparsity_mask"])
),
"channel_ids": np.asarray(static_kwargs["channel_ids"]),
"unit_ids": np.asarray(static_kwargs["unit_ids"]),
"sampling_frequency": static_kwargs["sampling_frequency"],
"nbefore": static_kwargs["nbefore"],
"is_scaled": static_kwargs["is_scaled"],
"probe": static_kwargs["probe"] if static_kwargs["probe"] is None else Probe.from_dict(static_kwargs["probe"]),
"probe": (
static_kwargs["probe"] if static_kwargs["probe"] is None else Probe.from_dict(static_kwargs["probe"])
),
}
shared_drifting_templates = SharedMemoryDriftingTemplates(
shm.name,
Expand All @@ -329,7 +333,7 @@ def from_drifting_templates(drifting_templates):
shm_templates,
drifting_templates.displacements,
main_shm_owner=True,
**init_kwargs
**init_kwargs,
)
shm.close()
return shared_drifting_templates
Expand Down
5 changes: 2 additions & 3 deletions src/spikeinterface/generation/tests/test_drift_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,10 @@ def test_SharedMemoryDriftingTemplates():
drifting_templates.precompute_displacements(displacement)
shm_drifting_templates = SharedMemoryDriftingTemplates.from_drifting_templates(drifting_templates)

assert np.array_equal(
shm_drifting_templates.templates_array_moved, drifting_templates.templates_array_moved
)
assert np.array_equal(shm_drifting_templates.templates_array_moved, drifting_templates.templates_array_moved)
assert np.array_equal(shm_drifting_templates.displacements, drifting_templates.displacements)


def test_InjectDriftingTemplatesRecording(create_cache_folder):
cache_folder = create_cache_folder
templates = make_some_templates()
Expand Down

0 comments on commit 69c8782

Please sign in to comment.