Skip to content

Commit

Permalink
Merge pull request #2027 from samuelgarcia/fix_non_json_rec
Browse files Browse the repository at this point in the history
Improve serialization concept : memory/json/pickle
  • Loading branch information
samuelgarcia authored Sep 27, 2023
2 parents ceaebfa + 0ea10e3 commit c5bafd1
Show file tree
Hide file tree
Showing 25 changed files with 252 additions and 152 deletions.
3 changes: 1 addition & 2 deletions doc/modules/core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -547,8 +547,7 @@ workflow.
In order to do this, one can use the :code:`Numpy*` classes, :py:class:`~spikeinterface.core.NumpyRecording`,
:py:class:`~spikeinterface.core.NumpySorting`, :py:class:`~spikeinterface.core.NumpyEvent`, and
:py:class:`~spikeinterface.core.NumpySnippets`. These object behave exactly like normal SpikeInterface objects,
but they are not bound to a file. This makes these objects *not dumpable*, so parallel processing is not supported.
In order to make them *dumpable*, one can simply :code:`save()` them (see :ref:`save_load`).
but they are not bound to a file.

Also note the class :py:class:`~spikeinterface.core.SharedMemorySorting` which is very similar to
Similar to :py:class:`~spikeinterface.core.NumpySorting` but with an unerlying SharedMemory which is usefull for
Expand Down
10 changes: 6 additions & 4 deletions src/spikeinterface/comparison/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class HybridUnitsRecording(InjectTemplatesRecording):
The refractory period of the injected spike train (in ms).
injected_sorting_folder: str | Path | None
If given, the injected sorting is saved to this folder.
It must be specified if injected_sorting is None or not dumpable.
It must be specified if injected_sorting is None or not serialisable to file.
Returns
-------
Expand Down Expand Up @@ -84,7 +84,8 @@ def __init__(
)
# save injected sorting if necessary
self.injected_sorting = injected_sorting
if not self.injected_sorting.check_if_json_serializable():
if not self.injected_sorting.check_serializablility("json"):
# TODO later : also use pickle
assert injected_sorting_folder is not None, "Provide injected_sorting_folder to injected sorting object"
self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder)

Expand Down Expand Up @@ -137,7 +138,7 @@ class HybridSpikesRecording(InjectTemplatesRecording):
this refractory period.
injected_sorting_folder: str | Path | None
If given, the injected sorting is saved to this folder.
It must be specified if injected_sorting is None or not dumpable.
It must be specified if injected_sorting is None or not serializable to file.
Returns
-------
Expand Down Expand Up @@ -180,7 +181,8 @@ def __init__(
self.injected_sorting = injected_sorting

# save injected sorting if necessary
if not self.injected_sorting.check_if_json_serializable():
if not self.injected_sorting.check_serializablility("json"):
# TODO later : also use pickle
assert injected_sorting_folder is not None, "Provide injected_sorting_folder to injected sorting object"
self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder)

