Skip to content

Commit

Permalink
Merge branch 'main' into quality_metrics_update
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow authored Sep 11, 2024
2 parents c1f0b2a + 73f4d58 commit 715ce54
Show file tree
Hide file tree
Showing 12 changed files with 170 additions and 34 deletions.
11 changes: 7 additions & 4 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,15 +497,17 @@ def set_times(self, times, segment_index=None, with_warning=True):

def reset_times(self):
"""
Reset times in-memory for all segments that have a time vector.
Reset time information in-memory for all segments that have a time vector.
If the timestamps come from a file, the files won't be modified. but only the in-memory
attributes of the recording objects are deleted.
attributes of the recording objects are deleted. Also `t_start` is set to None and the
segment's sampling frequency is set to the recording's sampling frequency.
"""
for segment_index in range(self.get_num_segments()):
rs = self._recording_segments[segment_index]
if self.has_time_vector(segment_index):
rs = self._recording_segments[segment_index]
rs.t_start = None
rs.time_vector = None
rs.t_start = None
rs.sampling_frequency = self.sampling_frequency

def sample_index_to_time(self, sample_ind, segment_index=None):
"""
Expand Down Expand Up @@ -565,6 +567,7 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs):
channel_ids=self.get_channel_ids(),
time_axis=0,
file_offset=0,
is_filtered=self.is_filtered(),
gain_to_uV=self.get_channel_gains(),
offset_to_uV=self.get_channel_offsets(),
)
Expand Down
4 changes: 3 additions & 1 deletion src/spikeinterface/core/binaryrecordingextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,12 @@ def __init__(
num_chan=None,
):
# This assigns num_channels if num_channels is not None, otherwise num_chan is assigned
# num_chan needs to be be kept for backward compatibility but should not be used by the
# end user
num_channels = num_channels or num_chan
assert num_channels is not None, "You must provide num_channels or num_chan"
if num_chan is not None:
warnings.warn("`num_chan` is to be deprecated in version 0.100, please use `num_channels` instead")
warnings.warn("`num_chan` is to be deprecated as of version 0.100, please use `num_channels` instead")

if channel_ids is None:
channel_ids = list(range(num_channels))
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/core_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def is_dict_extractor(d: dict) -> bool:
extractor_dict_element = namedtuple(typename="extractor_dict_element", field_names=["value", "name", "access_path"])


def extractor_dict_iterator(extractor_dict: dict) -> Generator[extractor_dict_element]:
def extractor_dict_iterator(extractor_dict: dict) -> Generator[extractor_dict_element, None, None]:
"""
Iterator for recursive traversal of a dictionary.
This function explores the dictionary recursively and yields the path to each value along with the value itself.
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/recording_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _init_binary_worker(recording, file_path_dict, dtype, byte_offest, cast_unsi
def write_binary_recording(
recording: "BaseRecording",
file_paths: list[Path | str] | Path | str,
dtype: np.ndtype = None,
dtype: np.typing.DTypeLike = None,
add_file_extension: bool = True,
byte_offset: int = 0,
auto_cast_uint: bool = True,
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/segmentutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def __init__(self, parent_segments, sampling_frequency, ignore_times=True):
self.parent_segments = parent_segments
self.all_length = [rec_seg.get_num_samples() for rec_seg in self.parent_segments]
self.cumsum_length = [0] + [sum(self.all_length[: i + 1]) for i in range(len(self.all_length))]
self.total_length = int(np.sum(self.all_length))
self.total_length = int(sum(self.all_length))

def get_num_samples(self):
return self.total_length
Expand Down Expand Up @@ -450,7 +450,7 @@ def __init__(self, parent_segments, parent_num_samples, sampling_frequency):
self.parent_segments = parent_segments
self.parent_num_samples = parent_num_samples
self.cumsum_length = np.cumsum([0] + self.parent_num_samples)
self.total_num_samples = np.sum(self.parent_num_samples)
self.total_num_samples = int(sum(self.parent_num_samples))

def get_num_samples(self):
return self.total_num_samples
Expand Down
90 changes: 83 additions & 7 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import shutil
import warnings
import importlib
from time import perf_counter

import numpy as np

Expand Down Expand Up @@ -1336,6 +1337,7 @@ def compute_several_extensions(self, extensions, save=True, verbose=False, **job

job_name = "Compute : " + " + ".join(extensions_with_pipeline.keys())

t_start = perf_counter()
results = run_node_pipeline(
self.recording,
all_nodes,
Expand All @@ -1345,10 +1347,15 @@ def compute_several_extensions(self, extensions, save=True, verbose=False, **job
squeeze_output=False,
verbose=verbose,
)
t_end = perf_counter()
# for pipeline node extensions we can only track the runtime of the run_node_pipeline
runtime_s = t_end - t_start

for r, result in enumerate(results):
extension_name, variable_name = result_routage[r]
extension_instances[extension_name].data[variable_name] = result
extension_instances[extension_name].run_info["runtime_s"] = runtime_s
extension_instances[extension_name].run_info["run_completed"] = True

for extension_name, extension_instance in extension_instances.items():
self.extensions[extension_name] = extension_instance
Expand Down Expand Up @@ -1738,8 +1745,12 @@ def __init__(self, sorting_analyzer):
self._sorting_analyzer = weakref.ref(sorting_analyzer)

self.params = None
self.run_info = self._default_run_info_dict()
self.data = dict()

def _default_run_info_dict(self):
return dict(run_completed=False, runtime_s=None)

#######
# This 3 methods must be implemented in the subclass!!!
# See DummyAnalyzerExtension in test_sortinganalyzer.py as a simple example
Expand Down Expand Up @@ -1851,11 +1862,42 @@ def _get_zarr_extension_group(self, mode="r+"):
def load(cls, sorting_analyzer):
ext = cls(sorting_analyzer)
ext.load_params()
ext.load_data()
if cls.need_backward_compatibility_on_load:
ext._handle_backward_compatibility_on_load()
ext.load_run_info()
if ext.run_info is not None:
if ext.run_info["run_completed"]:
ext.load_data()
if cls.need_backward_compatibility_on_load:
ext._handle_backward_compatibility_on_load()
if len(ext.data) > 0:
return ext
else:
# this is for back-compatibility of old analyzers
ext.load_data()
if cls.need_backward_compatibility_on_load:
ext._handle_backward_compatibility_on_load()
if len(ext.data) > 0:
return ext
# If extension run not completed, or data has gone missing,
# return None to indicate that the extension should be (re)computed.
return None

def load_run_info(self):
if self.format == "binary_folder":
extension_folder = self._get_binary_extension_folder()
run_info_file = extension_folder / "run_info.json"
if run_info_file.is_file():
with open(str(run_info_file), "r") as f:
run_info = json.load(f)
else:
warnings.warn(f"Found no run_info file for {self.extension_name}, extension should be re-computed.")
run_info = None

return ext
elif self.format == "zarr":
extension_group = self._get_zarr_extension_group(mode="r")
run_info = extension_group.attrs.get("run_info", None)
if run_info is None:
warnings.warn(f"Found no run_info file for {self.extension_name}, extension should be re-computed.")
self.run_info = run_info

def load_params(self):
if self.format == "binary_folder":
Expand All @@ -1873,12 +1915,17 @@ def load_params(self):
self.params = params

def load_data(self):
ext_data = None
if self.format == "binary_folder":
extension_folder = self._get_binary_extension_folder()
for ext_data_file in extension_folder.iterdir():
# patch for https://github.com/SpikeInterface/spikeinterface/issues/3041
# maybe add a check for version number from the info.json during loading only
if ext_data_file.name == "params.json" or ext_data_file.name == "info.json":
if (
ext_data_file.name == "params.json"
or ext_data_file.name == "info.json"
or ext_data_file.name == "run_info.json"
):
continue
ext_data_name = ext_data_file.stem
if ext_data_file.suffix == ".json":
Expand Down Expand Up @@ -1919,6 +1966,9 @@ def load_data(self):
ext_data = np.array(ext_data_)
self.data[ext_data_name] = ext_data

if len(self.data) == 0:
warnings.warn(f"Found no data for {self.extension_name}, extension should be re-computed.")

def copy(self, new_sorting_analyzer, unit_ids=None):
# alessio : please note that this also replace the old select_units!!!
new_extension = self.__class__(new_sorting_analyzer)
Expand All @@ -1927,6 +1977,7 @@ def copy(self, new_sorting_analyzer, unit_ids=None):
new_extension.data = self.data
else:
new_extension.data = self._select_extension_data(unit_ids)
new_extension.run_info = self.run_info.copy()
new_extension.save()
return new_extension

Expand All @@ -1944,24 +1995,33 @@ def merge(
new_extension.data = self._merge_extension_data(
merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask, verbose=verbose, **job_kwargs
)
new_extension.run_info = self.run_info.copy()
new_extension.save()
return new_extension

def run(self, save=True, **kwargs):
if save and not self.sorting_analyzer.is_read_only():
# this also reset the folder or zarr group
# NB: this call to _save_params() also resets the folder or zarr group
self._save_params()
self._save_importing_provenance()
self._save_run_info()

t_start = perf_counter()
self._run(**kwargs)
t_end = perf_counter()
self.run_info["runtime_s"] = t_end - t_start

if save and not self.sorting_analyzer.is_read_only():
self._save_data(**kwargs)

self.run_info["run_completed"] = True
self._save_run_info()

def save(self, **kwargs):
self._save_params()
self._save_importing_provenance()
self._save_data(**kwargs)
self._save_run_info()

def _save_data(self, **kwargs):
if self.format == "memory":
Expand Down Expand Up @@ -2060,6 +2120,7 @@ def reset(self):
"""
self._reset_extension_folder()
self.params = None
self.run_info = self._default_run_info_dict()
self.data = dict()

def set_params(self, save=True, **params):
Expand All @@ -2080,6 +2141,7 @@ def set_params(self, save=True, **params):
if save:
self._save_params()
self._save_importing_provenance()
self._save_run_info()

def _save_params(self):
params_to_save = self.params.copy()
Expand Down Expand Up @@ -2117,14 +2179,28 @@ def _save_importing_provenance(self):
extension_group = self._get_zarr_extension_group(mode="r+")
extension_group.attrs["info"] = info

def _save_run_info(self):
run_info = self.run_info.copy()

if self.format == "binary_folder":
extension_folder = self._get_binary_extension_folder()
run_info_file = extension_folder / "run_info.json"
run_info_file.write_text(json.dumps(run_info, indent=4), encoding="utf8")
elif self.format == "zarr":
extension_group = self._get_zarr_extension_group(mode="r+")
extension_group.attrs["run_info"] = run_info

def get_pipeline_nodes(self):
assert (
self.use_nodepipeline
), "AnalyzerExtension.get_pipeline_nodes() must be called only when use_nodepipeline=True"
return self._get_pipeline_nodes()

def get_data(self, *args, **kwargs):
assert len(self.data) > 0, f"You must run the extension {self.extension_name} before retrieving data"
assert self.run_info[
"run_completed"
], f"You must run the extension {self.extension_name} before retrieving data"
assert len(self.data) > 0, "Extension has been run but no data found."
return self._get_data(*args, **kwargs)


Expand Down
7 changes: 7 additions & 0 deletions src/spikeinterface/core/tests/test_baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,14 @@ def test_BaseRecording(create_cache_folder):
# reset times
rec.reset_times()
for segm in range(num_seg):
time_info = rec.get_time_info(segment_index=segm)
assert not rec.has_time_vector(segment_index=segm)
assert time_info["t_start"] is None
assert time_info["time_vector"] is None
assert time_info["sampling_frequency"] == rec.sampling_frequency

# resetting time again should be ok
rec.reset_times()

# test 3d probe
rec_3d = generate_recording(ndim=3, num_channels=30)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def test_sequential_reading_of_small_traces(folder_with_binary_files):
dtype = "float32"

file_paths = [folder / "traces_cached_seg0.raw"]
# `num_chan` is kept for backward compatibility so including it at least one test
# run is good to ensure that it is appropriately accepted as an argument
recording = BinaryRecordingExtractor(
num_chan=num_channels,
file_paths=file_paths,
Expand Down
33 changes: 33 additions & 0 deletions src/spikeinterface/core/tests/test_sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,39 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset):
)


def test_load_without_runtime_info(tmp_path, dataset):
recording, sorting = dataset

folder = tmp_path / "test_SortingAnalyzer_run_info"

extensions = ["random_spikes", "templates"]
# binary_folder
sorting_analyzer = create_sorting_analyzer(
sorting, recording, format="binary_folder", folder=folder, sparse=False, sparsity=None
)
sorting_analyzer.compute(extensions)
# remove run_info.json to mimic a previous version of spikeinterface
for ext in extensions:
(folder / "extensions" / ext / "run_info.json").unlink()
# should raise a warning for missing run_info
with pytest.warns(UserWarning):
sorting_analyzer = load_sorting_analyzer(folder, format="auto")

# zarr
folder = tmp_path / "test_SortingAnalyzer_run_info.zarr"
sorting_analyzer = create_sorting_analyzer(
sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None
)
sorting_analyzer.compute(extensions)
# remove run_info from attrs to mimic a previous version of spikeinterface
root = sorting_analyzer._get_zarr_root(mode="r+")
for ext in extensions:
del root["extensions"][ext].attrs["run_info"]
# should raise a warning for missing run_info
with pytest.warns(UserWarning):
sorting_analyzer = load_sorting_analyzer(folder, format="auto")


def test_SortingAnalyzer_tmp_recording(dataset):
recording, sorting = dataset
recording_cached = recording.save(mode="memory")
Expand Down
8 changes: 6 additions & 2 deletions src/spikeinterface/generation/drift_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,13 @@ def interpolate_templates(templates_array, source_locations, dest_locations, int
source_locations = np.asarray(source_locations)
dest_locations = np.asarray(dest_locations)
if dest_locations.ndim == 2:
new_shape = templates_array.shape
new_shape = (*templates_array.shape[:2], len(dest_locations))
elif dest_locations.ndim == 3:
new_shape = (dest_locations.shape[0],) + templates_array.shape
new_shape = (
dest_locations.shape[0],
*templates_array.shape[:2],
dest_locations.shape[1],
)
else:
raise ValueError(f"Incorrect dimensions for dest_locations: {dest_locations.ndim}. Dimensions can be 2 or 3. ")

Expand Down
Loading

0 comments on commit 715ce54

Please sign in to comment.