Skip to content

Commit

Permalink
Merge branch 'main' of github.com:spikeinterface/spikeinterface into …
Browse files Browse the repository at this point in the history
…numba_similarity
  • Loading branch information
yger committed Sep 24, 2024
2 parents e3b28dd + b9f50e3 commit 5bac5ed
Show file tree
Hide file tree
Showing 10 changed files with 71 additions and 36 deletions.
8 changes: 6 additions & 2 deletions .github/scripts/determine_testing_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
sortingcomponents_changed = False
generation_changed = False
stream_extractors_changed = False
github_actions_changed = False


for changed_file in changed_files_in_the_pull_request_paths:
Expand Down Expand Up @@ -78,9 +79,12 @@
sorters_internal_changed = True
else:
sorters_changed = True
elif ".github" in changed_file.parts:
if "workflows" in changed_file.parts:
github_actions_changed = True


run_everything = core_changed or pyproject_toml_changed or neobaseextractor_changed
run_everything = core_changed or pyproject_toml_changed or neobaseextractor_changed or github_actions_changed
run_generation_tests = run_everything or generation_changed
run_extractor_tests = run_everything or extractors_changed or plexon2_changed
run_preprocessing_tests = run_everything or preprocessing_changed
Expand All @@ -96,7 +100,7 @@
run_sorters_test = run_everything or sorters_changed
run_internal_sorters_test = run_everything or run_sortingcomponents_tests or sorters_internal_changed

run_streaming_extractors_test = stream_extractors_changed
run_streaming_extractors_test = stream_extractors_changed or github_actions_changed

install_plexon_dependencies = plexon2_changed

Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/all-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ on:
env:
KACHERY_CLOUD_CLIENT_ID: ${{ secrets.KACHERY_CLOUD_CLIENT_ID }}
KACHERY_CLOUD_PRIVATE_KEY: ${{ secrets.KACHERY_CLOUD_PRIVATE_KEY }}
KACHERY_ZONE: ${{ secrets.KACHERY_ZONE }}

concurrency: # Cancel previous workflows on the same pull request
group: ${{ github.workflow }}-${{ github.ref }}
Expand All @@ -25,7 +26,7 @@ jobs:
fail-fast: false
matrix:
python-version: ["3.9", "3.12"] # Lower and higher versions we support
os: [macos-13, windows-latest, ubuntu-latest]
os: [macos-latest, windows-latest, ubuntu-latest]
steps:
- uses: actions/checkout@v4
- name: Setup Python ${{ matrix.python-version }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/full-test-with-codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ on:
env:
KACHERY_CLOUD_CLIENT_ID: ${{ secrets.KACHERY_CLOUD_CLIENT_ID }}
KACHERY_CLOUD_PRIVATE_KEY: ${{ secrets.KACHERY_CLOUD_PRIVATE_KEY }}
KACHERY_ZONE: ${{ secrets.KACHERY_ZONE }}

jobs:
full-tests-with-codecov:
Expand Down
34 changes: 29 additions & 5 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,11 +608,11 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs):
probegroup = self.get_probegroup()
cached.set_probegroup(probegroup)

time_vectors = self._get_time_vectors()
if time_vectors is not None:
for segment_index, time_vector in enumerate(time_vectors):
if time_vector is not None:
cached.set_times(time_vector, segment_index=segment_index)
for segment_index in range(self.get_num_segments()):
if self.has_time_vector(segment_index):
# the use of get_times is preferred since timestamps are converted to array
time_vector = self.get_times(segment_index=segment_index)
cached.set_times(time_vector, segment_index=segment_index)

return cached

Expand Down Expand Up @@ -746,6 +746,30 @@ def _select_segments(self, segment_indices):

return SelectSegmentRecording(self, segment_indices=segment_indices)

def get_channel_locations(
self,
channel_ids: list | np.ndarray | tuple | None = None,
axes: "xy" | "yz" | "xz" | "xyz" = "xy",
) -> np.ndarray:
"""
Get the physical locations of specified channels.
Parameters
----------
channel_ids : array-like, optional
The IDs of the channels for which to retrieve locations. If None, retrieves locations
for all available channels. Default is None.
axes : "xy" | "yz" | "xz" | "xyz", default: "xy"
The spatial axes to return, specified as a string (e.g., "xy", "xyz"). Default is "xy".
Returns
-------
np.ndarray
A 2D or 3D array of shape (n_channels, n_dimensions) containing the locations of the channels.
The number of dimensions depends on the `axes` argument (e.g., 2 for "xy", 3 for "xyz").
"""
return super().get_channel_locations(channel_ids=channel_ids, axes=axes)

