Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LFP subsampling #2

Merged
merged 10 commits into from
Jul 9, 2024
44 changes: 44 additions & 0 deletions code/run_capsule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import argparse
from pathlib import Path
import numpy as np
import json

import spikeinterface as si
import spikeinterface.extractors as se
Expand Down Expand Up @@ -79,6 +80,13 @@
write_nidq_group.add_argument('--write-nidq', action='store_true', help=write_nidq_help)
write_nidq_group.add_argument('static_write_nidq', nargs='?', default="false", help=write_nidq_help)

lfp_subsampling_args = parser.add_mutually_exclusive_group()
lfp_subsampling_args.add_argument('--temporal_factor', default=2, help='Ratio of input samples to output samples in time')
lfp_subsampling_args.add_argument('--spatial_factor', default=4, help='Distance between channels to keep')
# common median referencing for probes in agar
lfp_subsampling_args.add_argument('--surface_channel_agar_probes_indices', help='Index of surface channel (e.g. index 0 corresponds to channel 1) of probe for common median referencing for probes in agar. Pass in as dict where key is probe and value is surface channel (e.g. """{"ProbeA": 350}"""',
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
type=str)

if __name__ == "__main__":

args = parser.parse_args()
Expand All @@ -104,6 +112,9 @@
else:
WRITE_NIDQ = True if args.static_write_nidq == "true" else False

TEMPORAL_SUBSAMPLING_FACTOR = args.temporal_factor
SPATIAL_CHANNEL_SUBSAMPLING_FACTOR = args.spatial_factor

print(
f"Stub test: {STUB_TEST} - Stub seconds: {STUB_SECONDS} - Write lfp: {WRITE_LFP} - Write raw: {WRITE_RAW} - Write NIDQ: {WRITE_NIDQ}"
)
Expand Down Expand Up @@ -325,8 +336,41 @@
recording_lfp = se.read_openephys(
oe_folder, stream_name=lfp_stream_name, block_index=block_index
)

channel_ids = recording_lfp.get_channel_ids()

# re-reference only for agar - subtract median of channels out of brain using surface channel index arg, from allensdk
if args.surface_channel_agar_probes_indices:
surface_channel_probes = json.loads(args.surface_channel_agar_probes_indices)
if probe.name in surface_channel_probes:
surface_channel_index = surface_channel_probes[probe.name]
# get indices of channels out of brain including surface channel
reference_channel_indices = np.arange(surface_channel_index, len(channel_ids), 1)
reference_channel_ids = channel_ids[reference_channel_indices]
groups = [[channel_id] for channel_id in reference_channel_ids]
# common median reference to channels out of brain
recording_lfp = spre.common_reference(recording_lfp, reference="single", groups=groups,
ref_channel_ids=reference_channel_ids)

# spatial subsampling from allensdk - keep every nth channel
channel_ids_to_keep = channel_ids[0:len(channel_ids):SPATIAL_CHANNEL_SUBSAMPLING_FACTOR]
channel_ids_to_remove = [channel_id for channel_id in channel_ids if channel_id not in channel_ids_to_keep]
recording_lfp_spatial_subsampled = recording_lfp.remove_channels(channel_ids_to_remove)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also use recording_lfp_spatial_subsampled = recording_lfp.channel_slice(channel_ids_to_keep)

assert (recording_lfp_spatial_subsampled.get_num_channels() == int(recording_lfp.get_num_channels() / SPATIAL_CHANNEL_SUBSAMPLING_FACTOR)
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
), f"Mismatch when downsampling spatially. Got {recording_lfp_spatial_subsampled.get_num_channels()} number of channels given {SPATIAL_CHANNEL_SUBSAMPLING_FACTOR} channel stride and {recording_lfp.get_num_channels()} original channels"

# time subsampling/decimate
recording_lfp = spre.resample(recording_lfp_spatial_subsampled,
int(recording_lfp.sampling_frequency / TEMPORAL_SUBSAMPLING_FACTOR))
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved
assert(recording_lfp.get_num_samples() == int(recording_lfp_spatial_subsampled.get_num_samples() / TEMPORAL_SUBSAMPLING_FACTOR)
), f"Mismatch when downsampling temporally. Got {recording_lfp.get_num_samples()} samples given {TEMPORAL_SUBSAMPLING_FACTOR} factor and {recording_lfp_spatial_subsampled.get_num_samples()} original samples"

# high pass filter from allensdk
recording_lfp = spre.highpass_filter(recording_lfp, freq_min=0.1)
alejoe91 marked this conversation as resolved.
Show resolved Hide resolved

# Assign to the correct channel group
recording_lfp.set_channel_groups([probe_device_name] * recording_lfp.get_num_channels())

if STUB_TEST:
end_frame = int(STUB_SECONDS * recording_lfp.sampling_frequency)
recording_lfp = recording_lfp.frame_slice(start_frame=0, end_frame=end_frame)
Expand Down