diff --git a/nipype2pydra/statements/workflow_build.py b/nipype2pydra/statements/workflow_build.py index 9dfce6bd..62c13080 100644 --- a/nipype2pydra/statements/workflow_build.py +++ b/nipype2pydra/statements/workflow_build.py @@ -177,17 +177,24 @@ def targets(self): @property def wf_in(self): - return self.source_name is None or ( - (self.target_name, str(self.target_in)) - in self.workflow_converter._input_mapping - ) + if self.source_name is None: + return True + for inpt in self.workflow_converter.inputs.values(): + if self.target_name == inpt.node_name and str(self.target_in) == inpt.field: + return True + return False @property def wf_out(self): - return self.target_name is None or ( - (self.source_name, str(self.source_out)) - in self.workflow_converter._output_mapping - ) + if self.target_name is None: + return True + for output in self.workflow_converter.outputs.values(): + if ( + self.source_name == output.node_name + and str(self.source_out) == output.field + ): + return True + return False @cached_property def conditional(self): @@ -212,12 +219,13 @@ def wf_in_name(self): raise ValueError( f"Cannot get wf_in_name for {self} as it is not a workflow input" ) - # source_out_name = ( - # self.source_out - # if not isinstance(self.source_out, DynamicField) - # else self.source_out.varname - # ) - return self.workflow_converter.get_input(self.source_out, self.source_name).name + if self.source_name is None: + return ( + self.source_out + if not isinstance(self.source_out, DynamicField) + else self.source_out.varname + ) + return self.workflow_converter.get_input(self.target_in, self.target_name).name @property def wf_out_name(self): @@ -225,11 +233,15 @@ def wf_out_name(self): raise ValueError( f"Cannot get wf_out_name for {self} as it is not a workflow output" ) - return self.workflow_converter.get_output(self.target_in, self.target_name).name + if self.target_name is None: + return self.target_in + return self.workflow_converter.get_output( + self.source_out, self.source_name + ).name def __str__(self): if not self.include: - return f"{self.indent}pass\n" if self.conditional else "" + return f"{self.indent}pass" if self.conditional else "" code_str = "" # Get source lazy-field if self.wf_in: @@ -450,7 +462,7 @@ def converted_interface(self): def __str__(self): if not self.include: - return f"{self.indent}pass\n" if self.conditional else "" + return f"{self.indent}pass" if self.conditional else "" args = ["=".join(a) for a in self.arg_name_vals] conn_args = [] for conn in sorted(self.in_conns, key=attrgetter("target_in")): @@ -580,7 +592,7 @@ class AddNestedWorkflowStatement(AddNodeStatement): def __str__(self): if not self.include: - return f"{self.indent}pass\n" if self.conditional else "" + return f"{self.indent}pass" if self.conditional else "" if self.nested_workflow: config_params = [ f"{n}_{c}={n}_{c}" for n, c in self.nested_workflow.used_configs @@ -659,7 +671,9 @@ def add_input_connection(self, conn: ConnectionStatement): target_name = None if target_name == self.nested_workflow.input_node: target_name = None - nested_input = self.nested_workflow.get_input(target_in, node_name=target_name) + nested_input = self.nested_workflow.get_input( + target_in, node_name=target_name, create=True + ) conn.target_in = nested_input.name super().add_input_connection(conn) if target_name: @@ -705,7 +719,7 @@ def add_output_connection(self, conn: ConnectionStatement): if source_name == self.nested_workflow.output_node: source_name = None nested_output = self.nested_workflow.get_output( - source_out, node_name=source_name + source_out, node_name=source_name, create=True ) conn.source_out = nested_output.name super().add_output_connection(conn) @@ -736,7 +750,7 @@ class NodeAssignmentStatement: def __str__(self): if not any(n.include for n in self.nodes): - return "" + return f"{self.indent}pass" if self.conditional else "" node = self.nodes[0] node_name = node.name workflow_variable = self.nodes[0].workflow_variable diff --git a/nipype2pydra/utils/misc.py b/nipype2pydra/utils/misc.py index d65edfad..56254d3b 100644 --- a/nipype2pydra/utils/misc.py +++ b/nipype2pydra/utils/misc.py @@ -463,7 +463,7 @@ def from_named_dicts_converter( allow_none=False, ) -> ty.Dict[str, T]: converted = {} - for name, conv in dct.items() or []: + for name, conv in (dct or {}).items(): if isinstance(conv, dict): conv = klass(name=name, **conv) converted[name] = conv diff --git a/nipype2pydra/workflow.py b/nipype2pydra/workflow.py index fed02f9b..181dfdde 100644 --- a/nipype2pydra/workflow.py +++ b/nipype2pydra/workflow.py @@ -12,7 +12,7 @@ import black.report import attrs import yaml -from fileformats.core import from_mime, FileSet +from fileformats.core import from_mime, FileSet, Field from .utils import ( UsedSymbols, split_source_into_statements, @@ -114,6 +114,8 @@ def type_repr_(t): ) if t in (ty.Any, ty.Union, ty.List, ty.Tuple): return f"ty.{t.__name__}" + elif issubclass(t, Field): + return t.primative.__name__ elif issubclass(t, FileSet): return t.__name__ else: @@ -407,7 +409,7 @@ def exported_outputs(self): return (o for o in self.outputs.values() if o.export) def get_input( - self, field_name: str, node_name: ty.Optional[str] = None + self, field_name: str, node_name: ty.Optional[str] = None, create: bool = False ) -> WorkflowInput: """ Returns the name of the input field in the workflow for the given node and field @@ -416,17 +418,21 @@ def get_input( try: return self._input_mapping[(node_name, field_name)] except KeyError: - inpt_name = ( - field_name - if node_name is None or node_name == self.input_node - else f"{node_name}_{field_name}" - ) + if node_name is None or node_name == self.input_node: + inpt_name = field_name + elif create: + inpt_name = f"{node_name}_{field_name}" + else: + raise KeyError( + f"Unrecognised output corresponding to {node_name}:{field_name} field, " + "set create=True to auto-create" + ) inpt = WorkflowInput(name=inpt_name, field=field_name, node_name=node_name) self.inputs[inpt_name] = self._input_mapping[(node_name, field_name)] = inpt return inpt def get_output( - self, field_name: str, node_name: ty.Optional[str] = None + self, field_name: str, node_name: ty.Optional[str] = None, create: bool = False ) -> WorkflowOutput: """ Returns the name of the input field in the workflow for the given node and field @@ -435,11 +441,15 @@ def get_output( try: return self._output_mapping[(node_name, field_name)] except KeyError: - outpt_name = ( - field_name - if node_name is None or node_name == self.input_node - else f"{node_name}_{field_name}" - ) + if node_name is None or node_name == self.output_node: + outpt_name = field_name + elif create: + outpt_name = f"{node_name}_{field_name}" + else: + raise KeyError( + f"Unrecognised output corresponding to {node_name}:{field_name} field, " + "set create=True to auto-create" + ) outpt = WorkflowOutput( name=outpt_name, field=field_name, node_name=node_name ) @@ -923,11 +933,11 @@ def prepare_connections(self): for node in nodes: if isinstance(node, AddNestedWorkflowStatement): exported_inputs.update( - (i.name, self.get_input(i.name, node_name)) + (i.name, self.get_input(i.name, node_name, create=True)) for i in node.nested_workflow.exported_inputs ) exported_outputs.update( - (o.name, self.get_output(o.name, node_name)) + (o.name, self.get_output(o.name, node_name, create=True)) for o in node.nested_workflow.exported_outputs ) for inpt_name, exp_inpt in exported_inputs: @@ -957,16 +967,20 @@ def prepare_connections(self): self.parsed_statements.append(conn_stmt) while self._unprocessed_connections: conn = self._unprocessed_connections.pop() - if conn.wf_in: - self.get_input(conn.source_out).out_conns.append(conn) - else: + try: + inpt = self.get_input(conn.source_out, node_name=conn.source_name) + except KeyError: for src_node in self.nodes[conn.source_name]: src_node.add_output_connection(conn) - if conn.wf_out: - self.get_output(conn.target_in).in_conns.append(conn) else: + inpt.out_conns.append(conn) + try: + outpt = self.get_output(conn.target_in, node_name=conn.target_name) + except KeyError: for tgt_node in self.nodes[conn.target_name]: tgt_node.add_input_connection(conn) + else: + outpt.in_conns.append(conn) def _parse_statements(self, func_body: str) -> ty.Tuple[ ty.List[