Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LFP subsampling #2

Merged
merged 10 commits into from
Jul 9, 2024
169 changes: 136 additions & 33 deletions code/run_capsule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import argparse
from pathlib import Path
import numpy as np
import json

import spikeinterface as si
import spikeinterface.extractors as se
Expand Down Expand Up @@ -38,10 +39,7 @@
lfp_sampling_rate = 2500

# default compressors
default_electrical_series_compressors = dict(
hdf5="gzip",
zarr=Blosc(cname="zstd", clevel=9, shuffle=Blosc.BITSHUFFLE)
)
default_electrical_series_compressors = dict(hdf5="gzip", zarr=Blosc(cname="zstd", clevel=9, shuffle=Blosc.BITSHUFFLE))

# default event line from open ephys
data_folder = Path("../data/")
Expand All @@ -56,31 +54,61 @@
# positional arguments
stub_group = parser.add_mutually_exclusive_group()
stub_help = "Write a stub version for testing"
stub_group.add_argument('--stub', action='store_true', help=stub_help)
stub_group.add_argument('static_stub', nargs='?', default="false", help=stub_help)
stub_group.add_argument("--stub", action="store_true", help=stub_help)
stub_group.add_argument("static_stub", nargs="?", default="false", help=stub_help)

stub_seconds_group = parser.add_mutually_exclusive_group()
stub_seconds_help = "Duration of stub recording"
stub_seconds_group.add_argument('--stub-seconds', default=10, help=stub_seconds_help)
stub_seconds_group.add_argument('static_stub_seconds', nargs='?', default="10", help=stub_help)
stub_seconds_group.add_argument("--stub-seconds", default=10, help=stub_seconds_help)
stub_seconds_group.add_argument("static_stub_seconds", nargs="?", default="10", help=stub_help)

write_lfp_group = parser.add_mutually_exclusive_group()
write_lfp_help = "Whether to write LFP electrical series"
write_lfp_group.add_argument('--skip-lfp', action='store_true', help=write_lfp_help)
write_lfp_group.add_argument('static_write_lfp', nargs='?', default="true", help=write_lfp_help)
write_lfp_group.add_argument("--skip-lfp", action="store_true", help=write_lfp_help)
write_lfp_group.add_argument("static_write_lfp", nargs="?", default="true", help=write_lfp_help)

write_raw_group = parser.add_mutually_exclusive_group()
write_raw_help = "Whether to write RAW electrical series"
write_raw_group.add_argument('--skip-raw', action='store_true', help=write_raw_help)
write_raw_group.add_argument('static_write_raw', nargs='?', default="true", help=write_raw_help)
write_raw_group.add_argument("--skip-raw", action="store_true", help=write_raw_help)
write_raw_group.add_argument("static_write_raw", nargs="?", default="true", help=write_raw_help)

write_nidq_group = parser.add_mutually_exclusive_group()
write_nidq_help = "Whether to write NIDQ stream"
write_nidq_group.add_argument('--write-nidq', action='store_true', help=write_nidq_help)
write_nidq_group.add_argument('static_write_nidq', nargs='?', default="false", help=write_nidq_help)
write_nidq_group.add_argument("--write-nidq", action="store_true", help=write_nidq_help)
write_nidq_group.add_argument("static_write_nidq", nargs="?", default="false", help=write_nidq_help)

if __name__ == "__main__":
lfp_temporal_subsampling_group = parser.add_mutually_exclusive_group()
lfp_temporal_subsampling_help = (
"Ratio of input samples to output samples in time. Use 0 or 1 to keep all samples. Default is 2."
)
lfp_temporal_subsampling_group.add_argument("--lfp_temporal_factor", default=2, help=lfp_temporal_subsampling_help)
lfp_temporal_subsampling_group.add_argument("static_lfp_temporal_factor", nargs="?", help=lfp_temporal_subsampling_help)

lfp_spatial_subsampling_group = parser.add_mutually_exclusive_group()
lfp_spatial_subsampling_help = (
"Controls number of channels to skip in spatial subsampling. Use 0 or 1 to keep all channels. Default is 4."
)
lfp_spatial_subsampling_group.add_argument("--lfp_spatial_factor", default=4, help=lfp_spatial_subsampling_help)
lfp_spatial_subsampling_group.add_argument("static_lfp_spatial_factor", nargs="?", help=lfp_spatial_subsampling_help)

