Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LFP subsampling #2

Merged
merged 10 commits into from
Jul 9, 2024
Merged
167 changes: 134 additions & 33 deletions code/run_capsule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import argparse
from pathlib import Path
import numpy as np
import json

import spikeinterface as si
import spikeinterface.extractors as se
Expand Down Expand Up @@ -38,10 +39,7 @@
lfp_sampling_rate = 2500

# default compressors
default_electrical_series_compressors = dict(
hdf5="gzip",
zarr=Blosc(cname="zstd", clevel=9, shuffle=Blosc.BITSHUFFLE)
)
default_electrical_series_compressors = dict(hdf5="gzip", zarr=Blosc(cname="zstd", clevel=9, shuffle=Blosc.BITSHUFFLE))

# default event line from open ephys
data_folder = Path("../data/")
Expand All @@ -56,31 +54,61 @@
# positional arguments
stub_group = parser.add_mutually_exclusive_group()
stub_help = "Write a stub version for testing"
stub_group.add_argument('--stub', action='store_true', help=stub_help)
stub_group.add_argument('static_stub', nargs='?', default="false", help=stub_help)
stub_group.add_argument("--stub", action="store_true", help=stub_help)
stub_group.add_argument("static_stub", nargs="?", default="false", help=stub_help)

stub_seconds_group = parser.add_mutually_exclusive_group()
stub_seconds_help = "Duration of stub recording"
stub_seconds_group.add_argument('--stub-seconds', default=10, help=stub_seconds_help)
stub_seconds_group.add_argument('static_stub_seconds', nargs='?', default="10", help=stub_help)
stub_seconds_group.add_argument("--stub-seconds", default=10, help=stub_seconds_help)
stub_seconds_group.add_argument("static_stub_seconds", nargs="?", default="10", help=stub_help)

write_lfp_group = parser.add_mutually_exclusive_group()
write_lfp_help = "Whether to write LFP electrical series"
write_lfp_group.add_argument('--skip-lfp', action='store_true', help=write_lfp_help)
write_lfp_group.add_argument('static_write_lfp', nargs='?', default="true", help=write_lfp_help)
write_lfp_group.add_argument("--skip-lfp", action="store_true", help=write_lfp_help)
write_lfp_group.add_argument("static_write_lfp", nargs="?", default="true", help=write_lfp_help)

write_raw_group = parser.add_mutually_exclusive_group()
write_raw_help = "Whether to write RAW electrical series"
write_raw_group.add_argument('--skip-raw', action='store_true', help=write_raw_help)
write_raw_group.add_argument('static_write_raw', nargs='?', default="true", help=write_raw_help)
write_raw_group.add_argument("--skip-raw", action="store_true", help=write_raw_help)
write_raw_group.add_argument("static_write_raw", nargs="?", default="true", help=write_raw_help)

write_nidq_group = parser.add_mutually_exclusive_group()
write_nidq_help = "Whether to write NIDQ stream"
write_nidq_group.add_argument('--write-nidq', action='store_true', help=write_nidq_help)
write_nidq_group.add_argument('static_write_nidq', nargs='?', default="false", help=write_nidq_help)
write_nidq_group.add_argument("--write-nidq", action="store_true", help=write_nidq_help)
write_nidq_group.add_argument("static_write_nidq", nargs="?", default="false", help=write_nidq_help)

if __name__ == "__main__":
lfp_temporal_subsampling_group = parser.add_mutually_exclusive_group()
lfp_temporal_subsampling_help = (
"Ratio of input samples to output samples in time. Use 0 or 1 to keep all samples. Default is 2."
)
lfp_temporal_subsampling_group.add_argument("--lfp_temporal_factor", default=2, help=lfp_temporal_subsampling_help)
lfp_temporal_subsampling_group.add_argument("static_lfp_temporal_factor", nargs="?", help=lfp_temporal_subsampling_help)

lfp_spatial_subsampling_group = parser.add_mutually_exclusive_group()
lfp_spatial_subsampling_help = (
"Controls number of channels to skip in spatial subsampling. Use 0 or 1 to keep all channels. Default is 4."
)
lfp_spatial_subsampling_group.add_argument("--lfp_spatial_factor", default=4, help=lfp_spatial_subsampling_help)
lfp_spatial_subsampling_group.add_argument("static_lfp_spatial_factor", nargs="?", help=lfp_spatial_subsampling_help)

