Skip to content

Commit

Permalink
Merge pull request #3403 from h-mayorquin/add_get_channel_locations_t…
Browse files Browse the repository at this point in the history
…o_the_api

Add `get_channel_locations` to the base recording api
  • Loading branch information
alejoe91 authored Sep 23, 2024
2 parents 9d7832c + 895288c commit b9f50e3
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
24 changes: 24 additions & 0 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,30 @@ def _select_segments(self, segment_indices):

return SelectSegmentRecording(self, segment_indices=segment_indices)

def get_channel_locations(
self,
channel_ids: list | np.ndarray | tuple | None = None,
axes: "xy" | "yz" | "xz" | "xyz" = "xy",
) -> np.ndarray:
"""
Get the physical locations of specified channels.
Parameters
----------
channel_ids : array-like, optional
The IDs of the channels for which to retrieve locations. If None, retrieves locations
for all available channels. Default is None.
axes : "xy" | "yz" | "xz" | "xyz", default: "xy"
The spatial axes to return, specified as a string (e.g., "xy", "xyz"). Default is "xy".
Returns
-------
np.ndarray
A 2D or 3D array of shape (n_channels, n_dimensions) containing the locations of the channels.
The number of dimensions depends on the `axes` argument (e.g., 2 for "xy", 3 for "xyz").
"""
return super().get_channel_locations(channel_ids=channel_ids, axes=axes)

def is_binary_compatible(self) -> bool:
"""
Checks if the recording is "binary" compatible.
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/baserecordingsnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def set_channel_locations(self, locations, channel_ids=None):
raise ValueError("set_channel_locations(..) destroys the probe description, prefer _set_probes(..)")
self.set_property("location", locations, ids=channel_ids)

def get_channel_locations(self, channel_ids=None, axes: str = "xy"):
def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarray:
if channel_ids is None:
channel_ids = self.get_channel_ids()
channel_indices = self.ids_to_indices(channel_ids)
Expand Down

0 comments on commit b9f50e3

Please sign in to comment.