def is_binary_compatible(self) -> bool:
"""
Checks if the recording is "binary" compatible.
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/baserecordingsnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def set_channel_locations(self, locations, channel_ids=None):
raise ValueError("set_channel_locations(..) destroys the probe description, prefer _set_probes(..)")
self.set_property("location", locations, ids=channel_ids)

def get_channel_locations(self, channel_ids=None, axes: str = "xy"):
def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarray:
if channel_ids is None:
channel_ids = self.get_channel_ids()
channel_indices = self.ids_to_indices(channel_ids)
Expand Down
7 changes: 2 additions & 5 deletions src/spikeinterface/core/job_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,8 @@ def divide_segment_into_chunks(num_frames, chunk_size):
else:
n = num_frames // chunk_size

frame_starts = np.arange(n) * chunk_size
frame_stops = frame_starts + chunk_size

frame_starts = frame_starts.tolist()
frame_stops = frame_stops.tolist()
frame_starts = [i * chunk_size for i in range(n)]
frame_stops = [frame_start + chunk_size for frame_start in frame_starts]

if (num_frames % chunk_size) > 0:
frame_starts.append(n * chunk_size)
Expand Down
24 changes: 14 additions & 10 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,8 @@ def __repr__(self) -> str:
txt += " - sparse"
if self.has_recording():
txt += " - has recording"
if self.has_temporary_recording():
txt += " - has temporary recording"
ext_txt = f"Loaded {len(self.extensions)} extensions: " + ", ".join(self.extensions.keys())
txt += "\n" + ext_txt
return txt
Expand Down Expand Up @@ -350,7 +352,7 @@ def create_memory(cls, sorting, recording, sparsity, return_scaled, rec_attribut
def create_binary_folder(cls, folder, sorting, recording, sparsity, return_scaled, rec_attributes):
# used by create and save_as

assert recording is not None, "To create a SortingAnalyzer you need recording not None"
assert recording is not None, "To create a SortingAnalyzer you need to specify the recording"

folder = Path(folder)
if folder.is_dir():
Expand Down Expand Up @@ -1221,7 +1223,7 @@ def compute(self, input, save=True, extension_params=None, verbose=False, **kwar
extensions[ext_name] = ext_params
self.compute_several_extensions(extensions=extensions, save=save, verbose=verbose, **job_kwargs)
else:
raise ValueError("SortingAnalyzer.compute() need str, dict or list")
raise ValueError("SortingAnalyzer.compute() needs a str, dict or list")

def compute_one_extension(self, extension_name, save=True, verbose=False, **kwargs) -> "AnalyzerExtension":
"""
Expand Down Expand Up @@ -1355,7 +1357,9 @@ def compute_several_extensions(self, extensions, save=True, verbose=False, **job

for extension_name, extension_params in extensions_with_pipeline.items():
extension_class = get_extension_class(extension_name)
assert self.has_recording(), f"Extension {extension_name} need the recording"
assert (
self.has_recording() or self.has_temporary_recording()
), f"Extension {extension_name} requires the recording"

for variable_name in extension_class.nodepipeline_variables:
result_routage.append((extension_name, variable_name))
Expand Down Expand Up @@ -1603,17 +1607,17 @@ def _sort_extensions_by_dependency(extensions):
def _get_children_dependencies(extension_name):
"""
Extension classes have a `depend_on` attribute to declare on which class they
depend. For instance "templates" depend on "waveforms". "waveforms depends on "random_spikes".
depend on. For instance "templates" depends on "waveforms". "waveforms" depends on "random_spikes".
This function is making the reverse way : get all children that depend of a
This function is going the opposite way: it finds all children that depend on a
particular extension.
This is recursive so this includes : children and so grand children and great grand children
The implementation is recursive so that the output includes children, grand children, great grand children, etc.
This function is usefull for deleting on recompute.
For instance recompute the "waveforms" need to delete "template"
This make sens if "ms_before" is change in "waveforms" because the template also depends
on this parameters.
This function is useful for deleting existing extensions on recompute.
For instance, recomputing the "waveforms" needs to delete the "templates", since the latter depends on the former.
For this particular example, if we change the "ms_before" parameter of the "waveforms", also the "templates" will
require recomputation as this parameter is inherited.
"""
names = []
children = _extension_children[extension_name]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,7 @@ def _read_old_waveforms_extractor_binary(folder, sorting):
ext = ComputeRandomSpikes(sorting_analyzer)
ext.params = dict()
ext.data = dict(random_spikes_indices=random_spikes_indices)
ext.run_info = None
sorting_analyzer.extensions["random_spikes"] = ext

ext = ComputeWaveforms(sorting_analyzer)
Expand All @@ -545,6 +546,7 @@ def _read_old_waveforms_extractor_binary(folder, sorting):
dtype=params["dtype"],
)
ext.data["waveforms"] = waveforms
ext.run_info = None
sorting_analyzer.extensions["waveforms"] = ext

