From d9b169d1337cfd8ead75d3e8bf842045707c3a13 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 28 Sep 2024 16:58:30 +0200 Subject: [PATCH] Improve IBL recording extractors by PID --- .../extractors/iblextractors.py | 31 ++++++++++++------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/extractors/iblextractors.py b/src/spikeinterface/extractors/iblextractors.py index 5dd549347d..317ea21cce 100644 --- a/src/spikeinterface/extractors/iblextractors.py +++ b/src/spikeinterface/extractors/iblextractors.py @@ -105,6 +105,8 @@ def get_stream_names(eid: str, cache_folder: Optional[Union[Path, str]] = None, An instance of the ONE API to use for data loading. If not provided, a default instance is created using the default parameters. If you need to use a specific instance, you can create it using the ONE API and pass it here. + stream_type : "ap" | "lf" | None, default: None + The stream type to load, required when pid is provided and stream_name is not. Returns ------- @@ -140,6 +142,7 @@ def __init__( remove_cached: bool = True, stream: bool = True, one: "one.api.OneAlyx" = None, + stream_type: str | None = None, ): try: from brainbox.io.one import SpikeSortingLoader @@ -154,20 +157,24 @@ def __init__( one = IblRecordingExtractor._get_default_one(cache_folder=cache_folder) if pid is not None: + assert stream_type is not None, "When providing a PID, you must also provide a stream type." eid, _ = one.pid2eid(pid) - - stream_names = IblRecordingExtractor.get_stream_names(eid=eid, cache_folder=cache_folder, one=one) - if len(stream_names) > 1: - assert ( - stream_name is not None - ), f"Multiple streams found for session. Please specify a stream name from {stream_names}." - assert stream_name in stream_names, ( - f"The `stream_name` '{stream_name}' is not available for this experiment {eid}! " - f"Please choose one of {stream_names}." - ) + pids, probes = one.eid2pid(eid) + pname = probes[pids.index(pid)] + stream_name = f"{pname}.{stream_type}" else: - stream_name = stream_names[0] - pname, stream_type = stream_name.split(".") + stream_names = IblRecordingExtractor.get_stream_names(eid=eid, cache_folder=cache_folder, one=one) + if len(stream_names) > 1: + assert ( + stream_name is not None + ), f"Multiple streams found for session. Please specify a stream name from {stream_names}." + assert stream_name in stream_names, ( + f"The `stream_name` '{stream_name}' is not available for this experiment {eid}! " + f"Please choose one of {stream_names}." + ) + else: + stream_name = stream_names[0] + pname, stream_type = stream_name.split(".") self.ssl = SpikeSortingLoader(one=one, eid=eid, pid=pid, pname=pname) if pid is None: