diff --git a/narps_open/pipelines/team_O6R6.py b/narps_open/pipelines/team_O6R6.py index c703e00a..4bd9b8ab 100644 --- a/narps_open/pipelines/team_O6R6.py +++ b/narps_open/pipelines/team_O6R6.py @@ -35,8 +35,8 @@ def __init__(self): self.team_id = 'O6R6' self.contrast_list = ['1', '2'] self.run_level_contrasts = [ - ('effect_of_gain', 'T', ['gain', 'loss'], [1, 0]), - ('effect_of_loss', 'T', ['gain', 'loss'], [0, 1]) + ('effect_of_gain', 'T', ['gain_trial', 'loss_trial'], [1, 0]), + ('effect_of_loss', 'T', ['gain_trial', 'loss_trial'], [0, 1]) ] def get_preprocessing(self): @@ -75,7 +75,7 @@ def get_subject_information(event_file, group): if float(info[4]) > 0.0: # Response time exists onsets_trial.append(float(info[0])) durations_trial.append(float(info[4])) - weights_trial.append(1.0) + weights_trial.append(1.0) gain_amount = float(info[2]) loss_amount = float(info[3]) @@ -107,6 +107,22 @@ def get_subject_information(event_file, group): ) ] + def get_subject_group(subject_id: str): + """ + Return the group of the subject (either 'equalRange' or 'equalIndifference'). + + Parameters : + - subject_id : str, the subject identifier + + Returns : + - group : str, the group to which belong the subject + """ + from narps_open.data.participants import get_group + + if subject_id in get_group('equalRange'): + return 'equalRange' + return 'equalIndifference' + def get_run_level_analysis(self): """ Create the run level analysis workflow. @@ -146,13 +162,25 @@ def get_run_level_analysis(self): smoothing_func.inputs.fwhm = self.fwhm run_level.connect(select_files, 'func', smoothing_func, 'in_file') + # Function Node get_subject_group + # This returns the name of the subject's group + subject_group = Node(Function( + function = self.get_subject_group, + input_names = ['subject_id'], + output_names = ['group'] + ), + name = 'subject_group' + ) + run_level.connect(information_source, 'subject_id', subject_group, 'input_str') + # Get Subject Info - get subject specific condition information subject_information = Node(Function( function = self.get_subject_information, - input_names = ['event_file'], + input_names = ['event_file', 'group'], output_names = ['subject_info'] ), name = 'subject_information') run_level.connect(select_files, 'events', subject_information, 'event_file') + run_level.connect(subject_group, 'group', subject_information, 'group') # SpecifyModel Node - Generate run level model specify_model = Node(SpecifyModel(), name = 'specify_model') @@ -164,7 +192,7 @@ def get_run_level_analysis(self): # Level1Design Node - Generate files for run level computation model_design = Node(Level1Design(), name = 'model_design') - model_design.inputs.bases = {'dgamma' : {'derivs' : False }} + model_design.inputs.bases = {'dgamma' : {'derivs' : True }} model_design.inputs.interscan_interval = TaskInformation()['RepetitionTime'] model_design.inputs.model_serial_correlations = True model_design.inputs.contrasts = self.run_level_contrasts @@ -502,11 +530,16 @@ def get_group_level_analysis_sub_workflow(self, method): group_level.connect(specify_model, 'design_con', estimate_model, 't_con_file') group_level.connect(specify_model, 'design_grp', estimate_model, 'cov_split_file') - # Cluster Node - Perform clustering on statistical output - cluster = Node(Cluster(), name = 'cluster') - # TODO : add parameters - group_level.connect(estimate_model, 'zstats', cluster, 'in_file') - group_level.connect(estimate_model, 'copes', cluster, 'cope_file') + # Randomise Node - Perform clustering on statistical output + randomise = Node(Randomise(), + name = 'randomise', + synchronize = True) + randomise.inputs.tfce = True + randomise.inputs.num_perm = 5000 + randomise.inputs.c_thresh = 0.05 + group_level.connect(mask_intersection, 'out_file', randomise, 'mask') + group_level.connect(estimate_model, 'zstats', randomise, 'in_file') + group_level.connect(estimate_model, 'copes', randomise, 'tcon') # Datasink Node - Save important files data_sink = Node(DataSink(), name = 'data_sink') diff --git a/tests/pipelines/test_team_O6R6.py b/tests/pipelines/test_team_O6R6.py index f46cc930..ef13926b 100644 --- a/tests/pipelines/test_team_O6R6.py +++ b/tests/pipelines/test_team_O6R6.py @@ -51,11 +51,11 @@ def test_outputs(): # 1 - 1 subject outputs pipeline.subject_list = ['001'] - helpers.test_pipeline_outputs(pipeline, [0, 4*1*2*4, 4*2*1 + 2*1, 8*4*2 + 4*4, 18]) + helpers.test_pipeline_outputs(pipeline, [0, 2*1*4*4, 2*4*1 + 2*1, 8*4*2 + 4*4, 18]) # 2 - 4 subjects outputs pipeline.subject_list = ['001', '002', '003', '004'] - helpers.test_pipeline_outputs(pipeline, [0, 4*4*2*4, 4*2*4 + 2*4, 8*4*2 + 4*4, 18]) + helpers.test_pipeline_outputs(pipeline, [0, 2*4*4*4, 2*4*4 + 2*4, 8*4*2 + 4*4, 18]) @staticmethod @mark.unit_test @@ -75,18 +75,18 @@ def test_subject_information(): assert bunch.conditions == ['trial', 'gain_trial', 'loss_trial'] helpers.compare_float_2d_arrays(bunch.onsets, [ [4.071, 11.834, 27.535, 36.435], - [4.071, 11.834, 27.535, 36.435], - [4.071, 11.834, 27.535, 36.435] - ]) + [4.071, 11.834], + [27.535, 36.435] + ]) helpers.compare_float_2d_arrays(bunch.durations, [ [2.388, 2.289, 2.08, 2.288], - [2.388, 2.289, 2.08, 2.288], - [2.388, 2.289, 2.08, 2.288] + [2.388, 2.289], + [2.08, 2.288] ]) helpers.compare_float_2d_arrays(bunch.amplitudes, [ - [1.0, 1.0, 1.0, 1.0, 1.0], - [14.0, 34.0, 38.0, 10.0, 16.0], - [6.0, 14.0, 19.0, 15.0, 17.0] + [1.0, 1.0, 1.0, 1.0], + [3.0, 13.0], + [3.5, 4.5] ]) # Compare bunches to expected @@ -95,18 +95,18 @@ def test_subject_information(): assert bunch.conditions == ['trial', 'gain_trial', 'loss_trial'] helpers.compare_float_2d_arrays(bunch.onsets, [ [4.071, 11.834, 27.535, 36.435], - [4.071, 11.834, 27.535, 36.435], - [4.071, 11.834, 27.535, 36.435] + [4.071, 11.834], + [27.535, 36.435] ]) helpers.compare_float_2d_arrays(bunch.durations, [ [2.388, 2.289, 2.08, 2.288], - [2.388, 2.289, 2.08, 2.288], - [2.388, 2.289, 2.08, 2.288] + [2.388, 2.289], + [2.08, 2.288] ]) helpers.compare_float_2d_arrays(bunch.amplitudes, [ - [1.0, 1.0, 1.0, 1.0, 1.0], - [14.0, 34.0, 38.0, 10.0, 16.0], - [6.0, 14.0, 19.0, 15.0, 17.0] + [1.0, 1.0, 1.0, 1.0], + [10.0, 30.0], + [11.0, 13.0] ]) @staticmethod