diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index ba18cf09b6..ad31b97d8e 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -47,6 +47,10 @@ def __init__(self, main_ids: Sequence) -> None: # 'main_ids' will either be channel_ids or units_ids # They are used for properties self._main_ids = np.array(main_ids) + if len(self._main_ids) > 0: + assert ( + self._main_ids.dtype.kind in "uiSU" + ), f"Main IDs can only be integers (signed/unsigned) or strings, not {self._main_ids.dtype}" # dict at object level self._annotations = {} diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index affde8a75e..d411f38d2a 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -1,4 +1,4 @@ -from typing import List +from __future__ import annotations from pathlib import Path import numpy as np @@ -19,7 +19,7 @@ class BaseRecordingSnippets(BaseExtractor): has_default_locations = False - def __init__(self, sampling_frequency: float, channel_ids: List, dtype): + def __init__(self, sampling_frequency: float, channel_ids: list[str, int], dtype: np.dtype): BaseExtractor.__init__(self, channel_ids) self._sampling_frequency = sampling_frequency self._dtype = np.dtype(dtype) diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index f35bc2b266..b4e3c11f55 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -1,10 +1,8 @@ from typing import List, Union -from pathlib import Path from .base import BaseSegment from .baserecordingsnippets import BaseRecordingSnippets import numpy as np from warnings import warn -from probeinterface import Probe, ProbeGroup, write_probeinterface, read_probeinterface, select_axes # snippets segments? diff --git a/src/spikeinterface/core/npysnippetsextractor.py b/src/spikeinterface/core/npysnippetsextractor.py index 80979ce6c9..69c48356e5 100644 --- a/src/spikeinterface/core/npysnippetsextractor.py +++ b/src/spikeinterface/core/npysnippetsextractor.py @@ -27,6 +27,9 @@ def __init__( num_segments = len(file_paths) data = np.load(file_paths[0], mmap_mode="r") + if channel_ids is None: + channel_ids = np.arange(data["snippet"].shape[2]) + BaseSnippets.__init__( self, sampling_frequency, @@ -84,7 +87,7 @@ def write_snippets(snippets, file_paths, dtype=None): arr = np.empty(n, dtype=snippets_t, order="F") arr["frame"] = snippets.get_frames(segment_index=i) arr["snippet"] = snippets.get_snippets(segment_index=i).astype(dtype, copy=False) - + file_paths[i].parent.mkdir(parents=True, exist_ok=True) np.save(file_paths[i], arr) diff --git a/src/spikeinterface/extractors/cellexplorersortingextractor.py b/src/spikeinterface/extractors/cellexplorersortingextractor.py index 31241a4147..0980e89f1c 100644 --- a/src/spikeinterface/extractors/cellexplorersortingextractor.py +++ b/src/spikeinterface/extractors/cellexplorersortingextractor.py @@ -118,7 +118,7 @@ def __init__( spike_times = spikes_data["times"] # CellExplorer reports spike times in units seconds; SpikeExtractors uses time units of sampling frames - unit_ids = unit_ids[:].tolist() + unit_ids = [str(unit_id) for unit_id in unit_ids] spiketrains_dict = {unit_id: spike_times[index] for index, unit_id in enumerate(unit_ids)} for unit_id in unit_ids: spiketrains_dict[unit_id] = (sampling_frequency * spiketrains_dict[unit_id]).round().astype(np.int64)