Skip to content

Commit

Permalink
debugging input/output mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
tclose committed May 13, 2024
1 parent 28d185b commit 9667474
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 42 deletions.
56 changes: 35 additions & 21 deletions nipype2pydra/statements/workflow_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,17 +177,24 @@ def targets(self):

@property
def wf_in(self):
return self.source_name is None or (
(self.target_name, str(self.target_in))
in self.workflow_converter._input_mapping
)
if self.source_name is None:
return True
for inpt in self.workflow_converter.inputs.values():
if self.target_name == inpt.node_name and str(self.target_in) == inpt.field:
return True
return False

@property
def wf_out(self):
return self.target_name is None or (
(self.source_name, str(self.source_out))
in self.workflow_converter._output_mapping
)
if self.target_name is None:
return True
for output in self.workflow_converter.outputs.values():
if (
self.source_name == output.node_name
and str(self.source_out) == output.field
):
return True
return False

@cached_property
def conditional(self):
Expand All @@ -212,24 +219,29 @@ def wf_in_name(self):
raise ValueError(
f"Cannot get wf_in_name for {self} as it is not a workflow input"
)
# source_out_name = (
# self.source_out
# if not isinstance(self.source_out, DynamicField)
# else self.source_out.varname
# )
return self.workflow_converter.get_input(self.source_out, self.source_name).name
if self.source_name is None:
return (
self.source_out
if not isinstance(self.source_out, DynamicField)
else self.source_out.varname
)
return self.workflow_converter.get_input(self.target_in, self.target_name).name

@property
def wf_out_name(self):
if not self.wf_out:
raise ValueError(
f"Cannot get wf_out_name for {self} as it is not a workflow output"
)
return self.workflow_converter.get_output(self.target_in, self.target_name).name
if self.target_name is None:
return self.target_in
return self.workflow_converter.get_output(
self.source_out, self.source_name
).name

def __str__(self):
if not self.include:
return f"{self.indent}pass\n" if self.conditional else ""
return f"{self.indent}pass" if self.conditional else ""
code_str = ""
# Get source lazy-field
if self.wf_in:
Expand Down Expand Up @@ -450,7 +462,7 @@ def converted_interface(self):

def __str__(self):
if not self.include:
return f"{self.indent}pass\n" if self.conditional else ""
return f"{self.indent}pass" if self.conditional else ""
args = ["=".join(a) for a in self.arg_name_vals]
conn_args = []
for conn in sorted(self.in_conns, key=attrgetter("target_in")):
Expand Down Expand Up @@ -580,7 +592,7 @@ class AddNestedWorkflowStatement(AddNodeStatement):

