From 9add5def54fe63fd23f32d3cde5c2177f7eb1d09 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 19 Sep 2023 13:19:33 +0200 Subject: [PATCH 01/25] Deprecate multicomparison save/load functions in favor of pickle --- src/spikeinterface/comparison/multicomparisons.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/spikeinterface/comparison/multicomparisons.py b/src/spikeinterface/comparison/multicomparisons.py index 9e02fd5b2d..d1193907eb 100644 --- a/src/spikeinterface/comparison/multicomparisons.py +++ b/src/spikeinterface/comparison/multicomparisons.py @@ -1,6 +1,7 @@ from pathlib import Path import json import pickle +import warnings import numpy as np @@ -180,6 +181,11 @@ def get_agreement_sorting(self, minimum_agreement_count=1, minimum_agreement_cou return sorting def save_to_folder(self, save_folder): + warnings.warn( + "save_to_folder() is deprecated. You should save and load the multi sorting comparison object using pickle.\n>>> pickle.dump(mcmp, open('mcmp.pkl', 'wb')))))\n>>> mcmp_loaded = pickle.load(open('mcmp.pkl', 'rb'))", + DeprecationWarning, + stacklevel=2, + ) for sorting in self.object_list: assert ( sorting.check_if_json_serializable() @@ -205,6 +211,11 @@ def save_to_folder(self, save_folder): @staticmethod def load_from_folder(folder_path): + warnings.warn( + "load_from_folder() is deprecated. You should save and load the multi sorting comparison object using pickle.\n>>> pickle.dump(mcmp, open('mcmp.pkl', 'wb')))))\n>>> mcmp_loaded = pickle.load(open('mcmp.pkl', 'rb'))", + DeprecationWarning, + stacklevel=2, + ) folder_path = Path(folder_path) with (folder_path / "kwargs.json").open() as f: kwargs = json.load(f) From 3d2f41c0773a1e3b499f42918d582619e1fd0dba Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 19 Sep 2023 13:20:48 +0200 Subject: [PATCH 02/25] Formatting --- src/spikeinterface/comparison/multicomparisons.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/comparison/multicomparisons.py b/src/spikeinterface/comparison/multicomparisons.py index d1193907eb..d418b92ab8 100644 --- a/src/spikeinterface/comparison/multicomparisons.py +++ b/src/spikeinterface/comparison/multicomparisons.py @@ -182,7 +182,9 @@ def get_agreement_sorting(self, minimum_agreement_count=1, minimum_agreement_cou def save_to_folder(self, save_folder): warnings.warn( - "save_to_folder() is deprecated. You should save and load the multi sorting comparison object using pickle.\n>>> pickle.dump(mcmp, open('mcmp.pkl', 'wb')))))\n>>> mcmp_loaded = pickle.load(open('mcmp.pkl', 'rb'))", + "save_to_folder() is deprecated. " + "You should save and load the multi sorting comparison object using pickle." + "\n>>> pickle.dump(mcmp, open('mcmp.pkl', 'wb')))))\n>>> mcmp_loaded = pickle.load(open('mcmp.pkl', 'rb'))", DeprecationWarning, stacklevel=2, ) @@ -212,7 +214,9 @@ def save_to_folder(self, save_folder): @staticmethod def load_from_folder(folder_path): warnings.warn( - "load_from_folder() is deprecated. You should save and load the multi sorting comparison object using pickle.\n>>> pickle.dump(mcmp, open('mcmp.pkl', 'wb')))))\n>>> mcmp_loaded = pickle.load(open('mcmp.pkl', 'rb'))", + "load_from_folder() is deprecated. " + "You should save and load the multi sorting comparison object using pickle." + "\n>>> pickle.dump(mcmp, open('mcmp.pkl', 'wb')))))\n>>> mcmp_loaded = pickle.load(open('mcmp.pkl', 'rb'))", DeprecationWarning, stacklevel=2, ) From fac98233b84fa440b374d944d1c27b9d200cd0c1 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 19 Sep 2023 15:31:10 +0200 Subject: [PATCH 03/25] add tutorial to load matlab data --- doc/how_to/index.rst | 1 + doc/how_to/load_matalb_data.rst | 66 +++++++++++++++++++++++++++++++++ 2 files changed, 67 insertions(+) create mode 100644 doc/how_to/load_matalb_data.rst diff --git a/doc/how_to/index.rst b/doc/how_to/index.rst index dabad818f9..fa7210d4f0 100644 --- a/doc/how_to/index.rst +++ b/doc/how_to/index.rst @@ -7,3 +7,4 @@ How to guides get_started analyse_neuropixels handle_drift + load_matalb_data diff --git a/doc/how_to/load_matalb_data.rst b/doc/how_to/load_matalb_data.rst new file mode 100644 index 0000000000..39b9a48d65 --- /dev/null +++ b/doc/how_to/load_matalb_data.rst @@ -0,0 +1,66 @@ +Exporting MATLAB Data to Binary & Loading in SpikeInterface +=========================================================== + +In this tutorial, we'll go through the process of exporting your data from MATLAB in a binary format and then loading it using SpikeInterface in Python. Let's break down the steps. + +Exporting Data from MATLAB +-------------------------- + +First, ensure your data is structured correctly. The data matrix should be organized such that the first dimension corresponds to samples/time and the second dimension to channels. + +.. code-block:: matlab + + % Define the size of your data + num_samples = 1000; + num_channels = 384; + + % Generate random data as an example + data = rand(num_samples, num_channels); + + % Write the data to a binary file + fileID = fopen('your_data_as_a_binary.bin', 'wb'); + fwrite(fileID, data, 'double'); + fclose(fileID); + +.. note:: + + In a real-world scenario, replace the random data generation with your actual data. + +Loading Data in SpikeInterface +----------------------------- + +This should produce a binary file called `your_data_as_a_binary.bin` in your current MATLAB directory. +You will need the complete path (i.e. its location on your computer) to load it in Python. + +Once you have your data in a binary format, you can seamlessly load it into SpikeInterface using the following script: + +.. code-block:: python + + from spikeinterface.core.binaryrecordingextractor import BinaryRecordingExtractor + from pathlib import Path + + # Define the path to your binary file + file_path = Path("/The/Path/To/Your/Data/your_data_as_a_binary.bin") + + # Ensure the file exists + assert file_path.is_file() + + # Specify the parameters of your recording + sampling_frequency = 30_000.0 # in Hz, adjust as per your matlab dataset + num_channels = 384 # adjust as per your matlab dataset + dtype = "float64" + + # Load the data using SpikeInterface + recording = BinaryRecordingExtractor(file_path, sampling_frequency=sampling_frequency, + num_channels=num_channels, dtype=dtype, gain_to_uV=1, offset_to_uV=0) + + # Verify the shape of your data + assert recording.get_traces().shape == (num_samples, num_channels) + +Common Pitfalls & Tips +---------------------- + +1. **Data Shape**: Always ensure that your MATLAB data matrix's first dimension corresponds to samples/time and the second to channels. +2. **File Path**: Double-check the file path in Python to ensure you're pointing to the right directory. +3. **Data Type**: When moving data between MATLAB and Python, it's crucial to keep the data type consistent. In our example, we used `double` in MATLAB, which corresponds to `float64` in Python. +4. **Sampling Frequency**: Ensure you set the correct sampling frequency when loading data into SpikeInterface. From a395c3c7253cd7dadd813b25a4862610221f9cf4 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 20 Sep 2023 10:24:30 +0200 Subject: [PATCH 04/25] suggestions --- doc/how_to/index.rst | 2 +- ...d_matalb_data.rst => load_matlab_data.rst} | 26 +++++++++++-------- 2 files changed, 16 insertions(+), 12 deletions(-) rename doc/how_to/{load_matalb_data.rst => load_matlab_data.rst} (70%) diff --git a/doc/how_to/index.rst b/doc/how_to/index.rst index fa7210d4f0..da94cf549c 100644 --- a/doc/how_to/index.rst +++ b/doc/how_to/index.rst @@ -7,4 +7,4 @@ How to guides get_started analyse_neuropixels handle_drift - load_matalb_data + load_matlab_data diff --git a/doc/how_to/load_matalb_data.rst b/doc/how_to/load_matlab_data.rst similarity index 70% rename from doc/how_to/load_matalb_data.rst rename to doc/how_to/load_matlab_data.rst index 39b9a48d65..cca579036a 100644 --- a/doc/how_to/load_matalb_data.rst +++ b/doc/how_to/load_matlab_data.rst @@ -7,15 +7,16 @@ Exporting Data from MATLAB -------------------------- First, ensure your data is structured correctly. The data matrix should be organized such that the first dimension corresponds to samples/time and the second dimension to channels. +In the following MATLAB code, we generate random data as an example and then write it to a binary file. .. code-block:: matlab % Define the size of your data - num_samples = 1000; - num_channels = 384; + numSamples = 1000; + numChannels = 384; % Generate random data as an example - data = rand(num_samples, num_channels); + data = rand(numSamples, numChannels); % Write the data to a binary file fileID = fopen('your_data_as_a_binary.bin', 'wb'); @@ -36,22 +37,24 @@ Once you have your data in a binary format, you can seamlessly load it into Spik .. code-block:: python - from spikeinterface.core.binaryrecordingextractor import BinaryRecordingExtractor + import spikeinterface as si from pathlib import Path - # Define the path to your binary file + # In linux or mac file_path = Path("/The/Path/To/Your/Data/your_data_as_a_binary.bin") + # or for Windows + # file_path = Path(r"c:\path\to\your\data\your_data_as_a_binary.bin") # Ensure the file exists assert file_path.is_file() # Specify the parameters of your recording - sampling_frequency = 30_000.0 # in Hz, adjust as per your matlab dataset - num_channels = 384 # adjust as per your matlab dataset - dtype = "float64" + sampling_frequency = 30_000.0 # in Hz, adjust as per your MATLAB dataset + num_channels = 384 # adjust as per your MATLAB dataset + dtype = "float64" # equivalent of MATLAB double # Load the data using SpikeInterface - recording = BinaryRecordingExtractor(file_path, sampling_frequency=sampling_frequency, + recording = si.read_binary(file_path, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype, gain_to_uV=1, offset_to_uV=0) # Verify the shape of your data @@ -61,6 +64,7 @@ Common Pitfalls & Tips ---------------------- 1. **Data Shape**: Always ensure that your MATLAB data matrix's first dimension corresponds to samples/time and the second to channels. -2. **File Path**: Double-check the file path in Python to ensure you're pointing to the right directory. +2. **File Path**: Double-check the file path in Python to ensure you are pointing to the right directory. 3. **Data Type**: When moving data between MATLAB and Python, it's crucial to keep the data type consistent. In our example, we used `double` in MATLAB, which corresponds to `float64` in Python. -4. **Sampling Frequency**: Ensure you set the correct sampling frequency when loading data into SpikeInterface. +4. **Sampling Frequency**: Ensure you set the correct sampling frequency in Hz when loading data into SpikeInterface. +5. **Working on Python**: Matlab to python can feel like a big jump. If you are new to Python, we recommend checking out numpy's [Python for MATLAB Users](https://numpy.org/doc/stable/user/numpy-for-matlab-users.html) guide. From 6130e5bad0c8d825a4c44da881b5473e691a8712 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 20 Sep 2023 10:27:17 +0200 Subject: [PATCH 05/25] add an assertion --- doc/how_to/load_matlab_data.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/how_to/load_matlab_data.rst b/doc/how_to/load_matlab_data.rst index cca579036a..0a8345b792 100644 --- a/doc/how_to/load_matlab_data.rst +++ b/doc/how_to/load_matlab_data.rst @@ -46,7 +46,7 @@ Once you have your data in a binary format, you can seamlessly load it into Spik # file_path = Path(r"c:\path\to\your\data\your_data_as_a_binary.bin") # Ensure the file exists - assert file_path.is_file() + assert file_path.is_file(), f"Your path {file_path} is not a file, you probably have a typo or got the wrong path." # Specify the parameters of your recording sampling_frequency = 30_000.0 # in Hz, adjust as per your MATLAB dataset From 9a97e68f848d1126126bfecd819f456e12113813 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 20 Sep 2023 10:52:05 +0200 Subject: [PATCH 06/25] Improve the concept of check_if_json_serializable to more serialation engine like pickle. --- src/spikeinterface/comparison/hybrid.py | 6 ++- .../comparison/multicomparisons.py | 7 ++- src/spikeinterface/core/base.py | 50 +++++++++++++------ src/spikeinterface/core/generate.py | 2 + src/spikeinterface/core/numpyextractors.py | 16 ++++-- src/spikeinterface/core/old_api_utils.py | 8 ++- src/spikeinterface/core/tests/test_base.py | 12 +++-- .../core/tests/test_jsonification.py | 10 +++- .../core/tests/test_waveform_extractor.py | 41 +++++++++++++-- src/spikeinterface/core/waveform_extractor.py | 34 ++++++++++--- src/spikeinterface/preprocessing/motion.py | 3 +- src/spikeinterface/sorters/basesorter.py | 3 +- 12 files changed, 150 insertions(+), 42 deletions(-) diff --git a/src/spikeinterface/comparison/hybrid.py b/src/spikeinterface/comparison/hybrid.py index af410255b9..c48ce70147 100644 --- a/src/spikeinterface/comparison/hybrid.py +++ b/src/spikeinterface/comparison/hybrid.py @@ -84,7 +84,8 @@ def __init__( ) # save injected sorting if necessary self.injected_sorting = injected_sorting - if not self.injected_sorting.check_if_json_serializable(): + # if not self.injected_sorting.check_if_json_serializable(): + if not self.injected_sorting.check_serializablility("json"): assert injected_sorting_folder is not None, "Provide injected_sorting_folder to injected sorting object" self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder) @@ -180,7 +181,8 @@ def __init__( self.injected_sorting = injected_sorting # save injected sorting if necessary - if not self.injected_sorting.check_if_json_serializable(): + # if not self.injected_sorting.check_if_json_serializable(): + if not self.injected_sorting.check_serializablility("json"): assert injected_sorting_folder is not None, "Provide injected_sorting_folder to injected sorting object" self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder) diff --git a/src/spikeinterface/comparison/multicomparisons.py b/src/spikeinterface/comparison/multicomparisons.py index 9e02fd5b2d..3a7075905e 100644 --- a/src/spikeinterface/comparison/multicomparisons.py +++ b/src/spikeinterface/comparison/multicomparisons.py @@ -182,7 +182,8 @@ def get_agreement_sorting(self, minimum_agreement_count=1, minimum_agreement_cou def save_to_folder(self, save_folder): for sorting in self.object_list: assert ( - sorting.check_if_json_serializable() + # sorting.check_if_json_serializable() + sorting.check_serializablility("json") ), "MultiSortingComparison.save_to_folder() need json serializable sortings" save_folder = Path(save_folder) @@ -244,7 +245,9 @@ def __init__( BaseSorting.__init__(self, sampling_frequency=sampling_frequency, unit_ids=unit_ids) - self._is_json_serializable = False + # self._is_json_serializable = False + self._serializablility["json"] = False + self._serializablility["pickle"] = True if len(unit_ids) > 0: for k in ("agreement_number", "avg_agreement", "unit_ids"): diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 87c0805630..d87bd617c4 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -58,7 +58,8 @@ def __init__(self, main_ids: Sequence) -> None: self._properties = {} self._is_dumpable = True - self._is_json_serializable = True + # self._is_json_serializable = True + self._serializablility = {'json': True, 'pickle': True} # extractor specific list of pip extra requirements self.extra_requirements = [] @@ -490,6 +491,18 @@ def check_if_dumpable(self): return all([v.check_if_dumpable() for k, v in value.items()]) return self._is_dumpable + def check_serializablility(self, type="json"): + kwargs = self._kwargs + for value in kwargs.values(): + # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors + if isinstance(value, BaseExtractor): + return value.check_serializablility(type=type) + elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor): + return all([v.check_serializablility(type=type) for v in value]) + elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor): + return all([v.check_serializablility(type=type) for k, v in value.items()]) + return self._serializablility[type] + def check_if_json_serializable(self): """ Check if the object is json serializable, including nested objects. @@ -499,16 +512,23 @@ def check_if_json_serializable(self): bool True if the object is json serializable, False otherwise. """ - kwargs = self._kwargs - for value in kwargs.values(): - # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors - if isinstance(value, BaseExtractor): - return value.check_if_json_serializable() - elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor): - return all([v.check_if_json_serializable() for v in value]) - elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor): - return all([v.check_if_json_serializable() for k, v in value.items()]) - return self._is_json_serializable + # we keep this for backward compatilibity or not ???? + return self.check_serializablility("json") + + # kwargs = self._kwargs + # for value in kwargs.values(): + # # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors + # if isinstance(value, BaseExtractor): + # return value.check_if_json_serializable() + # elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor): + # return all([v.check_if_json_serializable() for v in value]) + # elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor): + # return all([v.check_if_json_serializable() for k, v in value.items()]) + # return self._is_json_serializable + + def check_if_pickle_serializable(self): + # is this needed + return self.check_serializablility("pickle") @staticmethod def _get_file_path(file_path: Union[str, Path], extensions: Sequence) -> Path: @@ -557,7 +577,7 @@ def dump(self, file_path: Union[str, Path], relative_to=None, folder_metadata=No if str(file_path).endswith(".json"): self.dump_to_json(file_path, relative_to=relative_to, folder_metadata=folder_metadata) elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"): - self.dump_to_pickle(file_path, relative_to=relative_to, folder_metadata=folder_metadata) + self.dump_to_pickle(file_path, folder_metadata=folder_metadata) else: raise ValueError("Dump: file must .json or .pkl") @@ -576,7 +596,8 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non folder_metadata: str, Path, or None Folder with files containing additional information (e.g. probe in BaseRecording) and properties. """ - assert self.check_if_json_serializable(), "The extractor is not json serializable" + # assert self.check_if_json_serializable(), "The extractor is not json serializable" + assert self.check_serializablility("json"), "The extractor is not json serializable" # Writing paths as relative_to requires recursively expanding the dict if relative_to: @@ -814,7 +835,8 @@ def save_to_folder(self, name=None, folder=None, verbose=True, **save_kwargs): # dump provenance provenance_file = folder / f"provenance.json" - if self.check_if_json_serializable(): + # if self.check_if_json_serializable(): + if self.check_serializablility("json"): self.dump(provenance_file) else: provenance_file.write_text(json.dumps({"warning": "the provenace is not dumpable!!!"}), encoding="utf8") diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 07837bcef7..706054c957 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1431,5 +1431,7 @@ def generate_ground_truth_recording( ) recording.annotate(is_filtered=True) recording.set_probe(probe, in_place=True) + recording.set_property("gain_to_uV", np.ones(num_channels)) + recording.set_property("offset_to_uV", np.zeros(num_channels)) return recording, sorting diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index d5663156c7..f55b975ddb 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -64,7 +64,9 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N assert len(t_starts) == len(traces_list), "t_starts must be a list of same size than traces_list" t_starts = [float(t_start) for t_start in t_starts] - self._is_json_serializable = False + # self._is_json_serializable = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False for i, traces in enumerate(traces_list): if t_starts is None: @@ -127,7 +129,9 @@ def __init__(self, spikes, sampling_frequency, unit_ids): BaseSorting.__init__(self, sampling_frequency, unit_ids) self._is_dumpable = True - self._is_json_serializable = False + # self._is_json_serializable = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False if spikes.size == 0: nseg = 1 @@ -358,7 +362,9 @@ def __init__(self, shm_name, shape, sampling_frequency, unit_ids, dtype=minimum_ BaseSorting.__init__(self, sampling_frequency, unit_ids) self._is_dumpable = True - self._is_json_serializable = False + # self._is_json_serializable = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False self.shm = SharedMemory(shm_name, create=False) self.shm_spikes = np.ndarray(shape=shape, dtype=dtype, buffer=self.shm.buf) @@ -517,7 +523,9 @@ def __init__(self, snippets_list, spikesframes_list, sampling_frequency, nbefore ) self._is_dumpable = False - self._is_json_serializable = False + # self._is_json_serializable = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False for snippets, spikesframes in zip(snippets_list, spikesframes_list): snp_segment = NumpySnippetsSegment(snippets, spikesframes) diff --git a/src/spikeinterface/core/old_api_utils.py b/src/spikeinterface/core/old_api_utils.py index 1ff31127f4..38fbef1547 100644 --- a/src/spikeinterface/core/old_api_utils.py +++ b/src/spikeinterface/core/old_api_utils.py @@ -183,7 +183,9 @@ def __init__(self, oldapi_recording_extractor): # set _is_dumpable to False to use dumping mechanism of old extractor self._is_dumpable = False - self._is_json_serializable = False + # self._is_json_serializable = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False self.annotate(is_filtered=oldapi_recording_extractor.is_filtered) @@ -269,7 +271,9 @@ def __init__(self, oldapi_sorting_extractor): self.add_sorting_segment(sorting_segment) self._is_dumpable = False - self._is_json_serializable = False + # self._is_json_serializable = False + self._serializablility["json"] = False + self._serializablility["pickle"] = False # add old properties copy_properties(oldapi_extractor=oldapi_sorting_extractor, new_extractor=self) diff --git a/src/spikeinterface/core/tests/test_base.py b/src/spikeinterface/core/tests/test_base.py index ea1a9cf0d2..77a5d7d9bf 100644 --- a/src/spikeinterface/core/tests/test_base.py +++ b/src/spikeinterface/core/tests/test_base.py @@ -50,18 +50,22 @@ def test_check_if_json_serializable(): test_extractor = generate_recording(seed=0, durations=[2]) # make a list of dumpable objects - test_extractor._is_json_serializable = True + # test_extractor._is_json_serializable = True + test_extractor._serializablility["json"] = True extractors_json_serializable = make_nested_extractors(test_extractor) for extractor in extractors_json_serializable: print(extractor) - assert extractor.check_if_json_serializable() + # assert extractor.check_if_json_serializable() + assert extractor.check_serializablility("json") # make not dumpable - test_extractor._is_json_serializable = False + # test_extractor._is_json_serializable = False + test_extractor._serializablility["json"] = False extractors_not_json_serializable = make_nested_extractors(test_extractor) for extractor in extractors_not_json_serializable: print(extractor) - assert not extractor.check_if_json_serializable() + # assert not extractor.check_if_json_serializable() + assert not extractor.check_serializablility("json") if __name__ == "__main__": diff --git a/src/spikeinterface/core/tests/test_jsonification.py b/src/spikeinterface/core/tests/test_jsonification.py index 473648c5ec..8572cda23e 100644 --- a/src/spikeinterface/core/tests/test_jsonification.py +++ b/src/spikeinterface/core/tests/test_jsonification.py @@ -142,9 +142,12 @@ def __init__(self, attribute, other_extractor=None, extractor_list=None, extract self.extractor_list = extractor_list self.extractor_dict = extractor_dict + BaseExtractor.__init__(self, main_ids=['1', '2']) # this already the case by default self._is_dumpable = True - self._is_json_serializable = True + # self._is_json_serializable = True + self._serializablility["json"] = True + self._serializablility["pickle"] = True self._kwargs = { "attribute": attribute, @@ -195,3 +198,8 @@ def test_encoding_numpy_scalars_within_nested_extractors_list(nested_extractor_l def test_encoding_numpy_scalars_within_nested_extractors_dict(nested_extractor_dict): json.dumps(nested_extractor_dict, cls=SIJsonEncoder) + + +if __name__ == '__main__': + nested_extractor = nested_extractor() + test_encoding_numpy_scalars_within_nested_extractors(nested_extractor_) \ No newline at end of file diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index 107ef5f180..f53b9cf18d 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -6,7 +6,7 @@ import zarr -from spikeinterface.core import generate_recording, generate_sorting, NumpySorting, ChannelSparsity +from spikeinterface.core import generate_recording, generate_sorting, NumpySorting, ChannelSparsity, generate_ground_truth_recording from spikeinterface import WaveformExtractor, BaseRecording, extract_waveforms, load_waveforms from spikeinterface.core.waveform_extractor import precompute_sparsity @@ -509,11 +509,46 @@ def test_compute_sparsity(): ) print(sparsity) +def test_non_json_object(): + recording, sorting = generate_ground_truth_recording( + durations=[30, 40], + sampling_frequency=30000.0, + num_channels=32, + num_units=5, + ) + + # recording is not save to keep it in memory + sorting = sorting.save() + + wf_folder = cache_folder / "test_waveform_extractor" + if wf_folder.is_dir(): + shutil.rmtree(wf_folder) + + + we = extract_waveforms( + recording, + sorting, + wf_folder, + mode="folder", + sparsity=None, + sparse=False, + ms_before=1.0, + ms_after=1.6, + max_spikes_per_unit=50, + n_jobs=4, + chunk_size=30000, + progress_bar=True, + ) + + # This used to fail because of json + we = load_waveforms(wf_folder) + if __name__ == "__main__": - test_WaveformExtractor() + # test_WaveformExtractor() # test_extract_waveforms() - # test_sparsity() # test_portability() # test_recordingless() # test_compute_sparsity() + test_non_json_object() + diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 6881ab3ec5..53852bf319 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -159,11 +159,20 @@ def load_from_folder( else: rec_attributes["probegroup"] = None else: - try: - recording = load_extractor(folder / "recording.json", base_folder=folder) - rec_attributes = None - except: + recording = None + if (folder / "recording.json").exists(): + try: + recording = load_extractor(folder / "recording.json", base_folder=folder) + except: + pass + elif (folder / "recording.pickle").exists(): + try: + recording = load_extractor(folder / "recording.pickle") + except: + pass + if recording is None: raise Exception("The recording could not be loaded. You can use the `with_recording=False` argument") + rec_attributes = None if sorting is None: sorting = load_extractor(folder / "sorting.json", base_folder=folder) @@ -271,9 +280,16 @@ def create( else: relative_to = None - if recording.check_if_json_serializable(): + # if recording.check_if_json_serializable(): + if recording.check_serializablility("json"): recording.dump(folder / "recording.json", relative_to=relative_to) - if sorting.check_if_json_serializable(): + elif recording.check_serializablility("pickle"): + # In this case we loose the relative_to!! + # TODO make sure that we do not dump to pickle a NumpyRecording!!!!! + recording.dump(folder / "recording.pickle") + + # if sorting.check_if_json_serializable(): + if sorting.check_serializablility("json"): sorting.dump(folder / "sorting.json", relative_to=relative_to) else: warn( @@ -879,9 +895,11 @@ def save( (folder / "params.json").write_text(json.dumps(check_json(self._params), indent=4), encoding="utf8") if self.has_recording(): - if self.recording.check_if_json_serializable(): + # if self.recording.check_if_json_serializable(): + if self.recording.check_serializablility("json"): self.recording.dump(folder / "recording.json", relative_to=relative_to) - if self.sorting.check_if_json_serializable(): + # if self.sorting.check_if_json_serializable(): + if self.sorting.check_serializablility("json"): self.sorting.dump(folder / "sorting.json", relative_to=relative_to) else: warn( diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index e2ef6e6794..0054fb94d4 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -333,7 +333,8 @@ def correct_motion( ) (folder / "parameters.json").write_text(json.dumps(parameters, indent=4, cls=SIJsonEncoder), encoding="utf8") (folder / "run_times.json").write_text(json.dumps(run_times, indent=4), encoding="utf8") - if recording.check_if_json_serializable(): + # if recording.check_if_json_serializable(): + if recording.check_serializablility("json"): recording.dump_to_json(folder / "recording.json") np.save(folder / "peaks.npy", peaks) diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index c7581ba1e1..da20506965 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -137,7 +137,8 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo ) rec_file = output_folder / "spikeinterface_recording.json" - if recording.check_if_json_serializable(): + # if recording.check_if_json_serializable(): + if recording.check_serializablility("json"): recording.dump_to_json(rec_file, relative_to=output_folder) else: d = {"warning": "The recording is not serializable to json"} From 0842509422d8498fab0c506d6ed2839b4f4d0a74 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 20 Sep 2023 12:29:11 +0200 Subject: [PATCH 07/25] my final version --- doc/how_to/load_matlab_data.rst | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/doc/how_to/load_matlab_data.rst b/doc/how_to/load_matlab_data.rst index 0a8345b792..3e602012a1 100644 --- a/doc/how_to/load_matlab_data.rst +++ b/doc/how_to/load_matlab_data.rst @@ -55,16 +55,42 @@ Once you have your data in a binary format, you can seamlessly load it into Spik # Load the data using SpikeInterface recording = si.read_binary(file_path, sampling_frequency=sampling_frequency, - num_channels=num_channels, dtype=dtype, gain_to_uV=1, offset_to_uV=0) + num_channels=num_channels, dtype=dtype) # Verify the shape of your data assert recording.get_traces().shape == (num_samples, num_channels) +This should be enough to get you started with loading your MATLAB data into SpikeInterface. You can use all the Spikeinterface machinery to process your data, including filtering, spike sorting, and more. + Common Pitfalls & Tips ---------------------- -1. **Data Shape**: Always ensure that your MATLAB data matrix's first dimension corresponds to samples/time and the second to channels. +1. **Data Shape**: Always ensure that your MATLAB data matrix's first dimension corresponds to samples/time and the second to channels. If the time happens to be in the second dimension, you can use `time_axis=1` as an argument in `si.read_binary()` to account for this. 2. **File Path**: Double-check the file path in Python to ensure you are pointing to the right directory. 3. **Data Type**: When moving data between MATLAB and Python, it's crucial to keep the data type consistent. In our example, we used `double` in MATLAB, which corresponds to `float64` in Python. 4. **Sampling Frequency**: Ensure you set the correct sampling frequency in Hz when loading data into SpikeInterface. 5. **Working on Python**: Matlab to python can feel like a big jump. If you are new to Python, we recommend checking out numpy's [Python for MATLAB Users](https://numpy.org/doc/stable/user/numpy-for-matlab-users.html) guide. + + +Using gains and offsets for integer data +---------------------------------------- + +A common technique used in raw formats is to store data as integer values, which provides a memory-efficient representation (i.e. lower ram) and use a gain and offset to convert it to float values that represent meaningful physical units. +In SpikeInterface this is done using the `gain_to_uV` and `offset_to_uV` parameters as the we handle traces in microvolts. Both values can be passed to `read_binary` when loading the data: + +.. code-block:: python + + sampling_frequency = 30_000.0 # in Hz, adjust as per your MATLAB dataset + num_channels = 384 # adjust as per your MATLAB dataset + dtype_int = 'int16' # adjust as per your MATLAB dataset + gain_to_uV = 0.195 # adjust as per your MATLAB dataset + offset_to_uV = 0 # adjust as per your MATLAB dataset + + recording = si.read_binary(file_path, sampling_frequency=sampling_frequency, + num_channels=num_channels, dtype=dtype_int, + gain_to_uV=gain_to_uV, offset_to_uV=offset_to_uV) + + recording.get_traces(start) + + +This will equip your recording object with capabilities to convert the data to float values in uV using the `get_traces()` method with the `return_scaled` parameter set to True. From 1ead6a33e658bf5a0365d21506a90dd9bd32e67c Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 20 Sep 2023 12:45:06 +0200 Subject: [PATCH 08/25] final review --- doc/how_to/load_matlab_data.rst | 72 +++++++++++++++++---------------- 1 file changed, 38 insertions(+), 34 deletions(-) diff --git a/doc/how_to/load_matlab_data.rst b/doc/how_to/load_matlab_data.rst index 3e602012a1..0a80f1fdf9 100644 --- a/doc/how_to/load_matlab_data.rst +++ b/doc/how_to/load_matlab_data.rst @@ -1,13 +1,13 @@ Exporting MATLAB Data to Binary & Loading in SpikeInterface =========================================================== -In this tutorial, we'll go through the process of exporting your data from MATLAB in a binary format and then loading it using SpikeInterface in Python. Let's break down the steps. +In this tutorial, we will walk through the process of exporting data from MATLAB in a binary format and subsequently loading it using SpikeInterface in Python. Exporting Data from MATLAB -------------------------- -First, ensure your data is structured correctly. The data matrix should be organized such that the first dimension corresponds to samples/time and the second dimension to channels. -In the following MATLAB code, we generate random data as an example and then write it to a binary file. +Begin by ensuring your data structure is correct. Organize your data matrix so that the first dimension corresponds to samples/time and the second to channels. +Here, we present a MATLAB code that creates a random dataset and writes it to a binary file as an illustration. .. code-block:: matlab @@ -25,72 +25,76 @@ In the following MATLAB code, we generate random data as an example and then wri .. note:: - In a real-world scenario, replace the random data generation with your actual data. + In your own script, replace the random data generation with your actual dataset. Loading Data in SpikeInterface ----------------------------- -This should produce a binary file called `your_data_as_a_binary.bin` in your current MATLAB directory. -You will need the complete path (i.e. its location on your computer) to load it in Python. +After executing the above MATLAB code, a binary file named `your_data_as_a_binary.bin` will be created in your MATLAB directory. To load this file in Python, you'll need its full path. -Once you have your data in a binary format, you can seamlessly load it into SpikeInterface using the following script: +Use the following Python script to load the binary data into SpikeInterface: .. code-block:: python import spikeinterface as si from pathlib import Path - # In linux or mac + # Define file path + # For Linux or macOS: file_path = Path("/The/Path/To/Your/Data/your_data_as_a_binary.bin") - # or for Windows + # For Windows: # file_path = Path(r"c:\path\to\your\data\your_data_as_a_binary.bin") - # Ensure the file exists - assert file_path.is_file(), f"Your path {file_path} is not a file, you probably have a typo or got the wrong path." + # Confirm file existence + assert file_path.is_file(), f"Error: {file_path} is not a valid file. Please check the path." - # Specify the parameters of your recording - sampling_frequency = 30_000.0 # in Hz, adjust as per your MATLAB dataset - num_channels = 384 # adjust as per your MATLAB dataset - dtype = "float64" # equivalent of MATLAB double + # Define recording parameters + sampling_frequency = 30_000.0 # Adjust according to your MATLAB dataset + num_channels = 384 # Adjust according to your MATLAB dataset + dtype = "float64" # MATLAB's double corresponds to Python's float64 - # Load the data using SpikeInterface + # Load data using SpikeInterface recording = si.read_binary(file_path, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype) - # Verify the shape of your data - assert recording.get_traces().shape == (num_samples, num_channels) + # Confirm the data shape + assert recording.get_traces().shape == (numSamples, num_channels) -This should be enough to get you started with loading your MATLAB data into SpikeInterface. You can use all the Spikeinterface machinery to process your data, including filtering, spike sorting, and more. +Follow the steps above to seamlessly import your MATLAB data into SpikeInterface. Once loaded, you can harness the full power of SpikeInterface for data processing, including filtering, spike sorting, and more. Common Pitfalls & Tips ---------------------- -1. **Data Shape**: Always ensure that your MATLAB data matrix's first dimension corresponds to samples/time and the second to channels. If the time happens to be in the second dimension, you can use `time_axis=1` as an argument in `si.read_binary()` to account for this. -2. **File Path**: Double-check the file path in Python to ensure you are pointing to the right directory. -3. **Data Type**: When moving data between MATLAB and Python, it's crucial to keep the data type consistent. In our example, we used `double` in MATLAB, which corresponds to `float64` in Python. -4. **Sampling Frequency**: Ensure you set the correct sampling frequency in Hz when loading data into SpikeInterface. -5. **Working on Python**: Matlab to python can feel like a big jump. If you are new to Python, we recommend checking out numpy's [Python for MATLAB Users](https://numpy.org/doc/stable/user/numpy-for-matlab-users.html) guide. - +1. **Data Shape**: Make sure your MATLAB data matrix's first dimension is samples/time and the second is channels. If your time is in the second dimension, use `time_axis=1` in `si.read_binary()`. +2. **File Path**: Always double-check the Python file path. +3. **Data Type Consistency**: Ensure data types between MATLAB and Python are consistent. MATLAB's `double` is equivalent to nUMPY's `float64`. +4. **Sampling Frequency**: Set the appropriate sampling frequency in Hz for SpikeInterface. +5. **Transition to Python**: Moving from MATLAB to Python can be challenging. For newcomers to Python, consider reviewing numpy's [Numpy for MATLAB Users](https://numpy.org/doc/stable/user/numpy-for-matlab-users.html) guide. Using gains and offsets for integer data ---------------------------------------- -A common technique used in raw formats is to store data as integer values, which provides a memory-efficient representation (i.e. lower ram) and use a gain and offset to convert it to float values that represent meaningful physical units. -In SpikeInterface this is done using the `gain_to_uV` and `offset_to_uV` parameters as the we handle traces in microvolts. Both values can be passed to `read_binary` when loading the data: +Raw data formats often store data as integer values for memory efficiency. To give these integers meaningful physical units, you can apply a gain and an offset. +In SpikeInterface, you can use the `gain_to_uV` and `offset_to_uV` parameters, since traces are handled in microvolts (uV). Both parameters can be integrated into the `read_binary` function. +If your data in MATLAB is stored as `int16`, and you know the gain and offset, you can use the following code to load the data: .. code-block:: python - sampling_frequency = 30_000.0 # in Hz, adjust as per your MATLAB dataset - num_channels = 384 # adjust as per your MATLAB dataset - dtype_int = 'int16' # adjust as per your MATLAB dataset - gain_to_uV = 0.195 # adjust as per your MATLAB dataset - offset_to_uV = 0 # adjust as per your MATLAB dataset + sampling_frequency = 30_000.0 # Adjust according to your MATLAB dataset + num_channels = 384 # Adjust according to your MATLAB dataset + dtype_int = 'int16' # Adjust according to your MATLAB dataset + gain_to_uV = 0.195 # Adjust according to your MATLAB dataset + offset_to_uV = 0 # Adjust according to your MATLAB dataset recording = si.read_binary(file_path, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype_int, gain_to_uV=gain_to_uV, offset_to_uV=offset_to_uV) - recording.get_traces(start) + recording.get_traces(return_scaled=True) # Return traces in micro volts (uV) + +This will equip your recording object with capabilities to convert the data to float values in uV using the `get_traces()` method with the `return_scaled` parameter set to `True`. + +.. note:: -This will equip your recording object with capabilities to convert the data to float values in uV using the `get_traces()` method with the `return_scaled` parameter set to True. + The gain and offset parameters are usually format depend and you will need to find out the correct values for your data format. You can load your data without gain and offset but then the traces will be in integer values and not in uV. From e31978ce8355dda2d87a713c2495ec915b805f92 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 20 Sep 2023 12:53:47 +0200 Subject: [PATCH 09/25] typo --- doc/how_to/load_matlab_data.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/how_to/load_matlab_data.rst b/doc/how_to/load_matlab_data.rst index 0a80f1fdf9..ca543ba43a 100644 --- a/doc/how_to/load_matlab_data.rst +++ b/doc/how_to/load_matlab_data.rst @@ -67,7 +67,7 @@ Common Pitfalls & Tips 1. **Data Shape**: Make sure your MATLAB data matrix's first dimension is samples/time and the second is channels. If your time is in the second dimension, use `time_axis=1` in `si.read_binary()`. 2. **File Path**: Always double-check the Python file path. -3. **Data Type Consistency**: Ensure data types between MATLAB and Python are consistent. MATLAB's `double` is equivalent to nUMPY's `float64`. +3. **Data Type Consistency**: Ensure data types between MATLAB and Python are consistent. MATLAB's `double` is equivalent to Numpy's `float64`. 4. **Sampling Frequency**: Set the appropriate sampling frequency in Hz for SpikeInterface. 5. **Transition to Python**: Moving from MATLAB to Python can be challenging. For newcomers to Python, consider reviewing numpy's [Numpy for MATLAB Users](https://numpy.org/doc/stable/user/numpy-for-matlab-users.html) guide. From 5aba5e0f65532165488303203d7739e188fe6e0c Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 20 Sep 2023 12:57:44 +0200 Subject: [PATCH 10/25] Update doc/how_to/load_matlab_data.rst Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- doc/how_to/load_matlab_data.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/how_to/load_matlab_data.rst b/doc/how_to/load_matlab_data.rst index ca543ba43a..7f90684701 100644 --- a/doc/how_to/load_matlab_data.rst +++ b/doc/how_to/load_matlab_data.rst @@ -97,4 +97,4 @@ This will equip your recording object with capabilities to convert the data to f .. note:: - The gain and offset parameters are usually format depend and you will need to find out the correct values for your data format. You can load your data without gain and offset but then the traces will be in integer values and not in uV. + The gain and offset parameters are usually format dependent and you will need to find out the correct values for your data format. You can load your data without gain and offset but then the traces will be in integer values and not in uV. From 3f4e182380995f56d458163356a70a813af6b146 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 20 Sep 2023 14:01:50 +0200 Subject: [PATCH 11/25] More check and clean for check_if_serializable() --- src/spikeinterface/comparison/hybrid.py | 4 +- .../comparison/multicomparisons.py | 2 - src/spikeinterface/core/base.py | 46 +++++++++---------- src/spikeinterface/core/generate.py | 2 + src/spikeinterface/core/numpyextractors.py | 8 ++-- src/spikeinterface/core/old_api_utils.py | 2 - src/spikeinterface/core/tests/test_base.py | 7 +-- .../core/tests/test_waveform_extractor.py | 2 + src/spikeinterface/core/waveform_extractor.py | 18 +++++--- src/spikeinterface/preprocessing/motion.py | 1 - 10 files changed, 44 insertions(+), 48 deletions(-) diff --git a/src/spikeinterface/comparison/hybrid.py b/src/spikeinterface/comparison/hybrid.py index c48ce70147..3b8e9e0a72 100644 --- a/src/spikeinterface/comparison/hybrid.py +++ b/src/spikeinterface/comparison/hybrid.py @@ -84,8 +84,8 @@ def __init__( ) # save injected sorting if necessary self.injected_sorting = injected_sorting - # if not self.injected_sorting.check_if_json_serializable(): if not self.injected_sorting.check_serializablility("json"): + # TODO later : also use pickle assert injected_sorting_folder is not None, "Provide injected_sorting_folder to injected sorting object" self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder) @@ -181,8 +181,8 @@ def __init__( self.injected_sorting = injected_sorting # save injected sorting if necessary - # if not self.injected_sorting.check_if_json_serializable(): if not self.injected_sorting.check_serializablility("json"): + # TODO later : also use pickle assert injected_sorting_folder is not None, "Provide injected_sorting_folder to injected sorting object" self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder) diff --git a/src/spikeinterface/comparison/multicomparisons.py b/src/spikeinterface/comparison/multicomparisons.py index 3a7075905e..09a8c8aed1 100644 --- a/src/spikeinterface/comparison/multicomparisons.py +++ b/src/spikeinterface/comparison/multicomparisons.py @@ -182,7 +182,6 @@ def get_agreement_sorting(self, minimum_agreement_count=1, minimum_agreement_cou def save_to_folder(self, save_folder): for sorting in self.object_list: assert ( - # sorting.check_if_json_serializable() sorting.check_serializablility("json") ), "MultiSortingComparison.save_to_folder() need json serializable sortings" @@ -245,7 +244,6 @@ def __init__( BaseSorting.__init__(self, sampling_frequency=sampling_frequency, unit_ids=unit_ids) - # self._is_json_serializable = False self._serializablility["json"] = False self._serializablility["pickle"] = True diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index d87bd617c4..63cf8e894f 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -484,11 +484,16 @@ def check_if_dumpable(self): for value in kwargs.values(): # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors if isinstance(value, BaseExtractor): - return value.check_if_dumpable() - elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor): - return all([v.check_if_dumpable() for v in value]) - elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor): - return all([v.check_if_dumpable() for k, v in value.items()]) + if not value.check_if_dumpable(): + return False + elif isinstance(value, list): + for v in value: + if isinstance(v, BaseExtractor) and not v.check_if_dumpable(): + return False + elif isinstance(value, dict): + for v in value.values(): + if isinstance(v, BaseExtractor) and not v.check_if_dumpable(): + return False return self._is_dumpable def check_serializablility(self, type="json"): @@ -496,11 +501,16 @@ def check_serializablility(self, type="json"): for value in kwargs.values(): # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors if isinstance(value, BaseExtractor): - return value.check_serializablility(type=type) - elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor): - return all([v.check_serializablility(type=type) for v in value]) - elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor): - return all([v.check_serializablility(type=type) for k, v in value.items()]) + if not value.check_serializablility(type=type): + return False + elif isinstance(value, list): + for v in value: + if isinstance(v, BaseExtractor) and not v.check_serializablility(type=type): + return False + elif isinstance(value, dict): + for v in value.values(): + if isinstance(v, BaseExtractor) and not v.check_serializablility(type=type): + return False return self._serializablility[type] def check_if_json_serializable(self): @@ -513,21 +523,11 @@ def check_if_json_serializable(self): True if the object is json serializable, False otherwise. """ # we keep this for backward compatilibity or not ???? + # is this needed ??? I think no. return self.check_serializablility("json") - # kwargs = self._kwargs - # for value in kwargs.values(): - # # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors - # if isinstance(value, BaseExtractor): - # return value.check_if_json_serializable() - # elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor): - # return all([v.check_if_json_serializable() for v in value]) - # elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor): - # return all([v.check_if_json_serializable() for k, v in value.items()]) - # return self._is_json_serializable - def check_if_pickle_serializable(self): - # is this needed + # is this needed ??? I think no. return self.check_serializablility("pickle") @staticmethod @@ -596,7 +596,6 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non folder_metadata: str, Path, or None Folder with files containing additional information (e.g. probe in BaseRecording) and properties. """ - # assert self.check_if_json_serializable(), "The extractor is not json serializable" assert self.check_serializablility("json"), "The extractor is not json serializable" # Writing paths as relative_to requires recursively expanding the dict @@ -835,7 +834,6 @@ def save_to_folder(self, name=None, folder=None, verbose=True, **save_kwargs): # dump provenance provenance_file = folder / f"provenance.json" - # if self.check_if_json_serializable(): if self.check_serializablility("json"): self.dump(provenance_file) else: diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 706054c957..362b598b0b 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1056,6 +1056,8 @@ def __init__( dtype = parent_recording.dtype if parent_recording is not None else templates.dtype BaseRecording.__init__(self, sorting.get_sampling_frequency(), channel_ids, dtype) + # Important : self._serializablility is not change here because it will depend on the sorting parents itself. + n_units = len(sorting.unit_ids) assert len(templates) == n_units self.spike_vector = sorting.to_spike_vector() diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index f55b975ddb..5ef955a6eb 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -64,7 +64,6 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N assert len(t_starts) == len(traces_list), "t_starts must be a list of same size than traces_list" t_starts = [float(t_start) for t_start in t_starts] - # self._is_json_serializable = False self._serializablility["json"] = False self._serializablility["pickle"] = False @@ -129,9 +128,9 @@ def __init__(self, spikes, sampling_frequency, unit_ids): BaseSorting.__init__(self, sampling_frequency, unit_ids) self._is_dumpable = True - # self._is_json_serializable = False self._serializablility["json"] = False - self._serializablility["pickle"] = False + # theorically this should be False but for simplicity make generators simples we still need this. + self._serializablility["pickle"] = True if spikes.size == 0: nseg = 1 @@ -362,7 +361,7 @@ def __init__(self, shm_name, shape, sampling_frequency, unit_ids, dtype=minimum_ BaseSorting.__init__(self, sampling_frequency, unit_ids) self._is_dumpable = True - # self._is_json_serializable = False + self._serializablility["json"] = False self._serializablility["pickle"] = False @@ -523,7 +522,6 @@ def __init__(self, snippets_list, spikesframes_list, sampling_frequency, nbefore ) self._is_dumpable = False - # self._is_json_serializable = False self._serializablility["json"] = False self._serializablility["pickle"] = False diff --git a/src/spikeinterface/core/old_api_utils.py b/src/spikeinterface/core/old_api_utils.py index 38fbef1547..a31edb0dd7 100644 --- a/src/spikeinterface/core/old_api_utils.py +++ b/src/spikeinterface/core/old_api_utils.py @@ -183,7 +183,6 @@ def __init__(self, oldapi_recording_extractor): # set _is_dumpable to False to use dumping mechanism of old extractor self._is_dumpable = False - # self._is_json_serializable = False self._serializablility["json"] = False self._serializablility["pickle"] = False @@ -271,7 +270,6 @@ def __init__(self, oldapi_sorting_extractor): self.add_sorting_segment(sorting_segment) self._is_dumpable = False - # self._is_json_serializable = False self._serializablility["json"] = False self._serializablility["pickle"] = False diff --git a/src/spikeinterface/core/tests/test_base.py b/src/spikeinterface/core/tests/test_base.py index 77a5d7d9bf..b716f6b1dd 100644 --- a/src/spikeinterface/core/tests/test_base.py +++ b/src/spikeinterface/core/tests/test_base.py @@ -46,16 +46,14 @@ def test_check_if_dumpable(): assert not extractor.check_if_dumpable() -def test_check_if_json_serializable(): +def test_check_if_serializable(): test_extractor = generate_recording(seed=0, durations=[2]) # make a list of dumpable objects - # test_extractor._is_json_serializable = True test_extractor._serializablility["json"] = True extractors_json_serializable = make_nested_extractors(test_extractor) for extractor in extractors_json_serializable: print(extractor) - # assert extractor.check_if_json_serializable() assert extractor.check_serializablility("json") # make not dumpable @@ -64,10 +62,9 @@ def test_check_if_json_serializable(): extractors_not_json_serializable = make_nested_extractors(test_extractor) for extractor in extractors_not_json_serializable: print(extractor) - # assert not extractor.check_if_json_serializable() assert not extractor.check_serializablility("json") if __name__ == "__main__": test_check_if_dumpable() - test_check_if_json_serializable() + test_check_if_serializable() diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index f53b9cf18d..3972c9186c 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -517,6 +517,8 @@ def test_non_json_object(): num_units=5, ) + + print(recording.check_serializablility("pickle")) # recording is not save to keep it in memory sorting = sorting.save() diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 53852bf319..3de1429feb 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -280,17 +280,17 @@ def create( else: relative_to = None - # if recording.check_if_json_serializable(): if recording.check_serializablility("json"): recording.dump(folder / "recording.json", relative_to=relative_to) elif recording.check_serializablility("pickle"): # In this case we loose the relative_to!! - # TODO make sure that we do not dump to pickle a NumpyRecording!!!!! recording.dump(folder / "recording.pickle") - # if sorting.check_if_json_serializable(): if sorting.check_serializablility("json"): sorting.dump(folder / "sorting.json", relative_to=relative_to) + elif sorting.check_serializablility("pickle"): + # In this case we loose the relative_to!! + sorting.dump(folder / "sorting.pickle") else: warn( "Sorting object is not dumpable, which might result in downstream errors for " @@ -895,12 +895,16 @@ def save( (folder / "params.json").write_text(json.dumps(check_json(self._params), indent=4), encoding="utf8") if self.has_recording(): - # if self.recording.check_if_json_serializable(): if self.recording.check_serializablility("json"): self.recording.dump(folder / "recording.json", relative_to=relative_to) - # if self.sorting.check_if_json_serializable(): + elif self.recording.check_serializablility("pickle"): + self.recording.dump(folder / "recording.pickle") + + if self.sorting.check_serializablility("json"): self.sorting.dump(folder / "sorting.json", relative_to=relative_to) + elif self.sorting.check_serializablility("pickle"): + self.sorting.dump(folder / "sorting.pickle", relative_to=relative_to) else: warn( "Sorting object is not dumpable, which might result in downstream errors for " @@ -949,10 +953,10 @@ def save( # write metadata zarr_root.attrs["params"] = check_json(self._params) if self.has_recording(): - if self.recording.check_if_json_serializable(): + if self.recording.check_serializablility("json"): rec_dict = self.recording.to_dict(relative_to=relative_to, recursive=True) zarr_root.attrs["recording"] = check_json(rec_dict) - if self.sorting.check_if_json_serializable(): + if self.sorting.check_serializablility("json"): sort_dict = self.sorting.to_dict(relative_to=relative_to, recursive=True) zarr_root.attrs["sorting"] = check_json(sort_dict) else: diff --git a/src/spikeinterface/preprocessing/motion.py b/src/spikeinterface/preprocessing/motion.py index 0054fb94d4..6ab1a9afce 100644 --- a/src/spikeinterface/preprocessing/motion.py +++ b/src/spikeinterface/preprocessing/motion.py @@ -333,7 +333,6 @@ def correct_motion( ) (folder / "parameters.json").write_text(json.dumps(parameters, indent=4, cls=SIJsonEncoder), encoding="utf8") (folder / "run_times.json").write_text(json.dumps(run_times, indent=4), encoding="utf8") - # if recording.check_if_json_serializable(): if recording.check_serializablility("json"): recording.dump_to_json(folder / "recording.json") From 615c5d9cd219e4016e7149f1ce170f043d507333 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 20 Sep 2023 14:19:46 +0200 Subject: [PATCH 12/25] Make pickle possible to dump in run sorter when json is not possible. --- src/spikeinterface/sorters/basesorter.py | 61 ++++++++++++------- .../sorters/external/herdingspikes.py | 4 +- .../sorters/external/mountainsort4.py | 4 +- .../sorters/external/mountainsort5.py | 4 +- .../sorters/external/pykilosort.py | 4 +- .../sorters/internal/spyking_circus2.py | 5 +- .../sorters/internal/tridesclous2.py | 4 +- src/spikeinterface/sorters/runsorter.py | 15 ++++- 8 files changed, 59 insertions(+), 42 deletions(-) diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index da20506965..bbcde31eed 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -137,9 +137,10 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo ) rec_file = output_folder / "spikeinterface_recording.json" - # if recording.check_if_json_serializable(): if recording.check_serializablility("json"): - recording.dump_to_json(rec_file, relative_to=output_folder) + recording.dump(rec_file, relative_to=output_folder) + elif recording.check_serializablility("pickle"): + recording.dump(output_folder / "spikeinterface_recording.pickle") else: d = {"warning": "The recording is not serializable to json"} rec_file.write_text(json.dumps(d, indent=4), encoding="utf8") @@ -186,6 +187,28 @@ def set_params_to_folder(cls, recording, output_folder, new_params, verbose): return params + @classmethod + def load_recording_from_folder(cls, output_folder, with_warnings=False): + + json_file = output_folder / "spikeinterface_recording.json" + pickle_file = output_folder / "spikeinterface_recording.pickle" + + + if json_file.exists(): + with (json_file).open("r", encoding="utf8") as f: + recording_dict = json.load(f) + if "warning" in recording_dict.keys() and with_warnings: + warnings.warn( + "The recording that has been sorted is not JSON serializable: it cannot be registered to the sorting object." + ) + recording = None + else: + recording = load_extractor(json_file, base_folder=output_folder) + elif pickle_file.exits(): + recording = load_extractor(pickle_file) + + return recording + @classmethod def _dump_params(cls, recording, output_folder, sorter_params, verbose): with (output_folder / "spikeinterface_params.json").open(mode="w", encoding="utf8") as f: @@ -272,7 +295,7 @@ def run_from_folder(cls, output_folder, raise_error, verbose): return run_time @classmethod - def get_result_from_folder(cls, output_folder): + def get_result_from_folder(cls, output_folder, register_recording=True, sorting_info=True): output_folder = Path(output_folder) sorter_output_folder = output_folder / "sorter_output" # check errors in log file @@ -295,27 +318,21 @@ def get_result_from_folder(cls, output_folder): # back-compatibility sorting = cls._get_result_from_folder(output_folder) - # register recording to Sorting object - # check if not json serializable - with (output_folder / "spikeinterface_recording.json").open("r", encoding="utf8") as f: - recording_dict = json.load(f) - if "warning" in recording_dict.keys(): - warnings.warn( - "The recording that has been sorted is not JSON serializable: it cannot be registered to the sorting object." - ) - else: - recording = load_extractor(output_folder / "spikeinterface_recording.json", base_folder=output_folder) + if register_recording: + # register recording to Sorting object + recording = cls.load_recording_from_folder( output_folder, with_warnings=False) if recording is not None: - # can be None when not dumpable sorting.register_recording(recording) - # set sorting info to Sorting object - with open(output_folder / "spikeinterface_recording.json", "r") as f: - rec_dict = json.load(f) - with open(output_folder / "spikeinterface_params.json", "r") as f: - params_dict = json.load(f) - with open(output_folder / "spikeinterface_log.json", "r") as f: - log_dict = json.load(f) - sorting.set_sorting_info(rec_dict, params_dict, log_dict) + + if sorting_info: + # set sorting info to Sorting object + with open(output_folder / "spikeinterface_recording.json", "r") as f: + rec_dict = json.load(f) + with open(output_folder / "spikeinterface_params.json", "r") as f: + params_dict = json.load(f) + with open(output_folder / "spikeinterface_log.json", "r") as f: + log_dict = json.load(f) + sorting.set_sorting_info(rec_dict, params_dict, log_dict) return sorting diff --git a/src/spikeinterface/sorters/external/herdingspikes.py b/src/spikeinterface/sorters/external/herdingspikes.py index a8d702ebe9..5180e6f1cc 100644 --- a/src/spikeinterface/sorters/external/herdingspikes.py +++ b/src/spikeinterface/sorters/external/herdingspikes.py @@ -147,9 +147,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): else: new_api = False - recording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) p = params diff --git a/src/spikeinterface/sorters/external/mountainsort4.py b/src/spikeinterface/sorters/external/mountainsort4.py index 69f97fd11c..f6f0b3eaeb 100644 --- a/src/spikeinterface/sorters/external/mountainsort4.py +++ b/src/spikeinterface/sorters/external/mountainsort4.py @@ -89,9 +89,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): def _run_from_folder(cls, sorter_output_folder, params, verbose): import mountainsort4 - recording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) # alias to params p = params diff --git a/src/spikeinterface/sorters/external/mountainsort5.py b/src/spikeinterface/sorters/external/mountainsort5.py index df6d276bf5..a88c59d688 100644 --- a/src/spikeinterface/sorters/external/mountainsort5.py +++ b/src/spikeinterface/sorters/external/mountainsort5.py @@ -115,9 +115,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): def _run_from_folder(cls, sorter_output_folder, params, verbose): import mountainsort5 as ms5 - recording: BaseRecording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) # alias to params p = params diff --git a/src/spikeinterface/sorters/external/pykilosort.py b/src/spikeinterface/sorters/external/pykilosort.py index 2a41d793d5..1962d56206 100644 --- a/src/spikeinterface/sorters/external/pykilosort.py +++ b/src/spikeinterface/sorters/external/pykilosort.py @@ -148,9 +148,7 @@ def _setup_recording(cls, recording, sorter_output_folder, params, verbose): @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): - recording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) if not recording.binary_compatible_with(time_axis=0, file_paths_lenght=1): # saved by setup recording diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 9de2762562..86cce1959b 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -54,9 +54,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): job_kwargs["verbose"] = verbose job_kwargs["progress_bar"] = verbose - recording = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) + sampling_rate = recording.get_sampling_frequency() num_channels = recording.get_num_channels() diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 42f51d3a77..ed327e0f3c 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -49,9 +49,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): import hdbscan - recording_raw = load_extractor( - sorter_output_folder.parent / "spikeinterface_recording.json", base_folder=sorter_output_folder.parent - ) + recording_raw = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) num_chans = recording_raw.get_num_channels() sampling_frequency = recording_raw.get_sampling_frequency() diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index 6e6ccc0358..e930ec7f79 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -624,10 +624,20 @@ def run_sorter_container( ) -def read_sorter_folder(output_folder, raise_error=True): +def read_sorter_folder(output_folder, register_recording=True, sorting_info=True, raise_error=True): """ Load a sorting object from a spike sorting output folder. The 'output_folder' must contain a valid 'spikeinterface_log.json' file + + + Parameters + ---------- + output_folder: Pth or str + The sorter folder + register_recording: bool, default: True + Attach recording (when json or pickle) to the sorting + sorting_info: bool, default: True + Attach sorting info to the sorting. """ output_folder = Path(output_folder) log_file = output_folder / "spikeinterface_log.json" @@ -647,7 +657,8 @@ def read_sorter_folder(output_folder, raise_error=True): sorter_name = log["sorter_name"] SorterClass = sorter_dict[sorter_name] - sorting = SorterClass.get_result_from_folder(output_folder) + sorting = SorterClass.get_result_from_folder(output_folder, register_recording=register_recording, + sorting_info=sorting_info) return sorting From b231e2dade552413bdd68e18aad95881a047f4cb Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 20 Sep 2023 14:47:14 +0200 Subject: [PATCH 13/25] correction --- doc/how_to/load_matlab_data.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/how_to/load_matlab_data.rst b/doc/how_to/load_matlab_data.rst index 7f90684701..0186ecf72b 100644 --- a/doc/how_to/load_matlab_data.rst +++ b/doc/how_to/load_matlab_data.rst @@ -57,8 +57,8 @@ Use the following Python script to load the binary data into SpikeInterface: recording = si.read_binary(file_path, sampling_frequency=sampling_frequency, num_channels=num_channels, dtype=dtype) - # Confirm the data shape - assert recording.get_traces().shape == (numSamples, num_channels) + # Confirm that the data was loaded correctly by comparing the data shapes and see they match the MATLAB data + print(recording.get_num_frames(), recording.get_num_channels()) Follow the steps above to seamlessly import your MATLAB data into SpikeInterface. Once loaded, you can harness the full power of SpikeInterface for data processing, including filtering, spike sorting, and more. From fb7681520e74a01be0fd4e56740936a4f6de4e25 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Wed, 20 Sep 2023 16:40:43 +0200 Subject: [PATCH 14/25] Update doc/how_to/load_matlab_data.rst Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- doc/how_to/load_matlab_data.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/how_to/load_matlab_data.rst b/doc/how_to/load_matlab_data.rst index 0186ecf72b..3943fbd30f 100644 --- a/doc/how_to/load_matlab_data.rst +++ b/doc/how_to/load_matlab_data.rst @@ -28,7 +28,7 @@ Here, we present a MATLAB code that creates a random dataset and writes it to a In your own script, replace the random data generation with your actual dataset. Loading Data in SpikeInterface ------------------------------ +------------------------------ After executing the above MATLAB code, a binary file named `your_data_as_a_binary.bin` will be created in your MATLAB directory. To load this file in Python, you'll need its full path. From 9ba6fc6cbf0b0fd3d7bfa0b22108c48a05770b67 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 21 Sep 2023 14:01:25 +0200 Subject: [PATCH 15/25] Update doc/how_to/load_matlab_data.rst Co-authored-by: Alessio Buccino --- doc/how_to/load_matlab_data.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/how_to/load_matlab_data.rst b/doc/how_to/load_matlab_data.rst index 3943fbd30f..aaca718096 100644 --- a/doc/how_to/load_matlab_data.rst +++ b/doc/how_to/load_matlab_data.rst @@ -93,7 +93,7 @@ If your data in MATLAB is stored as `int16`, and you know the gain and offset, y recording.get_traces(return_scaled=True) # Return traces in micro volts (uV) -This will equip your recording object with capabilities to convert the data to float values in uV using the `get_traces()` method with the `return_scaled` parameter set to `True`. +This will equip your recording object with capabilities to convert the data to float values in uV using the :code:`get_traces()` method with the :code:`return_scaled` parameter set to :code:`True`. .. note:: From c33f7233b54ccce797a903f8f495d8dbb30f0b2a Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 22 Sep 2023 10:21:53 -0400 Subject: [PATCH 16/25] test reorganize folders --- doc/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/conf.py b/doc/conf.py index 15cb65d46a..b120393911 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -118,11 +118,11 @@ 'examples_dirs': ['../examples/modules_gallery'], 'gallery_dirs': ['modules_gallery', ], # path where to save gallery generated examples 'subsection_order': ExplicitOrder([ - '../examples/modules_gallery/core/', '../examples/modules_gallery/extractors/', '../examples/modules_gallery/qualitymetrics', '../examples/modules_gallery/comparison', '../examples/modules_gallery/widgets', + '../examples/modules_gallery/core/', ]), 'within_subsection_order': FileNameSortKey, 'ignore_pattern': '/generate_', From f2188266647d7faf721d89089b6f9c0bd1d9e637 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Fri, 22 Sep 2023 16:22:01 +0200 Subject: [PATCH 17/25] feedback from Ramon --- src/spikeinterface/core/generate.py | 4 ++-- src/spikeinterface/core/tests/test_waveform_extractor.py | 2 -- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 362b598b0b..05d63f3c8d 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1433,7 +1433,7 @@ def generate_ground_truth_recording( ) recording.annotate(is_filtered=True) recording.set_probe(probe, in_place=True) - recording.set_property("gain_to_uV", np.ones(num_channels)) - recording.set_property("offset_to_uV", np.zeros(num_channels)) + recording.set_channel_gains(1.) + recording.set_channel_offsets(0.) return recording, sorting diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index 3972c9186c..f53b9cf18d 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -517,8 +517,6 @@ def test_non_json_object(): num_units=5, ) - - print(recording.check_serializablility("pickle")) # recording is not save to keep it in memory sorting = sorting.save() From 96be72e5ac05ec7f3bd63f866783b733fca22ab8 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 22 Sep 2023 10:43:48 -0400 Subject: [PATCH 18/25] try removing extra slash from some sections --- doc/conf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index b120393911..eb8bee5f9a 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -118,11 +118,11 @@ 'examples_dirs': ['../examples/modules_gallery'], 'gallery_dirs': ['modules_gallery', ], # path where to save gallery generated examples 'subsection_order': ExplicitOrder([ - '../examples/modules_gallery/extractors/', + '../examples/modules_gallery/core', + '../examples/modules_gallery/extractors', '../examples/modules_gallery/qualitymetrics', '../examples/modules_gallery/comparison', '../examples/modules_gallery/widgets', - '../examples/modules_gallery/core/', ]), 'within_subsection_order': FileNameSortKey, 'ignore_pattern': '/generate_', From c4fec2f135f5166bc3dfe4ebbd1a3ccdff8ddd63 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 22 Sep 2023 11:00:17 -0400 Subject: [PATCH 19/25] try setting nested_sections false --- doc/conf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/conf.py b/doc/conf.py index eb8bee5f9a..13d1ef4e65 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -126,6 +126,7 @@ ]), 'within_subsection_order': FileNameSortKey, 'ignore_pattern': '/generate_', + 'nested_sections': False, } intersphinx_mapping = { From 3c3451ecf6452419ebf83dd6dd2d9454ba7e6419 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 27 Sep 2023 10:00:35 +0200 Subject: [PATCH 20/25] replace is_dumpable() by a more explicit naming : is_memory_serializable() --- src/spikeinterface/core/base.py | 53 +++++++++---------- src/spikeinterface/core/job_tools.py | 2 +- src/spikeinterface/core/numpyextractors.py | 6 +-- src/spikeinterface/core/old_api_utils.py | 6 +-- src/spikeinterface/core/tests/test_base.py | 10 ++-- .../core/tests/test_jsonification.py | 3 +- .../postprocessing/spike_amplitudes.py | 2 +- 7 files changed, 38 insertions(+), 44 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 63cf8e894f..3b8765a398 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -57,9 +57,7 @@ def __init__(self, main_ids: Sequence) -> None: # * number of units for sorting self._properties = {} - self._is_dumpable = True - # self._is_json_serializable = True - self._serializablility = {'json': True, 'pickle': True} + self._serializablility = {'memory': True, 'json': True, 'pickle': True} # extractor specific list of pip extra requirements self.extra_requirements = [] @@ -472,31 +470,8 @@ def clone(self) -> "BaseExtractor": clone = BaseExtractor.from_dict(d) return clone - def check_if_dumpable(self): - """Check if the object is dumpable, including nested objects. - - Returns - ------- - bool - True if the object is dumpable, False otherwise. - """ - kwargs = self._kwargs - for value in kwargs.values(): - # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors - if isinstance(value, BaseExtractor): - if not value.check_if_dumpable(): - return False - elif isinstance(value, list): - for v in value: - if isinstance(v, BaseExtractor) and not v.check_if_dumpable(): - return False - elif isinstance(value, dict): - for v in value.values(): - if isinstance(v, BaseExtractor) and not v.check_if_dumpable(): - return False - return self._is_dumpable - def check_serializablility(self, type="json"): + def check_serializablility(self, type): kwargs = self._kwargs for value in kwargs.values(): # here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors @@ -512,6 +487,26 @@ def check_serializablility(self, type="json"): if isinstance(v, BaseExtractor) and not v.check_serializablility(type=type): return False return self._serializablility[type] + + + def check_if_dumpable(self): + warnings.warn( + "check_if_dumpable() is replace by is_memory_serializable()", DeprecationWarning, stacklevel=2 + ) + return self.check_serializablility("memory") + + def is_memory_serializable(self): + """ + Check if the object is serializable to memory with pickle, including nested objects. + + Returns + ------- + bool + True if the object is json serializable, False otherwise. + """ + return self.check_serializablility("memory") + + def check_if_json_serializable(self): """ @@ -636,7 +631,7 @@ def dump_to_pickle( folder_metadata: str, Path, or None Folder with files containing additional information (e.g. probe in BaseRecording) and properties. """ - assert self.check_if_dumpable(), "The extractor is not dumpable" + assert self.check_if_pickle_serializable(), "The extractor is not dumpable" dump_dict = self.to_dict( include_annotations=True, @@ -931,7 +926,7 @@ def save_to_zarr( zarr_root = zarr.open(zarr_path_init, mode="w", storage_options=storage_options) - if self.check_if_dumpable(): + if self.check_if_json_serializable(): zarr_root.attrs["provenance"] = check_json(self.to_dict()) else: zarr_root.attrs["provenance"] = None diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index c0ee77d2fd..0535872ca6 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -167,7 +167,7 @@ def ensure_n_jobs(recording, n_jobs=1): print(f"Python {sys.version} does not support parallel processing") n_jobs = 1 - if not recording.check_if_dumpable(): + if not recording.is_memory_serializable(): if n_jobs != 1: raise RuntimeError( "Recording is not dumpable and can't be processed in parallel. " diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 5ef955a6eb..d09016c8f1 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -127,7 +127,7 @@ def __init__(self, spikes, sampling_frequency, unit_ids): """ """ BaseSorting.__init__(self, sampling_frequency, unit_ids) - self._is_dumpable = True + self._serializablility["memory"] = True self._serializablility["json"] = False # theorically this should be False but for simplicity make generators simples we still need this. self._serializablility["pickle"] = True @@ -360,8 +360,8 @@ def __init__(self, shm_name, shape, sampling_frequency, unit_ids, dtype=minimum_ assert shape[0] > 0, "SharedMemorySorting only supported with no empty sorting" BaseSorting.__init__(self, sampling_frequency, unit_ids) - self._is_dumpable = True + self._serializablility["memory"] = True self._serializablility["json"] = False self._serializablility["pickle"] = False @@ -521,7 +521,7 @@ def __init__(self, snippets_list, spikesframes_list, sampling_frequency, nbefore dtype=dtype, ) - self._is_dumpable = False + self._serializablility["memory"] = False self._serializablility["json"] = False self._serializablility["pickle"] = False diff --git a/src/spikeinterface/core/old_api_utils.py b/src/spikeinterface/core/old_api_utils.py index a31edb0dd7..879700cc15 100644 --- a/src/spikeinterface/core/old_api_utils.py +++ b/src/spikeinterface/core/old_api_utils.py @@ -181,8 +181,8 @@ def __init__(self, oldapi_recording_extractor): dtype=oldapi_recording_extractor.get_dtype(return_scaled=False), ) - # set _is_dumpable to False to use dumping mechanism of old extractor - self._is_dumpable = False + # set to False to use dumping mechanism of old extractor + self._serializablility["memory"] = False self._serializablility["json"] = False self._serializablility["pickle"] = False @@ -269,7 +269,7 @@ def __init__(self, oldapi_sorting_extractor): sorting_segment = OldToNewSortingSegment(oldapi_sorting_extractor) self.add_sorting_segment(sorting_segment) - self._is_dumpable = False + self._serializablility["memory"] = False self._serializablility["json"] = False self._serializablility["pickle"] = False diff --git a/src/spikeinterface/core/tests/test_base.py b/src/spikeinterface/core/tests/test_base.py index b716f6b1dd..28dbd166ec 100644 --- a/src/spikeinterface/core/tests/test_base.py +++ b/src/spikeinterface/core/tests/test_base.py @@ -31,19 +31,19 @@ def make_nested_extractors(extractor): ) -def test_check_if_dumpable(): +def test_is_memory_serializable(): test_extractor = generate_recording(seed=0, durations=[2]) # make a list of dumpable objects extractors_dumpable = make_nested_extractors(test_extractor) for extractor in extractors_dumpable: - assert extractor.check_if_dumpable() + assert extractor.is_memory_serializable() # make not dumpable - test_extractor._is_dumpable = False + test_extractor._serializablility["memory"] = False extractors_not_dumpable = make_nested_extractors(test_extractor) for extractor in extractors_not_dumpable: - assert not extractor.check_if_dumpable() + assert not extractor.is_memory_serializable() def test_check_if_serializable(): @@ -66,5 +66,5 @@ def test_check_if_serializable(): if __name__ == "__main__": - test_check_if_dumpable() + test_is_memory_serializable() test_check_if_serializable() diff --git a/src/spikeinterface/core/tests/test_jsonification.py b/src/spikeinterface/core/tests/test_jsonification.py index 8572cda23e..026e676966 100644 --- a/src/spikeinterface/core/tests/test_jsonification.py +++ b/src/spikeinterface/core/tests/test_jsonification.py @@ -144,8 +144,7 @@ def __init__(self, attribute, other_extractor=None, extractor_list=None, extract BaseExtractor.__init__(self, main_ids=['1', '2']) # this already the case by default - self._is_dumpable = True - # self._is_json_serializable = True + self._serializablility["memory"] = True self._serializablility["json"] = True self._serializablility["pickle"] = True diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 38cb714d59..aa99f7fc5e 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -75,7 +75,7 @@ def _run(self, **job_kwargs): n_jobs = ensure_n_jobs(recording, job_kwargs.get("n_jobs", None)) if n_jobs != 1: # TODO: avoid dumping sorting and use spike vector and peak pipeline instead - assert sorting.check_if_dumpable(), ( + assert sorting.is_memory_serializable(), ( "The sorting object is not dumpable and cannot be processed in parallel. You can use the " "`sorting.save()` function to make it dumpable" ) From 9d3dceaacc77158487c47972a2d949a71bb3c65a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Sep 2023 08:52:23 +0000 Subject: [PATCH 21/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/comparison/multicomparisons.py | 4 ++-- src/spikeinterface/core/base.py | 10 ++-------- src/spikeinterface/core/generate.py | 4 ++-- src/spikeinterface/core/numpyextractors.py | 2 +- .../core/tests/test_jsonification.py | 8 ++++---- .../core/tests/test_waveform_extractor.py | 15 ++++++++++----- src/spikeinterface/core/waveform_extractor.py | 1 - src/spikeinterface/sorters/basesorter.py | 6 ++---- src/spikeinterface/sorters/runsorter.py | 7 ++++--- 9 files changed, 27 insertions(+), 30 deletions(-) diff --git a/src/spikeinterface/comparison/multicomparisons.py b/src/spikeinterface/comparison/multicomparisons.py index 6fe474822b..f44e14c4c4 100644 --- a/src/spikeinterface/comparison/multicomparisons.py +++ b/src/spikeinterface/comparison/multicomparisons.py @@ -189,8 +189,8 @@ def save_to_folder(self, save_folder): stacklevel=2, ) for sorting in self.object_list: - assert ( - sorting.check_serializablility("json") + assert sorting.check_serializablility( + "json" ), "MultiSortingComparison.save_to_folder() need json serializable sortings" save_folder = Path(save_folder) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 3b8765a398..6e91cedcb5 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -57,7 +57,7 @@ def __init__(self, main_ids: Sequence) -> None: # * number of units for sorting self._properties = {} - self._serializablility = {'memory': True, 'json': True, 'pickle': True} + self._serializablility = {"memory": True, "json": True, "pickle": True} # extractor specific list of pip extra requirements self.extra_requirements = [] @@ -470,7 +470,6 @@ def clone(self) -> "BaseExtractor": clone = BaseExtractor.from_dict(d) return clone - def check_serializablility(self, type): kwargs = self._kwargs for value in kwargs.values(): @@ -488,11 +487,8 @@ def check_serializablility(self, type): return False return self._serializablility[type] - def check_if_dumpable(self): - warnings.warn( - "check_if_dumpable() is replace by is_memory_serializable()", DeprecationWarning, stacklevel=2 - ) + warnings.warn("check_if_dumpable() is replace by is_memory_serializable()", DeprecationWarning, stacklevel=2) return self.check_serializablility("memory") def is_memory_serializable(self): @@ -506,8 +502,6 @@ def is_memory_serializable(self): """ return self.check_serializablility("memory") - - def check_if_json_serializable(self): """ Check if the object is json serializable, including nested objects. diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 05d63f3c8d..eeb1e8af60 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1433,7 +1433,7 @@ def generate_ground_truth_recording( ) recording.annotate(is_filtered=True) recording.set_probe(probe, in_place=True) - recording.set_channel_gains(1.) - recording.set_channel_offsets(0.) + recording.set_channel_gains(1.0) + recording.set_channel_offsets(0.0) return recording, sorting diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index d09016c8f1..3d7ec6cd1a 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -523,7 +523,7 @@ def __init__(self, snippets_list, spikesframes_list, sampling_frequency, nbefore self._serializablility["memory"] = False self._serializablility["json"] = False - self._serializablility["pickle"] = False + self._serializablility["pickle"] = False for snippets, spikesframes in zip(snippets_list, spikesframes_list): snp_segment = NumpySnippetsSegment(snippets, spikesframes) diff --git a/src/spikeinterface/core/tests/test_jsonification.py b/src/spikeinterface/core/tests/test_jsonification.py index 026e676966..1c491bd7a6 100644 --- a/src/spikeinterface/core/tests/test_jsonification.py +++ b/src/spikeinterface/core/tests/test_jsonification.py @@ -142,11 +142,11 @@ def __init__(self, attribute, other_extractor=None, extractor_list=None, extract self.extractor_list = extractor_list self.extractor_dict = extractor_dict - BaseExtractor.__init__(self, main_ids=['1', '2']) + BaseExtractor.__init__(self, main_ids=["1", "2"]) # this already the case by default self._serializablility["memory"] = True self._serializablility["json"] = True - self._serializablility["pickle"] = True + self._serializablility["pickle"] = True self._kwargs = { "attribute": attribute, @@ -199,6 +199,6 @@ def test_encoding_numpy_scalars_within_nested_extractors_dict(nested_extractor_d json.dumps(nested_extractor_dict, cls=SIJsonEncoder) -if __name__ == '__main__': +if __name__ == "__main__": nested_extractor = nested_extractor() - test_encoding_numpy_scalars_within_nested_extractors(nested_extractor_) \ No newline at end of file + test_encoding_numpy_scalars_within_nested_extractors(nested_extractor_) diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index f53b9cf18d..12dac52d43 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -6,7 +6,13 @@ import zarr -from spikeinterface.core import generate_recording, generate_sorting, NumpySorting, ChannelSparsity, generate_ground_truth_recording +from spikeinterface.core import ( + generate_recording, + generate_sorting, + NumpySorting, + ChannelSparsity, + generate_ground_truth_recording, +) from spikeinterface import WaveformExtractor, BaseRecording, extract_waveforms, load_waveforms from spikeinterface.core.waveform_extractor import precompute_sparsity @@ -509,14 +515,15 @@ def test_compute_sparsity(): ) print(sparsity) + def test_non_json_object(): recording, sorting = generate_ground_truth_recording( durations=[30, 40], sampling_frequency=30000.0, num_channels=32, num_units=5, - ) - + ) + # recording is not save to keep it in memory sorting = sorting.save() @@ -524,7 +531,6 @@ def test_non_json_object(): if wf_folder.is_dir(): shutil.rmtree(wf_folder) - we = extract_waveforms( recording, sorting, @@ -551,4 +557,3 @@ def test_non_json_object(): # test_recordingless() # test_compute_sparsity() test_non_json_object() - diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index 3de1429feb..cd8a62f5bc 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -900,7 +900,6 @@ def save( elif self.recording.check_serializablility("pickle"): self.recording.dump(folder / "recording.pickle") - if self.sorting.check_serializablility("json"): self.sorting.dump(folder / "sorting.json", relative_to=relative_to) elif self.sorting.check_serializablility("pickle"): diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index bbcde31eed..8d87558191 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -189,11 +189,9 @@ def set_params_to_folder(cls, recording, output_folder, new_params, verbose): @classmethod def load_recording_from_folder(cls, output_folder, with_warnings=False): - json_file = output_folder / "spikeinterface_recording.json" pickle_file = output_folder / "spikeinterface_recording.pickle" - if json_file.exists(): with (json_file).open("r", encoding="utf8") as f: recording_dict = json.load(f) @@ -206,7 +204,7 @@ def load_recording_from_folder(cls, output_folder, with_warnings=False): recording = load_extractor(json_file, base_folder=output_folder) elif pickle_file.exits(): recording = load_extractor(pickle_file) - + return recording @classmethod @@ -320,7 +318,7 @@ def get_result_from_folder(cls, output_folder, register_recording=True, sorting_ if register_recording: # register recording to Sorting object - recording = cls.load_recording_from_folder( output_folder, with_warnings=False) + recording = cls.load_recording_from_folder(output_folder, with_warnings=False) if recording is not None: sorting.register_recording(recording) diff --git a/src/spikeinterface/sorters/runsorter.py b/src/spikeinterface/sorters/runsorter.py index e930ec7f79..bd5667b15f 100644 --- a/src/spikeinterface/sorters/runsorter.py +++ b/src/spikeinterface/sorters/runsorter.py @@ -629,7 +629,7 @@ def read_sorter_folder(output_folder, register_recording=True, sorting_info=True Load a sorting object from a spike sorting output folder. The 'output_folder' must contain a valid 'spikeinterface_log.json' file - + Parameters ---------- output_folder: Pth or str @@ -657,8 +657,9 @@ def read_sorter_folder(output_folder, register_recording=True, sorting_info=True sorter_name = log["sorter_name"] SorterClass = sorter_dict[sorter_name] - sorting = SorterClass.get_result_from_folder(output_folder, register_recording=register_recording, - sorting_info=sorting_info) + sorting = SorterClass.get_result_from_folder( + output_folder, register_recording=register_recording, sorting_info=sorting_info + ) return sorting From 7329927cfb3035d764648a2175d617aa8999c67b Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 27 Sep 2023 10:54:57 +0200 Subject: [PATCH 22/25] rename to check_if_memory_serializable --- src/spikeinterface/core/base.py | 6 +----- src/spikeinterface/core/job_tools.py | 2 +- src/spikeinterface/core/tests/test_base.py | 8 ++++---- src/spikeinterface/postprocessing/spike_amplitudes.py | 2 +- 4 files changed, 7 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 6e91cedcb5..b1b5065339 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -487,11 +487,7 @@ def check_serializablility(self, type): return False return self._serializablility[type] - def check_if_dumpable(self): - warnings.warn("check_if_dumpable() is replace by is_memory_serializable()", DeprecationWarning, stacklevel=2) - return self.check_serializablility("memory") - - def is_memory_serializable(self): + def check_if_memory_serializable(self): """ Check if the object is serializable to memory with pickle, including nested objects. diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 0535872ca6..9369ad0b61 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -167,7 +167,7 @@ def ensure_n_jobs(recording, n_jobs=1): print(f"Python {sys.version} does not support parallel processing") n_jobs = 1 - if not recording.is_memory_serializable(): + if not recording.check_if_memory_serializable(): if n_jobs != 1: raise RuntimeError( "Recording is not dumpable and can't be processed in parallel. " diff --git a/src/spikeinterface/core/tests/test_base.py b/src/spikeinterface/core/tests/test_base.py index 28dbd166ec..8d0907c700 100644 --- a/src/spikeinterface/core/tests/test_base.py +++ b/src/spikeinterface/core/tests/test_base.py @@ -31,19 +31,19 @@ def make_nested_extractors(extractor): ) -def test_is_memory_serializable(): +def test_check_if_memory_serializable(): test_extractor = generate_recording(seed=0, durations=[2]) # make a list of dumpable objects extractors_dumpable = make_nested_extractors(test_extractor) for extractor in extractors_dumpable: - assert extractor.is_memory_serializable() + assert extractor.check_if_memory_serializable() # make not dumpable test_extractor._serializablility["memory"] = False extractors_not_dumpable = make_nested_extractors(test_extractor) for extractor in extractors_not_dumpable: - assert not extractor.is_memory_serializable() + assert not extractor.check_if_memory_serializable() def test_check_if_serializable(): @@ -66,5 +66,5 @@ def test_check_if_serializable(): if __name__ == "__main__": - test_is_memory_serializable() + test_check_if_memory_serializable() test_check_if_serializable() diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index aa99f7fc5e..9eb5a815d4 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -75,7 +75,7 @@ def _run(self, **job_kwargs): n_jobs = ensure_n_jobs(recording, job_kwargs.get("n_jobs", None)) if n_jobs != 1: # TODO: avoid dumping sorting and use spike vector and peak pipeline instead - assert sorting.is_memory_serializable(), ( + assert sorting.check_if_memory_serializable(), ( "The sorting object is not dumpable and cannot be processed in parallel. You can use the " "`sorting.save()` function to make it dumpable" ) From b9c6a38e99430fc7b734e0751871e6d08eb5aea1 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 27 Sep 2023 10:56:28 +0200 Subject: [PATCH 23/25] oups --- src/spikeinterface/core/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index b1b5065339..e3b88588e2 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -494,7 +494,7 @@ def check_if_memory_serializable(self): Returns ------- bool - True if the object is json serializable, False otherwise. + True if the object is memory serializable, False otherwise. """ return self.check_serializablility("memory") From 331379a3f441e2691eb15985b60254fcc9e3f887 Mon Sep 17 00:00:00 2001 From: Samuel Garcia Date: Wed, 27 Sep 2023 11:13:29 +0200 Subject: [PATCH 24/25] Remove "dumpable" naming also in doc and warnings. --- doc/modules/core.rst | 3 +-- src/spikeinterface/comparison/hybrid.py | 4 ++-- src/spikeinterface/core/base.py | 8 ++++---- src/spikeinterface/core/job_tools.py | 4 ++-- src/spikeinterface/core/tests/test_base.py | 17 ++++++++--------- .../core/tests/test_core_tools.py | 1 - src/spikeinterface/core/tests/test_job_tools.py | 6 +++--- .../core/tests/test_waveform_extractor.py | 2 +- src/spikeinterface/core/waveform_extractor.py | 15 ++++++++------- .../postprocessing/spike_amplitudes.py | 6 ------ .../sorters/tests/test_launcher.py | 2 +- 11 files changed, 30 insertions(+), 38 deletions(-) diff --git a/doc/modules/core.rst b/doc/modules/core.rst index fdc4d71fe7..976a82a4a3 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -547,8 +547,7 @@ workflow. In order to do this, one can use the :code:`Numpy*` classes, :py:class:`~spikeinterface.core.NumpyRecording`, :py:class:`~spikeinterface.core.NumpySorting`, :py:class:`~spikeinterface.core.NumpyEvent`, and :py:class:`~spikeinterface.core.NumpySnippets`. These object behave exactly like normal SpikeInterface objects, -but they are not bound to a file. This makes these objects *not dumpable*, so parallel processing is not supported. -In order to make them *dumpable*, one can simply :code:`save()` them (see :ref:`save_load`). +but they are not bound to a file. Also note the class :py:class:`~spikeinterface.core.SharedMemorySorting` which is very similar to Similar to :py:class:`~spikeinterface.core.NumpySorting` but with an unerlying SharedMemory which is usefull for diff --git a/src/spikeinterface/comparison/hybrid.py b/src/spikeinterface/comparison/hybrid.py index 3b8e9e0a72..e0c98cd772 100644 --- a/src/spikeinterface/comparison/hybrid.py +++ b/src/spikeinterface/comparison/hybrid.py @@ -39,7 +39,7 @@ class HybridUnitsRecording(InjectTemplatesRecording): The refractory period of the injected spike train (in ms). injected_sorting_folder: str | Path | None If given, the injected sorting is saved to this folder. - It must be specified if injected_sorting is None or not dumpable. + It must be specified if injected_sorting is None or not serialisable to file. Returns ------- @@ -138,7 +138,7 @@ class HybridSpikesRecording(InjectTemplatesRecording): this refractory period. injected_sorting_folder: str | Path | None If given, the injected sorting is saved to this folder. - It must be specified if injected_sorting is None or not dumpable. + It must be specified if injected_sorting is None or not serializable to file. Returns ------- diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index e3b88588e2..73f8619348 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -621,7 +621,7 @@ def dump_to_pickle( folder_metadata: str, Path, or None Folder with files containing additional information (e.g. probe in BaseRecording) and properties. """ - assert self.check_if_pickle_serializable(), "The extractor is not dumpable" + assert self.check_if_pickle_serializable(), "The extractor is not serializable to file with pickle" dump_dict = self.to_dict( include_annotations=True, @@ -658,8 +658,8 @@ def load(file_path: Union[str, Path], base_folder: Optional[Union[Path, str, boo d = pickle.load(f) else: raise ValueError(f"Impossible to load {file_path}") - if "warning" in d and "not dumpable" in d["warning"]: - print("The extractor was not dumpable") + if "warning" in d: + print("The extractor was not serializable to file") return None extractor = BaseExtractor.from_dict(d, base_folder=base_folder) return extractor @@ -822,7 +822,7 @@ def save_to_folder(self, name=None, folder=None, verbose=True, **save_kwargs): if self.check_serializablility("json"): self.dump(provenance_file) else: - provenance_file.write_text(json.dumps({"warning": "the provenace is not dumpable!!!"}), encoding="utf8") + provenance_file.write_text(json.dumps({"warning": "the provenace is not json serializable!!!"}), encoding="utf8") self.save_metadata_to_folder(folder) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 9369ad0b61..84ee502c14 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -170,8 +170,8 @@ def ensure_n_jobs(recording, n_jobs=1): if not recording.check_if_memory_serializable(): if n_jobs != 1: raise RuntimeError( - "Recording is not dumpable and can't be processed in parallel. " - "You can use the `recording.save()` function to make it dumpable or set 'n_jobs' to 1." + "Recording is not serializable to memory and can't be processed in parallel. " + "You can use the `rec = recording.save(folder=...)` function or set 'n_jobs' to 1." ) return n_jobs diff --git a/src/spikeinterface/core/tests/test_base.py b/src/spikeinterface/core/tests/test_base.py index 8d0907c700..a944be3da0 100644 --- a/src/spikeinterface/core/tests/test_base.py +++ b/src/spikeinterface/core/tests/test_base.py @@ -34,30 +34,29 @@ def make_nested_extractors(extractor): def test_check_if_memory_serializable(): test_extractor = generate_recording(seed=0, durations=[2]) - # make a list of dumpable objects - extractors_dumpable = make_nested_extractors(test_extractor) - for extractor in extractors_dumpable: + # make a list of memory serializable objects + extractors_mem_serializable = make_nested_extractors(test_extractor) + for extractor in extractors_mem_serializable: assert extractor.check_if_memory_serializable() - # make not dumpable + # make not not memory serilizable test_extractor._serializablility["memory"] = False - extractors_not_dumpable = make_nested_extractors(test_extractor) - for extractor in extractors_not_dumpable: + extractors_not_mem_serializable = make_nested_extractors(test_extractor) + for extractor in extractors_not_mem_serializable: assert not extractor.check_if_memory_serializable() def test_check_if_serializable(): test_extractor = generate_recording(seed=0, durations=[2]) - # make a list of dumpable objects + # make a list of json serializable objects test_extractor._serializablility["json"] = True extractors_json_serializable = make_nested_extractors(test_extractor) for extractor in extractors_json_serializable: print(extractor) assert extractor.check_serializablility("json") - # make not dumpable - # test_extractor._is_json_serializable = False + # make of not json serializable objects test_extractor._serializablility["json"] = False extractors_not_json_serializable = make_nested_extractors(test_extractor) for extractor in extractors_not_json_serializable: diff --git a/src/spikeinterface/core/tests/test_core_tools.py b/src/spikeinterface/core/tests/test_core_tools.py index a3cd0caa92..223b2a8a3a 100644 --- a/src/spikeinterface/core/tests/test_core_tools.py +++ b/src/spikeinterface/core/tests/test_core_tools.py @@ -142,7 +142,6 @@ def test_write_memory_recording(): recording = NoiseGeneratorRecording( num_channels=2, durations=[10.325, 3.5], sampling_frequency=30_000, strategy="tile_pregenerated" ) - # make dumpable recording = recording.save() # write with loop diff --git a/src/spikeinterface/core/tests/test_job_tools.py b/src/spikeinterface/core/tests/test_job_tools.py index 7d7af6025b..a904e4dd32 100644 --- a/src/spikeinterface/core/tests/test_job_tools.py +++ b/src/spikeinterface/core/tests/test_job_tools.py @@ -36,7 +36,7 @@ def test_ensure_n_jobs(): n_jobs = ensure_n_jobs(recording, n_jobs=1) assert n_jobs == 1 - # dumpable + # check serializable n_jobs = ensure_n_jobs(recording.save(), n_jobs=-1) assert n_jobs > 1 @@ -45,7 +45,7 @@ def test_ensure_chunk_size(): recording = generate_recording(num_channels=2) dtype = recording.get_dtype() assert dtype == "float32" - # make dumpable + # make serializable recording = recording.save() chunk_size = ensure_chunk_size(recording, total_memory="512M", chunk_size=None, chunk_memory=None, n_jobs=2) @@ -90,7 +90,7 @@ def init_func(arg1, arg2, arg3): def test_ChunkRecordingExecutor(): recording = generate_recording(num_channels=2) - # make dumpable + # make serializable recording = recording.save() init_args = "a", 120, "yep" diff --git a/src/spikeinterface/core/tests/test_waveform_extractor.py b/src/spikeinterface/core/tests/test_waveform_extractor.py index 12dac52d43..2bbf5e9b0f 100644 --- a/src/spikeinterface/core/tests/test_waveform_extractor.py +++ b/src/spikeinterface/core/tests/test_waveform_extractor.py @@ -315,7 +315,7 @@ def test_recordingless(): recording = recording.save(folder=cache_folder / "recording1") sorting = sorting.save(folder=cache_folder / "sorting1") - # recording and sorting are not dumpable + # recording and sorting are not serializable wf_folder = cache_folder / "wf_recordingless" # save with relative paths diff --git a/src/spikeinterface/core/waveform_extractor.py b/src/spikeinterface/core/waveform_extractor.py index cd8a62f5bc..2710ff1338 100644 --- a/src/spikeinterface/core/waveform_extractor.py +++ b/src/spikeinterface/core/waveform_extractor.py @@ -290,11 +290,12 @@ def create( sorting.dump(folder / "sorting.json", relative_to=relative_to) elif sorting.check_serializablility("pickle"): # In this case we loose the relative_to!! + # TODO later the dump to pickle should dump the dictionary and so relative could be put back sorting.dump(folder / "sorting.pickle") else: warn( - "Sorting object is not dumpable, which might result in downstream errors for " - "parallel processing. To make the sorting dumpable, use the `sorting.save()` function." + "Sorting object is not serializable to file, which might result in downstream errors for " + "parallel processing. To make the sorting serializable, use the `sorting = sorting.save()` function." ) # dump some attributes of the recording for the mode with_recording=False at next load @@ -903,11 +904,11 @@ def save( if self.sorting.check_serializablility("json"): self.sorting.dump(folder / "sorting.json", relative_to=relative_to) elif self.sorting.check_serializablility("pickle"): - self.sorting.dump(folder / "sorting.pickle", relative_to=relative_to) + self.sorting.dump(folder / "sorting.pickle") else: warn( - "Sorting object is not dumpable, which might result in downstream errors for " - "parallel processing. To make the sorting dumpable, use the `sorting.save()` function." + "Sorting object is not serializable to file, which might result in downstream errors for " + "parallel processing. To make the sorting serializable, use the `sorting = sorting.save()` function." ) # dump some attributes of the recording for the mode with_recording=False at next load @@ -960,8 +961,8 @@ def save( zarr_root.attrs["sorting"] = check_json(sort_dict) else: warn( - "Sorting object is not dumpable, which might result in downstream errors for " - "parallel processing. To make the sorting dumpable, use the `sorting.save()` function." + "Sorting object is not json serializable, which might result in downstream errors for " + "parallel processing. To make the sorting serializable, use the `sorting = sorting.save()` function." ) recording_info = zarr_root.create_group("recording_info") recording_info.attrs["recording_attributes"] = check_json(rec_attributes) diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 9eb5a815d4..ccd2121174 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -73,12 +73,6 @@ def _run(self, **job_kwargs): func = _spike_amplitudes_chunk init_func = _init_worker_spike_amplitudes n_jobs = ensure_n_jobs(recording, job_kwargs.get("n_jobs", None)) - if n_jobs != 1: - # TODO: avoid dumping sorting and use spike vector and peak pipeline instead - assert sorting.check_if_memory_serializable(), ( - "The sorting object is not dumpable and cannot be processed in parallel. You can use the " - "`sorting.save()` function to make it dumpable" - ) init_args = (recording, sorting.to_multiprocessing(n_jobs), extremum_channels_index, peak_shifts, return_scaled) processor = ChunkRecordingExecutor( recording, func, init_func, init_args, handle_returns=True, job_name="extract amplitudes", **job_kwargs diff --git a/src/spikeinterface/sorters/tests/test_launcher.py b/src/spikeinterface/sorters/tests/test_launcher.py index 14c938f8ba..a5e29c8fd9 100644 --- a/src/spikeinterface/sorters/tests/test_launcher.py +++ b/src/spikeinterface/sorters/tests/test_launcher.py @@ -178,7 +178,7 @@ def test_run_sorters_with_list(): if working_folder.is_dir(): shutil.rmtree(working_folder) - # make dumpable + # make serializable rec0 = load_extractor(cache_folder / "toy_rec_0") rec1 = load_extractor(cache_folder / "toy_rec_1") From 0ea10e3baf97fbcedc8c25c2745754cacabb7b5c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Sep 2023 09:13:52 +0000 Subject: [PATCH 25/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 73f8619348..e8b3232e13 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -822,7 +822,9 @@ def save_to_folder(self, name=None, folder=None, verbose=True, **save_kwargs): if self.check_serializablility("json"): self.dump(provenance_file) else: - provenance_file.write_text(json.dumps({"warning": "the provenace is not json serializable!!!"}), encoding="utf8") + provenance_file.write_text( + json.dumps({"warning": "the provenace is not json serializable!!!"}), encoding="utf8" + ) self.save_metadata_to_folder(folder)