lfp_highpass_filter_group = parser.add_mutually_exclusive_group()
lfp_highpass_filter_help = (
"Cutoff frequency for highpass filter to apply to the LFP recorsings. Default is 0.1 Hz. Use 0 to skip filtering."
)
lfp_highpass_filter_group.add_argument("--lfp_highpass_freq_min", default=0.1, help=lfp_highpass_filter_help)
lfp_highpass_filter_group.add_argument("static_lfp_highpass_freq_min", nargs="?", help=lfp_highpass_filter_help)

# common median referencing for probes in agar
lfp_surface_channel_agar_group = parser.add_mutually_exclusive_group()
lfp_surface_channel_help = "Index of surface channel (e.g. index 0 corresponds to channel 1) of probe for common median referencing for probes in agar. Pass in as JSON string where key is probe and value is surface channel (e.g. \"{'ProbeA': 350, 'ProbeB': 360}\")"
lfp_surface_channel_agar_group.add_argument(
"--surface_channel_agar_probes_indices", help=lfp_surface_channel_help, default="", type=str
)
lfp_surface_channel_agar_group.add_argument(
"static_surface_channel_agar_probes_indices", help=lfp_surface_channel_help, nargs="?", type=str
)

if __name__ == "__main__":
args = parser.parse_args()

stub = args.stub or args.static_stub
Expand All @@ -104,12 +132,37 @@
else:
WRITE_NIDQ = True if args.static_write_nidq == "true" else False

