Skip to content

Commit

Permalink
Merge pull request #3443 from alejoe91/fix-save-as-recordingless
Browse files Browse the repository at this point in the history
Allow to save recordingless analyzer as
  • Loading branch information
samuelgarcia authored Oct 1, 2024
2 parents 0ae32a3 + c338658 commit b0c2bae
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 46 deletions.
96 changes: 54 additions & 42 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import shutil
import warnings
import importlib
from copy import copy
from packaging.version import parse
from time import perf_counter

Expand Down Expand Up @@ -254,6 +255,7 @@ def create(
sparsity=None,
return_scaled=True,
):
assert recording is not None, "To create a SortingAnalyzer you need to specify the recording"
# some checks
if sorting.sampling_frequency != recording.sampling_frequency:
if math.isclose(sorting.sampling_frequency, recording.sampling_frequency, abs_tol=1e-2, rel_tol=1e-5):
Expand Down Expand Up @@ -352,8 +354,6 @@ def create_memory(cls, sorting, recording, sparsity, return_scaled, rec_attribut
def create_binary_folder(cls, folder, sorting, recording, sparsity, return_scaled, rec_attributes):
# used by create and save_as

assert recording is not None, "To create a SortingAnalyzer you need to specify the recording"

folder = Path(folder)
if folder.is_dir():
raise ValueError(f"Folder already exists {folder}")
Expand All @@ -369,26 +369,34 @@ def create_binary_folder(cls, folder, sorting, recording, sparsity, return_scale
json.dump(check_json(info), f, indent=4)

# save a copy of the sorting
# NumpyFolderSorting.write_sorting(sorting, folder / "sorting")
sorting.save(folder=folder / "sorting")

# save recording and sorting provenance
if recording.check_serializability("json"):
recording.dump(folder / "recording.json", relative_to=folder)
elif recording.check_serializability("pickle"):
recording.dump(folder / "recording.pickle", relative_to=folder)
if recording is not None:
# save recording and sorting provenance
if recording.check_serializability("json"):
recording.dump(folder / "recording.json", relative_to=folder)
elif recording.check_serializability("pickle"):
recording.dump(folder / "recording.pickle", relative_to=folder)
else:
warnings.warn("The Recording is not serializable! The recording link will be lost for future load")
else:
assert rec_attributes is not None, "recording or rec_attributes must be provided"
warnings.warn("Recording not provided, instntiating SortingAnalyzer in recordingless mode.")

if sorting.check_serializability("json"):
sorting.dump(folder / "sorting_provenance.json", relative_to=folder)
elif sorting.check_serializability("pickle"):
sorting.dump(folder / "sorting_provenance.pickle", relative_to=folder)
else:
warnings.warn(
"The sorting provenance is not serializable! The sorting provenance link will be lost for future load"
)

# dump recording attributes
probegroup = None
rec_attributes_file = folder / "recording_info" / "recording_attributes.json"
rec_attributes_file.parent.mkdir()
if rec_attributes is None:
assert recording is not None
rec_attributes = get_rec_attributes(recording)
rec_attributes_file.write_text(json.dumps(check_json(rec_attributes), indent=4), encoding="utf8")
probegroup = recording.get_probegroup()
Expand Down Expand Up @@ -519,20 +527,21 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_at
zarr_root.attrs["settings"] = check_json(settings)

# the recording
rec_dict = recording.to_dict(relative_to=folder, recursive=True)

if recording.check_serializability("json"):
# zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.JSON())
zarr_rec = np.array([check_json(rec_dict)], dtype=object)
zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.JSON())
elif recording.check_serializability("pickle"):
# zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.Pickle())
zarr_rec = np.array([rec_dict], dtype=object)
zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.Pickle())
if recording is not None:
rec_dict = recording.to_dict(relative_to=folder, recursive=True)
if recording.check_serializability("json"):
# zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.JSON())
zarr_rec = np.array([check_json(rec_dict)], dtype=object)
zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.JSON())
elif recording.check_serializability("pickle"):
# zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.Pickle())
zarr_rec = np.array([rec_dict], dtype=object)
zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.Pickle())
else:
warnings.warn("The Recording is not serializable! The recording link will be lost for future load")
else:
warnings.warn(
"SortingAnalyzer with zarr : the Recording is not json serializable, the recording link will be lost for future load"
)
assert rec_attributes is not None, "recording or rec_attributes must be provided"
warnings.warn("Recording not provided, instntiating SortingAnalyzer in recordingless mode.")

# sorting provenance
sort_dict = sorting.to_dict(relative_to=folder, recursive=True)
Expand All @@ -542,14 +551,14 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_scaled, rec_at
elif sorting.check_serializability("pickle"):
zarr_sort = np.array([sort_dict], dtype=object)
zarr_root.create_dataset("sorting_provenance", data=zarr_sort, object_codec=numcodecs.Pickle())

