Skip to content

Commit

Permalink
Merge pull request NeuralEnsemble#1541 from h-mayorquin/fixate_plexon…
Browse files Browse the repository at this point in the history
…2_streams

Fix-ate plexon 2 streams
  • Loading branch information
zm711 authored Sep 5, 2024
2 parents 0aa596e + 058bce4 commit 60b26d4
Showing 1 changed file with 57 additions and 32 deletions.
89 changes: 57 additions & 32 deletions neo/rawio/plexon2rawio/plexon2rawio.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
import pathlib
import warnings
import platform
import re

from collections import namedtuple
from urllib.request import urlopen
from datetime import datetime

Expand All @@ -53,6 +53,10 @@ class Plexon2RawIO(BaseRawIO):
pl2_dll_file_path: str | Path | None, default: None
The path to the necessary dll for loading pl2 files
If None will find correct dll for architecture and if it does not exist will download it
reading_attempts: int, default: 15
Number of attempts to read the file before raising an error
This opening process is somewhat unreliable and might fail occasionally. Adjust this higher
if you encounter problems in opening the file.
Notes
-----
Expand Down Expand Up @@ -88,7 +92,7 @@ class Plexon2RawIO(BaseRawIO):
extensions = ["pl2"]
rawmode = "one-file"

def __init__(self, filename, pl2_dll_file_path=None):
def __init__(self, filename, pl2_dll_file_path=None, reading_attempts=15):

# signals, event and spiking data will be cached
# cached signal data can be cleared using `clear_analogsignal_cache()()`
Expand Down Expand Up @@ -128,7 +132,6 @@ def __init__(self, filename, pl2_dll_file_path=None):

self.pl2reader = PyPL2FileReader(pl2_dll_file_path=pl2_dll_file_path)

reading_attempts = 10
for attempt in range(reading_attempts):
self.pl2reader.pl2_open_file(self.filename)

Expand All @@ -152,46 +155,72 @@ def _parse_header(self):
# Scanning sources and populating signal channels at the same time. Sources have to have
# same sampling rate and number of samples to belong to one stream.
signal_channels = []
source_characteristics = {}
Source = namedtuple("Source", "id name sampling_rate n_samples")
for c in range(self.pl2reader.pl2_file_info.m_TotalNumberOfAnalogChannels):
achannel_info = self.pl2reader.pl2_get_analog_channel_info(c)
channel_num_samples = []

# We will build the stream ids based on the channel prefixes
# The channel prefixes are the first characters of the channel names which have the following format:
# WB{number}, FPX{number}, SPKCX{number}, AI{number}, etc
# We will extract the prefix and use it as stream id
regex_prefix_pattern = r"^\D+" # Match any non-digit character at the beginning of the string

for channel_index in range(self.pl2reader.pl2_file_info.m_TotalNumberOfAnalogChannels):
achannel_info = self.pl2reader.pl2_get_analog_channel_info(channel_index)
# only consider active channels
if not (achannel_info.m_ChannelEnabled and achannel_info.m_ChannelRecordingEnabled):
continue

# assign to matching stream or create new stream based on signal characteristics
rate = achannel_info.m_SamplesPerSecond
n_samples = achannel_info.m_NumberOfValues
source_id = str(achannel_info.m_Source)

channel_source = Source(source_id, f"stream@{rate}Hz", rate, n_samples)
existing_source = source_characteristics.setdefault(source_id, channel_source)

# ensure that stream of this channel and existing stream have same properties
if channel_source != existing_source:
raise ValueError(
f"The channel source {channel_source} must be the same as the existing source {existing_source}"
)
num_samples = achannel_info.m_NumberOfValues
channel_num_samples.append(num_samples)

ch_name = achannel_info.m_Name.decode()
chan_id = f"source{achannel_info.m_Source}.{achannel_info.m_Channel}"
dtype = "int16"
units = achannel_info.m_Units.decode()
gain = achannel_info.m_CoeffToConvertToUnits
offset = 0.0 # PL2 files don't contain information on signal offset
stream_id = source_id

channel_prefix = re.match(regex_prefix_pattern, ch_name).group(0)
stream_id = channel_prefix
signal_channels.append((ch_name, chan_id, rate, dtype, units, gain, offset, stream_id))

signal_channels = np.array(signal_channels, dtype=_signal_channel_dtype)
self.signal_stream_characteristics = source_characteristics

# create signal streams from source information
channel_num_samples = np.array(channel_num_samples)

# We are using channel prefixes as stream_ids
# The meaning of the channel prefixes was provided by a Plexon Engineer, see here:
# https://github.com/NeuralEnsemble/python-neo/pull/1495#issuecomment-2184256894
stream_id_to_stream_name = {
"WB": "WB-Wideband",
"FP": "FPl-Low Pass Filtered",
"SP": "SPKC-High Pass Filtered",
"AI": "AI-Auxiliary Input",
}

unique_stream_ids = np.unique(signal_channels["stream_id"])
signal_streams = []
for stream_idx, source in source_characteristics.items():
signal_streams.append((source.name, str(source.id)))
for stream_id in unique_stream_ids:
# We are using the channel prefixes as ids
# The users of plexon can modify the prefix of the channel names (e.g. `my_prefix` instead of `WB`).
# In that case we use the channel prefix both as stream id and name
stream_name = stream_id_to_stream_name.get(stream_id, stream_id)
signal_streams.append((stream_name, stream_id))

signal_streams = np.array(signal_streams, dtype=_signal_stream_dtype)

self.stream_id_samples = {}
self.stream_index_to_stream_id = {}
for stream_index, stream_id in enumerate(signal_streams["id"]):
# Keep a mapping from stream_index to stream_id
self.stream_index_to_stream_id[stream_index] = stream_id

# We extract the number of samples for each stream
mask = signal_channels["stream_id"] == stream_id
signal_num_samples = np.unique(channel_num_samples[mask])
assert signal_num_samples.size == 1, "All channels in a stream must have the same number of samples"
self.stream_id_samples[stream_id] = signal_num_samples[0]

# pre-loading spike channel_data for later usage
self._spike_channel_cache = {}
spike_channels = []
Expand Down Expand Up @@ -354,16 +383,12 @@ def _segment_t_stop(self, block_index, seg_index):
end_time = (
self.pl2reader.pl2_file_info.m_StartRecordingTime + self.pl2reader.pl2_file_info.m_DurationOfRecording
)
return end_time / self.pl2reader.pl2_file_info.m_TimestampFrequency
return float(end_time / self.pl2reader.pl2_file_info.m_TimestampFrequency)

def _get_signal_size(self, block_index, seg_index, stream_index):
# this must return an integer value (the number of samples)

stream_id = self.header["signal_streams"][stream_index]["id"]
stream_characteristic = list(self.signal_stream_characteristics.values())[stream_index]
if stream_id != stream_characteristic.id:
raise ValueError(f"The `stream_id` must be {stream_characteristic.id}")
return int(stream_characteristic.n_samples) # Avoids returning a numpy.int64 scalar
stream_id = self.stream_index_to_stream_id[stream_index]
num_samples = int(self.stream_id_samples[stream_id])
return num_samples

def _get_signal_t_start(self, block_index, seg_index, stream_index):
# This returns the t_start of signals as a float value in seconds
Expand Down

0 comments on commit 60b26d4

Please sign in to comment.