diff --git a/pydfc/data_loader.py b/pydfc/data_loader.py index baa654e..84e7713 100644 --- a/pydfc/data_loader.py +++ b/pydfc/data_loader.py @@ -6,6 +6,7 @@ """ import os +from copy import deepcopy import h5py import numpy as np @@ -16,10 +17,10 @@ ################################# DATA_LOADER functions ###################################### -def find_subj_list(data_root, sessions): +def find_subj_list(data_root): """ find the list of subjects in data_root - the files must follow the format: subjectID_sessionID + the files must follow the format: sub-subjectID only these files should be in the data_root """ if data_root[-1] != "/": @@ -31,32 +32,13 @@ def find_subj_list(data_root, sessions): FOLDERS.sort() SUBJECTS = list() for s in FOLDERS: - num = s[: s.find("_")] - SUBJECTS.append(num) - # the subjects might be repeated because of different sessions - SUBJECTS = list(set(SUBJECTS)) + if "sub-" in s: + SUBJECTS.append(s) SUBJECTS.sort() - print(str(len(SUBJECTS)) + " subjects were found. ") - - failed_subjs = [] - kept_subjs = [] - for subj in SUBJECTS: - kept_subjs.append(subj) - for session in sessions: - if not os.path.exists(f"{data_root}{subj}_{session}"): - failed_subjs.append(subj) - kept_subjs.remove(subj) - break - - print( - str(len(failed_subjs)) - + " subjects had missing sessions. " - + str(len(kept_subjs)) - + " subjects were kept. " - ) + print(f"{len(SUBJECTS)} subjects were found. ") - return kept_subjs + return SUBJECTS def load_from_array(subj_id2load=None, **params): @@ -83,7 +65,7 @@ def load_from_array(subj_id2load=None, **params): SESSIONs = params["SESSIONs"] # list of sessions if subj_id2load is None: - SUBJECTS = find_subj_list(params["data_root"], sessions=SESSIONs) + SUBJECTS = find_subj_list(params["data_root"]) else: SUBJECTS = [subj_id2load] @@ -337,12 +319,22 @@ def multi_nifti2timeseries( return BOLD_multi -def load_TS(data_root, file_name, SESSIONs, subj_id2load=None): +def load_TS( + data_root, + file_name, + SESSIONs, + subj_id2load=None, + task=None, + run=None, +): """ load a TIME_SERIES object from a .npy file if SESSIONs is a list, it will load all the sessions, if it is a string, it will load that session if subj_id2load is None, it will load all the subjects + file_name: name of the file to load + format example: {subj_id}_{task}_{run}_time-series.npy + (keep the {} for the variables) """ # check if SESSIONs is a list or a string flag = False @@ -351,17 +343,28 @@ def load_TS(data_root, file_name, SESSIONs, subj_id2load=None): flag = True if subj_id2load is None: - SUBJECTS = find_subj_list(data_root, sessions=SESSIONs) + SUBJECTS = find_subj_list(data_root) else: + assert "sub-" in subj_id2load, "subj_id2load must start with 'sub-'" SUBJECTS = [subj_id2load] TS = {} for session in SESSIONs: TS[session] = None for subj in SUBJECTS: - subj_fldr = f"{subj}_{session}" + subj_fldr = subj + # make the file_name + TS_file = deepcopy(file_name) + if "{subj_id}" in file_name: + TS_file = TS_file.replace("{subj_id}", subj) + if "{task}" in file_name: + assert task is not None, "task must be provided" + TS_file = TS_file.replace("{task}", task) + if "{run}" in file_name: + assert run is not None, "run must be provided" + TS_file = TS_file.replace("{run}", run) time_series = np.load( - f"{data_root}/{subj_fldr}/{file_name}", allow_pickle="True" + f"{data_root}/{subj_fldr}/{TS_file}", allow_pickle="True" ).item() if TS[session] is None: TS[session] = time_series