lfp_highpass_filter_group = parser.add_mutually_exclusive_group()
lfp_highpass_filter_help = (
"Cutoff frequency for highpass filter to apply to the LFP recorsings. Default is 0.1 Hz. Use 0 to skip filtering."
)
lfp_highpass_filter_group.add_argument("--lfp_highpass_freq_min", default=0.1, help=lfp_highpass_filter_help)
lfp_highpass_filter_group.add_argument("static_lfp_highpass_freq_min", nargs="?", help=lfp_highpass_filter_help)

# common median referencing for probes in agar
lfp_surface_channel_agar_group = parser.add_mutually_exclusive_group()
lfp_surface_channel_help = "Index of surface channel (e.g. index 0 corresponds to channel 1) of probe for common median referencing for probes in agar. Pass in as JSON string where key is probe and value is surface channel (e.g. \"{'ProbeA': 350, 'ProbeB': 360}\")"
lfp_surface_channel_agar_group.add_argument(
"--surface_channel_agar_probes_indices", help=lfp_surface_channel_help, default="", type=str
)
lfp_surface_channel_agar_group.add_argument(
"static_surface_channel_agar_probes_indices", help=lfp_surface_channel_help, nargs="?", type=str
)

if __name__ == "__main__":
args = parser.parse_args()

stub = args.stub or args.static_stub
Expand All @@ -104,12 +132,37 @@
else:
WRITE_NIDQ = True if args.static_write_nidq == "true" else False

