Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve serialization concept : memory/json/pickle #2027

Merged
merged 11 commits into from
Sep 27, 2023
3 changes: 1 addition & 2 deletions doc/modules/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -547,8 +547,7 @@ workflow.
In order to do this, one can use the :code:`Numpy*` classes, :py:class:`~spikeinterface.core.NumpyRecording`,
:py:class:`~spikeinterface.core.NumpySorting`, :py:class:`~spikeinterface.core.NumpyEvent`, and
:py:class:`~spikeinterface.core.NumpySnippets`. These object behave exactly like normal SpikeInterface objects,
but they are not bound to a file. This makes these objects *not dumpable*, so parallel processing is not supported.
In order to make them *dumpable*, one can simply :code:`save()` them (see :ref:`save_load`).
but they are not bound to a file.

Also note the class :py:class:`~spikeinterface.core.SharedMemorySorting` which is very similar to
Similar to :py:class:`~spikeinterface.core.NumpySorting` but with an unerlying SharedMemory which is usefull for
Expand Down
10 changes: 6 additions & 4 deletions src/spikeinterface/comparison/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class HybridUnitsRecording(InjectTemplatesRecording):
The refractory period of the injected spike train (in ms).
injected_sorting_folder: str | Path | None
If given, the injected sorting is saved to this folder.
It must be specified if injected_sorting is None or not dumpable.
It must be specified if injected_sorting is None or not serialisable to file.

Returns
-------
Expand Down Expand Up @@ -84,7 +84,8 @@ def __init__(
)
# save injected sorting if necessary
self.injected_sorting = injected_sorting
if not self.injected_sorting.check_if_json_serializable():
if not self.injected_sorting.check_serializablility("json"):
# TODO later : also use pickle
assert injected_sorting_folder is not None, "Provide injected_sorting_folder to injected sorting object"
self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder)

Expand Down Expand Up @@ -137,7 +138,7 @@ class HybridSpikesRecording(InjectTemplatesRecording):
this refractory period.
injected_sorting_folder: str | Path | None
If given, the injected sorting is saved to this folder.
It must be specified if injected_sorting is None or not dumpable.
It must be specified if injected_sorting is None or not serializable to file.

Returns
-------
Expand Down Expand Up @@ -180,7 +181,8 @@ def __init__(
self.injected_sorting = injected_sorting

# save injected sorting if necessary
if not self.injected_sorting.check_if_json_serializable():
if not self.injected_sorting.check_serializablility("json"):
# TODO later : also use pickle
assert injected_sorting_folder is not None, "Provide injected_sorting_folder to injected sorting object"
self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder)