print(
f"Stub test: {STUB_TEST} - Stub seconds: {STUB_SECONDS} - Write lfp: {WRITE_LFP} - Write raw: {WRITE_RAW} - Write NIDQ: {WRITE_NIDQ}"
TEMPORAL_SUBSAMPLING_FACTOR = args.static_lfp_temporal_factor or args.lfp_temporal_factor
TEMPORAL_SUBSAMPLING_FACTOR = int(TEMPORAL_SUBSAMPLING_FACTOR)
SPATIAL_CHANNEL_SUBSAMPLING_FACTOR = args.static_lfp_spatial_factor or args.lfp_spatial_factor
SPATIAL_CHANNEL_SUBSAMPLING_FACTOR = int(SPATIAL_CHANNEL_SUBSAMPLING_FACTOR)
HIGHPASS_FILTER_FREQ_MIN = args.static_lfp_highpass_freq_min or args.lfp_highpass_freq_min
HIGHPASS_FILTER_FREQ_MIN = float(HIGHPASS_FILTER_FREQ_MIN)
SURFACE_CHANNEL_AGAR_PROBES_INDICES = (
args.static_surface_channel_agar_probes_indices or args.surface_channel_agar_probes_indices
)
if SURFACE_CHANNEL_AGAR_PROBES_INDICES != "":
SURFACE_CHANNEL_AGAR_PROBES_INDICES = json.loads(SURFACE_CHANNEL_AGAR_PROBES_INDICES)
else:
SURFACE_CHANNEL_AGAR_PROBES_INDICES = None

print(f"Running NWB conversion with the following parameters:")
print(f"Stub test: {STUB_TEST}")
print(f"Stub seconds: {STUB_SECONDS}")
print(f"Write LFP: {WRITE_LFP}")
print(f"Write RAW: {WRITE_RAW}")
print(f"Write NIDQ: {WRITE_NIDQ}")
print(f"Temporal subsampling factor: {TEMPORAL_SUBSAMPLING_FACTOR}")
print(f"Spatial subsampling factor: {SPATIAL_CHANNEL_SUBSAMPLING_FACTOR}")
print(f"Highpass filter frequency: {HIGHPASS_FILTER_FREQ_MIN}")
print(f"Surface channel indices for agar probes: {SURFACE_CHANNEL_AGAR_PROBES_INDICES}")

# find ecephys session
sessions = [p.stem for p in data_folder.iterdir() if ("ecephys" in p.stem or "behavior" in p.stem) and "sorted" not in p.stem and "nwb" not in p.name]
sessions = [
p.stem
for p in data_folder.iterdir()
if ("ecephys" in p.stem or "behavior" in p.stem) and "sorted" not in p.stem and "nwb" not in p.name
]
assert len(sessions) == 1, "Attach one session (raw data) data at a time"
session = sessions[0]
ecephys_raw_folder = data_folder / session
Expand All @@ -124,7 +177,7 @@
NWB_BACKEND = "zarr"
NWB_SUFFIX = ".nwb.zarr"
io_class = NWBZarrIO
else:
else:
NWB_BACKEND = "hdf5"
NWB_SUFFIX = ".nwb"
io_class = NWBHDF5IO
Expand All @@ -138,8 +191,12 @@
compressed_folder = data_folder / session / "ecephys_compressed"
else:
assert (data_folder / session / "ecephys").is_dir()
oe_folder = data_folder / session / "ecephys"
compressed_folder = None
if (data_folder / session / "ecephys" / "ecephys_compressed").is_dir():
oe_folder = data_folder / session / "ecephys" / "ecephys_clipped"
compressed_folder = data_folder / session / "ecephys" / "ecephys_compressed"
else:
oe_folder = data_folder / session / "ecephys"
compressed_folder = None

# Read Open Ephys Folder structure with NEO
neo_io = OpenEphysBinaryRawIO(oe_folder)
Expand All @@ -151,7 +208,8 @@
experiment_ids = [eid for eid in neo_io.folder_structure[record_nodes[0]]["experiments"].keys()]
experiment_names = [e["name"] for eid, e in neo_io.folder_structure[record_nodes[0]]["experiments"].items()]
recording_names = [
r["name"] for rid, r in neo_io.folder_structure[record_nodes[0]]["experiments"][experiment_ids[0]]["recordings"].items()
r["name"]
for rid, r in neo_io.folder_structure[record_nodes[0]]["experiments"][experiment_ids[0]]["recordings"].items()
]

streams_to_process = []
Expand Down Expand Up @@ -188,20 +246,24 @@
# write 1 new nwb file per segment
with io_class(str(nwbfile_input_path), "r") as read_io:
nwbfile = read_io.read()

for stream_name in streams_to_process:
record_node, oe_stream_name = stream_name.split("#")
recording_folder_name = f"{experiment_name}_{stream_name}_{recording_name}"
settings_file = neo_io.folder_structure[record_node]["experiments"][experiment_ids[block_index]]["settings_file"]
settings_file = neo_io.folder_structure[record_node]["experiments"][experiment_ids[block_index]][
"settings_file"
]

# Add devices
added_devices, target_locations = get_devices_from_metadata(session_folder, segment_index=segment_index)
added_devices, target_locations = get_devices_from_metadata(
session_folder, segment_index=segment_index
)

# if devices not found in metadata, instantiate using probeinterface
if added_devices:
for device_name, device in added_devices.items():
if device_name not in nwbfile.devices:
nwbfile.add_device(device)
nwbfile.add_device(device)
for device_name, targeted_location in target_locations.items():
probe_no_spaces = device_name.replace(" ", "")
if probe_no_spaces in oe_stream_name:
Expand Down Expand Up @@ -229,7 +291,9 @@
stream_name_zarr = f"{experiment_name}_{stream_name}"
recording_multi_segment = si.read_zarr(compressed_folder / f"{stream_name_zarr}.zarr")
else:
recording_multi_segment = se.read_openephys(oe_folder, stream_name=stream_name, block_index=block_index)
recording_multi_segment = se.read_openephys(
oe_folder, stream_name=stream_name, block_index=block_index
)

recording = si.split_recording(recording_multi_segment)[segment_index]

Expand Down Expand Up @@ -257,8 +321,7 @@
}
electrode_metadata["Ecephys"].update(electrical_series_metadata)
add_electrical_series_kwargs = dict(
es_key=f"ElectricalSeries{probe_device_name}",
write_as="raw"
es_key=f"ElectricalSeries{probe_device_name}", write_as="raw"
)
# Add channel properties (group_name property to associate electrodes with group)
recording.set_channel_groups([probe_device_name] * recording.get_num_channels())
Expand Down Expand Up @@ -293,16 +356,22 @@

if "AP" not in stream_name:
# Wide-band NP recording: filter and resample LFP
print(f"\tAdding LFP data for stream {stream_name} from wide-band signal - segment {segment_index}")
print(
f"\tAdding LFP data for stream {stream_name} from wide-band signal - segment {segment_index}"
)
recording_lfp = spre.bandpass_filter(recording, **lfp_filter_kwargs)
recording_lfp = spre.resample(recording_lfp, lfp_sampling_rate)
recording_lfp = spre.scale(recording_lfp, dtype="int16")