print(
f"Stub test: {STUB_TEST} - Stub seconds: {STUB_SECONDS} - Write lfp: {WRITE_LFP} - Write raw: {WRITE_RAW} - Write NIDQ: {WRITE_NIDQ}"
TEMPORAL_SUBSAMPLING_FACTOR = args.static_lfp_temporal_factor or args.lfp_temporal_factor
TEMPORAL_SUBSAMPLING_FACTOR = int(TEMPORAL_SUBSAMPLING_FACTOR)
SPATIAL_CHANNEL_SUBSAMPLING_FACTOR = args.static_lfp_spatial_factor or args.lfp_spatial_factor
SPATIAL_CHANNEL_SUBSAMPLING_FACTOR = int(SPATIAL_CHANNEL_SUBSAMPLING_FACTOR)
HIGHPASS_FILTER_FREQ_MIN = args.static_lfp_highpass_freq_min or args.lfp_highpass_freq_min
HIGHPASS_FILTER_FREQ_MIN = float(HIGHPASS_FILTER_FREQ_MIN)
SURFACE_CHANNEL_AGAR_PROBES_INDICES = (
args.static_surface_channel_agar_probes_indices or args.surface_channel_agar_probes_indices
)
if SURFACE_CHANNEL_AGAR_PROBES_INDICES != "":
SURFACE_CHANNEL_AGAR_PROBES_INDICES = json.loads(SURFACE_CHANNEL_AGAR_PROBES_INDICES)
else:
SURFACE_CHANNEL_AGAR_PROBES_INDICES = None

print(f"Running NWB conversion with the following parameters:")
print(f"Stub test: {STUB_TEST}")
print(f"Stub seconds: {STUB_SECONDS}")
print(f"Write LFP: {WRITE_LFP}")
print(f"Write RAW: {WRITE_RAW}")
print(f"Write NIDQ: {WRITE_NIDQ}")
print(f"Temporal subsampling factor: {TEMPORAL_SUBSAMPLING_FACTOR}")
print(f"Spatial subsampling factor: {SPATIAL_CHANNEL_SUBSAMPLING_FACTOR}")
print(f"Highpass filter frequency: {HIGHPASS_FILTER_FREQ_MIN}")
print(f"Surface channel indices for agar probes: {SURFACE_CHANNEL_AGAR_PROBES_INDICES}")

# find ecephys session
sessions = [p.stem for p in data_folder.iterdir() if ("ecephys" in p.stem or "behavior" in p.stem) and "sorted" not in p.stem and "nwb" not in p.name]
sessions = [
p.stem
for p in data_folder.iterdir()
if ("ecephys" in p.stem or "behavior" in p.stem) and "sorted" not in p.stem and "nwb" not in p.name
]
assert len(sessions) == 1, "Attach one session (raw data) data at a time"
session = sessions[0]
ecephys_raw_folder = data_folder / session
Expand All @@ -124,7 +177,7 @@
NWB_BACKEND = "zarr"
NWB_SUFFIX = ".nwb.zarr"
io_class = NWBZarrIO
else:
else:
NWB_BACKEND = "hdf5"
NWB_SUFFIX = ".nwb"
io_class = NWBHDF5IO
Expand All @@ -138,8 +191,12 @@
compressed_folder = data_folder / session / "ecephys_compressed"
else:
assert (data_folder / session / "ecephys").is_dir()
oe_folder = data_folder / session / "ecephys"
compressed_folder = None
if (data_folder / session / "ecephys" / "ecephys_compressed").is_dir():
oe_folder = data_folder / session / "ecephys" / "ecephys_clipped"
compressed_folder = data_folder / session / "ecephys" / "ecephys_compressed"
else:
oe_folder = data_folder / session / "ecephys"
compressed_folder = None

# Read Open Ephys Folder structure with NEO
neo_io = OpenEphysBinaryRawIO(oe_folder)
Expand All @@ -151,7 +208,8 @@
experiment_ids = [eid for eid in neo_io.folder_structure[record_nodes[0]]["experiments"].keys()]
experiment_names = [e["name"] for eid, e in neo_io.folder_structure[record_nodes[0]]["experiments"].items()]
recording_names = [
r["name"] for rid, r in neo_io.folder_structure[record_nodes[0]]["experiments"][experiment_ids[0]]["recordings"].items()
r["name"]
for rid, r in neo_io.folder_structure[record_nodes[0]]["experiments"][experiment_ids[0]]["recordings"].items()
]

streams_to_process = []
Expand Down Expand Up @@ -188,20 +246,24 @@
# write 1 new nwb file per segment
with io_class(str(nwbfile_input_path), "r") as read_io:
nwbfile = read_io.read()

for stream_name in streams_to_process:
record_node, oe_stream_name = stream_name.split("#")
recording_folder_name = f"{experiment_name}_{stream_name}_{recording_name}"
settings_file = neo_io.folder_structure[record_node]["experiments"][experiment_ids[block_index]]["settings_file"]
settings_file = neo_io.folder_structure[record_node]["experiments"][experiment_ids[block_index]][
"settings_file"
]

# Add devices
added_devices, target_locations = get_devices_from_metadata(session_folder, segment_index=segment_index)
added_devices, target_locations = get_devices_from_metadata(
session_folder, segment_index=segment_index
)

# if devices not found in metadata, instantiate using probeinterface
if added_devices:
for device_name, device in added_devices.items():
if device_name not in nwbfile.devices:
nwbfile.add_device(device)
nwbfile.add_device(device)
for device_name, targeted_location in target_locations.items():
probe_no_spaces = device_name.replace(" ", "")
if probe_no_spaces in oe_stream_name:
Expand Down Expand Up @@ -229,7 +291,9 @@
stream_name_zarr = f"{experiment_name}_{stream_name}"
recording_multi_segment = si.read_zarr(compressed_folder / f"{stream_name_zarr}.zarr")
else:
recording_multi_segment = se.read_openephys(oe_folder, stream_name=stream_name, block_index=block_index)
recording_multi_segment = se.read_openephys(
oe_folder, stream_name=stream_name, block_index=block_index
)

recording = si.split_recording(recording_multi_segment)[segment_index]

Expand Down Expand Up @@ -257,8 +321,7 @@
}
electrode_metadata["Ecephys"].update(electrical_series_metadata)
add_electrical_series_kwargs = dict(
es_key=f"ElectricalSeries{probe_device_name}",
write_as="raw"
es_key=f"ElectricalSeries{probe_device_name}", write_as="raw"
)
# Add channel properties (group_name property to associate electrodes with group)
recording.set_channel_groups([probe_device_name] * recording.get_num_channels())
Expand Down Expand Up @@ -293,16 +356,22 @@

if "AP" not in stream_name:
# Wide-band NP recording: filter and resample LFP
print(f"\tAdding LFP data for stream {stream_name} from wide-band signal - segment {segment_index}")
print(
f"\tAdding LFP data for stream {stream_name} from wide-band signal - segment {segment_index}"
)
recording_lfp = spre.bandpass_filter(recording, **lfp_filter_kwargs)
recording_lfp = spre.resample(recording_lfp, lfp_sampling_rate)
recording_lfp = spre.scale(recording_lfp, dtype="int16")

