From 2f63f18255a53e4bf2c41b4cb4762118dfa2b5ad Mon Sep 17 00:00:00 2001 From: Tom Close Date: Tue, 14 May 2024 19:35:09 +1000 Subject: [PATCH] finally sorted out nested-workflow input/output propagation --- nipype2pydra/statements/workflow_build.py | 102 +++++++------ nipype2pydra/workflow.py | 172 +++++++++++++--------- 2 files changed, 164 insertions(+), 110 deletions(-) diff --git a/nipype2pydra/statements/workflow_build.py b/nipype2pydra/statements/workflow_build.py index 62c1308..79b356c 100644 --- a/nipype2pydra/statements/workflow_build.py +++ b/nipype2pydra/statements/workflow_build.py @@ -177,24 +177,21 @@ def targets(self): @property def wf_in(self): - if self.source_name is None: + try: + self.workflow_converter.get_input_from_conn(self) + except KeyError: + return False + else: return True - for inpt in self.workflow_converter.inputs.values(): - if self.target_name == inpt.node_name and str(self.target_in) == inpt.field: - return True - return False @property def wf_out(self): - if self.target_name is None: + try: + self.workflow_converter.get_output_from_conn(self) + except KeyError: + return False + else: return True - for output in self.workflow_converter.outputs.values(): - if ( - self.source_name == output.node_name - and str(self.source_out) == output.field - ): - return True - return False @cached_property def conditional(self): @@ -215,29 +212,11 @@ def workflow_variable(self): @property def wf_in_name(self): - if not self.wf_in: - raise ValueError( - f"Cannot get wf_in_name for {self} as it is not a workflow input" - ) - if self.source_name is None: - return ( - self.source_out - if not isinstance(self.source_out, DynamicField) - else self.source_out.varname - ) - return self.workflow_converter.get_input(self.target_in, self.target_name).name + return self.workflow_converter.get_input_from_conn(self).name @property def wf_out_name(self): - if not self.wf_out: - raise ValueError( - f"Cannot get wf_out_name for {self} as it is not a workflow output" - ) - if self.target_name is None: - return self.target_in - return self.workflow_converter.get_output( - self.source_out, self.source_name - ).name + return self.workflow_converter.get_output_from_conn(self).name def __str__(self): if not self.include: @@ -274,7 +253,7 @@ def __str__(self): # to add an "identity" node to pass it through intf_name = f"{base_task_name}_identity" code_str += ( - f"{self.indent}@pydra.mark.task\n" + f"\n{self.indent}@pydra.mark.task\n" f"{self.indent}def {intf_name}({self.wf_in_name}: ty.Any) -> ty.Any:\n" f"{self.indent} return {self.wf_in_name}\n\n" f"{self.indent}{self.workflow_variable}.add(" @@ -669,11 +648,29 @@ def add_input_connection(self, conn: ConnectionStatement): else: target_in = conn.target_in target_name = None - if target_name == self.nested_workflow.input_node: + # Check for replacements for the given target field + replacements = [ + i + for i in self.nested_workflow.inputs.values() + if any(n == target_name and f == target_in for n, f in i.replaces) + ] + if len(replacements) > 1: + raise ValueError( + f"Multiple inputs found for replacements of '{target_in}' " + f"field in '{target_name}' node in '{self.name}' workflow: " + + ", ".join(str(m) for m in replacements) + ) + elif len(replacements) == 1: + nested_input = replacements[0] target_name = None - nested_input = self.nested_workflow.get_input( - target_in, node_name=target_name, create=True - ) + else: + # If no replacements, create an input for the nested workflow + if target_name == self.nested_workflow.input_node: + target_name = None + nested_input = self.nested_workflow.make_input( + target_in, + node_name=target_name, + ) conn.target_in = nested_input.name super().add_input_connection(conn) if target_name: @@ -716,11 +713,26 @@ def add_output_connection(self, conn: ConnectionStatement): else: source_out = conn.source_out source_name = None - if source_name == self.nested_workflow.output_node: + replacements = [ + o + for o in self.nested_workflow.outputs.values() + if any(n == source_name and f == source_out for n, f in o.replaces) + ] + if len(replacements) > 1: + raise KeyError( + f"Multiple outputs found for replacements of '{source_out}' " + f"field in '{source_name}' node in '{self.name}' workflow: " + + ", ".join(str(m) for m in replacements) + ) + elif len(replacements) == 1: + nested_output = replacements[0] source_name = None - nested_output = self.nested_workflow.get_output( - source_out, node_name=source_name, create=True - ) + else: + if source_name == self.nested_workflow.output_node: + source_name = None + nested_output = self.nested_workflow.make_output( + source_out, node_name=source_name + ) conn.source_out = nested_output.name super().add_output_connection(conn) if source_name: @@ -759,7 +771,7 @@ def __str__(self): parts = self.attribute.split(".") nested_node_name = parts[2] attribute_name = parts[3] - target_in = nested_wf.get_input(attribute_name, nested_node_name).name + target_in = nested_wf.make_input(attribute_name, nested_node_name).name attribute = ".".join(parts[:2] + [target_in] + parts[4:]) workflow_variable = self.nodes[0].workflow_variable assert (n.workflow_variable == workflow_variable for n in self.nodes) @@ -782,6 +794,10 @@ def matches(cls, stmt, node_names: ty.List[str]) -> bool: return False return bool(cls.match_re(node_names).match(stmt)) + @property + def conditional(self): + return len(self.indent) != 4 + @classmethod def parse( cls, statement: str, workflow_converter: "WorkflowConverter" diff --git a/nipype2pydra/workflow.py b/nipype2pydra/workflow.py index 181dfdd..29c6efe 100644 --- a/nipype2pydra/workflow.py +++ b/nipype2pydra/workflow.py @@ -60,8 +60,7 @@ class WorkflowInterfaceField: "help": "Name of the input/output field in the converted workflow", }, ) - node_name: str = attrs.field( - converter=str, + node_name: ty.Optional[str] = attrs.field( metadata={ "help": "The name of the node that the input/output is connected to", }, @@ -115,7 +114,7 @@ def type_repr_(t): if t in (ty.Any, ty.Union, ty.List, ty.Tuple): return f"ty.{t.__name__}" elif issubclass(t, Field): - return t.primative.__name__ + return t.primitive.__name__ elif issubclass(t, FileSet): return t.__name__ else: @@ -322,40 +321,10 @@ class WorkflowConverter: _unprocessed_connections: ty.List[ConnectionStatement] = attrs.field( factory=list, repr=False ) - _input_mapping: ty.Dict[str, WorkflowInput] = attrs.field( - factory=dict, - init=False, - repr=False, - metadata={ - "help": ( - "The mapping of node and field names to the inputs they are connected to" - ), - }, - ) - _output_mapping: ty.Dict[str, WorkflowOutput] = attrs.field( - factory=dict, - init=False, - repr=False, - metadata={ - "help": ( - "The mapping of node and field names to the inputs they are connected to" - ), - }, - ) def __attrs_post_init__(self): if self.workflow_variable is None: self.workflow_variable = self.workflow_variable_default() - for inpt in self.inputs.values(): - self._input_mapping[(inpt.node_name, inpt.field)] = inpt - self._input_mapping.update( - {(node_name, field): inpt for node_name, field in inpt.replaces} - ) - for outpt in self.outputs.values(): - self._output_mapping[(outpt.node_name, outpt.field)] = outpt - self._output_mapping.update( - {(node_name, field): outpt for node_name, field in outpt.replaces} - ) @nipype_module.validator def _nipype_module_validator(self, _, value): @@ -408,55 +377,120 @@ def exported_inputs(self): def exported_outputs(self): return (o for o in self.outputs.values() if o.export) - def get_input( - self, field_name: str, node_name: ty.Optional[str] = None, create: bool = False + def get_input_from_conn(self, conn: ConnectionStatement) -> WorkflowInput: + """ + Returns the name of the input field in the workflow for the given node and field + escaped by the prefix of the node if present""" + if conn.source_name is None or conn.source_name == self.input_node: + return self.make_input(field_name=conn.source_out) + elif conn.target_name is None: + raise KeyError( + f"Could not find output corresponding to '{conn.source_out}' input" + ) + return self.make_input( + field_name=conn.target_in, node_name=conn.target_name, input_node_only=True + ) + + def get_output_from_conn(self, conn: ConnectionStatement) -> WorkflowOutput: + """ + Returns the name of the input field in the workflow for the given node and field + escaped by the prefix of the node if present""" + if conn.target_name is None or conn.target_name == self.output_node: + return self.make_output(field_name=conn.target_in) + elif conn.source_name is None: + raise KeyError( + f"Could not find output corresponding to '{conn.source_out}' input" + ) + return self.make_output( + field_name=conn.source_out, + node_name=conn.source_name, + output_node_only=True, + ) + + def make_input( + self, + field_name: str, + node_name: ty.Optional[str] = None, + input_node_only: bool = False, ) -> WorkflowInput: """ Returns the name of the input field in the workflow for the given node and field escaped by the prefix of the node if present""" field_name = str(field_name) - try: - return self._input_mapping[(node_name, field_name)] - except KeyError: + matching = [ + i + for i in self.inputs.values() + if i.node_name == node_name and i.field == field_name + ] + if len(matching) > 1: + raise KeyError( + f"Multiple inputs found for '{field_name}' field in " + f"'{node_name}' node in '{self.name}' workflow" + ) + elif len(matching) == 1: + return matching[0] + else: if node_name is None or node_name == self.input_node: inpt_name = field_name - elif create: - inpt_name = f"{node_name}_{field_name}" - else: + elif input_node_only: raise KeyError( - f"Unrecognised output corresponding to {node_name}:{field_name} field, " - "set create=True to auto-create" + f"Could not find input corresponding to '{field_name}' field in " + f"'{node_name}' node in '{self.name}' workflow, set " + "`only_input_node=False` to make an input for any node input" + ) from None + else: + inpt_name = f"{node_name}_{field_name}" + try: + return self.inputs[inpt_name] + except KeyError: + inpt = WorkflowInput( + name=inpt_name, field=field_name, node_name=node_name ) - inpt = WorkflowInput(name=inpt_name, field=field_name, node_name=node_name) - self.inputs[inpt_name] = self._input_mapping[(node_name, field_name)] = inpt - return inpt + self.inputs[inpt_name] = inpt + return inpt - def get_output( - self, field_name: str, node_name: ty.Optional[str] = None, create: bool = False + def make_output( + self, + field_name: str, + node_name: ty.Optional[str] = None, + output_node_only: bool = False, ) -> WorkflowOutput: """ Returns the name of the input field in the workflow for the given node and field escaped by the prefix of the node if present""" field_name = str(field_name) - try: - return self._output_mapping[(node_name, field_name)] - except KeyError: + matching = [ + o + for o in self.outputs.values() + if o.node_name == node_name and o.field == field_name + ] + if len(matching) > 1: + raise KeyError( + f"Multiple outputs found for '{field_name}' field in " + f"'{node_name}' node in '{self.name}' workflow: " + + ", ".join(str(m) for m in matching) + ) + elif len(matching) == 1: + return matching[0] + else: if node_name is None or node_name == self.output_node: outpt_name = field_name - elif create: - outpt_name = f"{node_name}_{field_name}" - else: + elif output_node_only: raise KeyError( - f"Unrecognised output corresponding to {node_name}:{field_name} field, " - "set create=True to auto-create" + f"Could not find output corresponding to '{field_name}' field in " + f"'{node_name}' node in '{self.name}' workflow, set " + "`only_output_node=False` to make an output for any node output" + ) from None + else: + outpt_name = f"{node_name}_{field_name}" + try: + return self.outputs[outpt_name] + except KeyError: + outpt = WorkflowOutput( + name=outpt_name, field=field_name, node_name=node_name ) - outpt = WorkflowOutput( - name=outpt_name, field=field_name, node_name=node_name - ) - self.outputs[outpt_name] = self._output_mapping[(node_name, field_name)] = ( - outpt - ) - return outpt + self.outputs[outpt_name] = outpt + return outpt def add_connection_to_input(self, in_conn: ConnectionStatement): """Add a in_connection to an input of the workflow, adding the input if not present""" @@ -933,11 +967,11 @@ def prepare_connections(self): for node in nodes: if isinstance(node, AddNestedWorkflowStatement): exported_inputs.update( - (i.name, self.get_input(i.name, node_name, create=True)) + (i.name, self.make_input(i.name, node_name)) for i in node.nested_workflow.exported_inputs ) exported_outputs.update( - (o.name, self.get_output(o.name, node_name, create=True)) + (o.name, self.make_output(o.name, node_name)) for o in node.nested_workflow.exported_outputs ) for inpt_name, exp_inpt in exported_inputs: @@ -968,18 +1002,22 @@ def prepare_connections(self): while self._unprocessed_connections: conn = self._unprocessed_connections.pop() try: - inpt = self.get_input(conn.source_out, node_name=conn.source_name) + inpt = self.get_input_from_conn(conn) except KeyError: for src_node in self.nodes[conn.source_name]: src_node.add_output_connection(conn) else: + conn.source_name = None + conn.source_out = inpt.name inpt.out_conns.append(conn) try: - outpt = self.get_output(conn.target_in, node_name=conn.target_name) + outpt = self.get_output_from_conn(conn) except KeyError: for tgt_node in self.nodes[conn.target_name]: tgt_node.add_input_connection(conn) else: + conn.target_name = None + conn.target_in = outpt.name outpt.in_conns.append(conn) def _parse_statements(self, func_body: str) -> ty.Tuple[