diff --git a/doc/modules/core.rst b/doc/modules/core.rst index fdc4d71fe7..976a82a4a3 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -547,8 +547,7 @@ workflow. In order to do this, one can use the :code:`Numpy*` classes, :py:class:`~spikeinterface.core.NumpyRecording`, :py:class:`~spikeinterface.core.NumpySorting`, :py:class:`~spikeinterface.core.NumpyEvent`, and :py:class:`~spikeinterface.core.NumpySnippets`. These object behave exactly like normal SpikeInterface objects, -but they are not bound to a file. This makes these objects *not dumpable*, so parallel processing is not supported. -In order to make them *dumpable*, one can simply :code:`save()` them (see :ref:`save_load`). +but they are not bound to a file. Also note the class :py:class:`~spikeinterface.core.SharedMemorySorting` which is very similar to Similar to :py:class:`~spikeinterface.core.NumpySorting` but with an unerlying SharedMemory which is usefull for diff --git a/src/spikeinterface/comparison/hybrid.py b/src/spikeinterface/comparison/hybrid.py index af410255b9..e0c98cd772 100644 --- a/src/spikeinterface/comparison/hybrid.py +++ b/src/spikeinterface/comparison/hybrid.py @@ -39,7 +39,7 @@ class HybridUnitsRecording(InjectTemplatesRecording): The refractory period of the injected spike train (in ms). injected_sorting_folder: str | Path | None If given, the injected sorting is saved to this folder. - It must be specified if injected_sorting is None or not dumpable. + It must be specified if injected_sorting is None or not serialisable to file. Returns ------- @@ -84,7 +84,8 @@ def __init__( ) # save injected sorting if necessary self.injected_sorting = injected_sorting - if not self.injected_sorting.check_if_json_serializable(): + if not self.injected_sorting.check_serializablility("json"): + # TODO later : also use pickle assert injected_sorting_folder is not None, "Provide injected_sorting_folder to injected sorting object" self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder) @@ -137,7 +138,7 @@ class HybridSpikesRecording(InjectTemplatesRecording): this refractory period. injected_sorting_folder: str | Path | None If given, the injected sorting is saved to this folder. - It must be specified if injected_sorting is None or not dumpable. + It must be specified if injected_sorting is None or not serializable to file. Returns ------- @@ -180,7 +181,8 @@ def __init__( self.injected_sorting = injected_sorting # save injected sorting if necessary - if not self.injected_sorting.check_if_json_serializable(): + if not self.injected_sorting.check_serializablility("json"): + # TODO later : also use pickle assert injected_sorting_folder is not None, "Provide injected_sorting_folder to injected sorting object" self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder) diff --git a/src/spikeinterface/comparison/multicomparisons.py b/src/spikeinterface/comparison/multicomparisons.py index d418b92ab8..f44e14c4c4 100644 --- a/src/spikeinterface/comparison/multicomparisons.py +++ b/src/spikeinterface/comparison/multicomparisons.py @@ -189,8 +189,8 @@ def save_to_folder(self, save_folder): stacklevel=2, ) for sorting in self.object_list: - assert ( - sorting.check_if_json_serializable() + assert sorting.check_serializablility( + "json" ), "MultiSortingComparison.save_to_folder() need json serializable sortings" save_folder = Path(save_folder) @@ -259,7 +259,8 @@ def __init__( BaseSorting.__init__(self, sampling_frequency=sampling_frequency, unit_ids=unit_ids) - self._is_json_serializable = False + self._serializablility["json"] = False + self._serializablility["pickle"] = True if len(unit_ids) > 0: for k in ("agreement_number", "avg_agreement", "unit_ids"): diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 87c0805630..e8b3232e13 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -57,8 +57,7 @@ def __init__(self, main_ids: Sequence) -> None: # * number of units for sorting self._properties = {} - self._is_dumpable = True - self._is_json_serializable = True + self._serializablility = {"memory": True, "json": True, "pickle": True} # extractor specific list of pip extra requirements self.extra_requirements = [] @@ -471,24 +470,33 @@ def clone(self) -> "BaseExtractor": clone = BaseExtractor.from_dict(d) return clone - def check_if_dumpable(self): - """Check if the object is dumpable, including nested objects. + def check_serializablility(self, type): + kwargs = self._kwargs + for value in kwargs.values(): + # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors + if isinstance(value, BaseExtractor): + if not value.check_serializablility(type=type): + return False + elif isinstance(value, list): + for v in value: + if isinstance(v, BaseExtractor) and not v.check_serializablility(type=type): + return False + elif isinstance(value, dict): + for v in value.values(): + if isinstance(v, BaseExtractor) and not v.check_serializablility(type=type): + return False + return self._serializablility[type] + + def check_if_memory_serializable(self): + """ + Check if the object is serializable to memory with pickle, including nested objects. Returns ------- bool - True if the object is dumpable, False otherwise. + True if the object is memory serializable, False otherwise. """ - kwargs = self._kwargs - for value in kwargs.values(): - # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors - if isinstance(value, BaseExtractor): - return value.check_if_dumpable() - elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor): - return all([v.check_if_dumpable() for v in value]) - elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor): - return all([v.check_if_dumpable() for k, v in value.items()]) - return self._is_dumpable + return self.check_serializablility("memory") def check_if_json_serializable(self): """ @@ -499,16 +507,13 @@ def check_if_json_serializable(self): bool True if the object is json serializable, False otherwise. """ - kwargs = self._kwargs - for value in kwargs.values(): - # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors - if isinstance(value, BaseExtractor): - return value.check_if_json_serializable() - elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor): - return all([v.check_if_json_serializable() for v in value]) - elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor): - return all([v.check_if_json_serializable() for k, v in value.items()]) - return self._is_json_serializable + # we keep this for backward compatilibity or not ???? + # is this needed ??? I think no. + return self.check_serializablility("json") + + def check_if_pickle_serializable(self): + # is this needed ??? I think no. + return self.check_serializablility("pickle") @staticmethod def _get_file_path(file_path: Union[str, Path], extensions: Sequence) -> Path: @@ -557,7 +562,7 @@ def dump(self, file_path: Union[str, Path], relative_to=None, folder_metadata=No if str(file_path).endswith(".json"): self.dump_to_json(file_path, relative_to=relative_to, folder_metadata=folder_metadata) elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"): - self.dump_to_pickle(file_path, relative_to=relative_to, folder_metadata=folder_metadata) + self.dump_to_pickle(file_path, folder_metadata=folder_metadata) else: raise ValueError("Dump: file must .json or .pkl") @@ -576,7 +581,7 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non folder_metadata: str, Path, or None Folder with files containing additional information (e.g. probe in BaseRecording) and properties. """ - assert self.check_if_json_serializable(), "The extractor is not json serializable" + assert self.check_serializablility("json"), "The extractor is not json serializable" # Writing paths as relative_to requires recursively expanding the dict if relative_to: @@ -616,7 +621,7 @@ def dump_to_pickle( folder_metadata: str, Path, or None Folder with files containing additional information (e.g. probe in BaseRecording) and properties. """ - assert self.check_if_dumpable(), "The extractor is not dumpable" + assert self.check_if_pickle_serializable(), "The extractor is not serializable to file with pickle" dump_dict = self.to_dict( include_annotations=True, @@ -653,8 +658,8 @@ def load(file_path: Union[str, Path], base_folder: Optional[Union[Path, str, boo d = pickle.load(f) else: raise ValueError(f"Impossible to load {file_path}") - if "warning" in d and "not dumpable" in d["warning"]: - print("The extractor was not dumpable") + if "warning" in d: + print("The extractor was not serializable to file") return None extractor = BaseExtractor.from_dict(d, base_folder=base_folder) return extractor @@ -814,10 +819,12 @@ def save_to_folder(self, name=None, folder=None, verbose=True, **save_kwargs): # dump provenance provenance_file = folder / f"provenance.json" - if self.check_if_json_serializable(): + if self.check_serializablility("json"): self.dump(provenance_file) else: - provenance_file.write_text(json.dumps({"warning": "the provenace is not dumpable!!!"}), encoding="utf8") + provenance_file.write_text( + json.dumps({"warning": "the provenace is not json serializable!!!"}), encoding="utf8" + ) self.save_metadata_to_folder(folder) @@ -911,7 +918,7 @@ def save_to_zarr( zarr_root = zarr.open(zarr_path_init, mode="w", storage_options=storage_options) - if self.check_if_dumpable(): + if self.check_if_json_serializable(): zarr_root.attrs["provenance"] = check_json(self.to_dict()) else: zarr_root.attrs["provenance"] = None diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 07837bcef7..eeb1e8af60 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1056,6 +1056,8 @@ def __init__( dtype = parent_recording.dtype if parent_recording is not None else templates.dtype BaseRecording.__init__(self, sorting.get_sampling_frequency(), channel_ids, dtype) + # Important : self._serializablility is not change here because it will depend on the sorting parents itself. + n_units = len(sorting.unit_ids) assert len(templates) == n_units self.spike_vector = sorting.to_spike_vector() @@ -1431,5 +1433,7 @@ def generate_ground_truth_recording( ) recording.annotate(is_filtered=True) recording.set_probe(probe, in_place=True) + recording.set_channel_gains(1.0) + recording.set_channel_offsets(0.0) return recording, sorting diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index c0ee77d2fd..84ee502c14 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -167,11 +167,11 @@ def ensure_n_jobs(recording, n_jobs=1): print(f"Python {sys.version} does not support parallel processing") n_jobs = 1 - if not recording.check_if_dumpable(): + if not recording.check_if_memory_serializable(): if n_jobs != 1: raise RuntimeError( - "Recording is not dumpable and can't be processed in parallel. " - "You can use the `recording.save()` function to make it dumpable or set 'n_jobs' to 1." + "Recording is not serializable to memory and can't be processed in parallel. " + "You can use the `rec = recording.save(folder=...)` function or set 'n_jobs' to 1." ) return n_jobs diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index d5663156c7..3d7ec6cd1a 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -64,7 +64,8 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N assert len(t_starts) == len(traces_list), "t_starts must be a list of same size than traces_list" t_starts = [float(t_start) for t_start in t_starts] - self._is_json_serializable = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False for i, traces in enumerate(traces_list): if t_starts is None: @@ -126,8 +127,10 @@ def __init__(self, spikes, sampling_frequency, unit_ids): """ """ BaseSorting.__init__(self, sampling_frequency, unit_ids) - self._is_dumpable = True - self._is_json_serializable = False + self._serializablility["memory"] = True + self._serializablility["json"] = False + # theorically this should be False but for simplicity make generators simples we still need this. + self._serializablility["pickle"] = True if spikes.size == 0: nseg = 1 @@ -357,8 +360,10 @@ def __init__(self, shm_name, shape, sampling_frequency, unit_ids, dtype=minimum_ assert shape[0] > 0, "SharedMemorySorting only supported with no empty sorting" BaseSorting.__init__(self, sampling_frequency, unit_ids) - self._is_dumpable = True - self._is_json_serializable = False + + self._serializablility["memory"] = True + self._serializablility["json"] = False + self._serializablility["pickle"] = False self.shm = SharedMemory(shm_name, create=False) self.shm_spikes = np.ndarray(shape=shape, dtype=dtype, buffer=self.shm.buf) @@ -516,8 +521,9 @@ def __init__(self, snippets_list, spikesframes_list, sampling_frequency, nbefore dtype=dtype, ) - self._is_dumpable = False - self._is_json_serializable = False + self._serializablility["memory"] = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False for snippets, spikesframes in zip(snippets_list, spikesframes_list): snp_segment = NumpySnippetsSegment(snippets, spikesframes) diff --git a/src/spikeinterface/core/old_api_utils.py b/src/spikeinterface/core/old_api_utils.py index 1ff31127f4..879700cc15 100644 --- a/src/spikeinterface/core/old_api_utils.py +++ b/src/spikeinterface/core/old_api_utils.py @@ -181,9 +181,10 @@ def __init__(self, oldapi_recording_extractor): dtype=oldapi_recording_extractor.get_dtype(return_scaled=False), ) - # set _is_dumpable to False to use dumping mechanism of old extractor - self._is_dumpable = False - self._is_json_serializable = False + # set to False to use dumping mechanism of old extractor + self._serializablility["memory"] = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False self.annotate(is_filtered=oldapi_recording_extractor.is_filtered) @@ -268,8 +269,9 @@ def __init__(self, oldapi_sorting_extractor): sorting_segment = OldToNewSortingSegment(oldapi_sorting_extractor) self.add_sorting_segment(sorting_segment) - self._is_dumpable = False - self._is_json_serializable = False + self._serializablility["memory"] = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False # add old properties copy_properties(oldapi_extractor=oldapi_sorting_extractor, new_extractor=self) diff --git a/src/spikeinterface/core/tests/test_base.py b/src/spikeinterface/core/tests/test_base.py index ea1a9cf0d2..a944be3da0 100644 --- a/src/spikeinterface/core/tests/test_base.py +++ b/src/spikeinterface/core/tests/test_base.py @@ -31,39 +31,39 @@ def make_nested_extractors(extractor): ) -def test_check_if_dumpable(): +def test_check_if_memory_serializable(): test_extractor = generate_recording(seed=0, durations=[2]) - # make a list of dumpable objects - extractors_dumpable = make_nested_extractors(test_extractor) - for extractor in extractors_dumpable: - assert extractor.check_if_dumpable() + # make a list of memory serializable objects + extractors_mem_serializable = make_nested_extractors(test_extractor) + for extractor in extractors_mem_serializable: + assert extractor.check_if_memory_serializable() - # make not dumpable - test_extractor._is_dumpable = False - extractors_not_dumpable = make_nested_extractors(test_extractor) - for extractor in extractors_not_dumpable: - assert not extractor.check_if_dumpable() + # make not not memory serilizable + test_extractor._serializablility["memory"] = False + extractors_not_mem_serializable = make_nested_extractors(test_extractor) + for extractor in extractors_not_mem_serializable: + assert not extractor.check_if_memory_serializable() -def test_check_if_json_serializable(): +def test_check_if_serializable(): test_extractor = generate_recording(seed=0, durations=[2]) - # make a list of dumpable objects - test_extractor._is_json_serializable = True + # make a list of json serializable objects + test_extractor._serializablility["json"] = True extractors_json_serializable = make_nested_extractors(test_extractor) for extractor in extractors_json_serializable: print(extractor) - assert extractor.check_if_json_serializable() + assert extractor.check_serializablility("json") - # make not dumpable - test_extractor._is_json_serializable = False + # make of not json serializable objects + test_extractor._serializablility["json"] = False extractors_not_json_serializable = make_nested_extractors(test_extractor) for extractor in extractors_not_json_serializable: print(extractor) - assert not extractor.check_if_json_serializable() + assert not extractor.check_serializablility("json") if __name__ == "__main__": - test_check_if_dumpable() - test_check_if_json_serializable() + test_check_if_memory_serializable() + test_check_if_serializable() diff --git a/src/spikeinterface/core/tests/test_core_tools.py b/src/spikeinterface/core/tests/test_core_tools.py index a3cd0caa92..223b2a8a3a 100644 --- a/src/spikeinterface/core/tests/test_core_tools.py +++ b/src/spikeinterface/core/tests/test_core_tools.py @@ -142,7 +142,6 @@ def test_write_memory_recording(): recording = NoiseGeneratorRecording( num_channels=2, durations=[10.325, 3.5], sampling_frequency=30_000, strategy="tile_pregenerated" ) - # make dumpable recording = recording.save() # write with loop diff --git a/src/spikeinterface/core/tests/test_job_tools.py b/src/spikeinterface/core/tests/test_job_tools.py index 7d7af6025b..a904e4dd32 100644 --- a/src/spikeinterface/core/tests/test_job_tools.py +++ b/src/spikeinterface/core/tests/test_job_tools.py @@ -36,7 +36,7 @@ def test_ensure_n_jobs(): n_jobs = ensure_n_jobs(recording, n_jobs=1) assert n_jobs == 1 - # dumpable + # check serializable n_jobs = ensure_n_jobs(recording.save(), n_jobs=-1) assert n_jobs > 1 @@ -45,7 +45,7 @@ def test_ensure_chunk_size(): recording = generate_recording(num_channels=2) dtype = recording.get_dtype() assert dtype == "float32" - # make dumpable + # make serializable recording = recording.save() chunk_size = ensure_chunk_size(recording, total_memory="512M", chunk_size=None, chunk_memory=None, n_jobs=2) @@ -90,7 +90,7 @@ def init_func(arg1, arg2, arg3): def test_ChunkRecordingExecutor(): recording = generate_recording(num_channels=2) - # make dumpable + # make serializable recording = recording.save() init_args = "a", 120, "yep" diff --git a/src/spikeinterface/core/tests/test_jsonification.py b/src/spikeinterface/core/tests/test_jsonification.py index 473648c5ec..1c491bd7a6 100644 --- a/src/spikeinterface/core/tests/test_jsonification.py +++ b/src/spikeinterface/core/tests/test_jsonification.py @@ -142,9 +142,11 @@ def __init__(self, attribute, other_extractor=None, extractor_list=None, extract self.extractor_list = extractor_list self.extractor_dict = extractor_dict + BaseExtractor.__init__(self, main_ids=["1", "2"]) # this already the case by default - self._is_dumpable = True - self._is_json_serializable = True + self._serializablility["memory"] = True + self._serializablility["json"] = True + self._serializablility["pickle"] = True self._kwargs = { "attribute": attribute, @@ -195,3 +197,8 @@ def test_encoding_numpy_scalars_within_nested_extractors_list(nested_extractor_l def test_encoding_numpy_scalars_within_nested_extractors_dict(nested_extractor_dict): json.dumps(nested_extractor_dict, cls=SIJsonEncoder) + + +if __name__ == "__main__": + nested_extractor = nested_extractor() + test_encoding_numpy_scalars_within_nested_extractors(nested_extractor_) diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index 107ef5f180..2bbf5e9b0f 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -6,7 +6,13 @@ import zarr -from spikeinterface.core import generate_recording, generate_sorting, NumpySorting, ChannelSparsity +from spikeinterface.core import ( + generate_recording, + generate_sorting, + NumpySorting, + ChannelSparsity, + generate_ground_truth_recording, +) from spikeinterface import WaveformExtractor, BaseRecording, extract_waveforms, load_waveforms from spikeinterface.core.waveform_extractor import precompute_sparsity @@ -309,7 +315,7 @@ def test_recordingless(): recording = recording.save(folder=cache_folder / "recording1") sorting = sorting.save(folder=cache_folder / "sorting1") - # recording and sorting are not dumpable + # recording and sorting are not serializable wf_folder = cache_folder / "wf_recordingless" # save with relative paths @@ -510,10 +516,44 @@ def test_compute_sparsity(): print(sparsity) +def test_non_json_object(): + recording, sorting = generate_ground_truth_recording( + durations=[30, 40], + sampling_frequency=30000.0, + num_channels=32, + num_units=5, + ) + + # recording is not save to keep it in memory + sorting = sorting.save() + + wf_folder = cache_folder / "test_waveform_extractor" + if wf_folder.is_dir(): + shutil.rmtree(wf_folder) + + we = extract_waveforms( + recording, + sorting, + wf_folder, + mode="folder", + sparsity=None, + sparse=False, + ms_before=1.0, + ms_after=1.6, + max_spikes_per_unit=50, + n_jobs=4, + chunk_size=30000, + progress_bar=True, + ) + + # This used to fail because of json + we = load_waveforms(wf_folder) + + if __name__ == "__main__": - test_WaveformExtractor() + # test_WaveformExtractor() # test_extract_waveforms() - # test_sparsity() # test_portability() # test_recordingless() # test_compute_sparsity() + test_non_json_object() diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 6881ab3ec5..2710ff1338 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -159,11 +159,20 @@ def load_from_folder( else: rec_attributes["probegroup"] = None else: - try: - recording = load_extractor(folder / "recording.json", base_folder=folder) - rec_attributes = None - except: + recording = None + if (folder / "recording.json").exists(): + try: + recording = load_extractor(folder / "recording.json", base_folder=folder) + except: + pass + elif (folder / "recording.pickle").exists(): + try: + recording = load_extractor(folder / "recording.pickle") + except: + pass + if recording is None: raise Exception("The recording could not be loaded. You can use the `with_recording=False` argument") + rec_attributes = None if sorting is None: sorting = load_extractor(folder / "sorting.json", base_folder=folder) @@ -271,14 +280,22 @@ def create( else: relative_to = None - if recording.check_if_json_serializable(): + if recording.check_serializablility("json"): recording.dump(folder / "recording.json", relative_to=relative_to) - if sorting.check_if_json_serializable(): + elif recording.check_serializablility("pickle"): + # In this case we loose the relative_to!! + recording.dump(folder / "recording.pickle") + + if sorting.check_serializablility("json"): sorting.dump(folder / "sorting.json", relative_to=relative_to) + elif sorting.check_serializablility("pickle"): + # In this case we loose the relative_to!! + # TODO later the dump to pickle should dump the dictionary and so relative could be put back + sorting.dump(folder / "sorting.pickle") else: warn( - "Sorting object is not dumpable, which might result in downstream errors for " - "parallel processing. To make the sorting dumpable, use the `sorting.save()` function." + "Sorting object is not serializable to file, which might result in downstream errors for " + "parallel processing. To make the sorting serializable, use the `sorting = sorting.save()` function." ) # dump some attributes of the recording for the mode with_recording=False at next load @@ -879,14 +896,19 @@ def save( (folder / "params.json").write_text(json.dumps(check_json(self._params), indent=4), encoding="utf8") if self.has_recording(): - if self.recording.check_if_json_serializable(): + if self.recording.check_serializablility("json"): self.recording.dump(folder / "recording.json", relative_to=relative_to) - if self.sorting.check_if_json_serializable(): + elif self.recording.check_serializablility("pickle"): + self.recording.dump(folder / "recording.pickle") + + if self.sorting.check_serializablility("json"): self.sorting.dump(folder / "sorting.json", relative_to=relative_to) + elif self.sorting.check_serializablility("pickle"): + self.sorting.dump(folder / "sorting.pickle") else: warn( - "Sorting object is not dumpable, which might result in downstream errors for " - "parallel processing. To make the sorting dumpable, use the `sorting.save()` function." + "Sorting object is not serializable to file, which might result in downstream errors for " + "parallel processing. To make the sorting serializable, use the `sorting = sorting.save()` function." ) # dump some attributes of the recording for the mode with_recording=False at next load @@ -931,16 +953,16 @@ def save( # write metadata zarr_root.attrs["params"] = check_json(self._params) if self.has_recording(): - if self.recording.check_if_json_serializable(): + if self.recording.check_serializablility("json"): rec_dict = self.recording.to_dict(relative_to=relative_to, recursive=True) zarr_root.attrs["recording"] = check_json(rec_dict) - if self.sorting.check_if_json_serializable(): + if self.sorting.check_serializablility("json"): sort_dict = self.sorting.to_dict(relative_to=relative_to, recursive=True) zarr_root.attrs["sorting"] = check_json(sort_dict) else: warn( - "Sorting object is not dumpable, which might result in downstream errors for " - "parallel processing. To make the sorting dumpable, use the `sorting.save()` function." + "Sorting object is not json serializable, which might result in downstream errors for " + "parallel processing. To make the sorting serializable, use the `sorting = sorting.save()` function." ) recording_info = zarr_root.create_group("recording_info") recording_info.attrs["recording_attributes"] = check_json(rec_attributes) diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 38cb714d59..ccd2121174 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -73,12 +73,6 @@ def _run(self, **job_kwargs): func = _spike_amplitudes_chunk init_func = _init_worker_spike_amplitudes n_jobs = ensure_n_jobs(recording, job_kwargs.get("n_jobs", None)) - if n_jobs != 1: - # TODO: avoid dumping sorting and use spike vector and peak pipeline instead - assert sorting.check_if_dumpable(), ( - "The sorting object is not dumpable and cannot be processed in parallel. You can use the " - "`sorting.save()` function to make it dumpable" - ) init_args = (recording, sorting.to_multiprocessing(n_jobs), extremum_channels_index, peak_shifts, return_scaled) processor = ChunkRecordingExecutor( recording, func, init_func, init_args, handle_returns=True, job_name="extract amplitudes", **job_kwargs diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index e2ef6e6794..6ab1a9afce 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -333,7 +333,7 @@ def correct_motion( ) (folder / "parameters.json").write_text(json.dumps(parameters, indent=4, cls=SIJsonEncoder), encoding="utf8") (folder / "run_times.json").write_text(json.dumps(run_times, indent=4), encoding="utf8") - if recording.check_if_json_serializable(): + if recording.check_serializablility("json"): recording.dump_to_json(folder / "recording.json") np.save(folder / "peaks.npy", peaks) diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index c7581ba1e1..8d87558191 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -137,8 +137,10 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo ) rec_file = output_folder / "spikeinterface_recording.json" - if recording.check_if_json_serializable(): - recording.dump_to_json(rec_file, relative_to=output_folder) + if recording.check_serializablility("json"): + recording.dump(rec_file, relative_to=output_folder) + elif recording.check_serializablility("pickle"): + recording.dump(output_folder / "spikeinterface_recording.pickle") else: d = {"warning": "The recording is not serializable to json"} rec_file.write_text(json.dumps(d, indent=4), encoding="utf8") @@ -185,6 +187,26 @@ def set_params_to_folder(cls, recording, output_folder, new_params, verbose): return params + @classmethod + def load_recording_from_folder(cls, output_folder, with_warnings=False): + json_file = output_folder / "spikeinterface_recording.json" + pickle_file = output_folder / "spikeinterface_recording.pickle" + + if json_file.exists(): + with (json_file).open("r", encoding="utf8") as f: + recording_dict = json.load(f) + if "warning" in recording_dict.keys() and with_warnings: + warnings.warn( + "The recording that has been sorted is not JSON serializable: it cannot be registered to the sorting object." + ) + recording = None + else: + recording = load_extractor(json_file, base_folder=output_folder) + elif pickle_file.exits(): + recording = load_extractor(pickle_file) + + return recording + @classmethod def _dump_params(cls, recording, output_folder, sorter_params, verbose): with (output_folder / "spikeinterface_params.json").open(mode="w", encoding="utf8") as f: @@ -271,7 +293,7 @@ def run_from_folder(cls, output_folder, raise_error, verbose): return run_time @classmethod - def get_result_from_folder(cls, output_folder): + def get_result_from_folder(cls, output_folder, register_recording=True, sorting_info=True): output_folder = Path(output_folder) sorter_output_folder = output_folder / "sorter_output" # check errors in log file @@ -294,27 +316,21 @@ def get_result_from_folder(cls, output_folder): # back-compatibility sorting = cls._get_result_from_folder(output_folder) - # register recording to Sorting object - # check if not json serializable - with (output_folder / "spikeinterface_recording.json").open("r", encoding="utf8") as f: - recording_dict = json.load(f) - if "warning" in recording_dict.keys(): - warnings.warn( - "The recording that has been sorted is not JSON serializable: it cannot be registered to the sorting object." - ) - else: - recording = load_extractor(output_folder / "spikeinterface_recording.json", base_folder=output_folder) + if register_recording: + # register recording to Sorting object + recording = cls.load_recording_from_folder(output_folder, with_warnings=False) if recording is not None: - # can be None when not dumpable sorting.register_recording(recording) - # set sorting info to Sorting object - with open(output_folder / "spikeinterface_recording.json", "r") as f: - rec_dict = json.load(f) - with open(output_folder / "spikeinterface_params.json", "r") as f: - params_dict = json.load(f) - with open(output_folder / "spikeinterface_log.json", "r") as f: - log_dict = json.load(f) - sorting.set_sorting_info(rec_dict, params_dict, log_dict) + + if sorting_info: + # set sorting info to Sorting object + with open(output_folder / "spikeinterface_recording.json", "r") as f: + rec_dict = json.load(f) + with open(output_folder / "spikeinterface_params.json", "r") as f: + params_dict = json.load(f) + with open(output_folder / "spikeinterface_log.json", "r") as f: + log_dict = json.load(f) + sorting.set_sorting_info(rec_dict, params_dict, log_dict) return sorting diff --git a/src/spikeinterface/sorters/external/herdingspikes.py b/src/spikeinterface/sorters/external/herdingspikes.py index a8d702ebe9..5180e6f1cc 100644 --- a/src/spikeinterface/sorters/external/herdingspikes.py +++ b/src/spikeinterface/sorters/external/herdingspikes.py @@ -147,9 +147,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): else: new_api = False - recording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) p = params diff --git a/src/spikeinterface/sorters/external/mountainsort4.py b/src/spikeinterface/sorters/external/mountainsort4.py index 69f97fd11c..f6f0b3eaeb 100644 --- a/src/spikeinterface/sorters/external/mountainsort4.py +++ b/src/spikeinterface/sorters/external/mountainsort4.py @@ -89,9 +89,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): def _run_from_folder(cls, sorter_output_folder, params, verbose): import mountainsort4 - recording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) # alias to params p = params diff --git a/src/spikeinterface/sorters/external/mountainsort5.py b/src/spikeinterface/sorters/external/mountainsort5.py index df6d276bf5..a88c59d688 100644 --- a/src/spikeinterface/sorters/external/mountainsort5.py +++ b/src/spikeinterface/sorters/external/mountainsort5.py @@ -115,9 +115,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): def _run_from_folder(cls, sorter_output_folder, params, verbose): import mountainsort5 as ms5 - recording: BaseRecording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) # alias to params p = params diff --git a/src/spikeinterface/sorters/external/pykilosort.py b/src/spikeinterface/sorters/external/pykilosort.py index 2a41d793d5..1962d56206 100644 --- a/src/spikeinterface/sorters/external/pykilosort.py +++ b/src/spikeinterface/sorters/external/pykilosort.py @@ -148,9 +148,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): - recording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) if not recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): # saved by setup recording diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index db3d88f116..66c008c3fc 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -52,9 +52,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): job_kwargs["verbose"] = verbose job_kwargs["progress_bar"] = verbose - recording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) + sampling_rate = recording.get_sampling_frequency() num_channels = recording.get_num_channels() diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 42f51d3a77..ed327e0f3c 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -49,9 +49,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): import hdbscan - recording_raw = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording_raw = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) num_chans = recording_raw.get_num_channels() sampling_frequency = recording_raw.get_sampling_frequency() diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index 6e6ccc0358..bd5667b15f 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -624,10 +624,20 @@ def run_sorter_container( ) -def read_sorter_folder(output_folder, raise_error=True): +def read_sorter_folder(output_folder, register_recording=True, sorting_info=True, raise_error=True): """ Load a sorting object from a spike sorting output folder. The 'output_folder' must contain a valid 'spikeinterface_log.json' file + + + Parameters + ---------- + output_folder: Pth or str + The sorter folder + register_recording: bool, default: True + Attach recording (when json or pickle) to the sorting + sorting_info: bool, default: True + Attach sorting info to the sorting. """ output_folder = Path(output_folder) log_file = output_folder / "spikeinterface_log.json" @@ -647,7 +657,9 @@ def read_sorter_folder(output_folder, raise_error=True): sorter_name = log["sorter_name"] SorterClass = sorter_dict[sorter_name] - sorting = SorterClass.get_result_from_folder(output_folder) + sorting = SorterClass.get_result_from_folder( + output_folder, register_recording=register_recording, sorting_info=sorting_info + ) return sorting diff --git a/src/spikeinterface/sorters/tests/test_launcher.py b/src/spikeinterface/sorters/tests/test_launcher.py index 14c938f8ba..a5e29c8fd9 100644 --- a/src/spikeinterface/sorters/tests/test_launcher.py +++ b/src/spikeinterface/sorters/tests/test_launcher.py @@ -178,7 +178,7 @@ def test_run_sorters_with_list(): if working_folder.is_dir(): shutil.rmtree(working_folder) - # make dumpable + # make serializable rec0 = load_extractor(cache_folder / "toy_rec_0") rec1 = load_extractor(cache_folder / "toy_rec_1")