# there is a bug in with sample mismatches for the last chunk if num_samples not divisible by chunk_size
# the workaround is to discard the last samples to make it "even"
if recording.get_num_segments() == 1:
recording_lfp = recording_lfp.frame_slice(start_frame=0,
end_frame=int(recording_lfp.get_num_samples() // lfp_sampling_rate * lfp_sampling_rate))
recording_lfp = recording_lfp.frame_slice(
start_frame=0,
end_frame=int(
recording_lfp.get_num_samples() // lfp_sampling_rate * lfp_sampling_rate
),
)
lfp_period = 1.0 / lfp_sampling_rate
for segment_index in range(recording.get_num_segments()):
ts_lfp = (
Expand All @@ -325,8 +394,40 @@
recording_lfp = se.read_openephys(
oe_folder, stream_name=lfp_stream_name, block_index=block_index
)

channel_ids = recording_lfp.get_channel_ids()

# re-reference only for agar - subtract median of channels out of brain using surface channel index arg
# similar processing to allensdk
if SURFACE_CHANNEL_AGAR_PROBES_INDICES is not None:
if probe.name in SURFACE_CHANNEL_AGAR_PROBES_INDICES:
surface_channel_index = SURFACE_CHANNEL_AGAR_PROBES_INDICES[probe.name]
# get indices of channels out of brain including surface channel
reference_channel_indices = np.arange(surface_channel_index, len(channel_ids))
reference_channel_ids = channel_ids[reference_channel_indices]
# common median reference to channels out of brain
recording_lfp = spre.common_reference(
recording_lfp,
reference="global",
ref_channel_ids=reference_channel_ids,
)

# spatial subsampling from allensdk - keep every nth channel
if SPATIAL_CHANNEL_SUBSAMPLING_FACTOR > 1:
channel_ids_to_keep = channel_ids[0 : len(channel_ids) : SPATIAL_CHANNEL_SUBSAMPLING_FACTOR]
recording_lfp = recording_lfp.channel_slice(channel_ids_to_keep)

# time subsampling/decimate
if TEMPORAL_SUBSAMPLING_FACTOR > 1:
recording_lfp = spre.decimate(recording_lfp, TEMPORAL_SUBSAMPLING_FACTOR)

# high pass filter from allensdk
if HIGHPASS_FILTER_FREQ_MIN > 0:
recording_lfp = spre.highpass_filter(recording_lfp, freq_min=HIGHPASS_FILTER_FREQ_MIN)

# Assign to the correct channel group
recording_lfp.set_channel_groups([probe_device_name] * recording_lfp.get_num_channels())

if STUB_TEST:
end_frame = int(STUB_SECONDS * recording_lfp.sampling_frequency)
recording_lfp = recording_lfp.frame_slice(start_frame=0, end_frame=end_frame)
Expand All @@ -347,7 +448,7 @@
print(f"Added {len(streams_to_process)} streams")

print(f"Configuring {NWB_BACKEND} backend")
backend_configuration = get_default_backend_configuration(nwbfile=nwbfile, backend=NWB_BACKEND)
backend_configuration = get_default_backend_configuration(nwbfile=nwbfile, backend=NWB_BACKEND)
es_compressor = default_electrical_series_compressors[NWB_BACKEND]

for key in backend_configuration.dataset_configurations.keys():
Expand Down
4 changes: 2 additions & 2 deletions environment/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ COPY git-askpass /

RUN pip install -U --no-cache-dir \
hdmf-zarr==0.6.0 \
neuroconv==0.4.8 \
neuroconv==0.4.9 \
nwbwidgets==0.11.3 \
pynwb==2.6.0 \
spikeinterface[full,widgets]==0.100.4 \
spikeinterface[full,widgets]==0.100.8 \
wavpack-numcodecs==0.1.5

COPY postInstall /
Expand Down
8 changes: 4 additions & 4 deletions environment/postInstall
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
set -e

# Install from source until 0.4.9 release
pip uninstall -y neuroconv
git clone https://github.com/catalystneuro/neuroconv.git
cd neuroconv
git checkout 6b7b654adf84ff7d68bca6ddc0b9f9d2f007224d
pip uninstall -y spikeinterface
git clone https://github.com/SpikeInterface/spikeinterface.git
cd spikeinterface
git checkout 1a101590ab32559e2af06edd67f2199df0b17112
pip install --ignore-installed --no-cache-dir .
cd ..