diff --git a/torchfix/visitors/vision/models_import.py b/torchfix/visitors/vision/models_import.py index 928f2c1..de75d5a 100644 --- a/torchfix/visitors/vision/models_import.py +++ b/torchfix/visitors/vision/models_import.py @@ -1,4 +1,5 @@ import libcst as cst +import libcst.matchers as m from ...common import TorchVisitor @@ -13,27 +14,21 @@ class TorchVisionModelsImportVisitor(TorchVisitor): def visit_Import(self, node: cst.Import) -> None: replacement = None for imported_item in node.names: - if isinstance(imported_item.name, cst.Attribute): - # TODO refactor using libcst.matchers.matches - if ( - isinstance(imported_item.name.value, cst.Name) - and imported_item.name.value.value == "torchvision" - and isinstance(imported_item.name.attr, cst.Name) - and imported_item.name.attr.value == "models" - and imported_item.asname is not None - and isinstance(imported_item.asname.name, cst.Name) - and imported_item.asname.name.value == "models" - ): - # Replace only if the import statement has no other names - if len(node.names) == 1: - replacement = cst.ImportFrom( - module=cst.Name("torchvision"), - names=[cst.ImportAlias(name=cst.Name("models"))], - ) - self.add_violation( - node, - error_code=self.ERROR_CODE, - message=self.MESSAGE, - replacement=replacement, + if m.matches(imported_item, m.ImportAlias( + name=m.Attribute(value=m.Name("torchvision"), + attr=m.Name("models")), + asname=m.AsName(name=m.Name("models")) + )): + # Replace only if the import statement has no other names + if len(node.names) == 1: + replacement = cst.ImportFrom( + module=cst.Name("torchvision"), + names=[cst.ImportAlias(name=cst.Name("models"))], ) - break + self.add_violation( + node, + error_code=self.ERROR_CODE, + message=self.MESSAGE, + replacement=replacement, + ) + break