Expand Down
7 changes: 4 additions & 3 deletions src/spikeinterface/comparison/multicomparisons.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ def save_to_folder(self, save_folder):
stacklevel=2,
)
for sorting in self.object_list:
assert (
sorting.check_if_json_serializable()
assert sorting.check_serializablility(
"json"
), "MultiSortingComparison.save_to_folder() need json serializable sortings"

save_folder = Path(save_folder)
Expand Down Expand Up @@ -259,7 +259,8 @@ def __init__(

BaseSorting.__init__(self, sampling_frequency=sampling_frequency, unit_ids=unit_ids)

self._is_json_serializable = False
self._serializablility["json"] = False
self._serializablility["pickle"] = True

if len(unit_ids) > 0:
for k in ("agreement_number", "avg_agreement", "unit_ids"):
Expand Down
73 changes: 40 additions & 33 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def __init__(self, main_ids: Sequence) -> None:
# * number of units for sorting
self._properties = {}

self._is_dumpable = True
self._is_json_serializable = True
self._serializablility = {"memory": True, "json": True, "pickle": True}

# extractor specific list of pip extra requirements
self.extra_requirements = []
Expand Down Expand Up @@ -471,24 +470,33 @@ def clone(self) -> "BaseExtractor":
clone = BaseExtractor.from_dict(d)
return clone

def check_if_dumpable(self):
"""Check if the object is dumpable, including nested objects.
def check_serializablility(self, type):
kwargs = self._kwargs
for value in kwargs.values():
# here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors
if isinstance(value, BaseExtractor):
if not value.check_serializablility(type=type):
return False
elif isinstance(value, list):
for v in value:
if isinstance(v, BaseExtractor) and not v.check_serializablility(type=type):
return False
elif isinstance(value, dict):
for v in value.values():
if isinstance(v, BaseExtractor) and not v.check_serializablility(type=type):
return False
return self._serializablility[type]

def check_if_memory_serializable(self):
"""
Check if the object is serializable to memory with pickle, including nested objects.
Returns
-------
bool
True if the object is dumpable, False otherwise.
True if the object is memory serializable, False otherwise.
"""
kwargs = self._kwargs
for value in kwargs.values():
# here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors
if isinstance(value, BaseExtractor):
return value.check_if_dumpable()
elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor):
return all([v.check_if_dumpable() for v in value])
elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor):
return all([v.check_if_dumpable() for k, v in value.items()])
return self._is_dumpable
return self.check_serializablility("memory")

def check_if_json_serializable(self):
"""
Expand All @@ -499,16 +507,13 @@ def check_if_json_serializable(self):
bool
True if the object is json serializable, False otherwise.
"""
kwargs = self._kwargs
for value in kwargs.values():
# here we check if the value is a BaseExtractor, a list of BaseExtractors, or a dict of BaseExtractors
if isinstance(value, BaseExtractor):
return value.check_if_json_serializable()
elif isinstance(value, list) and (len(value) > 0) and isinstance(value[0], BaseExtractor):
return all([v.check_if_json_serializable() for v in value])
elif isinstance(value, dict) and isinstance(value[list(value.keys())[0]], BaseExtractor):
return all([v.check_if_json_serializable() for k, v in value.items()])
return self._is_json_serializable
# we keep this for backward compatilibity or not ????
# is this needed ??? I think no.
return self.check_serializablility("json")

def check_if_pickle_serializable(self):
# is this needed ??? I think no.
return self.check_serializablility("pickle")

@staticmethod
def _get_file_path(file_path: Union[str, Path], extensions: Sequence) -> Path:
Expand Down Expand Up @@ -557,7 +562,7 @@ def dump(self, file_path: Union[str, Path], relative_to=None, folder_metadata=No
if str(file_path).endswith(".json"):
self.dump_to_json(file_path, relative_to=relative_to, folder_metadata=folder_metadata)
elif str(file_path).endswith(".pkl") or str(file_path).endswith(".pickle"):
self.dump_to_pickle(file_path, relative_to=relative_to, folder_metadata=folder_metadata)
self.dump_to_pickle(file_path, folder_metadata=folder_metadata)
else:
raise ValueError("Dump: file must .json or .pkl")

Expand All @@ -576,7 +581,7 @@ def dump_to_json(self, file_path: Union[str, Path, None] = None, relative_to=Non
folder_metadata: str, Path, or None
Folder with files containing additional information (e.g. probe in BaseRecording) and properties.
"""
assert self.check_if_json_serializable(), "The extractor is not json serializable"
assert self.check_serializablility("json"), "The extractor is not json serializable"

# Writing paths as relative_to requires recursively expanding the dict
if relative_to:
Expand Down Expand Up @@ -616,7 +621,7 @@ def dump_to_pickle(
folder_metadata: str, Path, or None
Folder with files containing additional information (e.g. probe in BaseRecording) and properties.
"""
assert self.check_if_dumpable(), "The extractor is not dumpable"
assert self.check_if_pickle_serializable(), "The extractor is not serializable to file with pickle"

dump_dict = self.to_dict(
include_annotations=True,
Expand Down Expand Up @@ -653,8 +658,8 @@ def load(file_path: Union[str, Path], base_folder: Optional[Union[Path, str, boo
d = pickle.load(f)
else:
raise ValueError(f"Impossible to load {file_path}")
if "warning" in d and "not dumpable" in d["warning"]:
print("The extractor was not dumpable")
if "warning" in d:
print("The extractor was not serializable to file")
return None
extractor = BaseExtractor.from_dict(d, base_folder=base_folder)
return extractor
Expand Down Expand Up @@ -814,10 +819,12 @@ def save_to_folder(self, name=None, folder=None, verbose=True, **save_kwargs):

# dump provenance
provenance_file = folder / f"provenance.json"
if self.check_if_json_serializable():
if self.check_serializablility("json"):
self.dump(provenance_file)
else:
provenance_file.write_text(json.dumps({"warning": "the provenace is not dumpable!!!"}), encoding="utf8")
provenance_file.write_text(
json.dumps({"warning": "the provenace is not json serializable!!!"}), encoding="utf8"
)

self.save_metadata_to_folder(folder)

Expand Down Expand Up @@ -911,7 +918,7 @@ def save_to_zarr(

zarr_root = zarr.open(zarr_path_init, mode="w", storage_options=storage_options)

if self.check_if_dumpable():
if self.check_if_json_serializable():
zarr_root.attrs["provenance"] = check_json(self.to_dict())
else:
zarr_root.attrs["provenance"] = None
Expand Down
4 changes: 4 additions & 0 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1056,6 +1056,8 @@ def __init__(
dtype = parent_recording.dtype if parent_recording is not None else templates.dtype
BaseRecording.__init__(self, sorting.get_sampling_frequency(), channel_ids, dtype)

# Important : self._serializablility is not change here because it will depend on the sorting parents itself.

n_units = len(sorting.unit_ids)
assert len(templates) == n_units
self.spike_vector = sorting.to_spike_vector()
Expand Down Expand Up @@ -1431,5 +1433,7 @@ def generate_ground_truth_recording(
)
recording.annotate(is_filtered=True)
recording.set_probe(probe, in_place=True)
recording.set_channel_gains(1.0)
recording.set_channel_offsets(0.0)

return recording, sorting
6 changes: 3 additions & 3 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,11 @@ def ensure_n_jobs(recording, n_jobs=1):
print(f"Python {sys.version} does not support parallel processing")
n_jobs = 1

if not recording.check_if_dumpable():
if not recording.check_if_memory_serializable():
if n_jobs != 1:
raise RuntimeError(
"Recording is not dumpable and can't be processed in parallel. "
"You can use the `recording.save()` function to make it dumpable or set 'n_jobs' to 1."
"Recording is not serializable to memory and can't be processed in parallel. "
"You can use the `rec = recording.save(folder=...)` function or set 'n_jobs' to 1."
)

return n_jobs
Expand Down
20 changes: 13 additions & 7 deletions src/spikeinterface/core/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def __init__(self, traces_list, sampling_frequency, t_starts=None, channel_ids=N
assert len(t_starts) == len(traces_list), "t_starts must be a list of same size than traces_list"
t_starts = [float(t_start) for t_start in t_starts]

self._is_json_serializable = False
self._serializablility["json"] = False
self._serializablility["pickle"] = False

for i, traces in enumerate(traces_list):
if t_starts is None:
Expand Down Expand Up @@ -126,8 +127,10 @@ def __init__(self, spikes, sampling_frequency, unit_ids):
""" """
BaseSorting.__init__(self, sampling_frequency, unit_ids)

self._is_dumpable = True
self._is_json_serializable = False
self._serializablility["memory"] = True
self._serializablility["json"] = False
# theorically this should be False but for simplicity make generators simples we still need this.
self._serializablility["pickle"] = True

if spikes.size == 0:
nseg = 1
Expand Down Expand Up @@ -357,8 +360,10 @@ def __init__(self, shm_name, shape, sampling_frequency, unit_ids, dtype=minimum_
assert shape[0] > 0, "SharedMemorySorting only supported with no empty sorting"

BaseSorting.__init__(self, sampling_frequency, unit_ids)
self._is_dumpable = True
self._is_json_serializable = False

self._serializablility["memory"] = True
self._serializablility["json"] = False
self._serializablility["pickle"] = False

self.shm = SharedMemory(shm_name, create=False)
self.shm_spikes = np.ndarray(shape=shape, dtype=dtype, buffer=self.shm.buf)
Expand Down Expand Up @@ -516,8 +521,9 @@ def __init__(self, snippets_list, spikesframes_list, sampling_frequency, nbefore
dtype=dtype,
)

self._is_dumpable = False
self._is_json_serializable = False
self._serializablility["memory"] = False
self._serializablility["json"] = False
self._serializablility["pickle"] = False

for snippets, spikesframes in zip(snippets_list, spikesframes_list):
snp_segment = NumpySnippetsSegment(snippets, spikesframes)
Expand Down
12 changes: 7 additions & 5 deletions src/spikeinterface/core/old_api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,10 @@ def __init__(self, oldapi_recording_extractor):
dtype=oldapi_recording_extractor.get_dtype(return_scaled=False),
)

# set _is_dumpable to False to use dumping mechanism of old extractor
self._is_dumpable = False
self._is_json_serializable = False
# set to False to use dumping mechanism of old extractor
self._serializablility["memory"] = False
self._serializablility["json"] = False
self._serializablility["pickle"] = False

self.annotate(is_filtered=oldapi_recording_extractor.is_filtered)

Expand Down Expand Up @@ -268,8 +269,9 @@ def __init__(self, oldapi_sorting_extractor):
sorting_segment = OldToNewSortingSegment(oldapi_sorting_extractor)
self.add_sorting_segment(sorting_segment)

self._is_dumpable = False
self._is_json_serializable = False
self._serializablility["memory"] = False
self._serializablility["json"] = False
self._serializablility["pickle"] = False

# add old properties
copy_properties(oldapi_extractor=oldapi_sorting_extractor, new_extractor=self)
Expand Down
38 changes: 19 additions & 19 deletions src/spikeinterface/core/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,39 +31,39 @@ def make_nested_extractors(extractor):
)


def test_check_if_dumpable():
def test_check_if_memory_serializable():
test_extractor = generate_recording(seed=0, durations=[2])

# make a list of dumpable objects
extractors_dumpable = make_nested_extractors(test_extractor)
for extractor in extractors_dumpable:
assert extractor.check_if_dumpable()
# make a list of memory serializable objects
extractors_mem_serializable = make_nested_extractors(test_extractor)
for extractor in extractors_mem_serializable:
assert extractor.check_if_memory_serializable()

# make not dumpable
test_extractor._is_dumpable = False
extractors_not_dumpable = make_nested_extractors(test_extractor)
for extractor in extractors_not_dumpable:
assert not extractor.check_if_dumpable()
# make not not memory serilizable
test_extractor._serializablility["memory"] = False
extractors_not_mem_serializable = make_nested_extractors(test_extractor)
for extractor in extractors_not_mem_serializable:
assert not extractor.check_if_memory_serializable()


def test_check_if_json_serializable():
def test_check_if_serializable():
test_extractor = generate_recording(seed=0, durations=[2])

# make a list of dumpable objects
test_extractor._is_json_serializable = True
# make a list of json serializable objects
test_extractor._serializablility["json"] = True
extractors_json_serializable = make_nested_extractors(test_extractor)
for extractor in extractors_json_serializable:
print(extractor)
assert extractor.check_if_json_serializable()
assert extractor.check_serializablility("json")

# make not dumpable
test_extractor._is_json_serializable = False
# make of not json serializable objects
test_extractor._serializablility["json"] = False
extractors_not_json_serializable = make_nested_extractors(test_extractor)
for extractor in extractors_not_json_serializable:
print(extractor)
assert not extractor.check_if_json_serializable()
assert not extractor.check_serializablility("json")


if __name__ == "__main__":
test_check_if_dumpable()
test_check_if_json_serializable()
test_check_if_memory_serializable()
test_check_if_serializable()
1 change: 0 additions & 1 deletion src/spikeinterface/core/tests/test_core_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@ def test_write_memory_recording():
recording = NoiseGeneratorRecording(
num_channels=2, durations=[10.325, 3.5], sampling_frequency=30_000, strategy="tile_pregenerated"
)
# make dumpable
recording = recording.save()

# write with loop
Expand Down
Loading

0 comments on commit c5bafd1

Please sign in to comment.