Skip to content

Commit

Permalink
debugging mriqc/niworkflows conversions
Browse files Browse the repository at this point in the history
  • Loading branch information
tclose committed Jun 7, 2024
1 parent abf3551 commit 8c696fd
Show file tree
Hide file tree
Showing 11 changed files with 155 additions and 96 deletions.
8 changes: 3 additions & 5 deletions nipype2pydra/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def used_symbols(self) -> UsedSymbols:
always_include=self.package.all_explicit,
translations=self.package.all_import_translations,
)
used.imports.update(i.to_statement() for i in self.imports)
used.import_stmts.update(i.to_statement() for i in self.imports)
return used

@cached_property
Expand All @@ -147,12 +147,10 @@ def converted_code(self) -> ty.List[str]:
@cached_property
def nested_interfaces(self):
potential_classes = {
full_address(c[1]): c[0]
for c in self.used_symbols.intra_pkg_classes
if c[0]
full_address(c[1]): c[0] for c in self.used_symbols.imported_classes if c[0]
}
potential_classes.update(
(full_address(c), c.__name__) for c in self.used_symbols.local_classes
(full_address(c), c.__name__) for c in self.used_symbols.classes
)
return {
potential_classes[address]: workflow
Expand Down
19 changes: 13 additions & 6 deletions nipype2pydra/interface/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,9 @@ def pydra_fld_input(self, field, nm):
val = getattr(field, key)
if val is not None:
if key == "argstr" and "%" in val:
val = self.string_formats(argstr=val, name=nm)
val = self.string_formats(
argstr=val, name=nm, type_=field.trait_type
)
elif key == "mandatory" and pydra_default is not None:
val = False # Overwrite mandatory to False if default is provided
pydra_metadata[pydra_key_nm] = val
Expand All @@ -666,7 +668,9 @@ def pydra_fld_input(self, field, nm):
template = getattr(field, "name_template")
name_source = ensure_list(getattr(field, "name_source"))
if name_source:
tmpl = self.string_formats(argstr=template, name=name_source[0])
tmpl = self.string_formats(
argstr=template, name=name_source[0], type_=field.trait_type
)
else:
tmpl = template
if nm in self.nipype_interface.output_spec().class_trait_names():
Expand Down Expand Up @@ -829,11 +833,14 @@ def pydra_type_converter(self, field, spec_type, name):
pydra_type = ty.Any
return pydra_type

def string_formats(self, argstr, name):
def string_formats(self, argstr, name, type_):
keys = re.findall(r"(%[0-9\.]*(?:s|d|i|g|f))", argstr)
new_argstr = argstr
for i, key in enumerate(keys):
repl = f"{name}" if len(keys) == 1 else f"{name}[{i}]"
if isinstance(type_, traits.trait_types.Bool):
repl = f"{name}:d"
else:
repl = f"{name}" if len(keys) == 1 else f"{name}[{i}]"
match = re.match(r"%([0-9\.]+)f", key)
if match:
repl += ":" + match.group(1)
Expand Down Expand Up @@ -972,7 +979,7 @@ def _converted_test(self):
)

return spec_str, UsedSymbols(
module_name=self.nipype_module.__name__, imports=imports
module_name=self.nipype_module.__name__, import_stmts=imports
)

def create_doctests(self, input_fields, nonstd_types):
Expand Down Expand Up @@ -1032,7 +1039,7 @@ def _misc_cleanups(self, body: str) -> str:
body = body.replace("self.cmd", f'"{self.nipype_interface._cmd}"')

body = body.replace("self.output_spec().get()", "{}")
body = body.replace("self._outputs()", "{}")
body = body.replace("self._outputs().get()", "{}")
# body = re.sub(
# r"outputs = self\.(output_spec|_outputs)\(\).*$",
# r"outputs = {}",
Expand Down
12 changes: 6 additions & 6 deletions nipype2pydra/interface/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def types_to_names(spec_fields):
translations=self.package.all_import_translations,
absolute_imports=True,
)
used.update(method_used, from_other_module=False)
used.update(method_used)

method_body = ""
for field in input_fields:
Expand Down Expand Up @@ -129,7 +129,7 @@ def types_to_names(spec_fields):
translations=self.package.all_import_translations,
absolute_imports=True,
)
used.update(init_used, from_other_module=False)
used.update(init_used)
method_body += init_code + "\n"

# Combined src of run_interface and list_outputs
Expand Down Expand Up @@ -163,7 +163,7 @@ def types_to_names(spec_fields):
translations=self.package.all_import_translations,
absolute_imports=True,
)
used.update(run_interface_used, from_other_module=False)
used.update(run_interface_used)
method_body += run_interface_code + "\n"