def __str__(self):
if not self.include:
return f"{self.indent}pass\n" if self.conditional else ""
return f"{self.indent}pass" if self.conditional else ""
if self.nested_workflow:
config_params = [
f"{n}_{c}={n}_{c}" for n, c in self.nested_workflow.used_configs
Expand Down Expand Up @@ -659,7 +671,9 @@ def add_input_connection(self, conn: ConnectionStatement):
target_name = None
if target_name == self.nested_workflow.input_node:
target_name = None
nested_input = self.nested_workflow.get_input(target_in, node_name=target_name)
nested_input = self.nested_workflow.get_input(
target_in, node_name=target_name, create=True
)
conn.target_in = nested_input.name
super().add_input_connection(conn)
if target_name:
Expand Down Expand Up @@ -705,7 +719,7 @@ def add_output_connection(self, conn: ConnectionStatement):
if source_name == self.nested_workflow.output_node:
source_name = None
nested_output = self.nested_workflow.get_output(
source_out, node_name=source_name
source_out, node_name=source_name, create=True
)
conn.source_out = nested_output.name
super().add_output_connection(conn)
Expand Down Expand Up @@ -736,7 +750,7 @@ class NodeAssignmentStatement:

def __str__(self):
if not any(n.include for n in self.nodes):
return ""
return f"{self.indent}pass" if self.conditional else ""
node = self.nodes[0]
node_name = node.name
workflow_variable = self.nodes[0].workflow_variable
Expand Down
2 changes: 1 addition & 1 deletion nipype2pydra/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def from_named_dicts_converter(
allow_none=False,
) -> ty.Dict[str, T]:
converted = {}
for name, conv in dct.items() or []:
for name, conv in (dct or {}).items():
if isinstance(conv, dict):
conv = klass(name=name, **conv)
converted[name] = conv
Expand Down
54 changes: 34 additions & 20 deletions nipype2pydra/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import black.report
import attrs
import yaml
from fileformats.core import from_mime, FileSet
from fileformats.core import from_mime, FileSet, Field
from .utils import (
UsedSymbols,
split_source_into_statements,
Expand Down Expand Up @@ -114,6 +114,8 @@ def type_repr_(t):
)
if t in (ty.Any, ty.Union, ty.List, ty.Tuple):
return f"ty.{t.__name__}"
elif issubclass(t, Field):
return t.primative.__name__
elif issubclass(t, FileSet):
return t.__name__
else:
Expand Down Expand Up @@ -407,7 +409,7 @@ def exported_outputs(self):
return (o for o in self.outputs.values() if o.export)

def get_input(
self, field_name: str, node_name: ty.Optional[str] = None
self, field_name: str, node_name: ty.Optional[str] = None, create: bool = False
) -> WorkflowInput:
"""
Returns the name of the input field in the workflow for the given node and field
Expand All @@ -416,17 +418,21 @@ def get_input(
try:
return self._input_mapping[(node_name, field_name)]
except KeyError:
inpt_name = (
field_name
if node_name is None or node_name == self.input_node
else f"{node_name}_{field_name}"
)
if node_name is None or node_name == self.input_node:
inpt_name = field_name
elif create:
inpt_name = f"{node_name}_{field_name}"
else:
raise KeyError(
f"Unrecognised output corresponding to {node_name}:{field_name} field, "
"set create=True to auto-create"
)
inpt = WorkflowInput(name=inpt_name, field=field_name, node_name=node_name)
self.inputs[inpt_name] = self._input_mapping[(node_name, field_name)] = inpt
return inpt

def get_output(
self, field_name: str, node_name: ty.Optional[str] = None
self, field_name: str, node_name: ty.Optional[str] = None, create: bool = False
) -> WorkflowOutput:
"""
Returns the name of the input field in the workflow for the given node and field
Expand All @@ -435,11 +441,15 @@ def get_output(
try:
return self._output_mapping[(node_name, field_name)]
except KeyError:
outpt_name = (
field_name
if node_name is None or node_name == self.input_node
else f"{node_name}_{field_name}"
)
if node_name is None or node_name == self.output_node:
outpt_name = field_name
elif create:
outpt_name = f"{node_name}_{field_name}"
else:
raise KeyError(
f"Unrecognised output corresponding to {node_name}:{field_name} field, "
"set create=True to auto-create"
)
outpt = WorkflowOutput(
name=outpt_name, field=field_name, node_name=node_name
)
Expand Down Expand Up @@ -923,11 +933,11 @@ def prepare_connections(self):
for node in nodes:
if isinstance(node, AddNestedWorkflowStatement):
exported_inputs.update(
(i.name, self.get_input(i.name, node_name))
(i.name, self.get_input(i.name, node_name, create=True))
for i in node.nested_workflow.exported_inputs
)
exported_outputs.update(
(o.name, self.get_output(o.name, node_name))
(o.name, self.get_output(o.name, node_name, create=True))
for o in node.nested_workflow.exported_outputs
)
for inpt_name, exp_inpt in exported_inputs:
Expand Down Expand Up @@ -957,16 +967,20 @@ def prepare_connections(self):
self.parsed_statements.append(conn_stmt)
while self._unprocessed_connections:
conn = self._unprocessed_connections.pop()
if conn.wf_in:
self.get_input(conn.source_out).out_conns.append(conn)
else:
try:
inpt = self.get_input(conn.source_out, node_name=conn.source_name)
except KeyError:
for src_node in self.nodes[conn.source_name]:
src_node.add_output_connection(conn)
if conn.wf_out:
self.get_output(conn.target_in).in_conns.append(conn)
else:
inpt.out_conns.append(conn)
try:
outpt = self.get_output(conn.target_in, node_name=conn.target_name)
except KeyError:
for tgt_node in self.nodes[conn.target_name]:
tgt_node.add_input_connection(conn)
else:
outpt.in_conns.append(conn)

def _parse_statements(self, func_body: str) -> ty.Tuple[
ty.List[
Expand Down

0 comments on commit 9667474

Please sign in to comment.