# else:
# warnings.warn("SortingAnalyzer with zarr : the sorting provenance is not json serializable, the sorting provenance link will be lost for futur load")
else:
warnings.warn(
"The sorting provenance is not serializable! The sorting provenance link will be lost for future load"
)

recording_info = zarr_root.create_group("recording_info")

if rec_attributes is None:
assert recording is not None
rec_attributes = get_rec_attributes(recording)
probegroup = recording.get_probegroup()
else:
Expand Down Expand Up @@ -605,11 +614,13 @@ def load_from_zarr(cls, folder, recording=None, storage_options=None):

# load recording if possible
if recording is None:
rec_dict = zarr_root["recording"][0]
try:
recording = load_extractor(rec_dict, base_folder=folder)
except:
recording = None
rec_field = zarr_root.get("recording")
if rec_field is not None:
rec_dict = rec_field[0]
try:
recording = load_extractor(rec_dict, base_folder=folder)
except:
recording = None
else:
# TODO maybe maybe not??? : do we need to check attributes match internal rec_attributes
# Note this will make the loading too slow
Expand Down Expand Up @@ -2015,7 +2026,7 @@ def copy(self, new_sorting_analyzer, unit_ids=None):
new_extension.data = self.data
else:
new_extension.data = self._select_extension_data(unit_ids)
new_extension.run_info = self.run_info.copy()
new_extension.run_info = copy(self.run_info)
new_extension.save()
return new_extension

Expand All @@ -2033,7 +2044,7 @@ def merge(
new_extension.data = self._merge_extension_data(
merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask, verbose=verbose, **job_kwargs
)
new_extension.run_info = self.run_info.copy()
new_extension.run_info = copy(self.run_info)
new_extension.save()
return new_extension

Expand Down Expand Up @@ -2251,15 +2262,16 @@ def _save_importing_provenance(self):
extension_group.attrs["info"] = info

def _save_run_info(self):
run_info = self.run_info.copy()

if self.format == "binary_folder":
extension_folder = self._get_binary_extension_folder()
run_info_file = extension_folder / "run_info.json"
run_info_file.write_text(json.dumps(run_info, indent=4), encoding="utf8")
elif self.format == "zarr":
extension_group = self._get_zarr_extension_group(mode="r+")
extension_group.attrs["run_info"] = run_info
if self.run_info is not None:
run_info = self.run_info.copy()

if self.format == "binary_folder":
extension_folder = self._get_binary_extension_folder()
run_info_file = extension_folder / "run_info.json"
run_info_file.write_text(json.dumps(run_info, indent=4), encoding="utf8")
elif self.format == "zarr":
extension_group = self._get_zarr_extension_group(mode="r+")
extension_group.attrs["run_info"] = run_info

def get_pipeline_nodes(self):
assert (
Expand Down
8 changes: 4 additions & 4 deletions src/spikeinterface/preprocessing/tests/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_causal_filter_main_kwargs(self, recording_and_data):

filt_data = causal_filter(recording, direction="forward", **options, margin_ms=0).get_traces()

assert np.allclose(test_data, filt_data, rtol=0, atol=1e-4)
assert np.allclose(test_data, filt_data, rtol=0, atol=1e-2)

# Then, change all kwargs to ensure they are propagated
# and check the backwards version.
Expand All @@ -66,7 +66,7 @@ def test_causal_filter_main_kwargs(self, recording_and_data):

filt_data = causal_filter(recording, direction="backward", **options, margin_ms=0).get_traces()

assert np.allclose(test_data, filt_data, rtol=0, atol=1e-4)
assert np.allclose(test_data, filt_data, rtol=0, atol=1e-2)

def test_causal_filter_custom_coeff(self, recording_and_data):
"""
Expand All @@ -89,7 +89,7 @@ def test_causal_filter_custom_coeff(self, recording_and_data):

filt_data = causal_filter(recording, direction="forward", **options, margin_ms=0).get_traces()

assert np.allclose(test_data, filt_data, rtol=0, atol=1e-4, equal_nan=True)
assert np.allclose(test_data, filt_data, rtol=0, atol=1e-2, equal_nan=True)

# Next, in "sos" mode
options["filter_mode"] = "sos"
Expand All @@ -100,7 +100,7 @@ def test_causal_filter_custom_coeff(self, recording_and_data):

filt_data = causal_filter(recording, direction="forward", **options, margin_ms=0).get_traces()

assert np.allclose(test_data, filt_data, rtol=0, atol=1e-4, equal_nan=True)
assert np.allclose(test_data, filt_data, rtol=0, atol=1e-2, equal_nan=True)

def test_causal_kwarg_error_raised(self, recording_and_data):
"""
Expand Down

0 comments on commit b0c2bae

Please sign in to comment.