list_outputs_code = inspect.getsource(
Expand Down Expand Up @@ -197,7 +197,7 @@ def types_to_names(spec_fields):
translations=self.package.all_import_translations,
absolute_imports=True,
)
used.update(list_outputs_used, from_other_module=False)
used.update(list_outputs_used)
method_body += list_outputs_code + "\n"

assert method_body, "Neither `run_interface` and `list_outputs` are defined"
Expand Down Expand Up @@ -250,12 +250,12 @@ def types_to_names(spec_fields):
additional_imports.add(imprt)
spec_str = repl_spec_str

used.imports.update(
used.import_stmts.update(
self.construct_imports(
nonstd_types,
spec_str,
include_task=False,
base=base_imports + list(used.imports) + list(additional_imports),
base=base_imports + list(used.import_stmts) + list(additional_imports),
)
)

Expand Down
2 changes: 1 addition & 1 deletion nipype2pydra/interface/shell_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def types_to_names(spec_fields):
)
used.update(super_used)

used.imports.update(
used.import_stmts.update(
self.construct_imports(
nonstd_types,
spec_str,
Expand Down
40 changes: 20 additions & 20 deletions nipype2pydra/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,7 +400,7 @@ def write(self, package_root: Path, to_include: ty.List[str] = None):
workflow.prepare_connections()

def collect_intra_pkg_objects(used: UsedSymbols, port_nipype: bool = True):
for _, klass in used.intra_pkg_classes:
for _, klass in used.imported_classes:
address = full_address(klass)
if address in self.nipype_port_converters:
if port_nipype:
Expand All @@ -412,10 +412,10 @@ def collect_intra_pkg_objects(used: UsedSymbols, port_nipype: bool = True):
)
elif full_address(klass) not in self.interfaces:
intra_pkg_modules[klass.__module__].add(klass)
for _, func in used.intra_pkg_funcs:
for _, func in used.imported_funcs:
if full_address(func) not in list(self.workflows):
intra_pkg_modules[func.__module__].add(func)
for const_mod_address, _, const_name in used.intra_pkg_constants:
for const_mod_address, _, const_name in used.imported_constants:
intra_pkg_modules[const_mod_address].add(const_name)

for conv in list(self.functions.values()) + list(self.classes.values()):
Expand All @@ -429,7 +429,7 @@ def collect_intra_pkg_objects(used: UsedSymbols, port_nipype: bool = True):
package_root,
already_converted=already_converted,
)
class_addrs = [full_address(c) for _, c in all_used.intra_pkg_classes]
class_addrs = [full_address(c) for _, c in all_used.imported_classes]
included_addrs = [c.full_address for c in interfaces_to_include]
interfaces_to_include.extend(
self.interfaces[a]
Expand Down Expand Up @@ -555,25 +555,23 @@ def write_intra_pkg_modules(
always_include=self.all_explicit,
)

classes = used.local_classes + [
o for o in objs if inspect.isclass(o) and o not in used.local_classes
classes = used.classes + [
o for o in objs if inspect.isclass(o) and o not in used.classes
]

functions = list(used.local_functions) + [
o
for o in objs
if inspect.isfunction(o) and o not in used.local_functions
functions = list(used.functions) + [
o for o in objs if inspect.isfunction(o) and o not in used.functions
]

self.write_to_module(
package_root=package_root,
module_name=out_mod_name,
used=UsedSymbols(
module_name=mod_name,
imports=used.imports,
import_stmts=used.import_stmts,
constants=used.constants,
local_classes=classes,
local_functions=functions,
classes=classes,
functions=functions,
),
find_replace=self.find_replace,
inline_intra_pkg=False,
Expand Down Expand Up @@ -871,11 +869,11 @@ def write_to_module(
existing_imports = parse_imports(existing_import_strs, relative_to=module_name)
converter_imports = []

for klass in used.local_classes:
for klass in used.classes:
if f"\nclass {klass.__name__}(" not in code_str:
try:
class_converter = self.classes[full_address(klass)]
converter_imports.extend(class_converter.used_symbols.imports)
converter_imports.extend(class_converter.used_symbols.import_stmts)
except KeyError:
class_converter = ClassConverter.from_object(klass, self)
code_str += "\n" + class_converter.converted_code + "\n"
Expand Down Expand Up @@ -903,11 +901,13 @@ def write_to_module(
if converted_code.strip() not in code_str:
code_str += "\n" + converted_code + "\n"

for func in sorted(used.local_functions, key=attrgetter("__name__")):
for func in sorted(used.functions, key=attrgetter("__name__")):
if f"\ndef {func.__name__}(" not in code_str:
if func.__name__ in self.functions:
function_converter = self.functions[full_address(func)]
converter_imports.extend(function_converter.used_symbols.imports)
converter_imports.extend(
function_converter.used_symbols.import_stmts
)
else:
function_converter = FunctionConverter.from_object(func, self)
code_str += "\n" + function_converter.converted_code + "\n"
Expand All @@ -923,7 +923,7 @@ def write_to_module(
code_str += (
"\n\n# Intra-package imports that have been inlined in this module\n\n"
)
for func_name, func in sorted(used.intra_pkg_funcs, key=itemgetter(0)):
for func_name, func in sorted(used.imported_funcs, key=itemgetter(0)):
func_src = get_source_code(func)
func_src = re.sub(
r"^(#[^\n]+\ndef) (\w+)(?=\()",
Expand All @@ -934,7 +934,7 @@ def write_to_module(
code_str += "\n\n" + cleanup_function_body(func_src)
inlined_symbols.append(func_name)

for klass_name, klass in sorted(used.intra_pkg_classes, key=itemgetter(0)):
for klass_name, klass in sorted(used.imported_classes, key=itemgetter(0)):
klass_src = get_source_code(klass)
klass_src = re.sub(
r"^(#[^\n]+\nclass) (\w+)(?=\()",
Expand Down Expand Up @@ -973,7 +973,7 @@ def write_to_module(
imports = ImportStatement.collate(
existing_imports
+ converter_imports
+ [i for i in used.imports if not i.indent]
+ [i for i in used.import_stmts if not i.indent]
+ GENERIC_PYDRA_IMPORTS
+ additional_imports
)
Expand Down
10 changes: 5 additions & 5 deletions nipype2pydra/pkg_gen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,13 +1123,13 @@ def insert_args_in_method_calls(
mod = import_module(mod_name)
used = UsedSymbols.find(mod, methods, omit_classes=(BaseInterface, TraitedSpec))
all_funcs.update(methods)
for func in used.local_functions:
for func in used.functions:
all_funcs.add(cleanup_function_body(get_source_code(func)))
for klass in used.local_classes:
for klass in used.classes:
klass_src = cleanup_function_body(get_source_code(klass))
if klass_src not in all_classes:
all_classes.append(klass_src)
for new_func_name, func in used.intra_pkg_funcs:
for new_func_name, func in used.imported_funcs:
if new_func_name is None:
continue # Not referenced directly in this module
func_src = get_source_code(func)
Expand All @@ -1148,7 +1148,7 @@ def insert_args_in_method_calls(
+ match.group(2)
)
all_funcs.add(cleanup_function_body(func_src))
for new_klass_name, klass in used.intra_pkg_classes:
for new_klass_name, klass in used.imported_classes:
if new_klass_name is None:
continue # Not referenced directly in this module
klass_src = get_source_code(klass)
Expand All @@ -1169,7 +1169,7 @@ def insert_args_in_method_calls(
klass_src = cleanup_function_body(klass_src)
if klass_src not in all_classes:
all_classes.append(klass_src)
all_imports.update(used.imports)
all_imports.update(used.import_stmts)
all_constants.update(used.constants)
return (
sorted(
Expand Down
2 changes: 2 additions & 0 deletions nipype2pydra/statements/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,8 @@ def parse_imports(
"from fileformats.generic import File, Directory",
"from pydra.engine.specs import MultiInputObj",
"from pathlib import Path",
"import json",
"import yaml",
"import logging",
"import pydra.mark",
"import typing as ty",
Expand Down
9 changes: 9 additions & 0 deletions nipype2pydra/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from importlib import import_module
from logging import getLogger
from pydra.engine.specs import MultiInputObj


logger = getLogger("nipype2pydra")
Expand Down Expand Up @@ -482,12 +483,20 @@ def from_named_dicts_converter(
def str_to_type(type_str: str) -> type:
"""Resolve a string representation of a type into a valid type"""
if "/" in type_str:
if type_str.startswith("multi["):
assert type_str.endswith("]"), f"Invalid multi type: {type_str}"
type_str = type_str[6:-1]
multi = True
else:
multi = False
tp = from_mime(type_str)
try:
# If datatype is a field, use its primitive instead
tp = tp.primitive # type: ignore
except AttributeError:
pass
if multi:
tp = MultiInputObj[tp]
else:

def resolve_type(type_str: str) -> type:
Expand Down
Loading

0 comments on commit 8c696fd

Please sign in to comment.