Skip to content

Commit

Permalink
Merge branch 'main' into bencmarks_components
Browse files Browse the repository at this point in the history
  • Loading branch information
yger authored Jun 19, 2024
2 parents f8beb1b + 4c95a2e commit 80133fb
Show file tree
Hide file tree
Showing 75 changed files with 2,063 additions and 951 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/installation-tips-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ jobs:
with:
python-version: '3.10'
- name: Test Conda Environment Creation
uses: conda-incubator/setup-miniconda@v2.2.0
uses: conda-incubator/setup-miniconda@v3
with:
miniconda-version: "latest"
environment-file: ./installation_tips/full_spikeinterface_environment_${{ matrix.label }}.yml
activate-environment: si_env
- name: Check Installation Tips
Expand Down
105 changes: 105 additions & 0 deletions doc/modules/curation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,111 @@ The merging and splitting operations are handled by the :py:class:`~spikeinterfa
# here is the final clean sorting
clean_sorting = cs.sorting
Manual curation format
----------------------

SpikeInterface internally supports a JSON-based manual curation format.
When manual curation is necessary, modifying a dataset in place is a bad practice.
Instead, to ensure the reproducibility of the spike sorting pipelines, we have introduced a simple and JSON-based manual curation format.
This format defines at the moment : merges + deletions + manual tags.
The simple file can be kept along side the output of a sorter and applied on the result to have a "clean" result.

This format has two part:

* **definition** with the folowing keys:

* "format_version" : format specification
* "unit_ids" : the list of unit_ds
* "label_definitions" : list of label categories and possible labels per category.
Every category can be *exclusive=True* onely one label or *exclusive=False* several labels possible

* **manual output** curation with the folowing keys:

* "manual_labels"
* "merged_unit_groups"
* "removed_units"

Here is the description of the format with a simple example:

.. code-block:: json
{
# the first part of the format is the definitation
"format_version": "1",
"unit_ids": [
"u1",
"u2",
"u3",
"u6",
"u10",
"u14",
"u20",
"u31",
"u42"
],
"label_definitions": {
"quality": {
"label_options": [
"good",
"noise",
"MUA",
"artifact"
],
"exclusive": true
},
"putative_type": {
"label_options": [
"excitatory",
"inhibitory",
"pyramidal",
"mitral"
],
"exclusive": false
}
},
# the second part of the format is manual action
"manual_labels": [
{
"unit_id": "u1",
"quality": [
"good"
]
},
{
"unit_id": "u2",
"quality": [
"noise"
],
"putative_type": [
"excitatory",
"pyramidal"
]
},
{
"unit_id": "u3",
"putative_type": [
"inhibitory"
]
}
],
"merged_unit_groups": [
[
"u3",
"u6"
],
[
"u10",
"u14",
"u20"
]
],
"removed_units": [
"u31",
"u42"
]
}
Automatic curation tools
------------------------
Expand Down
22 changes: 9 additions & 13 deletions doc/modules/motion_correction.rst
Original file line number Diff line number Diff line change
Expand Up @@ -163,21 +163,19 @@ The high-level :py:func:`~spikeinterface.preprocessing.correct_motion()` is inte
max_distance_um=150.0, **job_kwargs)
# Step 2: motion inference
motion, temporal_bins, spatial_bins = estimate_motion(recording=rec,
peaks=peaks,
peak_locations=peak_locations,
method="decentralized",
direction="y",
bin_duration_s=2.0,
bin_um=5.0,
win_step_um=50.0,
win_sigma_um=150.0)
motion = estimate_motion(recording=rec,
peaks=peaks,
peak_locations=peak_locations,
method="decentralized",
direction="y",
bin_duration_s=2.0,
bin_um=5.0,
win_step_um=50.0,
win_sigma_um=150.0)
# Step 3: motion interpolation
# this step is lazy
rec_corrected = interpolate_motion(recording=rec, motion=motion,
temporal_bins=temporal_bins,
spatial_bins=spatial_bins,
border_mode="remove_channels",
spatial_interpolation_method="kriging",
sigma_um=30.)
Expand Down Expand Up @@ -220,8 +218,6 @@ different preprocessing chains: one for motion correction and one for spike sort
rec_corrected2 = interpolate_motion(
recording=rec2,
motion=motion_info['motion'],
temporal_bins=motion_info['temporal_bins'],
spatial_bins=motion_info['spatial_bins'],
**motion_info['parameters']['interpolate_motion_kwargs'])
sorting = run_sorter(sorter_name="montainsort5", recording=rec_corrected2)
Expand Down
10 changes: 4 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ classifiers = [


dependencies = [
"numpy",
"numpy>=1.26, <2.0", # 1.20 np.ptp, 1.26 for avoiding pickling errors when numpy >2.0
"threadpoolctl>=3.0.0",
"tqdm",
"zarr>=2.16,<2.18",
Expand Down Expand Up @@ -65,18 +65,16 @@ extractors = [
"pyedflib>=0.1.30",
"sonpy;python_version<'3.10'",
"lxml", # lxml for neuroscope
"scipy<1.13",
"scipy",
"ONE-api>=2.7.0", # alf sorter and streaming IBL
"ibllib>=2.32.5", # streaming IBL
"ibllib>=2.36.0", # streaming IBL
"pymatreader>=0.0.32", # For cell explorer matlab files
"zugbruecke>=0.2; sys_platform!='win32'", # For plexon2
]

streaming_extractors = [
"ONE-api>=2.7.0", # alf sorter and streaming IBL
"ibllib>=2.32.5", # streaming IBL
"scipy<1.13", # ibl has a dependency on scipy but it does not have an upper bound
# Remove this once https://github.com/int-brain-lab/ibllib/issues/753
"ibllib>=2.36.0", # streaming IBL
# Following dependencies are for streaming with nwb files
"pynwb>=2.6.0",
"fsspec",
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@
# This flag must be set to False for release
# This avoids using versioning that contains ".dev0" (and this is a better choice)
# This is mainly useful when using run_sorter in a container and spikeinterface install
# DEV_MODE = True
DEV_MODE = False
DEV_MODE = True
# DEV_MODE = False
24 changes: 24 additions & 0 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,30 @@ def frame_slice(self, start_frame: int, end_frame: int) -> BaseRecording:
sub_recording = FrameSliceRecording(self, start_frame=start_frame, end_frame=end_frame)
return sub_recording

def time_slice(self, start_time: float, end_time: float) -> BaseRecording:
"""
Returns a new recording with sliced time. Note that this operation is not in place.
Parameters
----------
start_time : float
The start time in seconds.
end_time : float
The end time in seconds.
Returns
-------
BaseRecording
The object with sliced time.
"""

assert self.get_num_segments() == 1, "Time slicing is only supported for single segment recordings."

start_frame = self.time_to_sample_index(start_time)
end_frame = self.time_to_sample_index(end_time)

return self.frame_slice(start_frame=start_frame, end_frame=end_frame)

def _select_segments(self, segment_indices):
from .segmentutils import SelectSegmentRecording

Expand Down
7 changes: 6 additions & 1 deletion src/spikeinterface/core/core_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,12 @@ def default(self, obj):
if isinstance(obj, np.generic):
return obj.item()

if np.issctype(obj): # Cast numpy datatypes to their names
# Standard numpy dtypes like np.dtype('int32") are transformed this way
if isinstance(obj, np.dtype):
return np.dtype(obj).name

# This will transform to a string canonical representation of the dtype (e.g. np.int32 -> 'int32')
if isinstance(obj, type) and issubclass(obj, np.generic):
return np.dtype(obj).name

if isinstance(obj, np.ndarray):
Expand Down
33 changes: 32 additions & 1 deletion src/spikeinterface/core/recording_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ def get_chunk_with_margin(
case zero padding is used, in the second case np.pad is called
with mod="reflect".
"""
length = rec_segment.get_num_samples()
length = int(rec_segment.get_num_samples())

if channel_indices is None:
channel_indices = slice(None)
Expand Down Expand Up @@ -917,3 +917,34 @@ def get_rec_attributes(recording):
dtype=recording.get_dtype(),
)
return rec_attributes


def do_recording_attributes_match(recording1, recording2_attributes) -> bool:
"""
Check if two recordings have the same attributes
Parameters
----------
recording1 : BaseRecording
The first recording object
recording2_attributes : dict
The recording attributes to test against
Returns
-------
bool
True if the recordings have the same attributes
"""
recording1_attributes = get_rec_attributes(recording1)
recording2_attributes = deepcopy(recording2_attributes)
recording1_attributes.pop("properties")
recording2_attributes.pop("properties")

return (
np.array_equal(recording1_attributes["channel_ids"], recording2_attributes["channel_ids"])
and recording1_attributes["sampling_frequency"] == recording2_attributes["sampling_frequency"]
and recording1_attributes["num_channels"] == recording2_attributes["num_channels"]
and recording1_attributes["num_samples"] == recording2_attributes["num_samples"]
and recording1_attributes["is_filtered"] == recording2_attributes["is_filtered"]
and recording1_attributes["dtype"] == recording2_attributes["dtype"]
)
2 changes: 1 addition & 1 deletion src/spikeinterface/core/segmentutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def __init__(self, parent_segments, sampling_frequency, ignore_times=True):
BaseRecordingSegment.__init__(self, **time_kwargs)
self.parent_segments = parent_segments
self.all_length = [rec_seg.get_num_samples() for rec_seg in self.parent_segments]
self.cumsum_length = np.cumsum([0] + self.all_length)
self.cumsum_length = [0] + [sum(self.all_length[: i + 1]) for i in range(len(self.all_length))]
self.total_length = int(np.sum(self.all_length))

def get_num_samples(self):
Expand Down
37 changes: 33 additions & 4 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .basesorting import BaseSorting

from .base import load_extractor
from .recording_tools import check_probe_do_not_overlap, get_rec_attributes
from .recording_tools import check_probe_do_not_overlap, get_rec_attributes, do_recording_attributes_match
from .core_tools import check_json, retrieve_importing_provenance
from .job_tools import split_job_kwargs
from .numpyextractors import NumpySorting
Expand Down Expand Up @@ -203,6 +203,8 @@ def __init__(
self.format = format
self.sparsity = sparsity
self.return_scaled = return_scaled
# this is used to store temporary recording
self._temporary_recording = None

# extensions are not loaded at init
self.extensions = dict()
Expand Down Expand Up @@ -605,13 +607,37 @@ def load_from_zarr(cls, folder, recording=None):

return sorting_analyzer

def set_temporary_recording(self, recording: BaseRecording):
"""
Sets a temporary recording object. This function can be useful to temporarily set
a "cached" recording object that is not saved in the SortingAnalyzer object to speed up
computations. Upon reloading, the SortingAnalyzer object will try to reload the recording
from the original location in a lazy way.
Parameters
----------
recording : BaseRecording
The recording object to set as temporary recording.
"""
# check that recording is compatible
assert do_recording_attributes_match(recording, self.rec_attributes), "Recording attributes do not match."
assert np.array_equal(
recording.get_channel_locations(), self.get_channel_locations()
), "Recording channel locations do not match."
if self._recording is not None:
warnings.warn("SortingAnalyzer recording is already set. The current recording is temporarily replaced.")
self._temporary_recording = recording

def _save_or_select(self, format="binary_folder", folder=None, unit_ids=None) -> "SortingAnalyzer":
"""
Internal used by both save_as(), copy() and select_units() which are more or less the same.
"""

if self.has_recording():
recording = self.recording
recording = self._recording
elif self.has_temporary_recording():
recording = self._temporary_recording
else:
recording = None

Expand Down Expand Up @@ -728,9 +754,9 @@ def is_read_only(self) -> bool:

@property
def recording(self) -> BaseRecording:
if not self.has_recording():
if not self.has_recording() and not self.has_temporary_recording():
raise ValueError("SortingAnalyzer could not load the recording")
return self._recording
return self._temporary_recording or self._recording

@property
def channel_ids(self) -> np.ndarray:
Expand All @@ -747,6 +773,9 @@ def unit_ids(self) -> np.ndarray:
def has_recording(self) -> bool:
return self._recording is not None

def has_temporary_recording(self) -> bool:
return self._temporary_recording is not None

def is_sparse(self) -> bool:
return self.sparsity is not None

Expand Down
Loading

0 comments on commit 80133fb

Please sign in to comment.