diff --git a/src/MEArec/tools.py b/src/MEArec/tools.py index b058b67..f9bfe6f 100755 --- a/src/MEArec/tools.py +++ b/src/MEArec/tools.py @@ -560,7 +560,7 @@ def save_template_generator(tempgen, filename=None, verbose=True): print("\nSaved templates in", filename, "\n") -def save_recording_generator(recgen, filename=None, verbose=False): +def save_recording_generator(recgen, filename=None, verbose=False, include_spike_traces: bool = True): """ Save recordings to disk. @@ -572,6 +572,8 @@ def save_recording_generator(recgen, filename=None, verbose=False): Path to .h5 file verbose : bool If True output is verbose + include_spike_traces: bool, default=True + If True, will include the spike traces (can be heavy) """ filename = Path(filename) if not filename.parent.is_dir(): @@ -580,12 +582,12 @@ def save_recording_generator(recgen, filename=None, verbose=False): with h5py.File(filename, "w") as f: f.attrs["mearec_version"] = mearec_version f.attrs["date"] = datetime.now().strftime("%y-%m-%d %H:%M:%S") - save_recording_to_file(recgen, f) + save_recording_to_file(recgen, f, include_spike_traces=include_spike_traces) if verbose: print("\nSaved recordings in", filename, "\n") -def save_recording_to_file(recgen, f, path=""): +def save_recording_to_file(recgen, f, path="", include_spike_traces: bool = True): """ Save recordings to file handler. @@ -595,6 +597,8 @@ def save_recording_to_file(recgen, f, path=""): RecordingGenerator object to be saved filename : _io.TextIOWrapper File handler + include_spike_traces: bool, default=True + If True, will include the spike traces (can be heavy) """ save_dict_to_hdf5(recgen.info, f, path + "info/") if len(recgen.voltage_peaks) > 0: @@ -605,7 +609,7 @@ def save_recording_to_file(recgen, f, path=""): f.create_dataset(path + "recordings", data=recgen.recordings) if recgen.gain_to_uV is not None: f["recordings"].attrs["gain_to_uV"] = recgen.gain_to_uV - if len(recgen.spike_traces) > 0: + if len(recgen.spike_traces) > 0 and include_spike_traces: f.create_dataset(path + "spike_traces", data=recgen.spike_traces) if len(recgen.spiketrains) > 0: for ii in range(len(recgen.spiketrains)):