Skip to content

Commit

Permalink
Merge branch 'main' into tip-bottom-option-rm-channels
Browse files Browse the repository at this point in the history
  • Loading branch information
alejoe91 authored Nov 30, 2023
2 parents 7f3be40 + 1000aae commit 933c160
Show file tree
Hide file tree
Showing 52 changed files with 987 additions and 635 deletions.
2 changes: 1 addition & 1 deletion .github/actions/build-test-environment/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ runs:
python -m pip install -U pip # Official recommended way
source ${{ github.workspace }}/test_env/bin/activate
pip install tabulate # This produces summaries at the end
pip install -e .[test,extractors,full]
pip install -e .[test,extractors,streaming_extractors,full]
shell: bash
- name: Force installation of latest dev from key-packages when running dev (not release)
run: |
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ classifiers = [
dependencies = [
"numpy",
"neo>=0.12.0",
"joblib",
"threadpoolctl",
"tqdm",
"probeinterface>=0.2.19",
Expand Down Expand Up @@ -80,6 +79,7 @@ streaming_extractors = [
"aiohttp",
"requests",
"pynwb>=2.3.0",
"remfile"
]

full = [
Expand Down
4 changes: 4 additions & 0 deletions src/spikeinterface/core/channelslice.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None)
channel_ids = parent_recording.get_channel_ids()
if renamed_channel_ids is None:
renamed_channel_ids = channel_ids
else:
assert len(renamed_channel_ids) == len(
np.unique(renamed_channel_ids)
), "renamed_channel_ids must be unique!"

self._parent_recording = parent_recording
self._channel_ids = np.asarray(channel_ids)
Expand Down
22 changes: 1 addition & 21 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import os
import warnings

import joblib
import sys
import contextlib
from tqdm.auto import tqdm
Expand Down Expand Up @@ -95,25 +94,6 @@ def split_job_kwargs(mixed_kwargs):
return specific_kwargs, job_kwargs


# from https://stackoverflow.com/questions/24983493/tracking-progress-of-joblib-parallel-execution
@contextlib.contextmanager
def tqdm_joblib(tqdm_object):
"""Context manager to patch joblib to report into tqdm progress bar given as argument"""

class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
def __call__(self, *args, **kwargs):
tqdm_object.update(n=self.batch_size)
return super().__call__(*args, **kwargs)

old_batch_callback = joblib.parallel.BatchCompletionCallBack
joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
try:
yield tqdm_object
finally:
joblib.parallel.BatchCompletionCallBack = old_batch_callback
tqdm_object.close()


def divide_segment_into_chunks(num_frames, chunk_size):
if chunk_size is None:
chunks = [(0, num_frames)]
Expand Down Expand Up @@ -156,7 +136,7 @@ def _mem_to_int(mem):

def ensure_n_jobs(recording, n_jobs=1):
if n_jobs == -1:
n_jobs = joblib.cpu_count()
n_jobs = os.cpu_count()
elif n_jobs == 0:
n_jobs = 1
elif n_jobs is None:
Expand Down
9 changes: 5 additions & 4 deletions src/spikeinterface/core/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,15 @@ def __init__(self, spikes, sampling_frequency, unit_ids):
self._kwargs = dict(spikes=spikes, sampling_frequency=sampling_frequency, unit_ids=unit_ids)

@staticmethod
def from_sorting(source_sorting: BaseSorting, with_metadata=False) -> "NumpySorting":
def from_sorting(source_sorting: BaseSorting, with_metadata=False, copy_spike_vector=False) -> "NumpySorting":
"""
Create a numpy sorting from another sorting extractor
"""

sorting = NumpySorting(
source_sorting.to_spike_vector(), source_sorting.get_sampling_frequency(), source_sorting.unit_ids
)
spike_vector = source_sorting.to_spike_vector()
if copy_spike_vector:
spike_vector = spike_vector.copy()
sorting = NumpySorting(spike_vector, source_sorting.get_sampling_frequency(), source_sorting.unit_ids)
if with_metadata:
sorting.copy_metadata(source_sorting)
return sorting
Expand Down
9 changes: 9 additions & 0 deletions src/spikeinterface/core/tests/test_channelslicerecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import probeinterface

from spikeinterface.core import ChannelSliceRecording, BinaryRecordingExtractor
from spikeinterface.core.generate import generate_recording


def test_ChannelSliceRecording():
Expand Down Expand Up @@ -73,5 +74,13 @@ def test_ChannelSliceRecording():
assert np.all(traces3[:, 1] == 2)


def test_failure_with_non_unique_channel_ids():
durations = [1.0]
seed = 10
rec = generate_recording(num_channels=4, durations=durations, set_probe=False, seed=seed)
with pytest.raises(AssertionError):
rec_sliced = ChannelSliceRecording(rec, channel_ids=[0, 1], renamed_channel_ids=[0, 0])


if __name__ == "__main__":
test_ChannelSliceRecording()
2 changes: 0 additions & 2 deletions src/spikeinterface/core/tests/test_node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

from spikeinterface import extract_waveforms, get_template_extremum_channel, generate_ground_truth_recording

# from spikeinterface.extractors import MEArecRecordingExtractor
from spikeinterface.extractors import read_mearec

# from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.core.node_pipeline import (
Expand Down
33 changes: 11 additions & 22 deletions src/spikeinterface/core/tests/test_unitsselectionsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,16 @@

from spikeinterface.core import UnitsSelectionSorting

from spikeinterface.core import NpzSortingExtractor, load_extractor
from spikeinterface.core.base import BaseExtractor
from spikeinterface.core.generate import generate_sorting

from spikeinterface.core import create_sorting_npz


if hasattr(pytest, "global_test_folder"):
cache_folder = pytest.global_test_folder / "core"
else:
cache_folder = Path("cache_folder") / "core"


def test_unitsselectionsorting():
num_seg = 2
file_path = cache_folder / "test_BaseSorting.npz"

create_sorting_npz(num_seg, file_path)

sorting = NpzSortingExtractor(file_path)
print(sorting)
print(sorting.unit_ids)
def test_basic_functions():
sorting = generate_sorting(num_units=3, durations=[0.100, 0.100], sampling_frequency=30000.0)

sorting2 = UnitsSelectionSorting(sorting, unit_ids=[0, 2])
print(sorting2.unit_ids)
assert np.array_equal(sorting2.unit_ids, [0, 2])

sorting3 = UnitsSelectionSorting(sorting, unit_ids=[0, 2], renamed_unit_ids=["a", "b"])
print(sorting3.unit_ids)
assert np.array_equal(sorting3.unit_ids, ["a", "b"])

assert np.array_equal(
Expand All @@ -49,5 +31,12 @@ def test_unitsselectionsorting():
)


def test_failure_with_non_unique_unit_ids():
seed = 10
sorting = generate_sorting(num_units=3, durations=[0.100], sampling_frequency=30000.0, seed=seed)
with pytest.raises(AssertionError):
sorting2 = UnitsSelectionSorting(sorting, unit_ids=[0, 2], renamed_unit_ids=["a", "a"])


if __name__ == "__main__":
test_unitsselectionsorting()
test_basic_functions()
1 change: 1 addition & 0 deletions src/spikeinterface/core/unitsselectionsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(self, parent_sorting, unit_ids=None, renamed_unit_ids=None):
unit_ids = parent_sorting.get_unit_ids()
if renamed_unit_ids is None:
renamed_unit_ids = unit_ids
assert len(renamed_unit_ids) == len(np.unique(renamed_unit_ids)), "renamed_unit_ids must be unique!"

self._parent_sorting = parent_sorting
self._unit_ids = np.asarray(unit_ids)
Expand Down
71 changes: 53 additions & 18 deletions src/spikeinterface/core/waveform_extractor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import math
import pickle
from pathlib import Path
import shutil
from typing import Iterable, Literal, Optional
import json
import os
import weakref

import numpy as np
from copy import deepcopy
Expand Down Expand Up @@ -1197,7 +1200,7 @@ def precompute_templates(self, modes=("average", "std", "median", "percentile"),
The results is cached in memory as a 3d ndarray (nunits, nsamples, nchans)
and also saved as an npy file in the folder to avoid recomputation each time.
"""
# TODO : run this in parralel
# TODO : run this in parallel

unit_ids = self.unit_ids
num_chans = self.get_num_channels()
Expand Down Expand Up @@ -1235,7 +1238,7 @@ def precompute_templates(self, modes=("average", "std", "median", "percentile"),

for mode in modes:
templates = self._template_cache[mode_names[mode]]
if self.folder is not None:
if self.folder is not None and not self.is_read_only():
template_file = self.folder / f"templates_{mode_names[mode]}.npy"
np.save(template_file, templates)

Expand Down Expand Up @@ -1848,7 +1851,7 @@ class BaseWaveformExtractorExtension:
handle_sparsity = False

def __init__(self, waveform_extractor):
self.waveform_extractor = waveform_extractor
self._waveform_extractor = weakref.ref(waveform_extractor)

if self.waveform_extractor.folder is not None:
self.folder = self.waveform_extractor.folder
Expand Down Expand Up @@ -1895,8 +1898,20 @@ def __init__(self, waveform_extractor):
# register
self.waveform_extractor._loaded_extensions[self.extension_name] = self

@property
def waveform_extractor(self):
# Important : to avoid the WaveformExtractor referencing a BaseWaveformExtractorExtension
# and BaseWaveformExtractorExtension referencing a WaveformExtractor
# we need a weakref. Otherwise the garbage collector is not working properly
# and so the WaveformExtractor + its recording are still alive even after deleting explicitly
# the WaveformExtractor which makes it impossible to delete the folder!
we = self._waveform_extractor()
if we is None:
raise ValueError(f"The extension {self.extension_name} has lost its WaveformExtractor")
return we

@classmethod
def load(cls, folder, waveform_extractor=None):
def load(cls, folder, waveform_extractor):
folder = Path(folder)
assert folder.is_dir(), "Waveform folder does not exists"
if folder.suffix == ".zarr":
Expand All @@ -1907,8 +1922,8 @@ def load(cls, folder, waveform_extractor=None):
if "sparsity" in params and params["sparsity"] is not None:
params["sparsity"] = ChannelSparsity.from_dict(params["sparsity"])

if waveform_extractor is None:
waveform_extractor = WaveformExtractor.load(folder)
# if waveform_extractor is None:
# waveform_extractor = WaveformExtractor.load(folder)

# make instance with params
ext = cls(waveform_extractor)
Expand Down Expand Up @@ -1962,7 +1977,11 @@ def _load_extension_data(self):
if ext_data_file.suffix == ".json":
ext_data = json.load(ext_data_file.open("r"))
elif ext_data_file.suffix == ".npy":
ext_data = np.load(ext_data_file, mmap_mode="r")
# The lazy loading of an extension is complicated because if we compute again
# and have a link to the old buffer on windows then it fails
# ext_data = np.load(ext_data_file, mmap_mode="r")
# so we go back to full loading
ext_data = np.load(ext_data_file)
elif ext_data_file.suffix == ".csv":
import pandas as pd

Expand Down Expand Up @@ -2002,6 +2021,11 @@ def _save(self, **kwargs):
# Only save if not read only
if self.waveform_extractor.is_read_only():
return

# delete already saved
self._reset_folder()
self._save_params()

if self.format == "binary":
import pandas as pd

Expand Down Expand Up @@ -2052,18 +2076,26 @@ def _save(self, **kwargs):
except:
raise Exception(f"Could not save {ext_data_name} as extension data")

def _reset_folder(self):
"""
Delete the extension in folder (binary or zarr) and create an empty one.
"""
if self.format == "binary" and self.extension_folder is not None:
if self.extension_folder.is_dir():
shutil.rmtree(self.extension_folder)
self.extension_folder.mkdir()
elif self.format == "zarr":
import zarr

zarr_root = zarr.open(self.folder, mode="r+")
self.extension_group = zarr_root.create_group(self.extension_name, overwrite=True)

def reset(self):
"""
Reset the waveform extension.
Delete the sub folder and create a new empty one.
"""
if self.extension_folder is not None:
if self.format == "binary":
if self.extension_folder.is_dir():
shutil.rmtree(self.extension_folder)
self.extension_folder.mkdir()
elif self.format == "zarr":
del self.extension_group
self._reset_folder()

self._params = None
self._extension_data = dict()
Expand Down Expand Up @@ -2096,12 +2128,15 @@ def set_params(self, **params):
if self.waveform_extractor.is_read_only():
return

params_to_save = params.copy()
if "sparsity" in params and params["sparsity"] is not None:
self._save_params()

def _save_params(self):
params_to_save = self._params.copy()
if "sparsity" in params_to_save and params_to_save["sparsity"] is not None:
assert isinstance(
params["sparsity"], ChannelSparsity
params_to_save["sparsity"], ChannelSparsity
), "'sparsity' parameter must be a ChannelSparsity object!"
params_to_save["sparsity"] = params["sparsity"].to_dict()
params_to_save["sparsity"] = params_to_save["sparsity"].to_dict()
if self.format == "binary":
if self.extension_folder is not None:
param_file = self.extension_folder / "params.json"
Expand Down
Loading

0 comments on commit 933c160

Please sign in to comment.