Expand Down
7 changes: 4 additions & 3 deletions src/spikeinterface/comparison/multicomparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ def save_to_folder(self, save_folder):
stacklevel=2,
)
for sorting in self.object_list:
assert (
sorting.check_if_json_serializable()
assert sorting.check_serializablility(
"json"
), "MultiSortingComparison.save_to_folder() need json serializable sortings"

save_folder = Path(save_folder)
Expand Down Expand Up @@ -259,7 +259,8 @@ def __init__(

BaseSorting.__init__(self, sampling_frequency=sampling_frequency, unit_ids=unit_ids)

self._is_json_serializable = False
self._serializablility["json"] = False
self._serializablility["pickle"] = True

if len(unit_ids) > 0:
for k in ("agreement_number", "avg_agreement", "unit_ids"):
Expand Down
73 changes: 40 additions & 33 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def __init__(self, main_ids: Sequence) -> None:
# * number of units for sorting
self._properties = {}

self._is_dumpable = True
self._is_json_serializable = True
self._serializablility = {"memory": True, "json": True, "pickle": True}

# extractor specific list of pip extra requirements
self.extra_requirements = []
Expand Down Expand Up @@ -471,24 +470,33 @@ def clone(self) -> "BaseExtractor":
clone = BaseExtractor.from_dict(d)
return clone

def check_if_dumpable(self):
"""Check if the object is dumpable, including nested objects.
def check_serializablility(self, type):
kwargs = self._kwargs
for value in kwargs.values():
# here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors
if isinstance(value, BaseExtractor):
if not value.check_serializablility(type=type):
return False
elif isinstance(value, list):
for v in value:
if isinstance(v, BaseExtractor) and not v.check_serializablility(type=type):
return False
elif isinstance(value, dict):
for v in value.values():
if isinstance(v, BaseExtractor) and not v.check_serializablility(type=type):
return False
return self._serializablility[type]

def check_if_memory_serializable(self):
"""
Check if the object is serializable to memory with pickle, including nested objects.

Returns
-------
bool
True if the object is dumpable, False otherwise.
True if the object is memory serializable, False otherwise.
"""
kwargs = self._kwargs
for value in kwargs.values():
# here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors
if isinstance(value, BaseExtractor):
return value.check_if_dumpable()
elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor):
return all([v.check_if_dumpable() for v in value])
elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor):
return all([v.check_if_dumpable() for k, v in value.items()])
return self._is_dumpable
return self.check_serializablility("memory")

def check_if_json_serializable(self):
"""
Expand All @@ -499,16 +507,13 @@ def check_if_json_serializable(self):
bool
True if the object is json serializable, False otherwise.
"""
kwargs = self._kwargs
for value in kwargs.values():
# here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors
if isinstance(value, BaseExtractor):
return value.check_if_json_serializable()
elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor):
return all([v.check_if_json_serializable() for v in value])
elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor):
return all([v.check_if_json_serializable() for k, v in value.items()])
return self._is_json_serializable
# we keep this for backward compatilibity or not ????
# is this needed ??? I think no.
return self.check_serializablility("json")

def check_if_pickle_serializable(self):
# is this needed ??? I think no.
return self.check_serializablility("pickle")

@staticmethod
def _get_file_path(file_path: Union[str, Path], extensions: Sequence) -> Path:
Expand Down Expand Up @@ -557,7 +562,7 @@ def dump(self, file_path: Union[str, Path], relative_to=None, folder_metadata=No
if str(file_path).endswith(".json"):
self.dump_to_json(file_path, relative_to=relative_to, folder_metadata=folder_metadata)
elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"):
self.dump_to_pickle(file_path, relative_to=relative_to, folder_metadata=folder_metadata)
self.dump_to_pickle(file_path, folder_metadata=folder_metadata)
else:
raise ValueError("Dump: file must .json or .pkl")

Expand All @@ -576,7 +581,7 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non
folder_metadata: str, Path, or None
Folder with files containing additional information (e.g. probe in BaseRecording) and properties.
"""
assert self.check_if_json_serializable(), "The extractor is not json serializable"
assert self.check_serializablility("json"), "The extractor is not json serializable"

# Writing paths as relative_to requires recursively expanding the dict
if relative_to:
Expand Down Expand Up @@ -616,7 +621,7 @@ def dump_to_pickle(
folder_metadata: str, Path, or None
Folder with files containing additional information (e.g. probe in BaseRecording) and properties.
"""
assert self.check_if_dumpable(), "The extractor is not dumpable"
assert self.check_if_pickle_serializable(), "The extractor is not serializable to file with pickle"

dump_dict = self.to_dict(
include_annotations=True,
Expand Down Expand Up @@ -653,8 +658,8 @@ def load(file_path: Union[str, Path], base_folder: Optional[Union[Path, str, boo
d = pickle.load(f)
else:
raise ValueError(f"Impossible to load {file_path}")
if "warning" in d and "not dumpable" in d["warning"]:
print("The extractor was not dumpable")
if "warning" in d:
print("The extractor was not serializable to file")
return None
extractor = BaseExtractor.from_dict(d, base_folder=base_folder)
return extractor
Expand Down Expand Up @@ -814,10 +819,12 @@ def save_to_folder(self, name=None, folder=None, verbose=True, **save_kwargs):

# dump provenance
provenance_file = folder / f"provenance.json"
if self.check_if_json_serializable():
if self.check_serializablility("json"):
self.dump(provenance_file)
else:
provenance_file.write_text(json.dumps({"warning": "the provenace is not dumpable!!!"}), encoding="utf8")
provenance_file.write_text(
json.dumps({"warning": "the provenace is not json serializable!!!"}), encoding="utf8"
)

self.save_metadata_to_folder(folder)

Expand Down Expand Up @@ -911,7 +918,7 @@ def save_to_zarr(

zarr_root = zarr.open(zarr_path_init, mode="w", storage_options=storage_options)

if self.check_if_dumpable():
if self.check_if_json_serializable():
zarr_root.attrs["provenance"] = check_json(self.to_dict())
else:
zarr_root.attrs["provenance"] = None
Expand Down
4 changes: 4 additions & 0 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,8 @@ def __init__(
dtype = parent_recording.dtype if parent_recording is not None else templates.dtype
BaseRecording.__init__(self, sorting.get_sampling_frequency(), channel_ids, dtype)

# Important : self._serializablility is not change here because it will depend on the sorting parents itself.

n_units = len(sorting.unit_ids)
assert len(templates) == n_units
self.spike_vector = sorting.to_spike_vector()
Expand Down Expand Up @@ -1431,5 +1433,7 @@ def generate_ground_truth_recording(
)
recording.annotate(is_filtered=True)
recording.set_probe(probe, in_place=True)
recording.set_channel_gains(1.0)
recording.set_channel_offsets(0.0)

return recording, sorting
6 changes: 3 additions & 3 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,11 @@ def ensure_n_jobs(recording, n_jobs=1):
print(f"Python {sys.version} does not support parallel processing")
n_jobs = 1

if not recording.check_if_dumpable():
if not recording.check_if_memory_serializable():
if n_jobs != 1:
raise RuntimeError(
"Recording is not dumpable and can't be processed in parallel. "
"You can use the `recording.save()` function to make it dumpable or set 'n_jobs' to 1."
"Recording is not serializable to memory and can't be processed in parallel. "
"You can use the `rec = recording.save(folder=...)` function or set 'n_jobs' to 1."
)

return n_jobs
Expand Down
20 changes: 13 additions & 7 deletions src/spikeinterface/core/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N
assert len(t_starts) == len(traces_list), "t_starts must be a list of same size than traces_list"
t_starts = [float(t_start) for t_start in t_starts]

self._is_json_serializable = False
self._serializablility["json"] = False
self._serializablility["pickle"] = False

for i, traces in enumerate(traces_list):
if t_starts is None:
Expand Down Expand Up @@ -126,8 +127,10 @@ def __init__(self, spikes, sampling_frequency, unit_ids):
""" """
BaseSorting.__init__(self, sampling_frequency, unit_ids)

self._is_dumpable = True
self._is_json_serializable = False
self._serializablility["memory"] = True
self._serializablility["json"] = False
# theorically this should be False but for simplicity make generators simples we still need this.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, you guys were talking about this today. I am not sure I got it.

You don't wnat stuff to be pickable because they are heavy. You put all the spike_vector into disk. Is that correct?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to avoid NumpySorting to also be pickle into file. But finally it is too hard!!
So lets keep self._serializablility["pickle"] = True for NumpySorting so the generated recording can be save into pickle which the most urgent need.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But right now the generator recording can be pickled, can't it?

This works for me:

import spikeinterface.core as si

rec, sorting = si.generate_ground_truth_recording(durations=[3600., 1200.], sampling_frequency=30000., 
                                                  num_channels=32,
                                                  num_units=64,
                                                  )

from spikeinterface.core import load_extractor

rec.dump_to_pickle('generated_recording.pickle')
load_extractor('generated_recording.pickle')


sorting.dump_to_pickle('generated_sorting.pickle')
load_extractor('generated_sorting.pickle')

self._serializablility["pickle"] = True

if spikes.size == 0:
nseg = 1
Expand Down Expand Up @@ -357,8 +360,10 @@ def __init__(self, shm_name, shape, sampling_frequency, unit_ids, dtype=minimum_
assert shape[0] > 0, "SharedMemorySorting only supported with no empty sorting"

BaseSorting.__init__(self, sampling_frequency, unit_ids)
self._is_dumpable = True
self._is_json_serializable = False

self._serializablility["memory"] = True
self._serializablility["json"] = False
self._serializablility["pickle"] = False

self.shm = SharedMemory(shm_name, create=False)
self.shm_spikes = np.ndarray(shape=shape, dtype=dtype, buffer=self.shm.buf)
Expand Down Expand Up @@ -516,8 +521,9 @@ def __init__(self, snippets_list, spikesframes_list, sampling_frequency, nbefore
dtype=dtype,
)

self._is_dumpable = False
self._is_json_serializable = False
self._serializablility["memory"] = False
self._serializablility["json"] = False
self._serializablility["pickle"] = False

for snippets, spikesframes in zip(snippets_list, spikesframes_list):
snp_segment = NumpySnippetsSegment(snippets, spikesframes)
Expand Down
12 changes: 7 additions & 5 deletions src/spikeinterface/core/old_api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,10 @@ def __init__(self, oldapi_recording_extractor):
dtype=oldapi_recording_extractor.get_dtype(return_scaled=False),
)

# set _is_dumpable to False to use dumping mechanism of old extractor
self._is_dumpable = False
self._is_json_serializable = False
# set to False to use dumping mechanism of old extractor
self._serializablility["memory"] = False
self._serializablility["json"] = False
self._serializablility["pickle"] = False

self.annotate(is_filtered=oldapi_recording_extractor.is_filtered)

Expand Down Expand Up @@ -268,8 +269,9 @@ def __init__(self, oldapi_sorting_extractor):
sorting_segment = OldToNewSortingSegment(oldapi_sorting_extractor)
self.add_sorting_segment(sorting_segment)

self._is_dumpable = False
self._is_json_serializable = False
self._serializablility["memory"] = False
self._serializablility["json"] = False
self._serializablility["pickle"] = False

# add old properties
copy_properties(oldapi_extractor=oldapi_sorting_extractor, new_extractor=self)
Expand Down
38 changes: 19 additions & 19 deletions src/spikeinterface/core/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,39 +31,39 @@ def make_nested_extractors(extractor):
)


def test_check_if_dumpable():
def test_check_if_memory_serializable():
test_extractor = generate_recording(seed=0, durations=[2])

# make a list of dumpable objects
extractors_dumpable = make_nested_extractors(test_extractor)
for extractor in extractors_dumpable:
assert extractor.check_if_dumpable()
# make a list of memory serializable objects
extractors_mem_serializable = make_nested_extractors(test_extractor)
for extractor in extractors_mem_serializable:
assert extractor.check_if_memory_serializable()

# make not dumpable
test_extractor._is_dumpable = False
extractors_not_dumpable = make_nested_extractors(test_extractor)
for extractor in extractors_not_dumpable:
assert not extractor.check_if_dumpable()
# make not not memory serilizable
test_extractor._serializablility["memory"] = False
extractors_not_mem_serializable = make_nested_extractors(test_extractor)
for extractor in extractors_not_mem_serializable:
assert not extractor.check_if_memory_serializable()


def test_check_if_json_serializable():
def test_check_if_serializable():
test_extractor = generate_recording(seed=0, durations=[2])

# make a list of dumpable objects
test_extractor._is_json_serializable = True
# make a list of json serializable objects
test_extractor._serializablility["json"] = True
extractors_json_serializable = make_nested_extractors(test_extractor)
for extractor in extractors_json_serializable:
print(extractor)
assert extractor.check_if_json_serializable()
assert extractor.check_serializablility("json")

# make not dumpable
test_extractor._is_json_serializable = False
# make of not json serializable objects
test_extractor._serializablility["json"] = False
extractors_not_json_serializable = make_nested_extractors(test_extractor)
for extractor in extractors_not_json_serializable:
print(extractor)
assert not extractor.check_if_json_serializable()
assert not extractor.check_serializablility("json")


if __name__ == "__main__":
test_check_if_dumpable()
test_check_if_json_serializable()
test_check_if_memory_serializable()
test_check_if_serializable()
1 change: 0 additions & 1 deletion src/spikeinterface/core/tests/test_core_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ def test_write_memory_recording():
recording = NoiseGeneratorRecording(
num_channels=2, durations=[10.325, 3.5], sampling_frequency=30_000, strategy="tile_pregenerated"
)
# make dumpable
recording = recording.save()

# write with loop
Expand Down
Loading