# there is a bug in with sample mismatches for the last chunk if num_samples not divisible by chunk_size
# the workaround is to discard the last samples to make it "even"
if recording.get_num_segments() == 1:
recording_lfp = recording_lfp.frame_slice(start_frame=0,
end_frame=int(recording_lfp.get_num_samples() // lfp_sampling_rate * lfp_sampling_rate))
recording_lfp = recording_lfp.frame_slice(
start_frame=0,
end_frame=int(
recording_lfp.get_num_samples() // lfp_sampling_rate * lfp_sampling_rate
),
)
lfp_period = 1.0 / lfp_sampling_rate
for segment_index in range(recording.get_num_segments()):
ts_lfp = (
Expand All @@ -325,8 +394,42 @@
recording_lfp = se.read_openephys(
oe_folder, stream_name=lfp_stream_name, block_index=block_index
)

channel_ids = recording_lfp.get_channel_ids()

# re-reference only for agar - subtract median of channels out of brain using surface channel index arg
# similar processing to allensdk
if SURFACE_CHANNEL_AGAR_PROBES_INDICES is not None:
if probe.name in SURFACE_CHANNEL_AGAR_PROBES_INDICES:
surface_channel_index = SURFACE_CHANNEL_AGAR_PROBES_INDICES[probe.name]
# get indices of channels out of brain including surface channel
reference_channel_indices = np.arange(surface_channel_index, len(channel_ids), 1)
reference_channel_ids = channel_ids[reference_channel_indices]
groups = [[channel_id] for channel_id in reference_channel_ids]
# common median reference to channels out of brain
recording_lfp = spre.common_reference(
recording_lfp,
reference="single",
groups=groups,
ref_channel_ids=reference_channel_ids,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arjunsridhar12345 not sure this is correct. If I understand correctly, you want to subtract the median of the channnels in agar (from surface_channel_index to the last channel) to the remaining channels, correct?

The current solution will subtract to each channel group (so each channel in your case), the corresponding reference_channel_id, so all channels out of the brain will be zeroed out, with no effect on the other channels. Is this the wanted behavior?

Copy link
Collaborator Author

@arjunsridhar12345 arjunsridhar12345 Jun 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, want to subtract the median of channels from surface index to last to the remaining channels. Then, should the reference be "global" with only ref_channel_ids set? Or what parameters should be passed to get the desired effect?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmm I don't think we have that option yet, but we could cook in a workaround if needed.

See this issue: SpikeInterface/spikeinterface#2985

Copy link
Collaborator Author

@arjunsridhar12345 arjunsridhar12345 Jun 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ok cool. that seems close, basically just need to subtract the remaining traces from the median of those channels. is there any sort of set traces type functionality or another way you know of?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really...I'll think about a smart way to do it! We will probbaly need some custom recording functions

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arjunsridhar12345 fixed in this PR: SpikeInterface/spikeinterface#3139 (and SpikeInterface/spikeinterface#3140)

Until we release, I added a custom installation in the postinstall

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok great, thanks!


# spatial subsampling from allensdk - keep every nth channel
if SPATIAL_CHANNEL_SUBSAMPLING_FACTOR > 1:
channel_ids_to_keep = channel_ids[0 : len(channel_ids) : SPATIAL_CHANNEL_SUBSAMPLING_FACTOR]
recording_lfp = recording_lfp.channel_slice(channel_ids_to_keep)

# time subsampling/decimate
if TEMPORAL_SUBSAMPLING_FACTOR > 1:
recording_lfp = spre.decimate(recording_lfp, TEMPORAL_SUBSAMPLING_FACTOR)

# high pass filter from allensdk
if HIGHPASS_FILTER_FREQ_MIN > 0:
recording_lfp = spre.highpass_filter(recording_lfp, freq_min=HIGHPASS_FILTER_FREQ_MIN)

# Assign to the correct channel group
recording_lfp.set_channel_groups([probe_device_name] * recording_lfp.get_num_channels())

if STUB_TEST:
end_frame = int(STUB_SECONDS * recording_lfp.sampling_frequency)
recording_lfp = recording_lfp.frame_slice(start_frame=0, end_frame=end_frame)
Expand All @@ -347,7 +450,7 @@
print(f"Added {len(streams_to_process)} streams")

print(f"Configuring {NWB_BACKEND} backend")
backend_configuration = get_default_backend_configuration(nwbfile=nwbfile, backend=NWB_BACKEND)
backend_configuration = get_default_backend_configuration(nwbfile=nwbfile, backend=NWB_BACKEND)
es_compressor = default_electrical_series_compressors[NWB_BACKEND]

for key in backend_configuration.dataset_configurations.keys():
Expand Down