# templates saved dense
Expand All @@ -559,6 +561,7 @@ def _read_old_waveforms_extractor_binary(folder, sorting):
ext.params = dict(ms_before=params["ms_before"], ms_after=params["ms_after"], operators=list(templates.keys()))
for mode, arr in templates.items():
ext.data[mode] = arr
ext.run_info = None
sorting_analyzer.extensions["templates"] = ext

for old_name, new_name in old_extension_to_new_class_map.items():
Expand Down Expand Up @@ -631,6 +634,7 @@ def _read_old_waveforms_extractor_binary(folder, sorting):
ext.set_params(**updated_params, save=False)
if ext.need_backward_compatibility_on_load:
ext._handle_backward_compatibility_on_load()
ext.run_info = None

sorting_analyzer.extensions[new_name] = ext

Expand Down
7 changes: 4 additions & 3 deletions src/spikeinterface/extractors/tests/test_neoextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ class BlackrockSortingTest(SortingCommonTestSuite, unittest.TestCase):
ExtractorClass = BlackrockSortingExtractor
downloads = ["blackrock"]
entities = [
"blackrock/FileSpec2.3001.nev",
dict(file_path=local_folder / "blackrock/FileSpec2.3001.nev", sampling_frequency=30_000.0),
dict(file_path=local_folder / "blackrock/blackrock_2_1/l101210-001.nev", sampling_frequency=30_000.0),
]

Expand Down Expand Up @@ -278,8 +278,8 @@ class Spike2RecordingTest(RecordingCommonTestSuite, unittest.TestCase):


@pytest.mark.skipif(
version.parse(platform.python_version()) >= version.parse("3.10"),
reason="Sonpy only testing with Python < 3.10!",
version.parse(platform.python_version()) >= version.parse("3.10") or platform.system() == "Darwin",
reason="Sonpy only testing with Python < 3.10 and not supported on macOS!",
)
class CedRecordingTest(RecordingCommonTestSuite, unittest.TestCase):
ExtractorClass = CedRecordingExtractor
Expand All @@ -293,6 +293,7 @@ class CedRecordingTest(RecordingCommonTestSuite, unittest.TestCase):
]


@pytest.mark.skipif(platform.system() == "Darwin", reason="Maxwell plugin not supported on macOS")
class MaxwellRecordingTest(RecordingCommonTestSuite, unittest.TestCase):
ExtractorClass = MaxwellRecordingExtractor
downloads = ["maxwell"]
Expand Down
17 changes: 8 additions & 9 deletions src/spikeinterface/postprocessing/principal_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,12 +359,12 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs):

job_kwargs = fix_job_kwargs(job_kwargs)
p = self.params
we = self.sorting_analyzer
sorting = we.sorting
sorting_analyzer = self.sorting_analyzer
sorting = sorting_analyzer.sorting
assert (
we.has_recording()
), "To compute PCA projections for all spikes, the waveform extractor needs the recording"
recording = we.recording
sorting_analyzer.has_recording() or sorting_analyzer.has_temporary_recording()
), "To compute PCA projections for all spikes, the sorting analyzer needs the recording"
recording = sorting_analyzer.recording

# assert sorting.get_num_segments() == 1
assert p["mode"] in ("by_channel_local", "by_channel_global")
Expand All @@ -374,8 +374,9 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs):

sparsity = self.sorting_analyzer.sparsity
if sparsity is None:
sparse_channels_indices = {unit_id: np.arange(we.get_num_channels()) for unit_id in we.unit_ids}
max_channels_per_template = we.get_num_channels()
num_channels = recording.get_num_channels()
sparse_channels_indices = {unit_id: np.arange(num_channels) for unit_id in sorting_analyzer.unit_ids}
max_channels_per_template = num_channels
else:
sparse_channels_indices = sparsity.unit_id_to_channel_indices
max_channels_per_template = max([chan_inds.size for chan_inds in sparse_channels_indices.values()])
Expand Down Expand Up @@ -449,9 +450,7 @@ def _fit_by_channel_local(self, n_jobs, progress_bar):
return pca_models

def _fit_by_channel_global(self, progress_bar):
# we = self.sorting_analyzer
p = self.params
# unit_ids = we.unit_ids
unit_ids = self.sorting_analyzer.unit_ids

# there is one unique PCA accross channels
Expand Down

0 comments on commit 5bac5ed

Please sign in to comment.