From 2c34c91ff1b6f6be6e6e695538d8ff12b454c5dc Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 12 Sep 2024 10:49:24 -0600 Subject: [PATCH 01/17] add channel recording to the base recording api --- src/spikeinterface/core/baserecording.py | 24 +++++++++++++++++++ .../core/baserecordingsnippets.py | 2 +- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 082afd880b..c772a669ea 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -746,6 +746,30 @@ def _select_segments(self, segment_indices): return SelectSegmentRecording(self, segment_indices=segment_indices) + def get_channel_locations( + self, + channel_ids: list | np.ndarray | tuple | None = None, + axes: "xy" | "yz" | "xz" = "xy", + ) -> np.ndarray: + """ + Get the physical locations of specified channels. + + Parameters + ---------- + channel_ids : array-like, optional + The IDs of the channels for which to retrieve locations. If None, retrieves locations + for all available channels. Default is None. + axes : str, optional + The spatial axes to return, specified as a string (e.g., "xy", "xyz"). Default is "xy". + + Returns + ------- + np.ndarray + A 2D or 3D array of shape (n_channels, n_dimensions) containing the locations of the channels. + The number of dimensions depends on the `axes` argument (e.g., 2 for "xy", 3 for "xyz"). + """ + return super().get_channel_locations(channel_ids=channel_ids, axes=axes) + def is_binary_compatible(self) -> bool: """ Checks if the recording is "binary" compatible. diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 428472bf93..3953c1f058 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -344,7 +344,7 @@ def set_channel_locations(self, locations, channel_ids=None): raise ValueError("set_channel_locations(..) destroys the probe description, prefer _set_probes(..)") self.set_property("location", locations, ids=channel_ids) - def get_channel_locations(self, channel_ids=None, axes: str = "xy"): + def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarray: if channel_ids is None: channel_ids = self.get_channel_ids() channel_indices = self.ids_to_indices(channel_ids) From 9e21100bbbe74c53ce50457f0ca31403439483dd Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 12 Sep 2024 14:15:50 -0600 Subject: [PATCH 02/17] Update src/spikeinterface/core/baserecording.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/baserecording.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index c772a669ea..1b783a8fe4 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -759,7 +759,7 @@ def get_channel_locations( channel_ids : array-like, optional The IDs of the channels for which to retrieve locations. If None, retrieves locations for all available channels. Default is None. - axes : str, optional + axes : "xy" | "yz" | "xz" | "xyz", default: "xy" The spatial axes to return, specified as a string (e.g., "xy", "xyz"). Default is "xy". Returns From 895288c778ff59888e53704e51dc681bdf1d1929 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 12 Sep 2024 14:15:55 -0600 Subject: [PATCH 03/17] Update src/spikeinterface/core/baserecording.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/baserecording.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 1b783a8fe4..03001ae47e 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -749,7 +749,7 @@ def _select_segments(self, segment_indices): def get_channel_locations( self, channel_ids: list | np.ndarray | tuple | None = None, - axes: "xy" | "yz" | "xz" = "xy", + axes: "xy" | "yz" | "xz" | "xyz" = "xy", ) -> np.ndarray: """ Get the physical locations of specified channels. From af2dd1d1943155b152d1fd81ddf53ffbc7ae3047 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 16 Sep 2024 16:59:36 +0200 Subject: [PATCH 04/17] Add kachery_zone secret --- .github/workflows/all-tests.yml | 1 + .github/workflows/full-test-with-codecov.yml | 1 + 2 files changed, 2 insertions(+) diff --git a/.github/workflows/all-tests.yml b/.github/workflows/all-tests.yml index e12cf6805d..bc663675a9 100644 --- a/.github/workflows/all-tests.yml +++ b/.github/workflows/all-tests.yml @@ -12,6 +12,7 @@ on: env: KACHERY_CLOUD_CLIENT_ID: ${{ secrets.KACHERY_CLOUD_CLIENT_ID }} KACHERY_CLOUD_PRIVATE_KEY: ${{ secrets.KACHERY_CLOUD_PRIVATE_KEY }} + KACHERY_ZONE: ${{ secrets.KACHERY_ZONE }} concurrency: # Cancel previous workflows on the same pull request group: ${{ github.workflow }}-${{ github.ref }} diff --git a/.github/workflows/full-test-with-codecov.yml b/.github/workflows/full-test-with-codecov.yml index 6a222f5e25..407c614ebf 100644 --- a/.github/workflows/full-test-with-codecov.yml +++ b/.github/workflows/full-test-with-codecov.yml @@ -8,6 +8,7 @@ on: env: KACHERY_CLOUD_CLIENT_ID: ${{ secrets.KACHERY_CLOUD_CLIENT_ID }} KACHERY_CLOUD_PRIVATE_KEY: ${{ secrets.KACHERY_CLOUD_PRIVATE_KEY }} + KACHERY_ZONE: ${{ secrets.KACHERY_ZONE }} jobs: full-tests-with-codecov: From 6aae2177195e8ed334b190aac290eba63e871d18 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 16 Sep 2024 17:05:13 +0200 Subject: [PATCH 05/17] Trigger widgets tests --- src/spikeinterface/widgets/tests/test_widgets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index debcd52085..13838fd21a 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -690,7 +690,7 @@ def test_plot_motion_info(self): # mytest.test_plot_multicomparison() # mytest.test_plot_sorting_summary() # mytest.test_plot_motion() - mytest.test_plot_motion_info() + # mytest.test_plot_motion_info() plt.show() # TestWidgets.tearDownClass() From 9deae601f3bbc590b23181e18e21434d17ba268c Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 17 Sep 2024 20:21:22 -0600 Subject: [PATCH 06/17] add macos latest to test --- .github/workflows/all-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/all-tests.yml b/.github/workflows/all-tests.yml index e12cf6805d..80bada7bb4 100644 --- a/.github/workflows/all-tests.yml +++ b/.github/workflows/all-tests.yml @@ -25,7 +25,7 @@ jobs: fail-fast: false matrix: python-version: ["3.9", "3.12"] # Lower and higher versions we support - os: [macos-13, windows-latest, ubuntu-latest] + os: [macos-latest, windows-latest, ubuntu-latest] steps: - uses: actions/checkout@v4 - name: Setup Python ${{ matrix.python-version }} From 1462f9251873c2a94f0c2fb0832b4746f4555fa7 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 17 Sep 2024 21:03:02 -0600 Subject: [PATCH 07/17] add condition to run everything if workflow files are changed --- .github/scripts/determine_testing_environment.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/scripts/determine_testing_environment.py b/.github/scripts/determine_testing_environment.py index 95ad0afc49..518591fb9c 100644 --- a/.github/scripts/determine_testing_environment.py +++ b/.github/scripts/determine_testing_environment.py @@ -31,6 +31,7 @@ sortingcomponents_changed = False generation_changed = False stream_extractors_changed = False +github_actions_changed = False for changed_file in changed_files_in_the_pull_request_paths: @@ -78,9 +79,12 @@ sorters_internal_changed = True else: sorters_changed = True + elif ".github" in changed_file.parts: + if "workflows" in changed_file.parts: + github_actions_changed = True -run_everything = core_changed or pyproject_toml_changed or neobaseextractor_changed +run_everything = core_changed or pyproject_toml_changed or neobaseextractor_changed or github_actions_changed run_generation_tests = run_everything or generation_changed run_extractor_tests = run_everything or extractors_changed or plexon2_changed run_preprocessing_tests = run_everything or preprocessing_changed From 63310759d931d01767acb66a574c9a769eb31f1b Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 17 Sep 2024 21:22:13 -0600 Subject: [PATCH 08/17] add checks --- .github/scripts/determine_testing_environment.py | 2 +- src/spikeinterface/extractors/tests/test_neoextractors.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/scripts/determine_testing_environment.py b/.github/scripts/determine_testing_environment.py index 518591fb9c..aa85aa2b91 100644 --- a/.github/scripts/determine_testing_environment.py +++ b/.github/scripts/determine_testing_environment.py @@ -100,7 +100,7 @@ run_sorters_test = run_everything or sorters_changed run_internal_sorters_test = run_everything or run_sortingcomponents_tests or sorters_internal_changed -run_streaming_extractors_test = stream_extractors_changed +run_streaming_extractors_test = stream_extractors_changed or github_actions_changed install_plexon_dependencies = plexon2_changed diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index acd7ebe8ad..ed588149ba 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -278,8 +278,8 @@ class Spike2RecordingTest(RecordingCommonTestSuite, unittest.TestCase): @pytest.mark.skipif( - version.parse(platform.python_version()) >= version.parse("3.10"), - reason="Sonpy only testing with Python < 3.10!", + version.parse(platform.python_version()) >= version.parse("3.10") or platform.system() == "Darwin", + reason="Sonpy only testing with Python < 3.10 and not supported on macOS!", ) class CedRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = CedRecordingExtractor @@ -293,6 +293,7 @@ class CedRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ] +@pytest.mark.skipif(platform.system() == "Darwin", reason="Maxwell plugin not supported on macOS") class MaxwellRecordingTest(RecordingCommonTestSuite, unittest.TestCase): ExtractorClass = MaxwellRecordingExtractor downloads = ["maxwell"] From c3e52520d61ef0ffb468aa7167c68e040681f9e3 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 17 Sep 2024 21:26:19 -0600 Subject: [PATCH 09/17] add sampling frequency to blackrock to avoid warning --- src/spikeinterface/extractors/tests/test_neoextractors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/extractors/tests/test_neoextractors.py b/src/spikeinterface/extractors/tests/test_neoextractors.py index ed588149ba..3f73161218 100644 --- a/src/spikeinterface/extractors/tests/test_neoextractors.py +++ b/src/spikeinterface/extractors/tests/test_neoextractors.py @@ -234,7 +234,7 @@ class BlackrockSortingTest(SortingCommonTestSuite, unittest.TestCase): ExtractorClass = BlackrockSortingExtractor downloads = ["blackrock"] entities = [ - "blackrock/FileSpec2.3001.nev", + dict(file_path=local_folder / "blackrock/FileSpec2.3001.nev", sampling_frequency=30_000.0), dict(file_path=local_folder / "blackrock/blackrock_2_1/l101210-001.nev", sampling_frequency=30_000.0), ] From b63221ed1419dcdf08ee66c724a1d2815acce246 Mon Sep 17 00:00:00 2001 From: Yue Huang <806628409@qq.com> Date: Thu, 19 Sep 2024 02:07:02 +0800 Subject: [PATCH 10/17] Update job_tools.py --- src/spikeinterface/core/job_tools.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index a5279247f5..55870cd688 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -136,11 +136,8 @@ def divide_segment_into_chunks(num_frames, chunk_size): else: n = num_frames // chunk_size - frame_starts = np.arange(n) * chunk_size - frame_stops = frame_starts + chunk_size - - frame_starts = frame_starts.tolist() - frame_stops = frame_stops.tolist() + frame_starts = [i * chunk_size for i in range(n)] + frame_stops = [(i+1) * chunk_size for i in range(n)] if (num_frames % chunk_size) > 0: frame_starts.append(n * chunk_size) From 0781a39082639b87c0239d709a06000f38fdd1ba Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Sep 2024 18:16:37 +0000 Subject: [PATCH 11/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/job_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 55870cd688..3cd7313b76 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -137,7 +137,7 @@ def divide_segment_into_chunks(num_frames, chunk_size): n = num_frames // chunk_size frame_starts = [i * chunk_size for i in range(n)] - frame_stops = [(i+1) * chunk_size for i in range(n)] + frame_stops = [(i + 1) * chunk_size for i in range(n)] if (num_frames % chunk_size) > 0: frame_starts.append(n * chunk_size) From 41f4c311398cb909207535e3fcf2a373c850dafd Mon Sep 17 00:00:00 2001 From: Yue Huang <806628409@qq.com> Date: Thu, 19 Sep 2024 04:35:21 +0800 Subject: [PATCH 12/17] Update job_tools.py --- src/spikeinterface/core/job_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/job_tools.py b/src/spikeinterface/core/job_tools.py index 3cd7313b76..5240edcee7 100644 --- a/src/spikeinterface/core/job_tools.py +++ b/src/spikeinterface/core/job_tools.py @@ -137,7 +137,7 @@ def divide_segment_into_chunks(num_frames, chunk_size): n = num_frames // chunk_size frame_starts = [i * chunk_size for i in range(n)] - frame_stops = [(i + 1) * chunk_size for i in range(n)] + frame_stops = [frame_start + chunk_size for frame_start in frame_starts] if (num_frames % chunk_size) > 0: frame_starts.append(n * chunk_size) From 7009487948adf02340026e0a2a15d0ac2b1dcf6b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 19 Sep 2024 10:49:45 +0200 Subject: [PATCH 13/17] Update src/spikeinterface/widgets/tests/test_widgets.py --- src/spikeinterface/widgets/tests/test_widgets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/widgets/tests/test_widgets.py b/src/spikeinterface/widgets/tests/test_widgets.py index 4649667518..80f58f5ad9 100644 --- a/src/spikeinterface/widgets/tests/test_widgets.py +++ b/src/spikeinterface/widgets/tests/test_widgets.py @@ -690,7 +690,7 @@ def test_plot_motion_info(self): # mytest.test_plot_multicomparison() # mytest.test_plot_sorting_summary() # mytest.test_plot_motion() - # mytest.test_plot_motion_info() + mytest.test_plot_motion_info() plt.show() # TestWidgets.tearDownClass() From dc3a026056d6b9b89117f9550500ddba099edb3d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 20 Sep 2024 21:00:05 +0200 Subject: [PATCH 14/17] Set run_info to None for load_waveforms --- .../core/waveforms_extractor_backwards_compatibility.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py index a50a56bf85..5c7584ecd8 100644 --- a/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py +++ b/src/spikeinterface/core/waveforms_extractor_backwards_compatibility.py @@ -536,6 +536,7 @@ def _read_old_waveforms_extractor_binary(folder, sorting): ext = ComputeRandomSpikes(sorting_analyzer) ext.params = dict() ext.data = dict(random_spikes_indices=random_spikes_indices) + ext.run_info = None sorting_analyzer.extensions["random_spikes"] = ext ext = ComputeWaveforms(sorting_analyzer) @@ -545,6 +546,7 @@ def _read_old_waveforms_extractor_binary(folder, sorting): dtype=params["dtype"], ) ext.data["waveforms"] = waveforms + ext.run_info = None sorting_analyzer.extensions["waveforms"] = ext # templates saved dense @@ -559,6 +561,7 @@ def _read_old_waveforms_extractor_binary(folder, sorting): ext.params = dict(ms_before=params["ms_before"], ms_after=params["ms_after"], operators=list(templates.keys())) for mode, arr in templates.items(): ext.data[mode] = arr + ext.run_info = None sorting_analyzer.extensions["templates"] = ext for old_name, new_name in old_extension_to_new_class_map.items(): @@ -631,6 +634,7 @@ def _read_old_waveforms_extractor_binary(folder, sorting): ext.set_params(**updated_params, save=False) if ext.need_backward_compatibility_on_load: ext._handle_backward_compatibility_on_load() + ext.run_info = None sorting_analyzer.extensions[new_name] = ext From d9f53d04f99f78ecc4fbdab343e7d8e1161faf07 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 23 Sep 2024 09:40:52 +0200 Subject: [PATCH 15/17] Fix compute analyzer pipeline with tmp recording --- src/spikeinterface/core/sortinganalyzer.py | 6 ++++-- .../postprocessing/principal_component.py | 17 ++++++++--------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 177188f21d..0b4d959604 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -230,7 +230,7 @@ def __repr__(self) -> str: txt = f"{clsname}: {nchan} channels - {nunits} units - {nseg} segments - {self.format}" if self.is_sparse(): txt += " - sparse" - if self.has_recording(): + if self.has_recording() or self.has_temporary_recording(): txt += " - has recording" ext_txt = f"Loaded {len(self.extensions)} extensions: " + ", ".join(self.extensions.keys()) txt += "\n" + ext_txt @@ -1355,7 +1355,9 @@ def compute_several_extensions(self, extensions, save=True, verbose=False, **job for extension_name, extension_params in extensions_with_pipeline.items(): extension_class = get_extension_class(extension_name) - assert self.has_recording(), f"Extension {extension_name} need the recording" + assert ( + self.has_recording() or self.has_temporary_recording() + ), f"Extension {extension_name} need the recording" for variable_name in extension_class.nodepipeline_variables: result_routage.append((extension_name, variable_name)) diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index f1f89403c7..1871c11b85 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -359,12 +359,12 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs): job_kwargs = fix_job_kwargs(job_kwargs) p = self.params - we = self.sorting_analyzer - sorting = we.sorting + sorting_analyzer = self.sorting_analyzer + sorting = sorting_analyzer.sorting assert ( - we.has_recording() - ), "To compute PCA projections for all spikes, the waveform extractor needs the recording" - recording = we.recording + sorting_analyzer.has_recording() or sorting_analyzer.has_temporary_recording() + ), "To compute PCA projections for all spikes, the sorting analyzer needs the recording" + recording = sorting_analyzer.recording # assert sorting.get_num_segments() == 1 assert p["mode"] in ("by_channel_local", "by_channel_global") @@ -374,8 +374,9 @@ def run_for_all_spikes(self, file_path=None, verbose=False, **job_kwargs): sparsity = self.sorting_analyzer.sparsity if sparsity is None: - sparse_channels_indices = {unit_id: np.arange(we.get_num_channels()) for unit_id in we.unit_ids} - max_channels_per_template = we.get_num_channels() + num_channels = recording.get_num_channels() + sparse_channels_indices = {unit_id: np.arange(num_channels) for unit_id in sorting_analyzer.unit_ids} + max_channels_per_template = num_channels else: sparse_channels_indices = sparsity.unit_id_to_channel_indices max_channels_per_template = max([chan_inds.size for chan_inds in sparse_channels_indices.values()]) @@ -449,9 +450,7 @@ def _fit_by_channel_local(self, n_jobs, progress_bar): return pca_models def _fit_by_channel_global(self, progress_bar): - # we = self.sorting_analyzer p = self.params - # unit_ids = we.unit_ids unit_ids = self.sorting_analyzer.unit_ids # there is one unique PCA accross channels From 778b77343cefd2295396c2f5097a9859946fb4db Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 23 Sep 2024 09:51:54 +0200 Subject: [PATCH 16/17] Fix bug in saving zarr recordings --- src/spikeinterface/core/baserecording.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index e44ed9b948..d0b6ab0092 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -608,11 +608,11 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): probegroup = self.get_probegroup() cached.set_probegroup(probegroup) - time_vectors = self._get_time_vectors() - if time_vectors is not None: - for segment_index, time_vector in enumerate(time_vectors): - if time_vector is not None: - cached.set_times(time_vector, segment_index=segment_index) + for segment_index in range(self.get_num_segments()): + if self.has_time_vector(segment_index): + # the use of get_times is preferred since timestamps are converted to array + time_vector = self.get_times(segment_index=segment_index) + cached.set_times(time_vector, segment_index=segment_index) return cached From a071605f73ab3195a0c46d7254ae8b7859919bd8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 23 Sep 2024 18:01:27 +0200 Subject: [PATCH 17/17] Zach's suggestion and more docstring fixes --- src/spikeinterface/core/sortinganalyzer.py | 24 ++++++++++++---------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 0b4d959604..4961db8524 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -230,8 +230,10 @@ def __repr__(self) -> str: txt = f"{clsname}: {nchan} channels - {nunits} units - {nseg} segments - {self.format}" if self.is_sparse(): txt += " - sparse" - if self.has_recording() or self.has_temporary_recording(): + if self.has_recording(): txt += " - has recording" + if self.has_temporary_recording(): + txt += " - has temporary recording" ext_txt = f"Loaded {len(self.extensions)} extensions: " + ", ".join(self.extensions.keys()) txt += "\n" + ext_txt return txt @@ -350,7 +352,7 @@ def create_memory(cls, sorting, recording, sparsity, return_scaled, rec_attribut def create_binary_folder(cls, folder, sorting, recording, sparsity, return_scaled, rec_attributes): # used by create and save_as - assert recording is not None, "To create a SortingAnalyzer you need recording not None" + assert recording is not None, "To create a SortingAnalyzer you need to specify the recording" folder = Path(folder) if folder.is_dir(): @@ -1221,7 +1223,7 @@ def compute(self, input, save=True, extension_params=None, verbose=False, **kwar extensions[ext_name] = ext_params self.compute_several_extensions(extensions=extensions, save=save, verbose=verbose, **job_kwargs) else: - raise ValueError("SortingAnalyzer.compute() need str, dict or list") + raise ValueError("SortingAnalyzer.compute() needs a str, dict or list") def compute_one_extension(self, extension_name, save=True, verbose=False, **kwargs) -> "AnalyzerExtension": """ @@ -1357,7 +1359,7 @@ def compute_several_extensions(self, extensions, save=True, verbose=False, **job extension_class = get_extension_class(extension_name) assert ( self.has_recording() or self.has_temporary_recording() - ), f"Extension {extension_name} need the recording" + ), f"Extension {extension_name} requires the recording" for variable_name in extension_class.nodepipeline_variables: result_routage.append((extension_name, variable_name)) @@ -1605,17 +1607,17 @@ def _sort_extensions_by_dependency(extensions): def _get_children_dependencies(extension_name): """ Extension classes have a `depend_on` attribute to declare on which class they - depend. For instance "templates" depend on "waveforms". "waveforms depends on "random_spikes". + depend on. For instance "templates" depends on "waveforms". "waveforms" depends on "random_spikes". - This function is making the reverse way : get all children that depend of a + This function is going the opposite way: it finds all children that depend on a particular extension. - This is recursive so this includes : children and so grand children and great grand children + The implementation is recursive so that the output includes children, grand children, great grand children, etc. - This function is usefull for deleting on recompute. - For instance recompute the "waveforms" need to delete "template" - This make sens if "ms_before" is change in "waveforms" because the template also depends - on this parameters. + This function is useful for deleting existing extensions on recompute. + For instance, recomputing the "waveforms" needs to delete the "templates", since the latter depends on the former. + For this particular example, if we change the "ms_before" parameter of the "waveforms", also the "templates" will + require recomputation as this parameter is inherited. """ names = [] children = _extension_children[extension_name]