From f0393bcab2274ca31262e82e3fa2e53634bb2ade Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Boris=20Cl=C3=A9net?= Date: Mon, 20 Nov 2023 10:14:58 +0100 Subject: [PATCH] [TEST] for the core functions + adding get_group to narps_open.data.participants --- narps_open/core/common.py | 89 ++---- narps_open/data/participants.py | 8 + tests/core/test_common.py | 260 +++++++++++++++++- tests/data/test_participants.py | 51 +++- .../data/participants/participants.tsv | 5 + 5 files changed, 341 insertions(+), 72 deletions(-) create mode 100644 tests/test_data/data/participants/participants.tsv diff --git a/narps_open/core/common.py b/narps_open/core/common.py index f364d9a8..e40d4e9a 100644 --- a/narps_open/core/common.py +++ b/narps_open/core/common.py @@ -3,7 +3,7 @@ """ Common functions to write pipelines """ -def remove_file(_, file_name): +def remove_file(_, file_name: str) -> None: """ Fully remove files generated by a Node, once they aren't needed anymore. This function is meant to be used in a Nipype Function Node. @@ -20,75 +20,46 @@ def remove_file(_, file_name): except OSError as error: print(error) -def file_in_group(file: str, group: list) -> list: +def elements_in_string(input_str: str, elements: list) -> str: #| None: """ - Return the name of the file if it contains one element of the group list. + Return input_str if it contains one element of the elements list. Return None otherwise. This function is meant to be used in a Nipype Function Node. Parameters: - - file_name: str, a single filename of the file analyse - - group: list of str, elements to be searched in file_name + - input_str: str + - elements: list of str, elements to be searched in input_str """ - if [e in file for e in group] + if any(e in input_str for e in elements): + return input_str + return None +def clean_list(input_list: list, element = None) -> list: + """ + Remove elements of input_list that are equal to element and return the resultant list. + This function is meant to be used in a Nipype Function Node. It can be used inside a + nipype.Workflow.connect call as well. + Parameters: + - input_list: list + - element: any -def get_subgroups_contrasts(copes, varcopes, subject_list: list, participants_file: str): + Returns: + - input_list with elements equal to element removed """ - Return the file list containing only the files belonging to subject in the wanted group. -contrast of parameter estimate - Parameters : - - copes: original file list selected by select_files node - - varcopes: original file list selected by select_files node - - subject_list: list of subject IDs that are analyzed - - participants_file: file containing participants characteristics + return [f for f in input_list if f != element] - Returns : - - copes_equal_indifference : a subset of copes corresponding to subjects - in the equalIndifference group - - copes_equal_range : a subset of copes corresponding to subjects - in the equalRange group - - varcopes_equal_indifference : a subset of varcopes corresponding to subjects - in the equalIndifference group - - varcopes_equal_range : a subset of varcopes corresponding to subjects - in the equalRange group - - equal_indifference_ids : a list of subject ids in the equalIndifference group - - equal_range_ids : a list of subject ids in the equalRange group +def list_intersection(list_1: list, list_2: list) -> list: """ + Returns the intersection of two lists. + This function is meant to be used in a Nipype Function Node. It can be used inside a + nipype.Workflow.connect call as well. - subject_list_sub_ids = [] # ids as written in the participants file - equal_range_ids = [] # ids as 3-digit string - equal_indifference_ids = [] # ids as 3-digit string - equal_range_sub_ids = [] # ids as written in the participants file - equal_indifference_sub_ids = [] # ids as written in the participants file - - # Reading file containing participants IDs and groups - with open(participants_file, 'rt') as file: - next(file) # skip the header - - for line in file: - info = line.strip().split() - subject_id = info[0][-3:] - subject_group = info[1] - - # Check if the participant ID was selected and sort depending on group - if subject_id in subject_list: - subject_list_sub_ids.append(info[0]) - if subject_group == 'equalIndifference': - equal_indifference_ids.append(subject_id) - equal_indifference_sub_ids.append(info[0]) - elif subject_group == 'equalRange': - equal_range_ids.append(subject_id) - equal_range_sub_ids.append(info[0]) - + Parameters: + - list_1: list + - list_2: list - # Return sorted selected copes and varcopes by group, and corresponding ids - return \ - [c for c in copes if any(i in c for i in equal_indifference_sub_ids)],\ - [c for c in copes if any(i in c for i in equal_range_sub_ids)],\ - [c for c in copes if any(i in c for i in subject_list_sub_ids)],\ - [v for v in varcopes if any(i in v for i in equal_indifference_sub_ids)],\ - [v for v in varcopes if any(i in v for i in equal_range_sub_ids)],\ - [v for v in varcopes if any(i in v for i in subject_list_sub_ids)],\ - equal_indifference_ids, equal_range_ids + Returns: + - list, the intersection of list_1 and list_2 + """ + return [e for e in list_1 if e in list_2] diff --git a/narps_open/data/participants.py b/narps_open/data/participants.py index a9cc65a5..835e834f 100644 --- a/narps_open/data/participants.py +++ b/narps_open/data/participants.py @@ -49,3 +49,11 @@ def get_participants(team_id: str) -> list: def get_participants_subset(nb_participants: int = 108) -> list: """ Return a list of participants of length nb_participants """ return get_all_participants()[0:nb_participants] + +def get_group(group_name: str) -> list: + """ Return a list containing all the participants inside the group_name group + + Warning : the subject ids are return as written in the participants file (i.e.: 'sub-*') + """ + participants = get_participants_information() + return participants.loc[participants['group'] == group_name]['participant_id'].values.tolist() diff --git a/tests/core/test_common.py b/tests/core/test_common.py index a86add9f..0d50c05b 100644 --- a/tests/core/test_common.py +++ b/tests/core/test_common.py @@ -16,7 +16,7 @@ from pathlib import Path from pytest import mark, fixture -from nipype import Node, Function +from nipype import Node, Function, Workflow from narps_open.utils.configuration import Configuration import narps_open.core.common as co @@ -59,3 +59,261 @@ def test_remove_file(remove_test_dir): # Check file is removed assert not exists(test_file_path) + + @staticmethod + @mark.unit_test + def test_node_elements_in_string(): + """ Test the elements_in_string function as a nipype.Node """ + + # Inputs + string = 'test_string' + elements_false = ['z', 'u', 'warning'] + elements_true = ['z', 'u', 'warning', '_'] + + # Create a Nipype Node using elements_in_string + test_node = Node(Function( + function = co.elements_in_string, + input_names = ['input_str', 'elements'], + output_names = ['output'] + ), name = 'test_node') + test_node.inputs.input_str = string + test_node.inputs.elements = elements_true + out = test_node.run().outputs.output + + # Check return value + assert out == string + + # Change input and check return value + test_node = Node(Function( + function = co.elements_in_string, + input_names = ['input_str', 'elements'], + output_names = ['output'] + ), name = 'test_node') + test_node.inputs.input_str = string + test_node.inputs.elements = elements_false + out = test_node.run().outputs.output + assert out is None + + @staticmethod + @mark.unit_test + def test_connect_elements_in_string(remove_test_dir): + """ Test the elements_in_string function as evaluated in a connect """ + + # Inputs + string = 'test_string' + elements_false = ['z', 'u', 'warning'] + elements_true = ['z', 'u', 'warning', '_'] + function = lambda in_value: in_value + + # Create Nodes + node_1 = Node(Function( + function = function, + input_names = ['in_value'], + output_names = ['out_value'] + ), name = 'node_1') + node_1.inputs.in_value = string + node_true = Node(Function( + function = function, + input_names = ['in_value'], + output_names = ['out_value'] + ), name = 'node_true') + node_false = Node(Function( + function = function, + input_names = ['in_value'], + output_names = ['out_value'] + ), name = 'node_false') + + # Create Workflow + test_workflow = Workflow( + base_dir = TEMPORARY_DIR, + name = 'test_workflow' + ) + test_workflow.connect([ + # elements_in_string is evaluated as part of the connection + (node_1, node_true, [( + ('out_value', co.elements_in_string, elements_true), 'in_value')]), + (node_1, node_false, [( + ('out_value', co.elements_in_string, elements_false), 'in_value')]) + ]) + + test_workflow.run() + + test_file_t = join(TEMPORARY_DIR, 'test_workflow', 'node_true', '_report', 'report.rst') + with open(test_file_t, 'r', encoding = 'utf-8') as file: + assert '* out_value : test_string' in file.read() + + test_file_f = join(TEMPORARY_DIR, 'test_workflow', 'node_false', '_report', 'report.rst') + with open(test_file_f, 'r', encoding = 'utf-8') as file: + assert '* out_value : None' in file.read() + + @staticmethod + @mark.unit_test + def test_node_clean_list(): + """ Test the clean_list function as a nipype.Node """ + + # Inputs + input_list = ['z', '_', 'u', 'warning', '_', None] + element_to_remove_1 = '_' + output_list_1 = ['z', 'u', 'warning', None] + element_to_remove_2 = None + output_list_2 = ['z', '_', 'u', 'warning', '_'] + + # Create a Nipype Node using clean_list + test_node = Node(Function( + function = co.clean_list, + input_names = ['input_list', 'element'], + output_names = ['output'] + ), name = 'test_node') + test_node.inputs.input_list = input_list + test_node.inputs.element = element_to_remove_1 + + # Check return value + assert test_node.run().outputs.output == output_list_1 + + # Change input and check return value + test_node = Node(Function( + function = co.clean_list, + input_names = ['input_list', 'element'], + output_names = ['output'] + ), name = 'test_node') + test_node.inputs.input_list = input_list + test_node.inputs.element = element_to_remove_2 + + assert test_node.run().outputs.output == output_list_2 + + @staticmethod + @mark.unit_test + def test_connect_clean_list(remove_test_dir): + """ Test the clean_list function as evaluated in a connect """ + + # Inputs + input_list = ['z', '_', 'u', 'warning', '_', None] + element_to_remove_1 = '_' + output_list_1 = ['z', 'u', 'warning', None] + element_to_remove_2 = None + output_list_2 = ['z', '_', 'u', 'warning', '_'] + function = lambda in_value: in_value + + # Create Nodes + node_0 = Node(Function( + function = function, + input_names = ['in_value'], + output_names = ['out_value'] + ), name = 'node_0') + node_0.inputs.in_value = input_list + node_1 = Node(Function( + function = function, + input_names = ['in_value'], + output_names = ['out_value'] + ), name = 'node_1') + node_2 = Node(Function( + function = function, + input_names = ['in_value'], + output_names = ['out_value'] + ), name = 'node_2') + + # Create Workflow + test_workflow = Workflow( + base_dir = TEMPORARY_DIR, + name = 'test_workflow' + ) + test_workflow.connect([ + # elements_in_string is evaluated as part of the connection + (node_0, node_1, [(('out_value', co.clean_list, element_to_remove_1), 'in_value')]), + (node_0, node_2, [(('out_value', co.clean_list, element_to_remove_2), 'in_value')]) + ]) + test_workflow.run() + + test_file_1 = join(TEMPORARY_DIR, 'test_workflow', 'node_1', '_report', 'report.rst') + with open(test_file_1, 'r', encoding = 'utf-8') as file: + assert f'* out_value : {output_list_1}' in file.read() + + test_file_2 = join(TEMPORARY_DIR, 'test_workflow', 'node_2', '_report', 'report.rst') + with open(test_file_2, 'r', encoding = 'utf-8') as file: + assert f'* out_value : {output_list_2}' in file.read() + + @staticmethod + @mark.unit_test + def test_node_list_intersection(): + """ Test the list_intersection function as a nipype.Node """ + + # Inputs / ouptuts + input_list_1 = ['001', '002', '003', '004'] + input_list_2 = ['002', '004'] + input_list_3 = ['001', '003', '005'] + output_list_1 = ['002', '004'] + output_list_2 = ['001', '003'] + + # Create a Nipype Node using list_intersection + test_node = Node(Function( + function = co.list_intersection, + input_names = ['list_1', 'list_2'], + output_names = ['output'] + ), name = 'test_node') + test_node.inputs.list_1 = input_list_1 + test_node.inputs.list_2 = input_list_2 + + # Check return value + assert test_node.run().outputs.output == output_list_1 + + # Change input and check return value + test_node = Node(Function( + function = co.list_intersection, + input_names = ['list_1', 'list_2'], + output_names = ['output'] + ), name = 'test_node') + test_node.inputs.list_1 = input_list_1 + test_node.inputs.list_2 = input_list_3 + + assert test_node.run().outputs.output == output_list_2 + + @staticmethod + @mark.unit_test + def test_connect_list_intersection(remove_test_dir): + """ Test the list_intersection function as evaluated in a connect """ + + # Inputs / ouptuts + input_list_1 = ['001', '002', '003', '004'] + input_list_2 = ['002', '004'] + input_list_3 = ['001', '003', '005'] + output_list_1 = ['002', '004'] + output_list_2 = ['001', '003'] + function = lambda in_value: in_value + + # Create Nodes + node_0 = Node(Function( + function = function, + input_names = ['in_value'], + output_names = ['out_value'] + ), name = 'node_0') + node_0.inputs.in_value = input_list_1 + node_1 = Node(Function( + function = function, + input_names = ['in_value'], + output_names = ['out_value'] + ), name = 'node_1') + node_2 = Node(Function( + function = function, + input_names = ['in_value'], + output_names = ['out_value'] + ), name = 'node_2') + + # Create Workflow + test_workflow = Workflow( + base_dir = TEMPORARY_DIR, + name = 'test_workflow' + ) + test_workflow.connect([ + # elements_in_string is evaluated as part of the connection + (node_0, node_1, [(('out_value', co.list_intersection, input_list_2), 'in_value')]), + (node_0, node_2, [(('out_value', co.list_intersection, input_list_3), 'in_value')]) + ]) + test_workflow.run() + + test_file_1 = join(TEMPORARY_DIR, 'test_workflow', 'node_1', '_report', 'report.rst') + with open(test_file_1, 'r', encoding = 'utf-8') as file: + assert f'* out_value : {output_list_1}' in file.read() + + test_file_2 = join(TEMPORARY_DIR, 'test_workflow', 'node_2', '_report', 'report.rst') + with open(test_file_2, 'r', encoding = 'utf-8') as file: + assert f'* out_value : {output_list_2}' in file.read() diff --git a/tests/data/test_participants.py b/tests/data/test_participants.py index d30cd23e..f36f0a05 100644 --- a/tests/data/test_participants.py +++ b/tests/data/test_participants.py @@ -10,28 +10,46 @@ pytest -q test_participants.py pytest -q test_participants.py -k """ +from os.path import join -from pytest import mark +from pytest import mark, fixture import narps_open.data.participants as part +from narps_open.utils.configuration import Configuration + +@fixture +def mock_participants_data(mocker): + """ A fixture to provide mocked data from the test_data directory """ + + mocker.patch( + 'narps_open.data.participants.Configuration', + return_value = { + 'directories': { + 'dataset': join( + Configuration()['directories']['test_data'], + 'data', 'participants') + } + } + ) class TestParticipants: """ A class that contains all the unit tests for the participants module.""" @staticmethod @mark.unit_test - def test_get_participants_information(): + def test_get_participants_information(mock_participants_data): """ Test the get_participants_information function """ - """p_info = part.get_participants_information() - assert len(p_info) == 108 - assert p_info.at[5, 'participant_id'] == 'sub-006' - assert p_info.at[5, 'group'] == 'equalRange' - assert p_info.at[5, 'gender'] == 'M' - assert p_info.at[5, 'age'] == 30 - assert p_info.at[12, 'participant_id'] == 'sub-015' - assert p_info.at[12, 'group'] == 'equalIndifference' - assert p_info.at[12, 'gender'] == 'F' - assert p_info.at[12, 'age'] == 26""" + + p_info = part.get_participants_information() + assert len(p_info) == 4 + assert p_info.at[1, 'participant_id'] == 'sub-002' + assert p_info.at[1, 'group'] == 'equalRange' + assert p_info.at[1, 'gender'] == 'M' + assert p_info.at[1, 'age'] == 25 + assert p_info.at[2, 'participant_id'] == 'sub-003' + assert p_info.at[2, 'group'] == 'equalIndifference' + assert p_info.at[2, 'gender'] == 'F' + assert p_info.at[2, 'age'] == 27 @staticmethod @mark.unit_test @@ -87,3 +105,12 @@ def test_get_participants_subset(): assert len(participants_list) == 80 assert participants_list[0] == '020' assert participants_list[-1] == '003' + + @staticmethod + @mark.unit_test + def test_get_group(mock_participants_data): + """ Test the get_group function """ + + assert part.get_group('') == [] + assert part.get_group('equalRange') == ['sub-002', 'sub-004'] + assert part.get_group('equalIndifference') == ['sub-001', 'sub-003'] diff --git a/tests/test_data/data/participants/participants.tsv b/tests/test_data/data/participants/participants.tsv new file mode 100644 index 00000000..312dbcde --- /dev/null +++ b/tests/test_data/data/participants/participants.tsv @@ -0,0 +1,5 @@ +participant_id group gender age +sub-001 equalIndifference M 24 +sub-002 equalRange M 25 +sub-003 equalIndifference F 27 +sub-004 equalRange M 25 \ No newline at end of file