From 2c34c91ff1b6f6be6e6e695538d8ff12b454c5dc Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 12 Sep 2024 10:49:24 -0600 Subject: [PATCH 1/3] add channel recording to the base recording api --- src/spikeinterface/core/baserecording.py | 24 +++++++++++++++++++ .../core/baserecordingsnippets.py | 2 +- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 082afd880b..c772a669ea 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -746,6 +746,30 @@ def _select_segments(self, segment_indices): return SelectSegmentRecording(self, segment_indices=segment_indices) + def get_channel_locations( + self, + channel_ids: list | np.ndarray | tuple | None = None, + axes: "xy" | "yz" | "xz" = "xy", + ) -> np.ndarray: + """ + Get the physical locations of specified channels. + + Parameters + ---------- + channel_ids : array-like, optional + The IDs of the channels for which to retrieve locations. If None, retrieves locations + for all available channels. Default is None. + axes : str, optional + The spatial axes to return, specified as a string (e.g., "xy", "xyz"). Default is "xy". + + Returns + ------- + np.ndarray + A 2D or 3D array of shape (n_channels, n_dimensions) containing the locations of the channels. + The number of dimensions depends on the `axes` argument (e.g., 2 for "xy", 3 for "xyz"). + """ + return super().get_channel_locations(channel_ids=channel_ids, axes=axes) + def is_binary_compatible(self) -> bool: """ Checks if the recording is "binary" compatible. diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 428472bf93..3953c1f058 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -344,7 +344,7 @@ def set_channel_locations(self, locations, channel_ids=None): raise ValueError("set_channel_locations(..) destroys the probe description, prefer _set_probes(..)") self.set_property("location", locations, ids=channel_ids) - def get_channel_locations(self, channel_ids=None, axes: str = "xy"): + def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarray: if channel_ids is None: channel_ids = self.get_channel_ids() channel_indices = self.ids_to_indices(channel_ids) From 9e21100bbbe74c53ce50457f0ca31403439483dd Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 12 Sep 2024 14:15:50 -0600 Subject: [PATCH 2/3] Update src/spikeinterface/core/baserecording.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/baserecording.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index c772a669ea..1b783a8fe4 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -759,7 +759,7 @@ def get_channel_locations( channel_ids : array-like, optional The IDs of the channels for which to retrieve locations. If None, retrieves locations for all available channels. Default is None. - axes : str, optional + axes : "xy" | "yz" | "xz" | "xyz", default: "xy" The spatial axes to return, specified as a string (e.g., "xy", "xyz"). Default is "xy". Returns From 895288c778ff59888e53704e51dc681bdf1d1929 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 12 Sep 2024 14:15:55 -0600 Subject: [PATCH 3/3] Update src/spikeinterface/core/baserecording.py Co-authored-by: Zach McKenzie <92116279+zm711@users.noreply.github.com> --- src/spikeinterface/core/baserecording.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 1b783a8fe4..03001ae47e 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -749,7 +749,7 @@ def _select_segments(self, segment_indices): def get_channel_locations( self, channel_ids: list | np.ndarray | tuple | None = None, - axes: "xy" | "yz" | "xz" = "xy", + axes: "xy" | "yz" | "xz" | "xyz" = "xy", ) -> np.ndarray: """ Get the physical locations of specified channels.