Skip to content

Commit

Permalink
Merge branch 'main' into optimize_motion
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 authored Dec 4, 2023
2 parents cdc8d58 + 7f204c6 commit 8852629
Show file tree
Hide file tree
Showing 48 changed files with 969 additions and 577 deletions.
2 changes: 1 addition & 1 deletion .github/actions/build-test-environment/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ runs:
python -m pip install -U pip # Official recommended way
source ${{ github.workspace }}/test_env/bin/activate
pip install tabulate # This produces summaries at the end
pip install -e .[test,extractors,full]
pip install -e .[test,extractors,streaming_extractors,full]
shell: bash
- name: Force installation of latest dev from key-packages when running dev (not release)
run: |
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ streaming_extractors = [
"aiohttp",
"requests",
"pynwb>=2.3.0",
"remfile"
]

full = [
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/baserecordingsnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class BaseRecordingSnippets(BaseExtractor):

def __init__(self, sampling_frequency: float, channel_ids: list[str, int], dtype: np.dtype):
BaseExtractor.__init__(self, channel_ids)
self._sampling_frequency = sampling_frequency
self._sampling_frequency = float(sampling_frequency)
self._dtype = np.dtype(dtype)

@property
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class BaseSorting(BaseExtractor):

def __init__(self, sampling_frequency: float, unit_ids: List):
BaseExtractor.__init__(self, unit_ids)
self._sampling_frequency = sampling_frequency
self._sampling_frequency = float(sampling_frequency)
self._sorting_segments: List[BaseSortingSegment] = []
# this weak link is to handle times from a recording object
self._recording = None
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def __init__(

if not channel_from_template:
channel_distance = get_channel_distances(recording)
self.neighbours_mask = channel_distance < radius_um
self.neighbours_mask = channel_distance <= radius_um
self.peak_sign = peak_sign

# precompute segment slice
Expand Down Expand Up @@ -367,7 +367,7 @@ def __init__(
self.radius_um = radius_um
self.contact_locations = recording.get_channel_locations()
self.channel_distance = get_channel_distances(recording)
self.neighbours_mask = self.channel_distance < radius_um
self.neighbours_mask = self.channel_distance <= radius_um
self.max_num_chans = np.max(np.sum(self.neighbours_mask, axis=1))

def get_trace_margin(self):
Expand Down
9 changes: 5 additions & 4 deletions src/spikeinterface/core/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,15 @@ def __init__(self, spikes, sampling_frequency, unit_ids):
self._kwargs = dict(spikes=spikes, sampling_frequency=sampling_frequency, unit_ids=unit_ids)

@staticmethod
def from_sorting(source_sorting: BaseSorting, with_metadata=False) -> "NumpySorting":
def from_sorting(source_sorting: BaseSorting, with_metadata=False, copy_spike_vector=False) -> "NumpySorting":
"""
Create a numpy sorting from another sorting extractor
"""

sorting = NumpySorting(
source_sorting.to_spike_vector(), source_sorting.get_sampling_frequency(), source_sorting.unit_ids
)
spike_vector = source_sorting.to_spike_vector()
if copy_spike_vector:
spike_vector = spike_vector.copy()
sorting = NumpySorting(spike_vector, source_sorting.get_sampling_frequency(), source_sorting.unit_ids)
if with_metadata:
sorting.copy_metadata(source_sorting)
return sorting
Expand Down
65 changes: 49 additions & 16 deletions src/spikeinterface/core/waveform_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Iterable, Literal, Optional
import json
import os
import weakref

import numpy as np
from copy import deepcopy
Expand Down Expand Up @@ -1850,7 +1851,7 @@ class BaseWaveformExtractorExtension:
handle_sparsity = False

def __init__(self, waveform_extractor):
self.waveform_extractor = waveform_extractor
self._waveform_extractor = weakref.ref(waveform_extractor)

if self.waveform_extractor.folder is not None:
self.folder = self.waveform_extractor.folder
Expand Down Expand Up @@ -1897,8 +1898,20 @@ def __init__(self, waveform_extractor):
# register
self.waveform_extractor._loaded_extensions[self.extension_name] = self

@property
def waveform_extractor(self):
# Important : to avoid the WaveformExtractor referencing a BaseWaveformExtractorExtension
# and BaseWaveformExtractorExtension referencing a WaveformExtractor
# we need a weakref. Otherwise the garbage collector is not working properly
# and so the WaveformExtractor + its recording are still alive even after deleting explicitly
# the WaveformExtractor which makes it impossible to delete the folder!
we = self._waveform_extractor()
if we is None:
raise ValueError(f"The extension {self.extension_name} has lost its WaveformExtractor")
return we

@classmethod
def load(cls, folder, waveform_extractor=None):
def load(cls, folder, waveform_extractor):
folder = Path(folder)
assert folder.is_dir(), "Waveform folder does not exists"
if folder.suffix == ".zarr":
Expand All @@ -1909,8 +1922,8 @@ def load(cls, folder, waveform_extractor=None):
if "sparsity" in params and params["sparsity"] is not None:
params["sparsity"] = ChannelSparsity.from_dict(params["sparsity"])

if waveform_extractor is None:
waveform_extractor = WaveformExtractor.load(folder)
# if waveform_extractor is None:
# waveform_extractor = WaveformExtractor.load(folder)

# make instance with params
ext = cls(waveform_extractor)
Expand Down Expand Up @@ -1964,7 +1977,11 @@ def _load_extension_data(self):
if ext_data_file.suffix == ".json":
ext_data = json.load(ext_data_file.open("r"))
elif ext_data_file.suffix == ".npy":
ext_data = np.load(ext_data_file, mmap_mode="r")
# The lazy loading of an extension is complicated because if we compute again
# and have a link to the old buffer on windows then it fails
# ext_data = np.load(ext_data_file, mmap_mode="r")
# so we go back to full loading
ext_data = np.load(ext_data_file)
elif ext_data_file.suffix == ".csv":
import pandas as pd

Expand Down Expand Up @@ -2004,6 +2021,11 @@ def _save(self, **kwargs):
# Only save if not read only
if self.waveform_extractor.is_read_only():
return

# delete already saved
self._reset_folder()
self._save_params()

if self.format == "binary":
import pandas as pd

Expand Down Expand Up @@ -2054,18 +2076,26 @@ def _save(self, **kwargs):
except:
raise Exception(f"Could not save {ext_data_name} as extension data")

def _reset_folder(self):
"""
Delete the extension in folder (binary or zarr) and create an empty one.
"""
if self.format == "binary" and self.extension_folder is not None:
if self.extension_folder.is_dir():
shutil.rmtree(self.extension_folder)
self.extension_folder.mkdir()
elif self.format == "zarr":
import zarr

zarr_root = zarr.open(self.folder, mode="r+")
self.extension_group = zarr_root.create_group(self.extension_name, overwrite=True)

def reset(self):
"""
Reset the waveform extension.
Delete the sub folder and create a new empty one.
"""
if self.extension_folder is not None:
if self.format == "binary":
if self.extension_folder.is_dir():
shutil.rmtree(self.extension_folder)
self.extension_folder.mkdir()
elif self.format == "zarr":
del self.extension_group
self._reset_folder()

self._params = None
self._extension_data = dict()
Expand Down Expand Up @@ -2098,12 +2128,15 @@ def set_params(self, **params):
if self.waveform_extractor.is_read_only():
return

params_to_save = params.copy()
if "sparsity" in params and params["sparsity"] is not None:
self._save_params()

def _save_params(self):
params_to_save = self._params.copy()
if "sparsity" in params_to_save and params_to_save["sparsity"] is not None:
assert isinstance(
params["sparsity"], ChannelSparsity
params_to_save["sparsity"], ChannelSparsity
), "'sparsity' parameter must be a ChannelSparsity object!"
params_to_save["sparsity"] = params["sparsity"].to_dict()
params_to_save["sparsity"] = params_to_save["sparsity"].to_dict()
if self.format == "binary":
if self.extension_folder is not None:
param_file = self.extension_folder / "params.json"
Expand Down
66 changes: 66 additions & 0 deletions src/spikeinterface/exporters/tests/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import pytest
from pathlib import Path

from spikeinterface.core import generate_ground_truth_recording, extract_waveforms
from spikeinterface.postprocessing import (
compute_spike_amplitudes,
compute_template_similarity,
compute_principal_components,
)
from spikeinterface.qualitymetrics import compute_quality_metrics

if hasattr(pytest, "global_test_folder"):
cache_folder = pytest.global_test_folder / "exporters"
else:
cache_folder = Path("cache_folder") / "exporters"


def make_waveforms_extractor(sparse=True, with_group=False):
recording, sorting = generate_ground_truth_recording(
durations=[30.0],
sampling_frequency=28000.0,
num_channels=8,
num_units=4,
generate_probe_kwargs=dict(
num_columns=2,
xpitch=20,
ypitch=20,
contact_shapes="circle",
contact_shape_params={"radius": 6},
),
generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0),
noise_kwargs=dict(noise_level=5.0, strategy="on_the_fly"),
seed=2205,
)

if with_group:
recording.set_channel_groups([0, 0, 0, 0, 1, 1, 1, 1])
sorting.set_property("group", [0, 0, 1, 1])

we = extract_waveforms(recording=recording, sorting=sorting, folder=None, mode="memory", sparse=sparse)
compute_principal_components(we)
compute_spike_amplitudes(we)
compute_template_similarity(we)
compute_quality_metrics(we, metric_names=["snr"])

return we


@pytest.fixture(scope="module")
def waveforms_extractor_dense_for_export():
return make_waveforms_extractor(sparse=False)


@pytest.fixture(scope="module")
def waveforms_extractor_with_group_for_export():
return make_waveforms_extractor(sparse=False, with_group=True)


@pytest.fixture(scope="module")
def waveforms_extractor_sparse_for_export():
return make_waveforms_extractor(sparse=True)


if __name__ == "__main__":
we = make_waveforms_extractor(sparse=False)
print(we)
Loading

0 comments on commit 8852629

Please sign in to comment.