From fcc95d42cefb41379d87a7045c9f0732f6943ea5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Boris=20Cl=C3=A9net?= Date: Thu, 18 Jan 2024 10:58:39 +0100 Subject: [PATCH] [BUG] session information --- narps_open/pipelines/team_T54A.py | 8 +- tests/pipelines/test_team_T54A.py | 97 +++++++++++++++-------- tests/test_data/pipelines/events_resp.tsv | 5 ++ 3 files changed, 70 insertions(+), 40 deletions(-) create mode 100644 tests/test_data/pipelines/events_resp.tsv diff --git a/narps_open/pipelines/team_T54A.py b/narps_open/pipelines/team_T54A.py index 56156d46..66da848f 100644 --- a/narps_open/pipelines/team_T54A.py +++ b/narps_open/pipelines/team_T54A.py @@ -55,7 +55,7 @@ def get_subject_information(event_file): condition_names = ['trial', 'gain', 'loss', 'difficulty', 'response', 'missed'] onsets = {} durations = {} - amplitude = {} + amplitudes = {} for condition in condition_names: # Create dictionary items with empty lists @@ -72,7 +72,7 @@ def get_subject_information(event_file): if info[5] != 'NoResp': onsets['trial'].append(float(info[0])) durations['trial'].append(float(info[4])) - amplitudes['trial'].append(float(1)) + amplitudes['trial'].append(1.0) onsets['gain'].append(float(info[0])) durations['gain'].append(float(info[4])) amplitudes['gain'].append(float(info[2])) @@ -258,9 +258,7 @@ def get_run_level_analysis(self): ('ev_files', 'ev_files'), ('fsf_files', 'fsf_file')]), (smoothing_func, model_estimate, [('out_file', 'in_file')]), - (model_generation, model_estimate, [ - ('con_file', 'tcon_file'), - ('design_file', 'design_file')]), + (model_generation, model_estimate, [('design_file', 'design_file')]), (smoothing_func, remove_smoothed_files, [('out_file', 'file_name')]), (model_estimate, remove_smoothed_files, [('results_dir', '_')]), (model_estimate, data_sink, [('results_dir', 'run_level_analysis.@results')]), diff --git a/tests/pipelines/test_team_T54A.py b/tests/pipelines/test_team_T54A.py index b585d9b0..569580b1 100644 --- a/tests/pipelines/test_team_T54A.py +++ b/tests/pipelines/test_team_T54A.py @@ -33,6 +33,14 @@ def remove_test_dir(): yield # test runs here rmtree(TEMPORARY_DIR, ignore_errors = True) +def compare_float_2d_arrays(array_1, array_2): + """ Assert array_1 and array_2 are close enough """ + + assert len(array_1) == len(array_2) + for reference_array, test_array in zip(array_1, array_2): + assert len(reference_array) == len(test_array) + assert isclose(reference_array, test_array).all() + class TestPipelinesTeamT54A: """ A class that contains all the unit tests for the PipelineTeamT54A class.""" @@ -88,53 +96,72 @@ def test_outputs(): def test_subject_information(): """ Test the get_subject_information method """ - event_file_path = join( - Configuration()['directories']['test_data'], 'pipelines', 'events.tsv') + # Get test files + test_file = join(Configuration()['directories']['test_data'], 'pipelines', 'events.tsv') + test_file_2 = join(Configuration()['directories']['test_data'], + 'pipelines', 'events_resp.tsv') - information = PipelineTeamT54A.get_subject_information(event_file_path)[0] + # Prepare several scenarii + info_missed = PipelineTeamT54A.get_subject_information(test_file) + info_ok = PipelineTeamT54A.get_subject_information(test_file_2) - assert isinstance(information, Bunch) - assert information.conditions == [ - 'trial', - 'gain', - 'loss', - 'difficulty', - 'response', - 'missed' - ] - - reference_amplitudes = [ - [1.0, 1.0, 1.0, 1.0], - [14.0, 34.0, 10.0, 16.0], - [6.0, 14.0, 15.0, 17.0], - [1.0, 3.0, 10.0, 9.0], - [1.0, 1.0, 1.0, 1.0], - [1.0] - ] - for reference_array, test_array in zip(reference_amplitudes, information.amplitudes): - assert isclose(reference_array, test_array).all() - - reference_durations = [ + # Compare bunches to expected + bunch = info_missed[0] + assert isinstance(bunch, Bunch) + assert bunch.conditions == ['trial', 'gain', 'loss', 'difficulty', 'response', 'missed'] + compare_float_2d_arrays(bunch.onsets, [ + [4.071, 11.834, 27.535, 36.435], + [4.071, 11.834, 27.535, 36.435], + [4.071, 11.834, 27.535, 36.435], + [4.071, 11.834, 27.535, 36.435], + [6.459, 14.123, 29.615, 38.723], + [19.535] + ]) + compare_float_2d_arrays(bunch.durations, [ [2.388, 2.289, 2.08, 2.288], [2.388, 2.289, 2.08, 2.288], [2.388, 2.289, 2.08, 2.288], [2.388, 2.289, 2.08, 2.288], [0.0, 0.0, 0.0, 0.0], [0.0] - ] - for reference_array, test_array in zip(reference_durations, information.durations): - assert isclose(reference_array, test_array).all() - - reference_onsets = [ + ]) + compare_float_2d_arrays(bunch.amplitudes, [ + [1.0, 1.0, 1.0, 1.0], + [14.0, 34.0, 10.0, 16.0], + [6.0, 14.0, 15.0, 17.0], + [1.0, 3.0, 10.0, 9.0], + [1.0, 1.0, 1.0, 1.0], + [1.0] + ]) + assert bunch.regressor_names == None + assert bunch.regressors == None + + bunch = info_ok[0] + assert isinstance(bunch, Bunch) + assert bunch.conditions == ['trial', 'gain', 'loss', 'difficulty', 'response'] + compare_float_2d_arrays(bunch.onsets, [ [4.071, 11.834, 27.535, 36.435], [4.071, 11.834, 27.535, 36.435], [4.071, 11.834, 27.535, 36.435], [4.071, 11.834, 27.535, 36.435], - [6.459, 14.123, 29.615, 38.723], - [19.535] - ] - for reference_array, test_array in zip(reference_onsets, information.onsets): - assert isclose(reference_array, test_array).all() + [6.459, 14.123, 29.615, 38.723] + ]) + compare_float_2d_arrays(bunch.durations, [ + [2.388, 2.289, 2.08, 2.288], + [2.388, 2.289, 2.08, 2.288], + [2.388, 2.289, 2.08, 2.288], + [2.388, 2.289, 2.08, 2.288], + [0.0, 0.0, 0.0, 0.0] + ]) + compare_float_2d_arrays(bunch.amplitudes, [ + [1.0, 1.0, 1.0, 1.0], + [14.0, 34.0, 10.0, 16.0], + [6.0, 14.0, 15.0, 17.0], + [1.0, 3.0, 10.0, 9.0], + [1.0, 1.0, 1.0, 1.0] + ]) + assert bunch.regressor_names == None + assert bunch.regressors == None @staticmethod @mark.unit_test diff --git a/tests/test_data/pipelines/events_resp.tsv b/tests/test_data/pipelines/events_resp.tsv new file mode 100644 index 00000000..dd5ea1a5 --- /dev/null +++ b/tests/test_data/pipelines/events_resp.tsv @@ -0,0 +1,5 @@ +onset duration gain loss RT participant_response +4.071 4 14 6 2.388 weakly_accept +11.834 4 34 14 2.289 strongly_accept +27.535 4 10 15 2.08 strongly_reject +36.435 4 16 17 2.288 weakly_reject \ No newline at end of file