diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 082afd880b..03001ae47e 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -746,6 +746,30 @@ def _select_segments(self, segment_indices): return SelectSegmentRecording(self, segment_indices=segment_indices) + def get_channel_locations( + self, + channel_ids: list | np.ndarray | tuple | None = None, + axes: "xy" | "yz" | "xz" | "xyz" = "xy", + ) -> np.ndarray: + """ + Get the physical locations of specified channels. + + Parameters + ---------- + channel_ids : array-like, optional + The IDs of the channels for which to retrieve locations. If None, retrieves locations + for all available channels. Default is None. + axes : "xy" | "yz" | "xz" | "xyz", default: "xy" + The spatial axes to return, specified as a string (e.g., "xy", "xyz"). Default is "xy". + + Returns + ------- + np.ndarray + A 2D or 3D array of shape (n_channels, n_dimensions) containing the locations of the channels. + The number of dimensions depends on the `axes` argument (e.g., 2 for "xy", 3 for "xyz"). + """ + return super().get_channel_locations(channel_ids=channel_ids, axes=axes) + def is_binary_compatible(self) -> bool: """ Checks if the recording is "binary" compatible. diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 428472bf93..3953c1f058 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -344,7 +344,7 @@ def set_channel_locations(self, locations, channel_ids=None): raise ValueError("set_channel_locations(..) destroys the probe description, prefer _set_probes(..)") self.set_property("location", locations, ids=channel_ids) - def get_channel_locations(self, channel_ids=None, axes: str = "xy"): + def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarray: if channel_ids is None: channel_ids = self.get_channel_ids() channel_indices = self.ids_to_indices(channel_ids)