Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LFP subsampling #2

Merged
merged 10 commits into from
Jul 9, 2024
55 changes: 55 additions & 0 deletions code/run_capsule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import argparse
from pathlib import Path
import numpy as np
import json

import spikeinterface as si
import spikeinterface.extractors as se
Expand Down Expand Up @@ -79,6 +80,26 @@
write_nidq_group.add_argument('--write-nidq', action='store_true', help=write_nidq_help)
write_nidq_group.add_argument('static_write_nidq', nargs='?', default="false", help=write_nidq_help)

lfp_temporal_subsampling_group = parser.add_mutually_exclusive_group()
lfp_temporal_subsampling_help = "Ratio of input samples to output samples in time "
lfp_temporal_subsampling_group.add_argument('--lfp_temporal_factor', action='store_true', help=lfp_temporal_subsampling_help)
lfp_temporal_subsampling_group.add_argument('static_lfp_temporal_factor', nargs='?', default="2", help=lfp_temporal_subsampling_help)

lfp_spatial_subsampling_group = parser.add_mutually_exclusive_group()
lfp_spatial_subsampling_help = "Distance between channels to keep"
lfp_spatial_subsampling_group.add_argument('--lfp_spatial_factor', action='store_true', help=lfp_spatial_subsampling_help)
lfp_spatial_subsampling_group.add_argument('static_lfp_spatial_factor', nargs='?', default="4", help=lfp_spatial_subsampling_help)

# common median referencing for probes in agar
lfp_surface_channel_agar_group = parser.add_mutually_exclusive_group()
lfp_surface_channel_agar_group.add_argument('--surface_channel_agar_probes_indices', help='Index of surface channel (e.g. index 0 corresponds to channel 1) of probe for common median referencing for probes in agar. Pass in as dict where key is probe and value is surface channel (e.g. """{"ProbeA": 350}"""',
type=str)

lfp_highpass_filter_group = parser.add_mutually_exclusive_group()
lfp_highpass_filter_help = "Cutoff frequency for highpass filter"
lfp_highpass_filter_group.add_argument('--lfp_highpass_freq_min', action='store_true', help=lfp_highpass_filter_help)
lfp_highpass_filter_group.add_argument('static_lfp_highpass_freq_min', nargs='?', default="0.1", help=lfp_highpass_filter_help)

if __name__ == "__main__":

args = parser.parse_args()
Expand All @@ -104,6 +125,9 @@
else:
WRITE_NIDQ = True if args.static_write_nidq == "true" else False

TEMPORAL_SUBSAMPLING_FACTOR = int(args.lfp_temporal_factor) or int(args.static_lfp_temporal_factor)
SPATIAL_CHANNEL_SUBSAMPLING_FACTOR = int(args.lfp_spatial_factor) or int(args.static_lfp_spatial_factor)
HIGHPASS_FILTER_FREQ_MIN = float(args.lfp_highpass_freq_min) or float(args.static_lfp_highpass_freq_min)
print(
f"Stub test: {STUB_TEST} - Stub seconds: {STUB_SECONDS} - Write lfp: {WRITE_LFP} - Write raw: {WRITE_RAW} - Write NIDQ: {WRITE_NIDQ}"
)
Expand Down Expand Up @@ -325,8 +349,39 @@
recording_lfp = se.read_openephys(
oe_folder, stream_name=lfp_stream_name, block_index=block_index
)

channel_ids = recording_lfp.get_channel_ids()

# re-reference only for agar - subtract median of channels out of brain using surface channel index arg, from allensdk
if args.surface_channel_agar_probes_indices:
surface_channel_probes = json.loads(args.surface_channel_agar_probes_indices)
if probe.name in surface_channel_probes:
surface_channel_index = surface_channel_probes[probe.name]
# get indices of channels out of brain including surface channel
reference_channel_indices = np.arange(surface_channel_index, len(channel_ids), 1)
reference_channel_ids = channel_ids[reference_channel_indices]
groups = [[channel_id] for channel_id in reference_channel_ids]
# common median reference to channels out of brain
recording_lfp = spre.common_reference(recording_lfp, reference="single", groups=groups,
ref_channel_ids=reference_channel_ids)

# spatial subsampling from allensdk - keep every nth channel
channel_ids_to_keep = channel_ids[0:len(channel_ids):SPATIAL_CHANNEL_SUBSAMPLING_FACTOR]
recording_lfp_spatial_subsampled = recording_lfp.channel_slice(channel_ids_to_keep)
assert (recording_lfp_spatial_subsampled.get_num_channels() == int(recording_lfp.get_num_channels() / SPATIAL_CHANNEL_SUBSAMPLING_FACTOR)
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
), f"Mismatch when downsampling spatially. Got {recording_lfp_spatial_subsampled.get_num_channels()} number of channels given {SPATIAL_CHANNEL_SUBSAMPLING_FACTOR} channel stride and {recording_lfp.get_num_channels()} original channels"

# time subsampling/decimate
recording_lfp = spre.decimate(recording_lfp_spatial_subsampled, TEMPORAL_SUBSAMPLING_FACTOR)
assert(recording_lfp.get_num_samples() == np.ceil(recording_lfp_spatial_subsampled.get_num_samples() / TEMPORAL_SUBSAMPLING_FACTOR)
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
), f"Mismatch when downsampling temporally. Got {recording_lfp.get_num_samples()} samples given {TEMPORAL_SUBSAMPLING_FACTOR} factor and {recording_lfp_spatial_subsampled.get_num_samples()} original samples"

# high pass filter from allensdk
recording_lfp = spre.highpass_filter(recording_lfp, freq_min=HIGHPASS_FILTER_FREQ_MIN)

# Assign to the correct channel group
recording_lfp.set_channel_groups([probe_device_name] * recording_lfp.get_num_channels())

if STUB_TEST:
end_frame = int(STUB_SECONDS * recording_lfp.sampling_frequency)
recording_lfp = recording_lfp.frame_slice(start_frame=0, end_frame=end_frame)
Expand Down