From 9a97e68f848d1126126bfecd819f456e12113813 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 20 Sep 2023 10:52:05 +0200 Subject: [PATCH 01/10] Improve the concept of check_if_json_serializable to more serialation engine like pickle. --- src/spikeinterface/comparison/hybrid.py | 6 ++- .../comparison/multicomparisons.py | 7 ++- src/spikeinterface/core/base.py | 50 +++++++++++++------ src/spikeinterface/core/generate.py | 2 + src/spikeinterface/core/numpyextractors.py | 16 ++++-- src/spikeinterface/core/old_api_utils.py | 8 ++- src/spikeinterface/core/tests/test_base.py | 12 +++-- .../core/tests/test_jsonification.py | 10 +++- .../core/tests/test_waveform_extractor.py | 41 +++++++++++++-- src/spikeinterface/core/waveform_extractor.py | 34 ++++++++++--- src/spikeinterface/preprocessing/motion.py | 3 +- src/spikeinterface/sorters/basesorter.py | 3 +- 12 files changed, 150 insertions(+), 42 deletions(-) diff --git a/src/spikeinterface/comparison/hybrid.py b/src/spikeinterface/comparison/hybrid.py index af410255b9..c48ce70147 100644 --- a/src/spikeinterface/comparison/hybrid.py +++ b/src/spikeinterface/comparison/hybrid.py @@ -84,7 +84,8 @@ def __init__( ) # save injected sorting if necessary self.injected_sorting = injected_sorting - if not self.injected_sorting.check_if_json_serializable(): + # if not self.injected_sorting.check_if_json_serializable(): + if not self.injected_sorting.check_serializablility("json"): assert injected_sorting_folder is not None, "Provide injected_sorting_folder to injected sorting object" self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder) @@ -180,7 +181,8 @@ def __init__( self.injected_sorting = injected_sorting # save injected sorting if necessary - if not self.injected_sorting.check_if_json_serializable(): + # if not self.injected_sorting.check_if_json_serializable(): + if not self.injected_sorting.check_serializablility("json"): assert injected_sorting_folder is not None, "Provide injected_sorting_folder to injected sorting object" self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder) diff --git a/src/spikeinterface/comparison/multicomparisons.py b/src/spikeinterface/comparison/multicomparisons.py index 9e02fd5b2d..3a7075905e 100644 --- a/src/spikeinterface/comparison/multicomparisons.py +++ b/src/spikeinterface/comparison/multicomparisons.py @@ -182,7 +182,8 @@ def get_agreement_sorting(self, minimum_agreement_count=1, minimum_agreement_cou def save_to_folder(self, save_folder): for sorting in self.object_list: assert ( - sorting.check_if_json_serializable() + # sorting.check_if_json_serializable() + sorting.check_serializablility("json") ), "MultiSortingComparison.save_to_folder() need json serializable sortings" save_folder = Path(save_folder) @@ -244,7 +245,9 @@ def __init__( BaseSorting.__init__(self, sampling_frequency=sampling_frequency, unit_ids=unit_ids) - self._is_json_serializable = False + # self._is_json_serializable = False + self._serializablility["json"] = False + self._serializablility["pickle"] = True if len(unit_ids) > 0: for k in ("agreement_number", "avg_agreement", "unit_ids"): diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 87c0805630..d87bd617c4 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -58,7 +58,8 @@ def __init__(self, main_ids: Sequence) -> None: self._properties = {} self._is_dumpable = True - self._is_json_serializable = True + # self._is_json_serializable = True + self._serializablility = {'json': True, 'pickle': True} # extractor specific list of pip extra requirements self.extra_requirements = [] @@ -490,6 +491,18 @@ def check_if_dumpable(self): return all([v.check_if_dumpable() for k, v in value.items()]) return self._is_dumpable + def check_serializablility(self, type="json"): + kwargs = self._kwargs + for value in kwargs.values(): + # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors + if isinstance(value, BaseExtractor): + return value.check_serializablility(type=type) + elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor): + return all([v.check_serializablility(type=type) for v in value]) + elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor): + return all([v.check_serializablility(type=type) for k, v in value.items()]) + return self._serializablility[type] + def check_if_json_serializable(self): """ Check if the object is json serializable, including nested objects. @@ -499,16 +512,23 @@ def check_if_json_serializable(self): bool True if the object is json serializable, False otherwise. """ - kwargs = self._kwargs - for value in kwargs.values(): - # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors - if isinstance(value, BaseExtractor): - return value.check_if_json_serializable() - elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor): - return all([v.check_if_json_serializable() for v in value]) - elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor): - return all([v.check_if_json_serializable() for k, v in value.items()]) - return self._is_json_serializable + # we keep this for backward compatilibity or not ???? + return self.check_serializablility("json") + + # kwargs = self._kwargs + # for value in kwargs.values(): + # # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors + # if isinstance(value, BaseExtractor): + # return value.check_if_json_serializable() + # elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor): + # return all([v.check_if_json_serializable() for v in value]) + # elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor): + # return all([v.check_if_json_serializable() for k, v in value.items()]) + # return self._is_json_serializable + + def check_if_pickle_serializable(self): + # is this needed + return self.check_serializablility("pickle") @staticmethod def _get_file_path(file_path: Union[str, Path], extensions: Sequence) -> Path: @@ -557,7 +577,7 @@ def dump(self, file_path: Union[str, Path], relative_to=None, folder_metadata=No if str(file_path).endswith(".json"): self.dump_to_json(file_path, relative_to=relative_to, folder_metadata=folder_metadata) elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"): - self.dump_to_pickle(file_path, relative_to=relative_to, folder_metadata=folder_metadata) + self.dump_to_pickle(file_path, folder_metadata=folder_metadata) else: raise ValueError("Dump: file must .json or .pkl") @@ -576,7 +596,8 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non folder_metadata: str, Path, or None Folder with files containing additional information (e.g. probe in BaseRecording) and properties. """ - assert self.check_if_json_serializable(), "The extractor is not json serializable" + # assert self.check_if_json_serializable(), "The extractor is not json serializable" + assert self.check_serializablility("json"), "The extractor is not json serializable" # Writing paths as relative_to requires recursively expanding the dict if relative_to: @@ -814,7 +835,8 @@ def save_to_folder(self, name=None, folder=None, verbose=True, **save_kwargs): # dump provenance provenance_file = folder / f"provenance.json" - if self.check_if_json_serializable(): + # if self.check_if_json_serializable(): + if self.check_serializablility("json"): self.dump(provenance_file) else: provenance_file.write_text(json.dumps({"warning": "the provenace is not dumpable!!!"}), encoding="utf8") diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 07837bcef7..706054c957 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1431,5 +1431,7 @@ def generate_ground_truth_recording( ) recording.annotate(is_filtered=True) recording.set_probe(probe, in_place=True) + recording.set_property("gain_to_uV", np.ones(num_channels)) + recording.set_property("offset_to_uV", np.zeros(num_channels)) return recording, sorting diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index d5663156c7..f55b975ddb 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -64,7 +64,9 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N assert len(t_starts) == len(traces_list), "t_starts must be a list of same size than traces_list" t_starts = [float(t_start) for t_start in t_starts] - self._is_json_serializable = False + # self._is_json_serializable = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False for i, traces in enumerate(traces_list): if t_starts is None: @@ -127,7 +129,9 @@ def __init__(self, spikes, sampling_frequency, unit_ids): BaseSorting.__init__(self, sampling_frequency, unit_ids) self._is_dumpable = True - self._is_json_serializable = False + # self._is_json_serializable = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False if spikes.size == 0: nseg = 1 @@ -358,7 +362,9 @@ def __init__(self, shm_name, shape, sampling_frequency, unit_ids, dtype=minimum_ BaseSorting.__init__(self, sampling_frequency, unit_ids) self._is_dumpable = True - self._is_json_serializable = False + # self._is_json_serializable = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False self.shm = SharedMemory(shm_name, create=False) self.shm_spikes = np.ndarray(shape=shape, dtype=dtype, buffer=self.shm.buf) @@ -517,7 +523,9 @@ def __init__(self, snippets_list, spikesframes_list, sampling_frequency, nbefore ) self._is_dumpable = False - self._is_json_serializable = False + # self._is_json_serializable = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False for snippets, spikesframes in zip(snippets_list, spikesframes_list): snp_segment = NumpySnippetsSegment(snippets, spikesframes) diff --git a/src/spikeinterface/core/old_api_utils.py b/src/spikeinterface/core/old_api_utils.py index 1ff31127f4..38fbef1547 100644 --- a/src/spikeinterface/core/old_api_utils.py +++ b/src/spikeinterface/core/old_api_utils.py @@ -183,7 +183,9 @@ def __init__(self, oldapi_recording_extractor): # set _is_dumpable to False to use dumping mechanism of old extractor self._is_dumpable = False - self._is_json_serializable = False + # self._is_json_serializable = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False self.annotate(is_filtered=oldapi_recording_extractor.is_filtered) @@ -269,7 +271,9 @@ def __init__(self, oldapi_sorting_extractor): self.add_sorting_segment(sorting_segment) self._is_dumpable = False - self._is_json_serializable = False + # self._is_json_serializable = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False # add old properties copy_properties(oldapi_extractor=oldapi_sorting_extractor, new_extractor=self) diff --git a/src/spikeinterface/core/tests/test_base.py b/src/spikeinterface/core/tests/test_base.py index ea1a9cf0d2..77a5d7d9bf 100644 --- a/src/spikeinterface/core/tests/test_base.py +++ b/src/spikeinterface/core/tests/test_base.py @@ -50,18 +50,22 @@ def test_check_if_json_serializable(): test_extractor = generate_recording(seed=0, durations=[2]) # make a list of dumpable objects - test_extractor._is_json_serializable = True + # test_extractor._is_json_serializable = True + test_extractor._serializablility["json"] = True extractors_json_serializable = make_nested_extractors(test_extractor) for extractor in extractors_json_serializable: print(extractor) - assert extractor.check_if_json_serializable() + # assert extractor.check_if_json_serializable() + assert extractor.check_serializablility("json") # make not dumpable - test_extractor._is_json_serializable = False + # test_extractor._is_json_serializable = False + test_extractor._serializablility["json"] = False extractors_not_json_serializable = make_nested_extractors(test_extractor) for extractor in extractors_not_json_serializable: print(extractor) - assert not extractor.check_if_json_serializable() + # assert not extractor.check_if_json_serializable() + assert not extractor.check_serializablility("json") if __name__ == "__main__": diff --git a/src/spikeinterface/core/tests/test_jsonification.py b/src/spikeinterface/core/tests/test_jsonification.py index 473648c5ec..8572cda23e 100644 --- a/src/spikeinterface/core/tests/test_jsonification.py +++ b/src/spikeinterface/core/tests/test_jsonification.py @@ -142,9 +142,12 @@ def __init__(self, attribute, other_extractor=None, extractor_list=None, extract self.extractor_list = extractor_list self.extractor_dict = extractor_dict + BaseExtractor.__init__(self, main_ids=['1', '2']) # this already the case by default self._is_dumpable = True - self._is_json_serializable = True + # self._is_json_serializable = True + self._serializablility["json"] = True + self._serializablility["pickle"] = True self._kwargs = { "attribute": attribute, @@ -195,3 +198,8 @@ def test_encoding_numpy_scalars_within_nested_extractors_list(nested_extractor_l def test_encoding_numpy_scalars_within_nested_extractors_dict(nested_extractor_dict): json.dumps(nested_extractor_dict, cls=SIJsonEncoder) + + +if __name__ == '__main__': + nested_extractor = nested_extractor() + test_encoding_numpy_scalars_within_nested_extractors(nested_extractor_) \ No newline at end of file diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index 107ef5f180..f53b9cf18d 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -6,7 +6,7 @@ import zarr -from spikeinterface.core import generate_recording, generate_sorting, NumpySorting, ChannelSparsity +from spikeinterface.core import generate_recording, generate_sorting, NumpySorting, ChannelSparsity, generate_ground_truth_recording from spikeinterface import WaveformExtractor, BaseRecording, extract_waveforms, load_waveforms from spikeinterface.core.waveform_extractor import precompute_sparsity @@ -509,11 +509,46 @@ def test_compute_sparsity(): ) print(sparsity) +def test_non_json_object(): + recording, sorting = generate_ground_truth_recording( + durations=[30, 40], + sampling_frequency=30000.0, + num_channels=32, + num_units=5, + ) + + # recording is not save to keep it in memory + sorting = sorting.save() + + wf_folder = cache_folder / "test_waveform_extractor" + if wf_folder.is_dir(): + shutil.rmtree(wf_folder) + + + we = extract_waveforms( + recording, + sorting, + wf_folder, + mode="folder", + sparsity=None, + sparse=False, + ms_before=1.0, + ms_after=1.6, + max_spikes_per_unit=50, + n_jobs=4, + chunk_size=30000, + progress_bar=True, + ) + + # This used to fail because of json + we = load_waveforms(wf_folder) + if __name__ == "__main__": - test_WaveformExtractor() + # test_WaveformExtractor() # test_extract_waveforms() - # test_sparsity() # test_portability() # test_recordingless() # test_compute_sparsity() + test_non_json_object() + diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 6881ab3ec5..53852bf319 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -159,11 +159,20 @@ def load_from_folder( else: rec_attributes["probegroup"] = None else: - try: - recording = load_extractor(folder / "recording.json", base_folder=folder) - rec_attributes = None - except: + recording = None + if (folder / "recording.json").exists(): + try: + recording = load_extractor(folder / "recording.json", base_folder=folder) + except: + pass + elif (folder / "recording.pickle").exists(): + try: + recording = load_extractor(folder / "recording.pickle") + except: + pass + if recording is None: raise Exception("The recording could not be loaded. You can use the `with_recording=False` argument") + rec_attributes = None if sorting is None: sorting = load_extractor(folder / "sorting.json", base_folder=folder) @@ -271,9 +280,16 @@ def create( else: relative_to = None - if recording.check_if_json_serializable(): + # if recording.check_if_json_serializable(): + if recording.check_serializablility("json"): recording.dump(folder / "recording.json", relative_to=relative_to) - if sorting.check_if_json_serializable(): + elif recording.check_serializablility("pickle"): + # In this case we loose the relative_to!! + # TODO make sure that we do not dump to pickle a NumpyRecording!!!!! + recording.dump(folder / "recording.pickle") + + # if sorting.check_if_json_serializable(): + if sorting.check_serializablility("json"): sorting.dump(folder / "sorting.json", relative_to=relative_to) else: warn( @@ -879,9 +895,11 @@ def save( (folder / "params.json").write_text(json.dumps(check_json(self._params), indent=4), encoding="utf8") if self.has_recording(): - if self.recording.check_if_json_serializable(): + # if self.recording.check_if_json_serializable(): + if self.recording.check_serializablility("json"): self.recording.dump(folder / "recording.json", relative_to=relative_to) - if self.sorting.check_if_json_serializable(): + # if self.sorting.check_if_json_serializable(): + if self.sorting.check_serializablility("json"): self.sorting.dump(folder / "sorting.json", relative_to=relative_to) else: warn( diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index e2ef6e6794..0054fb94d4 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -333,7 +333,8 @@ def correct_motion( ) (folder / "parameters.json").write_text(json.dumps(parameters, indent=4, cls=SIJsonEncoder), encoding="utf8") (folder / "run_times.json").write_text(json.dumps(run_times, indent=4), encoding="utf8") - if recording.check_if_json_serializable(): + # if recording.check_if_json_serializable(): + if recording.check_serializablility("json"): recording.dump_to_json(folder / "recording.json") np.save(folder / "peaks.npy", peaks) diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index c7581ba1e1..da20506965 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -137,7 +137,8 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo ) rec_file = output_folder / "spikeinterface_recording.json" - if recording.check_if_json_serializable(): + # if recording.check_if_json_serializable(): + if recording.check_serializablility("json"): recording.dump_to_json(rec_file, relative_to=output_folder) else: d = {"warning": "The recording is not serializable to json"} From 3f4e182380995f56d458163356a70a813af6b146 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 20 Sep 2023 14:01:50 +0200 Subject: [PATCH 02/10] More check and clean for check_if_serializable() --- src/spikeinterface/comparison/hybrid.py | 4 +- .../comparison/multicomparisons.py | 2 - src/spikeinterface/core/base.py | 46 +++++++++---------- src/spikeinterface/core/generate.py | 2 + src/spikeinterface/core/numpyextractors.py | 8 ++-- src/spikeinterface/core/old_api_utils.py | 2 - src/spikeinterface/core/tests/test_base.py | 7 +-- .../core/tests/test_waveform_extractor.py | 2 + src/spikeinterface/core/waveform_extractor.py | 18 +++++--- src/spikeinterface/preprocessing/motion.py | 1 - 10 files changed, 44 insertions(+), 48 deletions(-) diff --git a/src/spikeinterface/comparison/hybrid.py b/src/spikeinterface/comparison/hybrid.py index c48ce70147..3b8e9e0a72 100644 --- a/src/spikeinterface/comparison/hybrid.py +++ b/src/spikeinterface/comparison/hybrid.py @@ -84,8 +84,8 @@ def __init__( ) # save injected sorting if necessary self.injected_sorting = injected_sorting - # if not self.injected_sorting.check_if_json_serializable(): if not self.injected_sorting.check_serializablility("json"): + # TODO later : also use pickle assert injected_sorting_folder is not None, "Provide injected_sorting_folder to injected sorting object" self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder) @@ -181,8 +181,8 @@ def __init__( self.injected_sorting = injected_sorting # save injected sorting if necessary - # if not self.injected_sorting.check_if_json_serializable(): if not self.injected_sorting.check_serializablility("json"): + # TODO later : also use pickle assert injected_sorting_folder is not None, "Provide injected_sorting_folder to injected sorting object" self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder) diff --git a/src/spikeinterface/comparison/multicomparisons.py b/src/spikeinterface/comparison/multicomparisons.py index 3a7075905e..09a8c8aed1 100644 --- a/src/spikeinterface/comparison/multicomparisons.py +++ b/src/spikeinterface/comparison/multicomparisons.py @@ -182,7 +182,6 @@ def get_agreement_sorting(self, minimum_agreement_count=1, minimum_agreement_cou def save_to_folder(self, save_folder): for sorting in self.object_list: assert ( - # sorting.check_if_json_serializable() sorting.check_serializablility("json") ), "MultiSortingComparison.save_to_folder() need json serializable sortings" @@ -245,7 +244,6 @@ def __init__( BaseSorting.__init__(self, sampling_frequency=sampling_frequency, unit_ids=unit_ids) - # self._is_json_serializable = False self._serializablility["json"] = False self._serializablility["pickle"] = True diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index d87bd617c4..63cf8e894f 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -484,11 +484,16 @@ def check_if_dumpable(self): for value in kwargs.values(): # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors if isinstance(value, BaseExtractor): - return value.check_if_dumpable() - elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor): - return all([v.check_if_dumpable() for v in value]) - elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor): - return all([v.check_if_dumpable() for k, v in value.items()]) + if not value.check_if_dumpable(): + return False + elif isinstance(value, list): + for v in value: + if isinstance(v, BaseExtractor) and not v.check_if_dumpable(): + return False + elif isinstance(value, dict): + for v in value.values(): + if isinstance(v, BaseExtractor) and not v.check_if_dumpable(): + return False return self._is_dumpable def check_serializablility(self, type="json"): @@ -496,11 +501,16 @@ def check_serializablility(self, type="json"): for value in kwargs.values(): # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors if isinstance(value, BaseExtractor): - return value.check_serializablility(type=type) - elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor): - return all([v.check_serializablility(type=type) for v in value]) - elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor): - return all([v.check_serializablility(type=type) for k, v in value.items()]) + if not value.check_serializablility(type=type): + return False + elif isinstance(value, list): + for v in value: + if isinstance(v, BaseExtractor) and not v.check_serializablility(type=type): + return False + elif isinstance(value, dict): + for v in value.values(): + if isinstance(v, BaseExtractor) and not v.check_serializablility(type=type): + return False return self._serializablility[type] def check_if_json_serializable(self): @@ -513,21 +523,11 @@ def check_if_json_serializable(self): True if the object is json serializable, False otherwise. """ # we keep this for backward compatilibity or not ???? + # is this needed ??? I think no. return self.check_serializablility("json") - # kwargs = self._kwargs - # for value in kwargs.values(): - # # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors - # if isinstance(value, BaseExtractor): - # return value.check_if_json_serializable() - # elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor): - # return all([v.check_if_json_serializable() for v in value]) - # elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor): - # return all([v.check_if_json_serializable() for k, v in value.items()]) - # return self._is_json_serializable - def check_if_pickle_serializable(self): - # is this needed + # is this needed ??? I think no. return self.check_serializablility("pickle") @staticmethod @@ -596,7 +596,6 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non folder_metadata: str, Path, or None Folder with files containing additional information (e.g. probe in BaseRecording) and properties. """ - # assert self.check_if_json_serializable(), "The extractor is not json serializable" assert self.check_serializablility("json"), "The extractor is not json serializable" # Writing paths as relative_to requires recursively expanding the dict @@ -835,7 +834,6 @@ def save_to_folder(self, name=None, folder=None, verbose=True, **save_kwargs): # dump provenance provenance_file = folder / f"provenance.json" - # if self.check_if_json_serializable(): if self.check_serializablility("json"): self.dump(provenance_file) else: diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 706054c957..362b598b0b 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1056,6 +1056,8 @@ def __init__( dtype = parent_recording.dtype if parent_recording is not None else templates.dtype BaseRecording.__init__(self, sorting.get_sampling_frequency(), channel_ids, dtype) + # Important : self._serializablility is not change here because it will depend on the sorting parents itself. + n_units = len(sorting.unit_ids) assert len(templates) == n_units self.spike_vector = sorting.to_spike_vector() diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index f55b975ddb..5ef955a6eb 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -64,7 +64,6 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N assert len(t_starts) == len(traces_list), "t_starts must be a list of same size than traces_list" t_starts = [float(t_start) for t_start in t_starts] - # self._is_json_serializable = False self._serializablility["json"] = False self._serializablility["pickle"] = False @@ -129,9 +128,9 @@ def __init__(self, spikes, sampling_frequency, unit_ids): BaseSorting.__init__(self, sampling_frequency, unit_ids) self._is_dumpable = True - # self._is_json_serializable = False self._serializablility["json"] = False - self._serializablility["pickle"] = False + # theorically this should be False but for simplicity make generators simples we still need this. + self._serializablility["pickle"] = True if spikes.size == 0: nseg = 1 @@ -362,7 +361,7 @@ def __init__(self, shm_name, shape, sampling_frequency, unit_ids, dtype=minimum_ BaseSorting.__init__(self, sampling_frequency, unit_ids) self._is_dumpable = True - # self._is_json_serializable = False + self._serializablility["json"] = False self._serializablility["pickle"] = False @@ -523,7 +522,6 @@ def __init__(self, snippets_list, spikesframes_list, sampling_frequency, nbefore ) self._is_dumpable = False - # self._is_json_serializable = False self._serializablility["json"] = False self._serializablility["pickle"] = False diff --git a/src/spikeinterface/core/old_api_utils.py b/src/spikeinterface/core/old_api_utils.py index 38fbef1547..a31edb0dd7 100644 --- a/src/spikeinterface/core/old_api_utils.py +++ b/src/spikeinterface/core/old_api_utils.py @@ -183,7 +183,6 @@ def __init__(self, oldapi_recording_extractor): # set _is_dumpable to False to use dumping mechanism of old extractor self._is_dumpable = False - # self._is_json_serializable = False self._serializablility["json"] = False self._serializablility["pickle"] = False @@ -271,7 +270,6 @@ def __init__(self, oldapi_sorting_extractor): self.add_sorting_segment(sorting_segment) self._is_dumpable = False - # self._is_json_serializable = False self._serializablility["json"] = False self._serializablility["pickle"] = False diff --git a/src/spikeinterface/core/tests/test_base.py b/src/spikeinterface/core/tests/test_base.py index 77a5d7d9bf..b716f6b1dd 100644 --- a/src/spikeinterface/core/tests/test_base.py +++ b/src/spikeinterface/core/tests/test_base.py @@ -46,16 +46,14 @@ def test_check_if_dumpable(): assert not extractor.check_if_dumpable() -def test_check_if_json_serializable(): +def test_check_if_serializable(): test_extractor = generate_recording(seed=0, durations=[2]) # make a list of dumpable objects - # test_extractor._is_json_serializable = True test_extractor._serializablility["json"] = True extractors_json_serializable = make_nested_extractors(test_extractor) for extractor in extractors_json_serializable: print(extractor) - # assert extractor.check_if_json_serializable() assert extractor.check_serializablility("json") # make not dumpable @@ -64,10 +62,9 @@ def test_check_if_json_serializable(): extractors_not_json_serializable = make_nested_extractors(test_extractor) for extractor in extractors_not_json_serializable: print(extractor) - # assert not extractor.check_if_json_serializable() assert not extractor.check_serializablility("json") if __name__ == "__main__": test_check_if_dumpable() - test_check_if_json_serializable() + test_check_if_serializable() diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index f53b9cf18d..3972c9186c 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -517,6 +517,8 @@ def test_non_json_object(): num_units=5, ) + + print(recording.check_serializablility("pickle")) # recording is not save to keep it in memory sorting = sorting.save() diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 53852bf319..3de1429feb 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -280,17 +280,17 @@ def create( else: relative_to = None - # if recording.check_if_json_serializable(): if recording.check_serializablility("json"): recording.dump(folder / "recording.json", relative_to=relative_to) elif recording.check_serializablility("pickle"): # In this case we loose the relative_to!! - # TODO make sure that we do not dump to pickle a NumpyRecording!!!!! recording.dump(folder / "recording.pickle") - # if sorting.check_if_json_serializable(): if sorting.check_serializablility("json"): sorting.dump(folder / "sorting.json", relative_to=relative_to) + elif sorting.check_serializablility("pickle"): + # In this case we loose the relative_to!! + sorting.dump(folder / "sorting.pickle") else: warn( "Sorting object is not dumpable, which might result in downstream errors for " @@ -895,12 +895,16 @@ def save( (folder / "params.json").write_text(json.dumps(check_json(self._params), indent=4), encoding="utf8") if self.has_recording(): - # if self.recording.check_if_json_serializable(): if self.recording.check_serializablility("json"): self.recording.dump(folder / "recording.json", relative_to=relative_to) - # if self.sorting.check_if_json_serializable(): + elif self.recording.check_serializablility("pickle"): + self.recording.dump(folder / "recording.pickle") + + if self.sorting.check_serializablility("json"): self.sorting.dump(folder / "sorting.json", relative_to=relative_to) + elif self.sorting.check_serializablility("pickle"): + self.sorting.dump(folder / "sorting.pickle", relative_to=relative_to) else: warn( "Sorting object is not dumpable, which might result in downstream errors for " @@ -949,10 +953,10 @@ def save( # write metadata zarr_root.attrs["params"] = check_json(self._params) if self.has_recording(): - if self.recording.check_if_json_serializable(): + if self.recording.check_serializablility("json"): rec_dict = self.recording.to_dict(relative_to=relative_to, recursive=True) zarr_root.attrs["recording"] = check_json(rec_dict) - if self.sorting.check_if_json_serializable(): + if self.sorting.check_serializablility("json"): sort_dict = self.sorting.to_dict(relative_to=relative_to, recursive=True) zarr_root.attrs["sorting"] = check_json(sort_dict) else: diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index 0054fb94d4..6ab1a9afce 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -333,7 +333,6 @@ def correct_motion( ) (folder / "parameters.json").write_text(json.dumps(parameters, indent=4, cls=SIJsonEncoder), encoding="utf8") (folder / "run_times.json").write_text(json.dumps(run_times, indent=4), encoding="utf8") - # if recording.check_if_json_serializable(): if recording.check_serializablility("json"): recording.dump_to_json(folder / "recording.json") From 615c5d9cd219e4016e7149f1ce170f043d507333 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 20 Sep 2023 14:19:46 +0200 Subject: [PATCH 03/10] Make pickle possible to dump in run sorter when json is not possible. --- src/spikeinterface/sorters/basesorter.py | 61 ++++++++++++------- .../sorters/external/herdingspikes.py | 4 +- .../sorters/external/mountainsort4.py | 4 +- .../sorters/external/mountainsort5.py | 4 +- .../sorters/external/pykilosort.py | 4 +- .../sorters/internal/spyking_circus2.py | 5 +- .../sorters/internal/tridesclous2.py | 4 +- src/spikeinterface/sorters/runsorter.py | 15 ++++- 8 files changed, 59 insertions(+), 42 deletions(-) diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index da20506965..bbcde31eed 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -137,9 +137,10 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo ) rec_file = output_folder / "spikeinterface_recording.json" - # if recording.check_if_json_serializable(): if recording.check_serializablility("json"): - recording.dump_to_json(rec_file, relative_to=output_folder) + recording.dump(rec_file, relative_to=output_folder) + elif recording.check_serializablility("pickle"): + recording.dump(output_folder / "spikeinterface_recording.pickle") else: d = {"warning": "The recording is not serializable to json"} rec_file.write_text(json.dumps(d, indent=4), encoding="utf8") @@ -186,6 +187,28 @@ def set_params_to_folder(cls, recording, output_folder, new_params, verbose): return params + @classmethod + def load_recording_from_folder(cls, output_folder, with_warnings=False): + + json_file = output_folder / "spikeinterface_recording.json" + pickle_file = output_folder / "spikeinterface_recording.pickle" + + + if json_file.exists(): + with (json_file).open("r", encoding="utf8") as f: + recording_dict = json.load(f) + if "warning" in recording_dict.keys() and with_warnings: + warnings.warn( + "The recording that has been sorted is not JSON serializable: it cannot be registered to the sorting object." + ) + recording = None + else: + recording = load_extractor(json_file, base_folder=output_folder) + elif pickle_file.exits(): + recording = load_extractor(pickle_file) + + return recording + @classmethod def _dump_params(cls, recording, output_folder, sorter_params, verbose): with (output_folder / "spikeinterface_params.json").open(mode="w", encoding="utf8") as f: @@ -272,7 +295,7 @@ def run_from_folder(cls, output_folder, raise_error, verbose): return run_time @classmethod - def get_result_from_folder(cls, output_folder): + def get_result_from_folder(cls, output_folder, register_recording=True, sorting_info=True): output_folder = Path(output_folder) sorter_output_folder = output_folder / "sorter_output" # check errors in log file @@ -295,27 +318,21 @@ def get_result_from_folder(cls, output_folder): # back-compatibility sorting = cls._get_result_from_folder(output_folder) - # register recording to Sorting object - # check if not json serializable - with (output_folder / "spikeinterface_recording.json").open("r", encoding="utf8") as f: - recording_dict = json.load(f) - if "warning" in recording_dict.keys(): - warnings.warn( - "The recording that has been sorted is not JSON serializable: it cannot be registered to the sorting object." - ) - else: - recording = load_extractor(output_folder / "spikeinterface_recording.json", base_folder=output_folder) + if register_recording: + # register recording to Sorting object + recording = cls.load_recording_from_folder( output_folder, with_warnings=False) if recording is not None: - # can be None when not dumpable sorting.register_recording(recording) - # set sorting info to Sorting object - with open(output_folder / "spikeinterface_recording.json", "r") as f: - rec_dict = json.load(f) - with open(output_folder / "spikeinterface_params.json", "r") as f: - params_dict = json.load(f) - with open(output_folder / "spikeinterface_log.json", "r") as f: - log_dict = json.load(f) - sorting.set_sorting_info(rec_dict, params_dict, log_dict) + + if sorting_info: + # set sorting info to Sorting object + with open(output_folder / "spikeinterface_recording.json", "r") as f: + rec_dict = json.load(f) + with open(output_folder / "spikeinterface_params.json", "r") as f: + params_dict = json.load(f) + with open(output_folder / "spikeinterface_log.json", "r") as f: + log_dict = json.load(f) + sorting.set_sorting_info(rec_dict, params_dict, log_dict) return sorting diff --git a/src/spikeinterface/sorters/external/herdingspikes.py b/src/spikeinterface/sorters/external/herdingspikes.py index a8d702ebe9..5180e6f1cc 100644 --- a/src/spikeinterface/sorters/external/herdingspikes.py +++ b/src/spikeinterface/sorters/external/herdingspikes.py @@ -147,9 +147,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): else: new_api = False - recording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) p = params diff --git a/src/spikeinterface/sorters/external/mountainsort4.py b/src/spikeinterface/sorters/external/mountainsort4.py index 69f97fd11c..f6f0b3eaeb 100644 --- a/src/spikeinterface/sorters/external/mountainsort4.py +++ b/src/spikeinterface/sorters/external/mountainsort4.py @@ -89,9 +89,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): def _run_from_folder(cls, sorter_output_folder, params, verbose): import mountainsort4 - recording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) # alias to params p = params diff --git a/src/spikeinterface/sorters/external/mountainsort5.py b/src/spikeinterface/sorters/external/mountainsort5.py index df6d276bf5..a88c59d688 100644 --- a/src/spikeinterface/sorters/external/mountainsort5.py +++ b/src/spikeinterface/sorters/external/mountainsort5.py @@ -115,9 +115,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): def _run_from_folder(cls, sorter_output_folder, params, verbose): import mountainsort5 as ms5 - recording: BaseRecording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) # alias to params p = params diff --git a/src/spikeinterface/sorters/external/pykilosort.py b/src/spikeinterface/sorters/external/pykilosort.py index 2a41d793d5..1962d56206 100644 --- a/src/spikeinterface/sorters/external/pykilosort.py +++ b/src/spikeinterface/sorters/external/pykilosort.py @@ -148,9 +148,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): - recording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) if not recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): # saved by setup recording diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 9de2762562..86cce1959b 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -54,9 +54,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): job_kwargs["verbose"] = verbose job_kwargs["progress_bar"] = verbose - recording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) + sampling_rate = recording.get_sampling_frequency() num_channels = recording.get_num_channels() diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 42f51d3a77..ed327e0f3c 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -49,9 +49,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): import hdbscan - recording_raw = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording_raw = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) num_chans = recording_raw.get_num_channels() sampling_frequency = recording_raw.get_sampling_frequency() diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index 6e6ccc0358..e930ec7f79 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -624,10 +624,20 @@ def run_sorter_container( ) -def read_sorter_folder(output_folder, raise_error=True): +def read_sorter_folder(output_folder, register_recording=True, sorting_info=True, raise_error=True): """ Load a sorting object from a spike sorting output folder. The 'output_folder' must contain a valid 'spikeinterface_log.json' file + + + Parameters + ---------- + output_folder: Pth or str + The sorter folder + register_recording: bool, default: True + Attach recording (when json or pickle) to the sorting + sorting_info: bool, default: True + Attach sorting info to the sorting. """ output_folder = Path(output_folder) log_file = output_folder / "spikeinterface_log.json" @@ -647,7 +657,8 @@ def read_sorter_folder(output_folder, raise_error=True): sorter_name = log["sorter_name"] SorterClass = sorter_dict[sorter_name] - sorting = SorterClass.get_result_from_folder(output_folder) + sorting = SorterClass.get_result_from_folder(output_folder, register_recording=register_recording, + sorting_info=sorting_info) return sorting From f2188266647d7faf721d89089b6f9c0bd1d9e637 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 22 Sep 2023 16:22:01 +0200 Subject: [PATCH 04/10] feedback from Ramon --- src/spikeinterface/core/generate.py | 4 ++-- src/spikeinterface/core/tests/test_waveform_extractor.py | 2 -- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 362b598b0b..05d63f3c8d 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1433,7 +1433,7 @@ def generate_ground_truth_recording( ) recording.annotate(is_filtered=True) recording.set_probe(probe, in_place=True) - recording.set_property("gain_to_uV", np.ones(num_channels)) - recording.set_property("offset_to_uV", np.zeros(num_channels)) + recording.set_channel_gains(1.) + recording.set_channel_offsets(0.) return recording, sorting diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index 3972c9186c..f53b9cf18d 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -517,8 +517,6 @@ def test_non_json_object(): num_units=5, ) - - print(recording.check_serializablility("pickle")) # recording is not save to keep it in memory sorting = sorting.save() From 3c3451ecf6452419ebf83dd6dd2d9454ba7e6419 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 27 Sep 2023 10:00:35 +0200 Subject: [PATCH 05/10] replace is_dumpable() by a more explicit naming : is_memory_serializable() --- src/spikeinterface/core/base.py | 53 +++++++++---------- src/spikeinterface/core/job_tools.py | 2 +- src/spikeinterface/core/numpyextractors.py | 6 +-- src/spikeinterface/core/old_api_utils.py | 6 +-- src/spikeinterface/core/tests/test_base.py | 10 ++-- .../core/tests/test_jsonification.py | 3 +- .../postprocessing/spike_amplitudes.py | 2 +- 7 files changed, 38 insertions(+), 44 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 63cf8e894f..3b8765a398 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -57,9 +57,7 @@ def __init__(self, main_ids: Sequence) -> None: # * number of units for sorting self._properties = {} - self._is_dumpable = True - # self._is_json_serializable = True - self._serializablility = {'json': True, 'pickle': True} + self._serializablility = {'memory': True, 'json': True, 'pickle': True} # extractor specific list of pip extra requirements self.extra_requirements = [] @@ -472,31 +470,8 @@ def clone(self) -> "BaseExtractor": clone = BaseExtractor.from_dict(d) return clone - def check_if_dumpable(self): - """Check if the object is dumpable, including nested objects. - - Returns - ------- - bool - True if the object is dumpable, False otherwise. - """ - kwargs = self._kwargs - for value in kwargs.values(): - # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors - if isinstance(value, BaseExtractor): - if not value.check_if_dumpable(): - return False - elif isinstance(value, list): - for v in value: - if isinstance(v, BaseExtractor) and not v.check_if_dumpable(): - return False - elif isinstance(value, dict): - for v in value.values(): - if isinstance(v, BaseExtractor) and not v.check_if_dumpable(): - return False - return self._is_dumpable - def check_serializablility(self, type="json"): + def check_serializablility(self, type): kwargs = self._kwargs for value in kwargs.values(): # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors @@ -512,6 +487,26 @@ def check_serializablility(self, type="json"): if isinstance(v, BaseExtractor) and not v.check_serializablility(type=type): return False return self._serializablility[type] + + + def check_if_dumpable(self): + warnings.warn( + "check_if_dumpable() is replace by is_memory_serializable()", DeprecationWarning, stacklevel=2 + ) + return self.check_serializablility("memory") + + def is_memory_serializable(self): + """ + Check if the object is serializable to memory with pickle, including nested objects. + + Returns + ------- + bool + True if the object is json serializable, False otherwise. + """ + return self.check_serializablility("memory") + + def check_if_json_serializable(self): """ @@ -636,7 +631,7 @@ def dump_to_pickle( folder_metadata: str, Path, or None Folder with files containing additional information (e.g. probe in BaseRecording) and properties. """ - assert self.check_if_dumpable(), "The extractor is not dumpable" + assert self.check_if_pickle_serializable(), "The extractor is not dumpable" dump_dict = self.to_dict( include_annotations=True, @@ -931,7 +926,7 @@ def save_to_zarr( zarr_root = zarr.open(zarr_path_init, mode="w", storage_options=storage_options) - if self.check_if_dumpable(): + if self.check_if_json_serializable(): zarr_root.attrs["provenance"] = check_json(self.to_dict()) else: zarr_root.attrs["provenance"] = None diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index c0ee77d2fd..0535872ca6 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -167,7 +167,7 @@ def ensure_n_jobs(recording, n_jobs=1): print(f"Python {sys.version} does not support parallel processing") n_jobs = 1 - if not recording.check_if_dumpable(): + if not recording.is_memory_serializable(): if n_jobs != 1: raise RuntimeError( "Recording is not dumpable and can't be processed in parallel. " diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 5ef955a6eb..d09016c8f1 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -127,7 +127,7 @@ def __init__(self, spikes, sampling_frequency, unit_ids): """ """ BaseSorting.__init__(self, sampling_frequency, unit_ids) - self._is_dumpable = True + self._serializablility["memory"] = True self._serializablility["json"] = False # theorically this should be False but for simplicity make generators simples we still need this. self._serializablility["pickle"] = True @@ -360,8 +360,8 @@ def __init__(self, shm_name, shape, sampling_frequency, unit_ids, dtype=minimum_ assert shape[0] > 0, "SharedMemorySorting only supported with no empty sorting" BaseSorting.__init__(self, sampling_frequency, unit_ids) - self._is_dumpable = True + self._serializablility["memory"] = True self._serializablility["json"] = False self._serializablility["pickle"] = False @@ -521,7 +521,7 @@ def __init__(self, snippets_list, spikesframes_list, sampling_frequency, nbefore dtype=dtype, ) - self._is_dumpable = False + self._serializablility["memory"] = False self._serializablility["json"] = False self._serializablility["pickle"] = False diff --git a/src/spikeinterface/core/old_api_utils.py b/src/spikeinterface/core/old_api_utils.py index a31edb0dd7..879700cc15 100644 --- a/src/spikeinterface/core/old_api_utils.py +++ b/src/spikeinterface/core/old_api_utils.py @@ -181,8 +181,8 @@ def __init__(self, oldapi_recording_extractor): dtype=oldapi_recording_extractor.get_dtype(return_scaled=False), ) - # set _is_dumpable to False to use dumping mechanism of old extractor - self._is_dumpable = False + # set to False to use dumping mechanism of old extractor + self._serializablility["memory"] = False self._serializablility["json"] = False self._serializablility["pickle"] = False @@ -269,7 +269,7 @@ def __init__(self, oldapi_sorting_extractor): sorting_segment = OldToNewSortingSegment(oldapi_sorting_extractor) self.add_sorting_segment(sorting_segment) - self._is_dumpable = False + self._serializablility["memory"] = False self._serializablility["json"] = False self._serializablility["pickle"] = False diff --git a/src/spikeinterface/core/tests/test_base.py b/src/spikeinterface/core/tests/test_base.py index b716f6b1dd..28dbd166ec 100644 --- a/src/spikeinterface/core/tests/test_base.py +++ b/src/spikeinterface/core/tests/test_base.py @@ -31,19 +31,19 @@ def make_nested_extractors(extractor): ) -def test_check_if_dumpable(): +def test_is_memory_serializable(): test_extractor = generate_recording(seed=0, durations=[2]) # make a list of dumpable objects extractors_dumpable = make_nested_extractors(test_extractor) for extractor in extractors_dumpable: - assert extractor.check_if_dumpable() + assert extractor.is_memory_serializable() # make not dumpable - test_extractor._is_dumpable = False + test_extractor._serializablility["memory"] = False extractors_not_dumpable = make_nested_extractors(test_extractor) for extractor in extractors_not_dumpable: - assert not extractor.check_if_dumpable() + assert not extractor.is_memory_serializable() def test_check_if_serializable(): @@ -66,5 +66,5 @@ def test_check_if_serializable(): if __name__ == "__main__": - test_check_if_dumpable() + test_is_memory_serializable() test_check_if_serializable() diff --git a/src/spikeinterface/core/tests/test_jsonification.py b/src/spikeinterface/core/tests/test_jsonification.py index 8572cda23e..026e676966 100644 --- a/src/spikeinterface/core/tests/test_jsonification.py +++ b/src/spikeinterface/core/tests/test_jsonification.py @@ -144,8 +144,7 @@ def __init__(self, attribute, other_extractor=None, extractor_list=None, extract BaseExtractor.__init__(self, main_ids=['1', '2']) # this already the case by default - self._is_dumpable = True - # self._is_json_serializable = True + self._serializablility["memory"] = True self._serializablility["json"] = True self._serializablility["pickle"] = True diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 38cb714d59..aa99f7fc5e 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -75,7 +75,7 @@ def _run(self, **job_kwargs): n_jobs = ensure_n_jobs(recording, job_kwargs.get("n_jobs", None)) if n_jobs != 1: # TODO: avoid dumping sorting and use spike vector and peak pipeline instead - assert sorting.check_if_dumpable(), ( + assert sorting.is_memory_serializable(), ( "The sorting object is not dumpable and cannot be processed in parallel. You can use the " "`sorting.save()` function to make it dumpable" ) From 9d3dceaacc77158487c47972a2d949a71bb3c65a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Sep 2023 08:52:23 +0000 Subject: [PATCH 06/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/comparison/multicomparisons.py | 4 ++-- src/spikeinterface/core/base.py | 10 ++-------- src/spikeinterface/core/generate.py | 4 ++-- src/spikeinterface/core/numpyextractors.py | 2 +- .../core/tests/test_jsonification.py | 8 ++++---- .../core/tests/test_waveform_extractor.py | 15 ++++++++++----- src/spikeinterface/core/waveform_extractor.py | 1 - src/spikeinterface/sorters/basesorter.py | 6 ++---- src/spikeinterface/sorters/runsorter.py | 7 ++++--- 9 files changed, 27 insertions(+), 30 deletions(-) diff --git a/src/spikeinterface/comparison/multicomparisons.py b/src/spikeinterface/comparison/multicomparisons.py index 6fe474822b..f44e14c4c4 100644 --- a/src/spikeinterface/comparison/multicomparisons.py +++ b/src/spikeinterface/comparison/multicomparisons.py @@ -189,8 +189,8 @@ def save_to_folder(self, save_folder): stacklevel=2, ) for sorting in self.object_list: - assert ( - sorting.check_serializablility("json") + assert sorting.check_serializablility( + "json" ), "MultiSortingComparison.save_to_folder() need json serializable sortings" save_folder = Path(save_folder) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 3b8765a398..6e91cedcb5 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -57,7 +57,7 @@ def __init__(self, main_ids: Sequence) -> None: # * number of units for sorting self._properties = {} - self._serializablility = {'memory': True, 'json': True, 'pickle': True} + self._serializablility = {"memory": True, "json": True, "pickle": True} # extractor specific list of pip extra requirements self.extra_requirements = [] @@ -470,7 +470,6 @@ def clone(self) -> "BaseExtractor": clone = BaseExtractor.from_dict(d) return clone - def check_serializablility(self, type): kwargs = self._kwargs for value in kwargs.values(): @@ -488,11 +487,8 @@ def check_serializablility(self, type): return False return self._serializablility[type] - def check_if_dumpable(self): - warnings.warn( - "check_if_dumpable() is replace by is_memory_serializable()", DeprecationWarning, stacklevel=2 - ) + warnings.warn("check_if_dumpable() is replace by is_memory_serializable()", DeprecationWarning, stacklevel=2) return self.check_serializablility("memory") def is_memory_serializable(self): @@ -506,8 +502,6 @@ def is_memory_serializable(self): """ return self.check_serializablility("memory") - - def check_if_json_serializable(self): """ Check if the object is json serializable, including nested objects. diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 05d63f3c8d..eeb1e8af60 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1433,7 +1433,7 @@ def generate_ground_truth_recording( ) recording.annotate(is_filtered=True) recording.set_probe(probe, in_place=True) - recording.set_channel_gains(1.) - recording.set_channel_offsets(0.) + recording.set_channel_gains(1.0) + recording.set_channel_offsets(0.0) return recording, sorting diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index d09016c8f1..3d7ec6cd1a 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -523,7 +523,7 @@ def __init__(self, snippets_list, spikesframes_list, sampling_frequency, nbefore self._serializablility["memory"] = False self._serializablility["json"] = False - self._serializablility["pickle"] = False + self._serializablility["pickle"] = False for snippets, spikesframes in zip(snippets_list, spikesframes_list): snp_segment = NumpySnippetsSegment(snippets, spikesframes) diff --git a/src/spikeinterface/core/tests/test_jsonification.py b/src/spikeinterface/core/tests/test_jsonification.py index 026e676966..1c491bd7a6 100644 --- a/src/spikeinterface/core/tests/test_jsonification.py +++ b/src/spikeinterface/core/tests/test_jsonification.py @@ -142,11 +142,11 @@ def __init__(self, attribute, other_extractor=None, extractor_list=None, extract self.extractor_list = extractor_list self.extractor_dict = extractor_dict - BaseExtractor.__init__(self, main_ids=['1', '2']) + BaseExtractor.__init__(self, main_ids=["1", "2"]) # this already the case by default self._serializablility["memory"] = True self._serializablility["json"] = True - self._serializablility["pickle"] = True + self._serializablility["pickle"] = True self._kwargs = { "attribute": attribute, @@ -199,6 +199,6 @@ def test_encoding_numpy_scalars_within_nested_extractors_dict(nested_extractor_d json.dumps(nested_extractor_dict, cls=SIJsonEncoder) -if __name__ == '__main__': +if __name__ == "__main__": nested_extractor = nested_extractor() - test_encoding_numpy_scalars_within_nested_extractors(nested_extractor_) \ No newline at end of file + test_encoding_numpy_scalars_within_nested_extractors(nested_extractor_) diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index f53b9cf18d..12dac52d43 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -6,7 +6,13 @@ import zarr -from spikeinterface.core import generate_recording, generate_sorting, NumpySorting, ChannelSparsity, generate_ground_truth_recording +from spikeinterface.core import ( + generate_recording, + generate_sorting, + NumpySorting, + ChannelSparsity, + generate_ground_truth_recording, +) from spikeinterface import WaveformExtractor, BaseRecording, extract_waveforms, load_waveforms from spikeinterface.core.waveform_extractor import precompute_sparsity @@ -509,14 +515,15 @@ def test_compute_sparsity(): ) print(sparsity) + def test_non_json_object(): recording, sorting = generate_ground_truth_recording( durations=[30, 40], sampling_frequency=30000.0, num_channels=32, num_units=5, - ) - + ) + # recording is not save to keep it in memory sorting = sorting.save() @@ -524,7 +531,6 @@ def test_non_json_object(): if wf_folder.is_dir(): shutil.rmtree(wf_folder) - we = extract_waveforms( recording, sorting, @@ -551,4 +557,3 @@ def test_non_json_object(): # test_recordingless() # test_compute_sparsity() test_non_json_object() - diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 3de1429feb..cd8a62f5bc 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -900,7 +900,6 @@ def save( elif self.recording.check_serializablility("pickle"): self.recording.dump(folder / "recording.pickle") - if self.sorting.check_serializablility("json"): self.sorting.dump(folder / "sorting.json", relative_to=relative_to) elif self.sorting.check_serializablility("pickle"): diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index bbcde31eed..8d87558191 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -189,11 +189,9 @@ def set_params_to_folder(cls, recording, output_folder, new_params, verbose): @classmethod def load_recording_from_folder(cls, output_folder, with_warnings=False): - json_file = output_folder / "spikeinterface_recording.json" pickle_file = output_folder / "spikeinterface_recording.pickle" - if json_file.exists(): with (json_file).open("r", encoding="utf8") as f: recording_dict = json.load(f) @@ -206,7 +204,7 @@ def load_recording_from_folder(cls, output_folder, with_warnings=False): recording = load_extractor(json_file, base_folder=output_folder) elif pickle_file.exits(): recording = load_extractor(pickle_file) - + return recording @classmethod @@ -320,7 +318,7 @@ def get_result_from_folder(cls, output_folder, register_recording=True, sorting_ if register_recording: # register recording to Sorting object - recording = cls.load_recording_from_folder( output_folder, with_warnings=False) + recording = cls.load_recording_from_folder(output_folder, with_warnings=False) if recording is not None: sorting.register_recording(recording) diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index e930ec7f79..bd5667b15f 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -629,7 +629,7 @@ def read_sorter_folder(output_folder, register_recording=True, sorting_info=True Load a sorting object from a spike sorting output folder. The 'output_folder' must contain a valid 'spikeinterface_log.json' file - + Parameters ---------- output_folder: Pth or str @@ -657,8 +657,9 @@ def read_sorter_folder(output_folder, register_recording=True, sorting_info=True sorter_name = log["sorter_name"] SorterClass = sorter_dict[sorter_name] - sorting = SorterClass.get_result_from_folder(output_folder, register_recording=register_recording, - sorting_info=sorting_info) + sorting = SorterClass.get_result_from_folder( + output_folder, register_recording=register_recording, sorting_info=sorting_info + ) return sorting From 7329927cfb3035d764648a2175d617aa8999c67b Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 27 Sep 2023 10:54:57 +0200 Subject: [PATCH 07/10] rename to check_if_memory_serializable --- src/spikeinterface/core/base.py | 6 +----- src/spikeinterface/core/job_tools.py | 2 +- src/spikeinterface/core/tests/test_base.py | 8 ++++---- src/spikeinterface/postprocessing/spike_amplitudes.py | 2 +- 4 files changed, 7 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 6e91cedcb5..b1b5065339 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -487,11 +487,7 @@ def check_serializablility(self, type): return False return self._serializablility[type] - def check_if_dumpable(self): - warnings.warn("check_if_dumpable() is replace by is_memory_serializable()", DeprecationWarning, stacklevel=2) - return self.check_serializablility("memory") - - def is_memory_serializable(self): + def check_if_memory_serializable(self): """ Check if the object is serializable to memory with pickle, including nested objects. diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 0535872ca6..9369ad0b61 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -167,7 +167,7 @@ def ensure_n_jobs(recording, n_jobs=1): print(f"Python {sys.version} does not support parallel processing") n_jobs = 1 - if not recording.is_memory_serializable(): + if not recording.check_if_memory_serializable(): if n_jobs != 1: raise RuntimeError( "Recording is not dumpable and can't be processed in parallel. " diff --git a/src/spikeinterface/core/tests/test_base.py b/src/spikeinterface/core/tests/test_base.py index 28dbd166ec..8d0907c700 100644 --- a/src/spikeinterface/core/tests/test_base.py +++ b/src/spikeinterface/core/tests/test_base.py @@ -31,19 +31,19 @@ def make_nested_extractors(extractor): ) -def test_is_memory_serializable(): +def test_check_if_memory_serializable(): test_extractor = generate_recording(seed=0, durations=[2]) # make a list of dumpable objects extractors_dumpable = make_nested_extractors(test_extractor) for extractor in extractors_dumpable: - assert extractor.is_memory_serializable() + assert extractor.check_if_memory_serializable() # make not dumpable test_extractor._serializablility["memory"] = False extractors_not_dumpable = make_nested_extractors(test_extractor) for extractor in extractors_not_dumpable: - assert not extractor.is_memory_serializable() + assert not extractor.check_if_memory_serializable() def test_check_if_serializable(): @@ -66,5 +66,5 @@ def test_check_if_serializable(): if __name__ == "__main__": - test_is_memory_serializable() + test_check_if_memory_serializable() test_check_if_serializable() diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index aa99f7fc5e..9eb5a815d4 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -75,7 +75,7 @@ def _run(self, **job_kwargs): n_jobs = ensure_n_jobs(recording, job_kwargs.get("n_jobs", None)) if n_jobs != 1: # TODO: avoid dumping sorting and use spike vector and peak pipeline instead - assert sorting.is_memory_serializable(), ( + assert sorting.check_if_memory_serializable(), ( "The sorting object is not dumpable and cannot be processed in parallel. You can use the " "`sorting.save()` function to make it dumpable" ) From b9c6a38e99430fc7b734e0751871e6d08eb5aea1 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 27 Sep 2023 10:56:28 +0200 Subject: [PATCH 08/10] oups --- src/spikeinterface/core/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index b1b5065339..e3b88588e2 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -494,7 +494,7 @@ def check_if_memory_serializable(self): Returns ------- bool - True if the object is json serializable, False otherwise. + True if the object is memory serializable, False otherwise. """ return self.check_serializablility("memory") From 331379a3f441e2691eb15985b60254fcc9e3f887 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 27 Sep 2023 11:13:29 +0200 Subject: [PATCH 09/10] Remove "dumpable" naming also in doc and warnings. --- doc/modules/core.rst | 3 +-- src/spikeinterface/comparison/hybrid.py | 4 ++-- src/spikeinterface/core/base.py | 8 ++++---- src/spikeinterface/core/job_tools.py | 4 ++-- src/spikeinterface/core/tests/test_base.py | 17 ++++++++--------- .../core/tests/test_core_tools.py | 1 - src/spikeinterface/core/tests/test_job_tools.py | 6 +++--- .../core/tests/test_waveform_extractor.py | 2 +- src/spikeinterface/core/waveform_extractor.py | 15 ++++++++------- .../postprocessing/spike_amplitudes.py | 6 ------ .../sorters/tests/test_launcher.py | 2 +- 11 files changed, 30 insertions(+), 38 deletions(-) diff --git a/doc/modules/core.rst b/doc/modules/core.rst index fdc4d71fe7..976a82a4a3 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -547,8 +547,7 @@ workflow. In order to do this, one can use the :code:`Numpy*` classes, :py:class:`~spikeinterface.core.NumpyRecording`, :py:class:`~spikeinterface.core.NumpySorting`, :py:class:`~spikeinterface.core.NumpyEvent`, and :py:class:`~spikeinterface.core.NumpySnippets`. These object behave exactly like normal SpikeInterface objects, -but they are not bound to a file. This makes these objects *not dumpable*, so parallel processing is not supported. -In order to make them *dumpable*, one can simply :code:`save()` them (see :ref:`save_load`). +but they are not bound to a file. Also note the class :py:class:`~spikeinterface.core.SharedMemorySorting` which is very similar to Similar to :py:class:`~spikeinterface.core.NumpySorting` but with an unerlying SharedMemory which is usefull for diff --git a/src/spikeinterface/comparison/hybrid.py b/src/spikeinterface/comparison/hybrid.py index 3b8e9e0a72..e0c98cd772 100644 --- a/src/spikeinterface/comparison/hybrid.py +++ b/src/spikeinterface/comparison/hybrid.py @@ -39,7 +39,7 @@ class HybridUnitsRecording(InjectTemplatesRecording): The refractory period of the injected spike train (in ms). injected_sorting_folder: str | Path | None If given, the injected sorting is saved to this folder. - It must be specified if injected_sorting is None or not dumpable. + It must be specified if injected_sorting is None or not serialisable to file. Returns ------- @@ -138,7 +138,7 @@ class HybridSpikesRecording(InjectTemplatesRecording): this refractory period. injected_sorting_folder: str | Path | None If given, the injected sorting is saved to this folder. - It must be specified if injected_sorting is None or not dumpable. + It must be specified if injected_sorting is None or not serializable to file. Returns ------- diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index e3b88588e2..73f8619348 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -621,7 +621,7 @@ def dump_to_pickle( folder_metadata: str, Path, or None Folder with files containing additional information (e.g. probe in BaseRecording) and properties. """ - assert self.check_if_pickle_serializable(), "The extractor is not dumpable" + assert self.check_if_pickle_serializable(), "The extractor is not serializable to file with pickle" dump_dict = self.to_dict( include_annotations=True, @@ -658,8 +658,8 @@ def load(file_path: Union[str, Path], base_folder: Optional[Union[Path, str, boo d = pickle.load(f) else: raise ValueError(f"Impossible to load {file_path}") - if "warning" in d and "not dumpable" in d["warning"]: - print("The extractor was not dumpable") + if "warning" in d: + print("The extractor was not serializable to file") return None extractor = BaseExtractor.from_dict(d, base_folder=base_folder) return extractor @@ -822,7 +822,7 @@ def save_to_folder(self, name=None, folder=None, verbose=True, **save_kwargs): if self.check_serializablility("json"): self.dump(provenance_file) else: - provenance_file.write_text(json.dumps({"warning": "the provenace is not dumpable!!!"}), encoding="utf8") + provenance_file.write_text(json.dumps({"warning": "the provenace is not json serializable!!!"}), encoding="utf8") self.save_metadata_to_folder(folder) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 9369ad0b61..84ee502c14 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -170,8 +170,8 @@ def ensure_n_jobs(recording, n_jobs=1): if not recording.check_if_memory_serializable(): if n_jobs != 1: raise RuntimeError( - "Recording is not dumpable and can't be processed in parallel. " - "You can use the `recording.save()` function to make it dumpable or set 'n_jobs' to 1." + "Recording is not serializable to memory and can't be processed in parallel. " + "You can use the `rec = recording.save(folder=...)` function or set 'n_jobs' to 1." ) return n_jobs diff --git a/src/spikeinterface/core/tests/test_base.py b/src/spikeinterface/core/tests/test_base.py index 8d0907c700..a944be3da0 100644 --- a/src/spikeinterface/core/tests/test_base.py +++ b/src/spikeinterface/core/tests/test_base.py @@ -34,30 +34,29 @@ def make_nested_extractors(extractor): def test_check_if_memory_serializable(): test_extractor = generate_recording(seed=0, durations=[2]) - # make a list of dumpable objects - extractors_dumpable = make_nested_extractors(test_extractor) - for extractor in extractors_dumpable: + # make a list of memory serializable objects + extractors_mem_serializable = make_nested_extractors(test_extractor) + for extractor in extractors_mem_serializable: assert extractor.check_if_memory_serializable() - # make not dumpable + # make not not memory serilizable test_extractor._serializablility["memory"] = False - extractors_not_dumpable = make_nested_extractors(test_extractor) - for extractor in extractors_not_dumpable: + extractors_not_mem_serializable = make_nested_extractors(test_extractor) + for extractor in extractors_not_mem_serializable: assert not extractor.check_if_memory_serializable() def test_check_if_serializable(): test_extractor = generate_recording(seed=0, durations=[2]) - # make a list of dumpable objects + # make a list of json serializable objects test_extractor._serializablility["json"] = True extractors_json_serializable = make_nested_extractors(test_extractor) for extractor in extractors_json_serializable: print(extractor) assert extractor.check_serializablility("json") - # make not dumpable - # test_extractor._is_json_serializable = False + # make of not json serializable objects test_extractor._serializablility["json"] = False extractors_not_json_serializable = make_nested_extractors(test_extractor) for extractor in extractors_not_json_serializable: diff --git a/src/spikeinterface/core/tests/test_core_tools.py b/src/spikeinterface/core/tests/test_core_tools.py index a3cd0caa92..223b2a8a3a 100644 --- a/src/spikeinterface/core/tests/test_core_tools.py +++ b/src/spikeinterface/core/tests/test_core_tools.py @@ -142,7 +142,6 @@ def test_write_memory_recording(): recording = NoiseGeneratorRecording( num_channels=2, durations=[10.325, 3.5], sampling_frequency=30_000, strategy="tile_pregenerated" ) - # make dumpable recording = recording.save() # write with loop diff --git a/src/spikeinterface/core/tests/test_job_tools.py b/src/spikeinterface/core/tests/test_job_tools.py index 7d7af6025b..a904e4dd32 100644 --- a/src/spikeinterface/core/tests/test_job_tools.py +++ b/src/spikeinterface/core/tests/test_job_tools.py @@ -36,7 +36,7 @@ def test_ensure_n_jobs(): n_jobs = ensure_n_jobs(recording, n_jobs=1) assert n_jobs == 1 - # dumpable + # check serializable n_jobs = ensure_n_jobs(recording.save(), n_jobs=-1) assert n_jobs > 1 @@ -45,7 +45,7 @@ def test_ensure_chunk_size(): recording = generate_recording(num_channels=2) dtype = recording.get_dtype() assert dtype == "float32" - # make dumpable + # make serializable recording = recording.save() chunk_size = ensure_chunk_size(recording, total_memory="512M", chunk_size=None, chunk_memory=None, n_jobs=2) @@ -90,7 +90,7 @@ def init_func(arg1, arg2, arg3): def test_ChunkRecordingExecutor(): recording = generate_recording(num_channels=2) - # make dumpable + # make serializable recording = recording.save() init_args = "a", 120, "yep" diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index 12dac52d43..2bbf5e9b0f 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -315,7 +315,7 @@ def test_recordingless(): recording = recording.save(folder=cache_folder / "recording1") sorting = sorting.save(folder=cache_folder / "sorting1") - # recording and sorting are not dumpable + # recording and sorting are not serializable wf_folder = cache_folder / "wf_recordingless" # save with relative paths diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index cd8a62f5bc..2710ff1338 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -290,11 +290,12 @@ def create( sorting.dump(folder / "sorting.json", relative_to=relative_to) elif sorting.check_serializablility("pickle"): # In this case we loose the relative_to!! + # TODO later the dump to pickle should dump the dictionary and so relative could be put back sorting.dump(folder / "sorting.pickle") else: warn( - "Sorting object is not dumpable, which might result in downstream errors for " - "parallel processing. To make the sorting dumpable, use the `sorting.save()` function." + "Sorting object is not serializable to file, which might result in downstream errors for " + "parallel processing. To make the sorting serializable, use the `sorting = sorting.save()` function." ) # dump some attributes of the recording for the mode with_recording=False at next load @@ -903,11 +904,11 @@ def save( if self.sorting.check_serializablility("json"): self.sorting.dump(folder / "sorting.json", relative_to=relative_to) elif self.sorting.check_serializablility("pickle"): - self.sorting.dump(folder / "sorting.pickle", relative_to=relative_to) + self.sorting.dump(folder / "sorting.pickle") else: warn( - "Sorting object is not dumpable, which might result in downstream errors for " - "parallel processing. To make the sorting dumpable, use the `sorting.save()` function." + "Sorting object is not serializable to file, which might result in downstream errors for " + "parallel processing. To make the sorting serializable, use the `sorting = sorting.save()` function." ) # dump some attributes of the recording for the mode with_recording=False at next load @@ -960,8 +961,8 @@ def save( zarr_root.attrs["sorting"] = check_json(sort_dict) else: warn( - "Sorting object is not dumpable, which might result in downstream errors for " - "parallel processing. To make the sorting dumpable, use the `sorting.save()` function." + "Sorting object is not json serializable, which might result in downstream errors for " + "parallel processing. To make the sorting serializable, use the `sorting = sorting.save()` function." ) recording_info = zarr_root.create_group("recording_info") recording_info.attrs["recording_attributes"] = check_json(rec_attributes) diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 9eb5a815d4..ccd2121174 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -73,12 +73,6 @@ def _run(self, **job_kwargs): func = _spike_amplitudes_chunk init_func = _init_worker_spike_amplitudes n_jobs = ensure_n_jobs(recording, job_kwargs.get("n_jobs", None)) - if n_jobs != 1: - # TODO: avoid dumping sorting and use spike vector and peak pipeline instead - assert sorting.check_if_memory_serializable(), ( - "The sorting object is not dumpable and cannot be processed in parallel. You can use the " - "`sorting.save()` function to make it dumpable" - ) init_args = (recording, sorting.to_multiprocessing(n_jobs), extremum_channels_index, peak_shifts, return_scaled) processor = ChunkRecordingExecutor( recording, func, init_func, init_args, handle_returns=True, job_name="extract amplitudes", **job_kwargs diff --git a/src/spikeinterface/sorters/tests/test_launcher.py b/src/spikeinterface/sorters/tests/test_launcher.py index 14c938f8ba..a5e29c8fd9 100644 --- a/src/spikeinterface/sorters/tests/test_launcher.py +++ b/src/spikeinterface/sorters/tests/test_launcher.py @@ -178,7 +178,7 @@ def test_run_sorters_with_list(): if working_folder.is_dir(): shutil.rmtree(working_folder) - # make dumpable + # make serializable rec0 = load_extractor(cache_folder / "toy_rec_0") rec1 = load_extractor(cache_folder / "toy_rec_1") From 0ea10e3baf97fbcedc8c25c2745754cacabb7b5c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Sep 2023 09:13:52 +0000 Subject: [PATCH 10/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 73f8619348..e8b3232e13 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -822,7 +822,9 @@ def save_to_folder(self, name=None, folder=None, verbose=True, **save_kwargs): if self.check_serializablility("json"): self.dump(provenance_file) else: - provenance_file.write_text(json.dumps({"warning": "the provenace is not json serializable!!!"}), encoding="utf8") + provenance_file.write_text( + json.dumps({"warning": "the provenace is not json serializable!!!"}), encoding="utf8" + ) self.save_metadata_to_folder(folder)