From a71eff4109c945bcfb1f384c39f6c4127629fa2f Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Fri, 21 Jun 2024 18:25:22 -0400 Subject: [PATCH] add typing --- src/spikeinterface/core/base.py | 6 +++--- src/spikeinterface/core/baserecordingsnippets.py | 8 ++++---- src/spikeinterface/core/basesorting.py | 10 ++++------ src/spikeinterface/core/binaryfolder.py | 2 +- src/spikeinterface/core/binaryrecordingextractor.py | 2 +- src/spikeinterface/core/core_tools.py | 13 ++++++++++--- src/spikeinterface/core/frameslicerecording.py | 2 +- src/spikeinterface/core/generate.py | 2 +- src/spikeinterface/core/numpyextractors.py | 2 +- src/spikeinterface/core/recording_tools.py | 12 +++++++++++- src/spikeinterface/core/sortinganalyzer.py | 2 +- src/spikeinterface/core/waveform_tools.py | 2 +- src/spikeinterface/sorters/basesorter.py | 6 +++--- 13 files changed, 42 insertions(+), 27 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 6fbc5ac289..4922707b35 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -550,7 +550,7 @@ def check_serializability(self, type): return False return self._serializability[type] - def check_if_memory_serializable(self): + def check_if_memory_serializable(self) -> bool: """ Check if the object is serializable to memory with pickle, including nested objects. @@ -561,7 +561,7 @@ def check_if_memory_serializable(self): """ return self.check_serializability("memory") - def check_if_json_serializable(self): + def check_if_json_serializable(self) -> bool: """ Check if the object is json serializable, including nested objects. @@ -574,7 +574,7 @@ def check_if_json_serializable(self): # is this needed ??? I think no. return self.check_serializability("json") - def check_if_pickle_serializable(self): + def check_if_pickle_serializable(self) -> bool: # is this needed ??? I think no. return self.check_serializability("pickle") diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 2a9f075954..428472bf93 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -48,7 +48,7 @@ def get_num_channels(self): def get_dtype(self): return self._dtype - def has_scaleable_traces(self): + def has_scaleable_traces(self) -> bool: if self.get_property("gain_to_uV") is None or self.get_property("offset_to_uV") is None: return False else: @@ -62,10 +62,10 @@ def has_scaled(self): ) return self.has_scaleable_traces() - def has_probe(self): + def has_probe(self) -> bool: return "contact_vector" in self.get_property_keys() - def has_channel_location(self): + def has_channel_location(self) -> bool: return self.has_probe() or "location" in self.get_property_keys() def is_filtered(self): @@ -366,7 +366,7 @@ def get_channel_locations(self, channel_ids=None, axes: str = "xy"): locations = np.asarray(locations)[channel_indices] return select_axes(locations, axes) - def has_3d_locations(self): + def has_3d_locations(self) -> bool: return self.get_property("location").shape[1] == 3 def clear_channel_locations(self, channel_ids=None): diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 7214d2780e..fd68df9dda 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import List, Optional, Union +from typing import Optional, Union import numpy as np @@ -73,7 +73,7 @@ def unit_ids(self): def sampling_frequency(self): return self._sampling_frequency - def get_unit_ids(self) -> List: + def get_unit_ids(self) -> list: return self._main_ids def get_num_units(self) -> int: @@ -121,7 +121,7 @@ def get_total_samples(self) -> int: s += self.get_num_samples(segment_index) return s - def get_total_duration(self): + def get_total_duration(self) -> float: """Returns the total duration in s of the associated recording. Returns @@ -219,7 +219,7 @@ def set_sorting_info(self, recording_dict, params_dict, log_dict): def has_recording(self): return self._recording is not None - def has_time_vector(self, segment_index=None): + def has_time_vector(self, segment_index=None) -> bool: """ Check if the segment of the registered recording has a time vector. """ @@ -515,8 +515,6 @@ def precompute_spike_trains(self, from_spike_vector=None): """ Pre-computes and caches all spike trains for this sorting - - Parameters ---------- from_spike_vector : None | bool, default: None diff --git a/src/spikeinterface/core/binaryfolder.py b/src/spikeinterface/core/binaryfolder.py index ec9bdfcc5e..546ac85f93 100644 --- a/src/spikeinterface/core/binaryfolder.py +++ b/src/spikeinterface/core/binaryfolder.py @@ -53,7 +53,7 @@ def __init__(self, folder_path): assert "num_chan" in self._bin_kwargs, "Cannot find num_channels or num_chan in binary.json" self._bin_kwargs["num_channels"] = self._bin_kwargs["num_chan"] - def is_binary_compatible(self): + def is_binary_compatible(self) -> bool: return True def get_binary_description(self): diff --git a/src/spikeinterface/core/binaryrecordingextractor.py b/src/spikeinterface/core/binaryrecordingextractor.py index 5d72532704..8fb9a78f2a 100644 --- a/src/spikeinterface/core/binaryrecordingextractor.py +++ b/src/spikeinterface/core/binaryrecordingextractor.py @@ -147,7 +147,7 @@ def write_recording(recording, file_paths, dtype=None, **job_kwargs): """ write_binary_recording(recording, file_paths=file_paths, dtype=dtype, **job_kwargs) - def is_binary_compatible(self): + def is_binary_compatible(self) -> bool: return True def get_binary_description(self): diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index 664eac169f..f3d8b3df7f 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -168,9 +168,14 @@ def make_shared_array(shape, dtype): return arr, shm -def is_dict_extractor(d): +def is_dict_extractor(d: dict) -> bool: """ - Check if a dict describe an extractor. + Check if a dict describes an extractor. + + Returns + ------- + is_extractor : bool + Whether the dict describes an extractor """ if not isinstance(d, dict): return False @@ -283,6 +288,7 @@ def check_paths_relative(input_dict, relative_folder) -> bool: Returns ------- relative_possible: bool + Whether the given input can be made relative to the relative_folder """ path_list = _get_paths_list(input_dict) relative_folder = Path(relative_folder).resolve().absolute() @@ -513,7 +519,8 @@ def normal_pdf(x, mu: float = 0.0, sigma: float = 1.0): def retrieve_importing_provenance(a_class): """ - Retrieve the import provenance of a class, including its import name (that consists of the class name and the module), the top-level module, and the module version. + Retrieve the import provenance of a class, including its import name (that consists of the class name and the module), + the top-level module, and the module version. Parameters ---------- diff --git a/src/spikeinterface/core/frameslicerecording.py b/src/spikeinterface/core/frameslicerecording.py index 133cbf886c..5c91d3cae1 100644 --- a/src/spikeinterface/core/frameslicerecording.py +++ b/src/spikeinterface/core/frameslicerecording.py @@ -82,7 +82,7 @@ def __init__(self, parent_recording_segment, start_frame, end_frame): self.start_frame = start_frame self.end_frame = end_frame - def get_num_samples(self): + def get_num_samples(self) -> int: return self.end_frame - self.start_frame def get_traces(self, start_frame, end_frame, channel_indices): diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 251678e675..370f5b42c6 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -1134,7 +1134,7 @@ def __init__( elif self.strategy == "on_the_fly": pass - def get_num_samples(self): + def get_num_samples(self) -> int: return self.num_samples def get_traces( diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 62cd2fe2cf..0ba1c05417 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -110,7 +110,7 @@ def __init__(self, traces, sampling_frequency, t_start): self._traces = traces self.num_samples = traces.shape[0] - def get_num_samples(self): + def get_num_samples(self) -> int: return self.num_samples def get_traces(self, start_frame, end_frame, channel_indices): diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 8b1b293543..b4c07e77c9 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -862,7 +862,17 @@ def order_channels_by_depth(recording, channel_ids=None, dimensions=("x", "y"), def check_probe_do_not_overlap(probes): """ When several probes this check that that they do not overlap in space - and so channel positions can be safly concatenated. + and so channel positions can be safely concatenated. + + Raises + ------ + Exception : + If probes are overlapping + + Returns + ------- + None : None + If the check is successful """ for i in range(len(probes)): probe_i = probes[i] diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 0094012013..e439ddf1ed 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1229,7 +1229,7 @@ def get_computable_extensions(self): """ return get_available_analyzer_extensions() - def get_default_extension_params(self, extension_name: str): + def get_default_extension_params(self, extension_name: str) -> dict: """ Get the default params for an extension. diff --git a/src/spikeinterface/core/waveform_tools.py b/src/spikeinterface/core/waveform_tools.py index acc368b2e5..befc49d034 100644 --- a/src/spikeinterface/core/waveform_tools.py +++ b/src/spikeinterface/core/waveform_tools.py @@ -679,7 +679,7 @@ def split_waveforms_by_units(unit_ids, spikes, all_waveforms, sparsity_mask=None return waveforms_by_units -def has_exceeding_spikes(recording, sorting): +def has_exceeding_spikes(recording, sorting) -> bool: """ Check if the sorting objects has spikes exceeding the recording number of samples, for all segments diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index 8c52626703..a9513f9f5a 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -343,7 +343,7 @@ def get_result_from_folder(cls, output_folder, register_recording=True, sorting_ return sorting @classmethod - def check_compiled(cls): + def check_compiled(cls) -> bool: """ Checks if the sorter is running inside an image with matlab-compiled version @@ -370,7 +370,7 @@ def check_compiled(cls): return True @classmethod - def use_gpu(cls, params): + def use_gpu(cls, params) -> bool: return cls.gpu_capability != "not-supported" ############################################# @@ -436,7 +436,7 @@ def get_job_kwargs(params, verbose): return job_kwargs -def is_log_ok(output_folder): +def is_log_ok(output_folder) -> bool: # log is OK when run_time is not None if (output_folder / "spikeinterface_log.json").is_file(): with open(output_folder / "spikeinterface_log.json", mode="r", encoding="utf8") as logfile: