From 8c696fd451b1418d7fadd5c9a51b40b2f99db2e9 Mon Sep 17 00:00:00 2001 From: Tom Close Date: Fri, 7 Jun 2024 15:42:51 +1000 Subject: [PATCH] debugging mriqc/niworkflows conversions --- nipype2pydra/helpers.py | 8 +- nipype2pydra/interface/base.py | 19 ++- nipype2pydra/interface/function.py | 12 +- nipype2pydra/interface/shell_command.py | 2 +- nipype2pydra/package.py | 40 +++---- nipype2pydra/pkg_gen/__init__.py | 10 +- nipype2pydra/statements/imports.py | 2 + nipype2pydra/utils/misc.py | 9 ++ nipype2pydra/utils/symbols.py | 113 +++++++++++------- .../utils/tests/test_utils_imports.py | 28 +++-- nipype2pydra/workflow.py | 8 +- 11 files changed, 155 insertions(+), 96 deletions(-) diff --git a/nipype2pydra/helpers.py b/nipype2pydra/helpers.py index 3892cd3..d344d9d 100644 --- a/nipype2pydra/helpers.py +++ b/nipype2pydra/helpers.py @@ -133,7 +133,7 @@ def used_symbols(self) -> UsedSymbols: always_include=self.package.all_explicit, translations=self.package.all_import_translations, ) - used.imports.update(i.to_statement() for i in self.imports) + used.import_stmts.update(i.to_statement() for i in self.imports) return used @cached_property @@ -147,12 +147,10 @@ def converted_code(self) -> ty.List[str]: @cached_property def nested_interfaces(self): potential_classes = { - full_address(c[1]): c[0] - for c in self.used_symbols.intra_pkg_classes - if c[0] + full_address(c[1]): c[0] for c in self.used_symbols.imported_classes if c[0] } potential_classes.update( - (full_address(c), c.__name__) for c in self.used_symbols.local_classes + (full_address(c), c.__name__) for c in self.used_symbols.classes ) return { potential_classes[address]: workflow diff --git a/nipype2pydra/interface/base.py b/nipype2pydra/interface/base.py index ec424d5..d7a2c1e 100644 --- a/nipype2pydra/interface/base.py +++ b/nipype2pydra/interface/base.py @@ -657,7 +657,9 @@ def pydra_fld_input(self, field, nm): val = getattr(field, key) if val is not None: if key == "argstr" and "%" in val: - val = self.string_formats(argstr=val, name=nm) + val = self.string_formats( + argstr=val, name=nm, type_=field.trait_type + ) elif key == "mandatory" and pydra_default is not None: val = False # Overwrite mandatory to False if default is provided pydra_metadata[pydra_key_nm] = val @@ -666,7 +668,9 @@ def pydra_fld_input(self, field, nm): template = getattr(field, "name_template") name_source = ensure_list(getattr(field, "name_source")) if name_source: - tmpl = self.string_formats(argstr=template, name=name_source[0]) + tmpl = self.string_formats( + argstr=template, name=name_source[0], type_=field.trait_type + ) else: tmpl = template if nm in self.nipype_interface.output_spec().class_trait_names(): @@ -829,11 +833,14 @@ def pydra_type_converter(self, field, spec_type, name): pydra_type = ty.Any return pydra_type - def string_formats(self, argstr, name): + def string_formats(self, argstr, name, type_): keys = re.findall(r"(%[0-9\.]*(?:s|d|i|g|f))", argstr) new_argstr = argstr for i, key in enumerate(keys): - repl = f"{name}" if len(keys) == 1 else f"{name}[{i}]" + if isinstance(type_, traits.trait_types.Bool): + repl = f"{name}:d" + else: + repl = f"{name}" if len(keys) == 1 else f"{name}[{i}]" match = re.match(r"%([0-9\.]+)f", key) if match: repl += ":" + match.group(1) @@ -972,7 +979,7 @@ def _converted_test(self): ) return spec_str, UsedSymbols( - module_name=self.nipype_module.__name__, imports=imports + module_name=self.nipype_module.__name__, import_stmts=imports ) def create_doctests(self, input_fields, nonstd_types): @@ -1032,7 +1039,7 @@ def _misc_cleanups(self, body: str) -> str: body = body.replace("self.cmd", f'"{self.nipype_interface._cmd}"') body = body.replace("self.output_spec().get()", "{}") - body = body.replace("self._outputs()", "{}") + body = body.replace("self._outputs().get()", "{}") # body = re.sub( # r"outputs = self\.(output_spec|_outputs)\(\).*$", # r"outputs = {}", diff --git a/nipype2pydra/interface/function.py b/nipype2pydra/interface/function.py index 1f8f6f4..27c6f8e 100644 --- a/nipype2pydra/interface/function.py +++ b/nipype2pydra/interface/function.py @@ -91,7 +91,7 @@ def types_to_names(spec_fields): translations=self.package.all_import_translations, absolute_imports=True, ) - used.update(method_used, from_other_module=False) + used.update(method_used) method_body = "" for field in input_fields: @@ -129,7 +129,7 @@ def types_to_names(spec_fields): translations=self.package.all_import_translations, absolute_imports=True, ) - used.update(init_used, from_other_module=False) + used.update(init_used) method_body += init_code + "\n" # Combined src of run_interface and list_outputs @@ -163,7 +163,7 @@ def types_to_names(spec_fields): translations=self.package.all_import_translations, absolute_imports=True, ) - used.update(run_interface_used, from_other_module=False) + used.update(run_interface_used) method_body += run_interface_code + "\n" list_outputs_code = inspect.getsource( @@ -197,7 +197,7 @@ def types_to_names(spec_fields): translations=self.package.all_import_translations, absolute_imports=True, ) - used.update(list_outputs_used, from_other_module=False) + used.update(list_outputs_used) method_body += list_outputs_code + "\n" assert method_body, "Neither `run_interface` and `list_outputs` are defined" @@ -250,12 +250,12 @@ def types_to_names(spec_fields): additional_imports.add(imprt) spec_str = repl_spec_str - used.imports.update( + used.import_stmts.update( self.construct_imports( nonstd_types, spec_str, include_task=False, - base=base_imports + list(used.imports) + list(additional_imports), + base=base_imports + list(used.import_stmts) + list(additional_imports), ) ) diff --git a/nipype2pydra/interface/shell_command.py b/nipype2pydra/interface/shell_command.py index 4a551a8..7d2dcb6 100644 --- a/nipype2pydra/interface/shell_command.py +++ b/nipype2pydra/interface/shell_command.py @@ -210,7 +210,7 @@ def types_to_names(spec_fields): ) used.update(super_used) - used.imports.update( + used.import_stmts.update( self.construct_imports( nonstd_types, spec_str, diff --git a/nipype2pydra/package.py b/nipype2pydra/package.py index dfbc01f..4e2741d 100644 --- a/nipype2pydra/package.py +++ b/nipype2pydra/package.py @@ -400,7 +400,7 @@ def write(self, package_root: Path, to_include: ty.List[str] = None): workflow.prepare_connections() def collect_intra_pkg_objects(used: UsedSymbols, port_nipype: bool = True): - for _, klass in used.intra_pkg_classes: + for _, klass in used.imported_classes: address = full_address(klass) if address in self.nipype_port_converters: if port_nipype: @@ -412,10 +412,10 @@ def collect_intra_pkg_objects(used: UsedSymbols, port_nipype: bool = True): ) elif full_address(klass) not in self.interfaces: intra_pkg_modules[klass.__module__].add(klass) - for _, func in used.intra_pkg_funcs: + for _, func in used.imported_funcs: if full_address(func) not in list(self.workflows): intra_pkg_modules[func.__module__].add(func) - for const_mod_address, _, const_name in used.intra_pkg_constants: + for const_mod_address, _, const_name in used.imported_constants: intra_pkg_modules[const_mod_address].add(const_name) for conv in list(self.functions.values()) + list(self.classes.values()): @@ -429,7 +429,7 @@ def collect_intra_pkg_objects(used: UsedSymbols, port_nipype: bool = True): package_root, already_converted=already_converted, ) - class_addrs = [full_address(c) for _, c in all_used.intra_pkg_classes] + class_addrs = [full_address(c) for _, c in all_used.imported_classes] included_addrs = [c.full_address for c in interfaces_to_include] interfaces_to_include.extend( self.interfaces[a] @@ -555,14 +555,12 @@ def write_intra_pkg_modules( always_include=self.all_explicit, ) - classes = used.local_classes + [ - o for o in objs if inspect.isclass(o) and o not in used.local_classes + classes = used.classes + [ + o for o in objs if inspect.isclass(o) and o not in used.classes ] - functions = list(used.local_functions) + [ - o - for o in objs - if inspect.isfunction(o) and o not in used.local_functions + functions = list(used.functions) + [ + o for o in objs if inspect.isfunction(o) and o not in used.functions ] self.write_to_module( @@ -570,10 +568,10 @@ def write_intra_pkg_modules( module_name=out_mod_name, used=UsedSymbols( module_name=mod_name, - imports=used.imports, + import_stmts=used.import_stmts, constants=used.constants, - local_classes=classes, - local_functions=functions, + classes=classes, + functions=functions, ), find_replace=self.find_replace, inline_intra_pkg=False, @@ -871,11 +869,11 @@ def write_to_module( existing_imports = parse_imports(existing_import_strs, relative_to=module_name) converter_imports = [] - for klass in used.local_classes: + for klass in used.classes: if f"\nclass {klass.__name__}(" not in code_str: try: class_converter = self.classes[full_address(klass)] - converter_imports.extend(class_converter.used_symbols.imports) + converter_imports.extend(class_converter.used_symbols.import_stmts) except KeyError: class_converter = ClassConverter.from_object(klass, self) code_str += "\n" + class_converter.converted_code + "\n" @@ -903,11 +901,13 @@ def write_to_module( if converted_code.strip() not in code_str: code_str += "\n" + converted_code + "\n" - for func in sorted(used.local_functions, key=attrgetter("__name__")): + for func in sorted(used.functions, key=attrgetter("__name__")): if f"\ndef {func.__name__}(" not in code_str: if func.__name__ in self.functions: function_converter = self.functions[full_address(func)] - converter_imports.extend(function_converter.used_symbols.imports) + converter_imports.extend( + function_converter.used_symbols.import_stmts + ) else: function_converter = FunctionConverter.from_object(func, self) code_str += "\n" + function_converter.converted_code + "\n" @@ -923,7 +923,7 @@ def write_to_module( code_str += ( "\n\n# Intra-package imports that have been inlined in this module\n\n" ) - for func_name, func in sorted(used.intra_pkg_funcs, key=itemgetter(0)): + for func_name, func in sorted(used.imported_funcs, key=itemgetter(0)): func_src = get_source_code(func) func_src = re.sub( r"^(#[^\n]+\ndef) (\w+)(?=\()", @@ -934,7 +934,7 @@ def write_to_module( code_str += "\n\n" + cleanup_function_body(func_src) inlined_symbols.append(func_name) - for klass_name, klass in sorted(used.intra_pkg_classes, key=itemgetter(0)): + for klass_name, klass in sorted(used.imported_classes, key=itemgetter(0)): klass_src = get_source_code(klass) klass_src = re.sub( r"^(#[^\n]+\nclass) (\w+)(?=\()", @@ -973,7 +973,7 @@ def write_to_module( imports = ImportStatement.collate( existing_imports + converter_imports - + [i for i in used.imports if not i.indent] + + [i for i in used.import_stmts if not i.indent] + GENERIC_PYDRA_IMPORTS + additional_imports ) diff --git a/nipype2pydra/pkg_gen/__init__.py b/nipype2pydra/pkg_gen/__init__.py index 9ab8151..170a6d5 100644 --- a/nipype2pydra/pkg_gen/__init__.py +++ b/nipype2pydra/pkg_gen/__init__.py @@ -1123,13 +1123,13 @@ def insert_args_in_method_calls( mod = import_module(mod_name) used = UsedSymbols.find(mod, methods, omit_classes=(BaseInterface, TraitedSpec)) all_funcs.update(methods) - for func in used.local_functions: + for func in used.functions: all_funcs.add(cleanup_function_body(get_source_code(func))) - for klass in used.local_classes: + for klass in used.classes: klass_src = cleanup_function_body(get_source_code(klass)) if klass_src not in all_classes: all_classes.append(klass_src) - for new_func_name, func in used.intra_pkg_funcs: + for new_func_name, func in used.imported_funcs: if new_func_name is None: continue # Not referenced directly in this module func_src = get_source_code(func) @@ -1148,7 +1148,7 @@ def insert_args_in_method_calls( + match.group(2) ) all_funcs.add(cleanup_function_body(func_src)) - for new_klass_name, klass in used.intra_pkg_classes: + for new_klass_name, klass in used.imported_classes: if new_klass_name is None: continue # Not referenced directly in this module klass_src = get_source_code(klass) @@ -1169,7 +1169,7 @@ def insert_args_in_method_calls( klass_src = cleanup_function_body(klass_src) if klass_src not in all_classes: all_classes.append(klass_src) - all_imports.update(used.imports) + all_imports.update(used.import_stmts) all_constants.update(used.constants) return ( sorted( diff --git a/nipype2pydra/statements/imports.py b/nipype2pydra/statements/imports.py index 0e9887a..a1b993a 100644 --- a/nipype2pydra/statements/imports.py +++ b/nipype2pydra/statements/imports.py @@ -587,6 +587,8 @@ def parse_imports( "from fileformats.generic import File, Directory", "from pydra.engine.specs import MultiInputObj", "from pathlib import Path", + "import json", + "import yaml", "import logging", "import pydra.mark", "import typing as ty", diff --git a/nipype2pydra/utils/misc.py b/nipype2pydra/utils/misc.py index 78c4f99..c91c58c 100644 --- a/nipype2pydra/utils/misc.py +++ b/nipype2pydra/utils/misc.py @@ -22,6 +22,7 @@ from importlib import import_module from logging import getLogger +from pydra.engine.specs import MultiInputObj logger = getLogger("nipype2pydra") @@ -482,12 +483,20 @@ def from_named_dicts_converter( def str_to_type(type_str: str) -> type: """Resolve a string representation of a type into a valid type""" if "/" in type_str: + if type_str.startswith("multi["): + assert type_str.endswith("]"), f"Invalid multi type: {type_str}" + type_str = type_str[6:-1] + multi = True + else: + multi = False tp = from_mime(type_str) try: # If datatype is a field, use its primitive instead tp = tp.primitive # type: ignore except AttributeError: pass + if multi: + tp = MultiInputObj[tp] else: def resolve_type(type_str: str) -> type: diff --git a/nipype2pydra/utils/symbols.py b/nipype2pydra/utils/symbols.py index c4524eb..0d169c8 100644 --- a/nipype2pydra/utils/symbols.py +++ b/nipype2pydra/utils/symbols.py @@ -24,7 +24,9 @@ class UsedSymbols: A class to hold the used symbols in a module Parameters - ------- + ---------- + module_name: str + the name of the module containing the functions to be converted imports : list[str] the import statements that need to be included in the converted file local_functions: set[callable] @@ -45,16 +47,31 @@ class UsedSymbols: set of all the constants defined within the package that are referenced by the function, (, , ), where the local alias and the definition of the constant + methods: set[callable] + the names of the methods that are referenced, by default None is a function not + a method + class_constants: set[tuple[str, str]] + the names of the class attributes that are referenced by the method + + class_name: str, optional + the name of the class that the methods originate from """ module_name: str - imports: ty.Set[str] = attrs.field(factory=set) - local_functions: ty.Set[ty.Callable] = attrs.field(factory=set) - local_classes: ty.List[type] = attrs.field(factory=list) + import_stmts: ty.Set[str] = attrs.field(factory=set) + functions: ty.Set[ty.Callable] = attrs.field(factory=set) + classes: ty.List[type] = attrs.field(factory=list) constants: ty.Set[ty.Tuple[str, str]] = attrs.field(factory=set) - intra_pkg_funcs: ty.Set[ty.Tuple[str, ty.Callable]] = attrs.field(factory=set) - intra_pkg_classes: ty.List[ty.Tuple[str, ty.Callable]] = attrs.field(factory=list) - intra_pkg_constants: ty.Set[ty.Tuple[str, str, str]] = attrs.field(factory=set) + methods: ty.Set[ty.Callable] = attrs.field(factory=set) + class_attrs: ty.Set[ty.Tuple[str, str]] = attrs.field(factory=set) + imported_funcs: ty.Set[ty.Tuple[str, ty.Callable]] = attrs.field(factory=set) + imported_classes: ty.List[ty.Tuple[str, ty.Callable]] = attrs.field(factory=list) + imported_constants: ty.Set[ty.Tuple[str, str, str]] = attrs.field(factory=set) + super_methoods: ty.Set[ty.Tuple[type, ty.Callable]] = attrs.field(factory=set) + super_class_attrs: ty.Set[ty.Tuple[type, ty.Tuple[str, str]]] = attrs.field( + factory=set + ) + klass: ty.Optional[type] = None ALWAYS_OMIT_MODULES = [ "traits.trait_handlers", # Old traits module, pre v6.0 @@ -74,34 +91,46 @@ def update( other: "UsedSymbols", absolute_imports: bool = False, to_be_inlined: bool = False, - from_other_module: bool = True, ): - if to_be_inlined or not from_other_module: - self.imports.update( - i.absolute() if absolute_imports else i for i in other.imports + if (self.module_name == other.module_name) or to_be_inlined: + self.import_stmts.update( + i.absolute() if absolute_imports else i for i in other.import_stmts ) - self.intra_pkg_funcs.update(other.intra_pkg_funcs) - self.intra_pkg_classes.extend( - c for c in other.intra_pkg_classes if c not in self.intra_pkg_classes + self.imported_funcs.update(other.imported_funcs) + self.imported_classes.extend( + c for c in other.imported_classes if c not in self.imported_classes ) - self.intra_pkg_constants.update(other.intra_pkg_constants) - if from_other_module: - self.intra_pkg_funcs.update((None, f) for f in other.local_functions) - self.intra_pkg_classes.extend( + self.imported_constants.update(other.imported_constants) + if self.module_name != other.module_name: + self.imported_funcs.update((None, f) for f in other.functions) + self.imported_classes.extend( (None, c) - for c in other.local_classes - if (None, c) not in self.intra_pkg_classes + for c in other.classes + if (None, c) not in self.imported_classes ) - self.intra_pkg_constants.update( + self.imported_constants.update( (other.module_name, None, c[0]) for c in other.constants ) else: - self.local_functions.update(other.local_functions) - self.local_classes.extend( - c for c in other.local_classes if c not in self.local_classes - ) - + self.functions.update(other.functions) + self.classes.extend(c for c in other.classes if c not in self.classes) self.constants.update(other.constants) + if other.klass: + if not self.klass: + raise ValueError( + f"Attempting to merge class symbols for {other.klass} with module " + f"symbols ({self.module_name}) with different names" + ) + if self.klass is other.klass: + self.methods.update(other.methods) + self.constants.update(other.constants) + else: + self.super_methoods.update( + (other.klass, m) for m in other.super_methoods + ) + self.super_class_attrs.update( + (other.klass, a) for a in other.super_class_attrs + ) DEFAULT_FILTERED_CONSTANTS = ( Undefined, @@ -243,19 +272,19 @@ def find( for local_func in local_functions: if ( local_func.__name__ in used_symbols - and local_func not in used.local_functions + and local_func not in used.functions ): - used.local_functions.add(local_func) + used.functions.add(local_func) cls._get_symbols(local_func, used_symbols) all_src += "\n\n" + inspect.getsource(local_func) for local_class in local_classes: if ( local_class.__name__ in used_symbols - and local_class not in used.local_classes + and local_class not in used.classes ): if issubclass(local_class, (BaseInterface, TraitedSpec)): continue - used.local_classes.append(local_class) + used.classes.append(local_class) class_body = inspect.getsource(local_class) bases = extract_args(class_body)[1] used_symbols.update(bases) @@ -334,12 +363,12 @@ def find( ) or inspect.isbuiltin(imported.object): # Case where an object is a nested import from a different package # which is imported in a chain from a neighbouring module - used.imports.add( + used.import_stmts.add( imported.as_independent_statement(resolve=True) ) stmt.drop(imported) elif inspect.isfunction(imported.object): - used.intra_pkg_funcs.add((imported.local_name, imported.object)) + used.imported_funcs.add((imported.local_name, imported.object)) # Recursively include objects imported in the module intra_pkg_objs[import_module(imported.object.__module__)].add( imported.object @@ -353,8 +382,8 @@ def find( # like we did for functions here because we need to preserve the # order the classes are defined in the module in case one inherits # from the other - if class_def not in used.intra_pkg_classes: - used.intra_pkg_classes.append(class_def) + if class_def not in used.imported_classes: + used.imported_classes.append(class_def) # Recursively include objects imported in the module intra_pkg_objs[import_module(imported.object.__module__)].add( imported.object, @@ -375,15 +404,15 @@ def find( obj = getattr(imported.object, attr_name) if inspect.isfunction(obj): - used.intra_pkg_funcs.add((obj.__name__, obj)) + used.imported_funcs.add((obj.__name__, obj)) intra_pkg_objs[imported.object.__name__].add(obj) elif inspect.isclass(obj): class_def = (obj.__name__, obj) - if class_def not in used.intra_pkg_classes: - used.intra_pkg_classes.append(class_def) + if class_def not in used.imported_classes: + used.imported_classes.append(class_def) intra_pkg_objs[imported.object.__name__].add(obj) else: - used.intra_pkg_constants.add( + used.imported_constants.add( ( imported.object.__name__, attr_name, @@ -397,7 +426,7 @@ def find( f"Cannot inline imported module in statement '{stmt}'" ) else: - used.intra_pkg_constants.add( + used.imported_constants.add( ( stmt.module_name, imported.local_name, @@ -423,7 +452,7 @@ def find( ) used.update(used_in_mod, to_be_inlined=collapse_intra_pkg) if stmt: - used.imports.add(stmt) + used.import_stmts.add(stmt) return used @classmethod @@ -495,7 +524,7 @@ def get_imported_object(self, name: str) -> ty.Any: # if not i.from_ # } all_imported = {} - for stmt in self.imports: + for stmt in self.import_stmts: all_imported.update(stmt.imported) try: return all_imported[name].object @@ -514,7 +543,7 @@ def get_imported_object(self, name: str) -> ty.Any: if imported_obj is None: raise ImportError( f"Could not find object named {name} in any of the imported modules:\n" - + "\n".join(str(i) for i in self.imports) + + "\n".join(str(i) for i in self.import_stmts) ) for part in parts[-i:]: imported_obj = getattr(imported_obj, part) diff --git a/nipype2pydra/utils/tests/test_utils_imports.py b/nipype2pydra/utils/tests/test_utils_imports.py index 483ebf0..ecb12c4 100644 --- a/nipype2pydra/utils/tests/test_utils_imports.py +++ b/nipype2pydra/utils/tests/test_utils_imports.py @@ -49,7 +49,9 @@ def test_get_imported_object1(): import_stmts = [ "import nipype.interfaces.utility as niu", ] - used = UsedSymbols(module_name="test_module", imports=parse_imports(import_stmts)) + used = UsedSymbols( + module_name="test_module", import_stmts=parse_imports(import_stmts) + ) assert ( used.get_imported_object("niu.IdentityInterface") is nipype.interfaces.utility.IdentityInterface @@ -60,7 +62,9 @@ def test_get_imported_object2(): import_stmts = [ "import nipype.interfaces.utility", ] - used = UsedSymbols(module_name="test_module", imports=parse_imports(import_stmts)) + used = UsedSymbols( + module_name="test_module", import_stmts=parse_imports(import_stmts) + ) assert ( used.get_imported_object("nipype.interfaces.utility") is nipype.interfaces.utility @@ -71,7 +75,9 @@ def test_get_imported_object3(): import_stmts = [ "from nipype.interfaces.utility import IdentityInterface", ] - used = UsedSymbols(module_name="test_module", imports=parse_imports(import_stmts)) + used = UsedSymbols( + module_name="test_module", import_stmts=parse_imports(import_stmts) + ) assert ( used.get_imported_object("IdentityInterface") is nipype.interfaces.utility.IdentityInterface @@ -82,7 +88,9 @@ def test_get_imported_object4(): import_stmts = [ "from nipype.interfaces.utility import IdentityInterface", ] - used = UsedSymbols(module_name="test_module", imports=parse_imports(import_stmts)) + used = UsedSymbols( + module_name="test_module", import_stmts=parse_imports(import_stmts) + ) assert ( used.get_imported_object("IdentityInterface.input_spec") is nipype.interfaces.utility.IdentityInterface.input_spec @@ -93,7 +101,9 @@ def test_get_imported_object5(): import_stmts = [ "import nipype.interfaces.utility", ] - used = UsedSymbols(module_name="test_module", imports=parse_imports(import_stmts)) + used = UsedSymbols( + module_name="test_module", import_stmts=parse_imports(import_stmts) + ) assert ( used.get_imported_object( "nipype.interfaces.utility.IdentityInterface.input_spec" @@ -106,7 +116,9 @@ def test_get_imported_object_fail1(): import_stmts = [ "import nipype.interfaces.utility", ] - used = UsedSymbols(module_name="test_module", imports=parse_imports(import_stmts)) + used = UsedSymbols( + module_name="test_module", import_stmts=parse_imports(import_stmts) + ) with pytest.raises(ImportError, match="Could not find object named"): used.get_imported_object("nipype.interfaces.utilityboo") @@ -115,6 +127,8 @@ def test_get_imported_object_fail2(): import_stmts = [ "from nipype.interfaces.utility import IdentityInterface", ] - used = UsedSymbols(module_name="test_module", imports=parse_imports(import_stmts)) + used = UsedSymbols( + module_name="test_module", import_stmts=parse_imports(import_stmts) + ) with pytest.raises(ImportError, match="Could not find object named"): used.get_imported_object("IdentityBoo") diff --git a/nipype2pydra/workflow.py b/nipype2pydra/workflow.py index f32544c..033a5b7 100644 --- a/nipype2pydra/workflow.py +++ b/nipype2pydra/workflow.py @@ -666,10 +666,10 @@ def func_body(self): @cached_property def nested_workflows(self): potential_funcs = { - full_address(f[1]): f[0] for f in self.used_symbols.intra_pkg_funcs if f[0] + full_address(f[1]): f[0] for f in self.used_symbols.imported_funcs if f[0] } potential_funcs.update( - (full_address(f), f.__name__) for f in self.used_symbols.local_functions + (full_address(f), f.__name__) for f in self.used_symbols.functions ) return { potential_funcs[address]: workflow @@ -731,7 +731,7 @@ def write( # main workflow code_str = self.converted_code - local_func_names = {f.__name__ for f in used.local_functions} + local_func_names = {f.__name__ for f in used.functions} # Convert any nested workflows for name, conv in self.nested_workflows.items(): if conv.address in already_converted: @@ -990,7 +990,7 @@ def test_used(self): return UsedSymbols( module_name=self.nipype_module.__name__, - imports=( + import_stmts=( nonstd_type_imports + parse_imports( [