Skip to content

Commit

Permalink
finally sorted out nested-workflow input/output propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
tclose committed May 14, 2024
1 parent 9667474 commit 2f63f18
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 110 deletions.
102 changes: 59 additions & 43 deletions nipype2pydra/statements/workflow_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,24 +177,21 @@ def targets(self):

@property
def wf_in(self):
if self.source_name is None:
try:
self.workflow_converter.get_input_from_conn(self)
except KeyError:
return False
else:
return True
for inpt in self.workflow_converter.inputs.values():
if self.target_name == inpt.node_name and str(self.target_in) == inpt.field:
return True
return False

@property
def wf_out(self):
if self.target_name is None:
try:
self.workflow_converter.get_output_from_conn(self)
except KeyError:
return False
else:
return True
for output in self.workflow_converter.outputs.values():
if (
self.source_name == output.node_name
and str(self.source_out) == output.field
):
return True
return False

@cached_property
def conditional(self):
Expand All @@ -215,29 +212,11 @@ def workflow_variable(self):

@property
def wf_in_name(self):
if not self.wf_in:
raise ValueError(
f"Cannot get wf_in_name for {self} as it is not a workflow input"
)
if self.source_name is None:
return (
self.source_out
if not isinstance(self.source_out, DynamicField)
else self.source_out.varname
)
return self.workflow_converter.get_input(self.target_in, self.target_name).name
return self.workflow_converter.get_input_from_conn(self).name

@property
def wf_out_name(self):
if not self.wf_out:
raise ValueError(
f"Cannot get wf_out_name for {self} as it is not a workflow output"
)
if self.target_name is None:
return self.target_in
return self.workflow_converter.get_output(
self.source_out, self.source_name
).name
return self.workflow_converter.get_output_from_conn(self).name

def __str__(self):
if not self.include:
Expand Down Expand Up @@ -274,7 +253,7 @@ def __str__(self):
# to add an "identity" node to pass it through
intf_name = f"{base_task_name}_identity"
code_str += (
f"{self.indent}@pydra.mark.task\n"
f"\n{self.indent}@pydra.mark.task\n"
f"{self.indent}def {intf_name}({self.wf_in_name}: ty.Any) -> ty.Any:\n"
f"{self.indent} return {self.wf_in_name}\n\n"
f"{self.indent}{self.workflow_variable}.add("
Expand Down Expand Up @@ -669,11 +648,29 @@ def add_input_connection(self, conn: ConnectionStatement):
else:
target_in = conn.target_in
target_name = None
if target_name == self.nested_workflow.input_node:
# Check for replacements for the given target field
replacements = [
i
for i in self.nested_workflow.inputs.values()
if any(n == target_name and f == target_in for n, f in i.replaces)
]
if len(replacements) > 1:
raise ValueError(
f"Multiple inputs found for replacements of '{target_in}' "
f"field in '{target_name}' node in '{self.name}' workflow: "
+ ", ".join(str(m) for m in replacements)
)
elif len(replacements) == 1:
nested_input = replacements[0]
target_name = None
nested_input = self.nested_workflow.get_input(
target_in, node_name=target_name, create=True
)
else:
# If no replacements, create an input for the nested workflow
if target_name == self.nested_workflow.input_node:
target_name = None
nested_input = self.nested_workflow.make_input(
target_in,
node_name=target_name,
)
conn.target_in = nested_input.name
super().add_input_connection(conn)
if target_name:
Expand Down Expand Up @@ -716,11 +713,26 @@ def add_output_connection(self, conn: ConnectionStatement):
else:
source_out = conn.source_out
source_name = None
if source_name == self.nested_workflow.output_node:
replacements = [
o
for o in self.nested_workflow.outputs.values()
if any(n == source_name and f == source_out for n, f in o.replaces)
]
if len(replacements) > 1:
raise KeyError(
f"Multiple outputs found for replacements of '{source_out}' "
f"field in '{source_name}' node in '{self.name}' workflow: "
+ ", ".join(str(m) for m in replacements)
)
elif len(replacements) == 1:
nested_output = replacements[0]
source_name = None
nested_output = self.nested_workflow.get_output(
source_out, node_name=source_name, create=True
)
else:
if source_name == self.nested_workflow.output_node:
source_name = None
nested_output = self.nested_workflow.make_output(
source_out, node_name=source_name
)
conn.source_out = nested_output.name
super().add_output_connection(conn)
if source_name:
Expand Down Expand Up @@ -759,7 +771,7 @@ def __str__(self):
parts = self.attribute.split(".")
nested_node_name = parts[2]
attribute_name = parts[3]
target_in = nested_wf.get_input(attribute_name, nested_node_name).name
target_in = nested_wf.make_input(attribute_name, nested_node_name).name
attribute = ".".join(parts[:2] + [target_in] + parts[4:])
workflow_variable = self.nodes[0].workflow_variable
assert (n.workflow_variable == workflow_variable for n in self.nodes)
Expand All @@ -782,6 +794,10 @@ def matches(cls, stmt, node_names: ty.List[str]) -> bool:
return False
return bool(cls.match_re(node_names).match(stmt))

@property
def conditional(self):
return len(self.indent) != 4

@classmethod
def parse(
cls, statement: str, workflow_converter: "WorkflowConverter"
Expand Down
Loading

0 comments on commit 2f63f18

Please sign in to comment.