diff --git a/tests/fixtures/vision/checker/models_import.py b/tests/fixtures/vision/checker/models_import.py new file mode 100644 index 0000000..8eae98e --- /dev/null +++ b/tests/fixtures/vision/checker/models_import.py @@ -0,0 +1,5 @@ +import torchvision.models as models +import torchvision.models as cnn +from torchvision.models import resnet50, resnet101 +import torchvision.models +from torchvision.models import * diff --git a/tests/fixtures/vision/checker/models_import.txt b/tests/fixtures/vision/checker/models_import.txt new file mode 100644 index 0000000..864cf35 --- /dev/null +++ b/tests/fixtures/vision/checker/models_import.txt @@ -0,0 +1 @@ +1:1 TOR203 Consider replacing 'import torchvision.models as models' with 'from torchvision import models'. diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index a26097a..a38d81d 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -19,6 +19,7 @@ from .visitors.vision import ( TorchVisionDeprecatedPretrainedVisitor, TorchVisionDeprecatedToTensorVisitor, + TorchVisionModelsImportVisitor, ) from .visitors.security import TorchUnsafeLoadVisitor @@ -35,6 +36,7 @@ TorchSynchronizedDataLoaderVisitor, TorchVisionDeprecatedPretrainedVisitor, TorchVisionDeprecatedToTensorVisitor, + TorchVisionModelsImportVisitor, TorchUnsafeLoadVisitor, TorchReentrantCheckpointVisitor, ] diff --git a/torchfix/visitors/vision/__init__.py b/torchfix/visitors/vision/__init__.py index 7adcc19..9bc944e 100644 --- a/torchfix/visitors/vision/__init__.py +++ b/torchfix/visitors/vision/__init__.py @@ -1,2 +1,3 @@ from .pretrained import TorchVisionDeprecatedPretrainedVisitor # noqa: F401 from .to_tensor import TorchVisionDeprecatedToTensorVisitor # noqa: F401 +from .models_import import TorchVisionModelsImportVisitor # noqa: F401 diff --git a/torchfix/visitors/vision/models_import.py b/torchfix/visitors/vision/models_import.py new file mode 100644 index 0000000..ba5a325 --- /dev/null +++ b/torchfix/visitors/vision/models_import.py @@ -0,0 +1,40 @@ +import libcst as cst + +from ...common import LintViolation, TorchVisitor + + +class TorchVisionModelsImportVisitor(TorchVisitor): + ERROR_CODE = "TOR203" + + def visit_Import(self, node: cst.Import) -> None: + for imported_item in node.names: + if isinstance(imported_item.name, cst.Attribute): + if ( + isinstance(imported_item.name.value, cst.Name) + and imported_item.name.value.value == "torchvision" + and isinstance(imported_item.name.attr, cst.Name) + and imported_item.name.attr.value == "models" + and imported_item.asname is not None + and isinstance(imported_item.asname.name, cst.Name) + and imported_item.asname.name.value == "models" + ): + position = self.get_metadata( + cst.metadata.WhitespaceInclusivePositionProvider, node + ) + replacement = cst.ImportFrom( + module=cst.Name("torchvision"), + names=[cst.ImportAlias(name=cst.Name("models"))], + ) + self.violations.append( + LintViolation( + error_code=self.ERROR_CODE, + message=( + "Consider replacing 'import torchvision.models as" + " models' with 'from torchvision import models'." + ), + line=position.start.line, + column=position.start.column, + node=node, + replacement=replacement + ) + )