Skip to content

Commit

Permalink
change behavior of load_TS
Browse files Browse the repository at this point in the history
  • Loading branch information
mtorabi59 committed Apr 29, 2024
1 parent 344d3f1 commit db699bd
Showing 1 changed file with 33 additions and 30 deletions.
63 changes: 33 additions & 30 deletions pydfc/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

import os
from copy import deepcopy

import h5py
import numpy as np
Expand All @@ -16,10 +17,10 @@
################################# DATA_LOADER functions ######################################


def find_subj_list(data_root, sessions):
def find_subj_list(data_root):
"""
find the list of subjects in data_root
the files must follow the format: subjectID_sessionID
the files must follow the format: sub-subjectID
only these files should be in the data_root
"""
if data_root[-1] != "/":
Expand All @@ -31,32 +32,13 @@ def find_subj_list(data_root, sessions):
FOLDERS.sort()
SUBJECTS = list()
for s in FOLDERS:
num = s[: s.find("_")]
SUBJECTS.append(num)
# the subjects might be repeated because of different sessions
SUBJECTS = list(set(SUBJECTS))
if "sub-" in s:
SUBJECTS.append(s)
SUBJECTS.sort()

print(str(len(SUBJECTS)) + " subjects were found. ")

failed_subjs = []
kept_subjs = []
for subj in SUBJECTS:
kept_subjs.append(subj)
for session in sessions:
if not os.path.exists(f"{data_root}{subj}_{session}"):
failed_subjs.append(subj)
kept_subjs.remove(subj)
break

print(
str(len(failed_subjs))
+ " subjects had missing sessions. "
+ str(len(kept_subjs))
+ " subjects were kept. "
)
print(f"{len(SUBJECTS)} subjects were found. ")

return kept_subjs
return SUBJECTS


def load_from_array(subj_id2load=None, **params):
Expand All @@ -83,7 +65,7 @@ def load_from_array(subj_id2load=None, **params):

SESSIONs = params["SESSIONs"] # list of sessions
if subj_id2load is None:
SUBJECTS = find_subj_list(params["data_root"], sessions=SESSIONs)
SUBJECTS = find_subj_list(params["data_root"])
else:
SUBJECTS = [subj_id2load]

Expand Down Expand Up @@ -337,12 +319,22 @@ def multi_nifti2timeseries(
return BOLD_multi


def load_TS(data_root, file_name, SESSIONs, subj_id2load=None):
def load_TS(
data_root,
file_name,
SESSIONs,
subj_id2load=None,
task=None,
run=None,
):
"""
load a TIME_SERIES object from a .npy file
if SESSIONs is a list, it will load all the sessions,
if it is a string, it will load that session
if subj_id2load is None, it will load all the subjects
file_name: name of the file to load
format example: {subj_id}_{task}_{run}_time-series.npy
(keep the {} for the variables)
"""
# check if SESSIONs is a list or a string
flag = False
Expand All @@ -351,17 +343,28 @@ def load_TS(data_root, file_name, SESSIONs, subj_id2load=None):
flag = True

if subj_id2load is None:
SUBJECTS = find_subj_list(data_root, sessions=SESSIONs)
SUBJECTS = find_subj_list(data_root)
else:
assert "sub-" in subj_id2load, "subj_id2load must start with 'sub-'"
SUBJECTS = [subj_id2load]

TS = {}
for session in SESSIONs:
TS[session] = None
for subj in SUBJECTS:
subj_fldr = f"{subj}_{session}"
subj_fldr = subj
# make the file_name
TS_file = deepcopy(file_name)
if "{subj_id}" in file_name:
TS_file = TS_file.replace("{subj_id}", subj)
if "{task}" in file_name:
assert task is not None, "task must be provided"
TS_file = TS_file.replace("{task}", task)
if "{run}" in file_name:
assert run is not None, "run must be provided"
TS_file = TS_file.replace("{run}", run)
time_series = np.load(
f"{data_root}/{subj_fldr}/{file_name}", allow_pickle="True"
f"{data_root}/{subj_fldr}/{TS_file}", allow_pickle="True"
).item()
if TS[session] is None:
TS[session] = time_series
Expand Down

0 comments on commit db699bd

Please sign in to comment.