Skip to content

Commit

Permalink
Group level of O6R6
Browse files Browse the repository at this point in the history
  • Loading branch information
bclenet committed Apr 15, 2024
1 parent 01aa0c8 commit 1278530
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 27 deletions.
53 changes: 43 additions & 10 deletions narps_open/pipelines/team_O6R6.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def __init__(self):
self.team_id = 'O6R6'
self.contrast_list = ['1', '2']
self.run_level_contrasts = [
('effect_of_gain', 'T', ['gain', 'loss'], [1, 0]),
('effect_of_loss', 'T', ['gain', 'loss'], [0, 1])
('effect_of_gain', 'T', ['gain_trial', 'loss_trial'], [1, 0]),
('effect_of_loss', 'T', ['gain_trial', 'loss_trial'], [0, 1])
]

def get_preprocessing(self):
Expand Down Expand Up @@ -75,7 +75,7 @@ def get_subject_information(event_file, group):
if float(info[4]) > 0.0: # Response time exists
onsets_trial.append(float(info[0]))
durations_trial.append(float(info[4]))
weights_trial.append(1.0)
weights_trial.append(1.0)

gain_amount = float(info[2])
loss_amount = float(info[3])
Expand Down Expand Up @@ -107,6 +107,22 @@ def get_subject_information(event_file, group):
)
]

def get_subject_group(subject_id: str):
"""
Return the group of the subject (either 'equalRange' or 'equalIndifference').
Parameters :
- subject_id : str, the subject identifier
Returns :
- group : str, the group to which belong the subject
"""
from narps_open.data.participants import get_group

if subject_id in get_group('equalRange'):
return 'equalRange'
return 'equalIndifference'

def get_run_level_analysis(self):
"""
Create the run level analysis workflow.
Expand Down Expand Up @@ -146,13 +162,25 @@ def get_run_level_analysis(self):
smoothing_func.inputs.fwhm = self.fwhm
run_level.connect(select_files, 'func', smoothing_func, 'in_file')

# Function Node get_subject_group
# This returns the name of the subject's group
subject_group = Node(Function(
function = self.get_subject_group,
input_names = ['subject_id'],
output_names = ['group']
),
name = 'subject_group'
)
run_level.connect(information_source, 'subject_id', subject_group, 'input_str')

# Get Subject Info - get subject specific condition information
subject_information = Node(Function(
function = self.get_subject_information,
input_names = ['event_file'],
input_names = ['event_file', 'group'],
output_names = ['subject_info']
), name = 'subject_information')
run_level.connect(select_files, 'events', subject_information, 'event_file')
run_level.connect(subject_group, 'group', subject_information, 'group')

# SpecifyModel Node - Generate run level model
specify_model = Node(SpecifyModel(), name = 'specify_model')
Expand All @@ -164,7 +192,7 @@ def get_run_level_analysis(self):

# Level1Design Node - Generate files for run level computation
model_design = Node(Level1Design(), name = 'model_design')
model_design.inputs.bases = {'dgamma' : {'derivs' : False }}
model_design.inputs.bases = {'dgamma' : {'derivs' : True }}
model_design.inputs.interscan_interval = TaskInformation()['RepetitionTime']
model_design.inputs.model_serial_correlations = True
model_design.inputs.contrasts = self.run_level_contrasts
Expand Down Expand Up @@ -502,11 +530,16 @@ def get_group_level_analysis_sub_workflow(self, method):
group_level.connect(specify_model, 'design_con', estimate_model, 't_con_file')
group_level.connect(specify_model, 'design_grp', estimate_model, 'cov_split_file')

# Cluster Node - Perform clustering on statistical output
cluster = Node(Cluster(), name = 'cluster')
# TODO : add parameters
group_level.connect(estimate_model, 'zstats', cluster, 'in_file')
group_level.connect(estimate_model, 'copes', cluster, 'cope_file')
# Randomise Node - Perform clustering on statistical output
randomise = Node(Randomise(),
name = 'randomise',
synchronize = True)
randomise.inputs.tfce = True
randomise.inputs.num_perm = 5000
randomise.inputs.c_thresh = 0.05
group_level.connect(mask_intersection, 'out_file', randomise, 'mask')
group_level.connect(estimate_model, 'zstats', randomise, 'in_file')
group_level.connect(estimate_model, 'copes', randomise, 'tcon')

# Datasink Node - Save important files
data_sink = Node(DataSink(), name = 'data_sink')
Expand Down
34 changes: 17 additions & 17 deletions tests/pipelines/test_team_O6R6.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ def test_outputs():

# 1 - 1 subject outputs
pipeline.subject_list = ['001']
helpers.test_pipeline_outputs(pipeline, [0, 4*1*2*4, 4*2*1 + 2*1, 8*4*2 + 4*4, 18])
helpers.test_pipeline_outputs(pipeline, [0, 2*1*4*4, 2*4*1 + 2*1, 8*4*2 + 4*4, 18])

# 2 - 4 subjects outputs
pipeline.subject_list = ['001', '002', '003', '004']
helpers.test_pipeline_outputs(pipeline, [0, 4*4*2*4, 4*2*4 + 2*4, 8*4*2 + 4*4, 18])
helpers.test_pipeline_outputs(pipeline, [0, 2*4*4*4, 2*4*4 + 2*4, 8*4*2 + 4*4, 18])

@staticmethod
@mark.unit_test
Expand All @@ -75,18 +75,18 @@ def test_subject_information():
assert bunch.conditions == ['trial', 'gain_trial', 'loss_trial']
helpers.compare_float_2d_arrays(bunch.onsets, [
[4.071, 11.834, 27.535, 36.435],
[4.071, 11.834, 27.535, 36.435],
[4.071, 11.834, 27.535, 36.435]
])
[4.071, 11.834],
[27.535, 36.435]
])
helpers.compare_float_2d_arrays(bunch.durations, [
[2.388, 2.289, 2.08, 2.288],
[2.388, 2.289, 2.08, 2.288],
[2.388, 2.289, 2.08, 2.288]
[2.388, 2.289],
[2.08, 2.288]
])
helpers.compare_float_2d_arrays(bunch.amplitudes, [
[1.0, 1.0, 1.0, 1.0, 1.0],
[14.0, 34.0, 38.0, 10.0, 16.0],
[6.0, 14.0, 19.0, 15.0, 17.0]
[1.0, 1.0, 1.0, 1.0],
[3.0, 13.0],
[3.5, 4.5]
])

# Compare bunches to expected
Expand All @@ -95,18 +95,18 @@ def test_subject_information():
assert bunch.conditions == ['trial', 'gain_trial', 'loss_trial']
helpers.compare_float_2d_arrays(bunch.onsets, [
[4.071, 11.834, 27.535, 36.435],
[4.071, 11.834, 27.535, 36.435],
[4.071, 11.834, 27.535, 36.435]
[4.071, 11.834],
[27.535, 36.435]
])
helpers.compare_float_2d_arrays(bunch.durations, [
[2.388, 2.289, 2.08, 2.288],
[2.388, 2.289, 2.08, 2.288],
[2.388, 2.289, 2.08, 2.288]
[2.388, 2.289],
[2.08, 2.288]
])
helpers.compare_float_2d_arrays(bunch.amplitudes, [
[1.0, 1.0, 1.0, 1.0, 1.0],
[14.0, 34.0, 38.0, 10.0, 16.0],
[6.0, 14.0, 19.0, 15.0, 17.0]
[1.0, 1.0, 1.0, 1.0],
[10.0, 30.0],
[11.0, 13.0]
])

@staticmethod
Expand Down

0 comments on commit 1278530

Please sign in to comment.