Skip to content

Commit

Permalink
[Issue 7] Update import torchvision.models as models (#26)
Browse files Browse the repository at this point in the history
* [Issue 7] Update import torchvision.models as models

* Move torchvision.models visitor to vision dir

* Move torchvision.models visitor to vision dir
  • Loading branch information
gesuwen authored Mar 5, 2024
1 parent 35f2488 commit 8846f5c
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tests/fixtures/vision/checker/models_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import torchvision.models as models
import torchvision.models as cnn
from torchvision.models import resnet50, resnet101
import torchvision.models
from torchvision.models import *
1 change: 1 addition & 0 deletions tests/fixtures/vision/checker/models_import.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
1:1 TOR203 Consider replacing 'import torchvision.models as models' with 'from torchvision import models'.
2 changes: 2 additions & 0 deletions torchfix/torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .visitors.vision import (
TorchVisionDeprecatedPretrainedVisitor,
TorchVisionDeprecatedToTensorVisitor,
TorchVisionModelsImportVisitor,
)
from .visitors.security import TorchUnsafeLoadVisitor

Expand All @@ -35,6 +36,7 @@
TorchSynchronizedDataLoaderVisitor,
TorchVisionDeprecatedPretrainedVisitor,
TorchVisionDeprecatedToTensorVisitor,
TorchVisionModelsImportVisitor,
TorchUnsafeLoadVisitor,
TorchReentrantCheckpointVisitor,
]
Expand Down
1 change: 1 addition & 0 deletions torchfix/visitors/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .pretrained import TorchVisionDeprecatedPretrainedVisitor # noqa: F401
from .to_tensor import TorchVisionDeprecatedToTensorVisitor # noqa: F401
from .models_import import TorchVisionModelsImportVisitor # noqa: F401
40 changes: 40 additions & 0 deletions torchfix/visitors/vision/models_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import libcst as cst

from ...common import LintViolation, TorchVisitor


class TorchVisionModelsImportVisitor(TorchVisitor):
ERROR_CODE = "TOR203"

def visit_Import(self, node: cst.Import) -> None:
for imported_item in node.names:
if isinstance(imported_item.name, cst.Attribute):
if (
isinstance(imported_item.name.value, cst.Name)
and imported_item.name.value.value == "torchvision"
and isinstance(imported_item.name.attr, cst.Name)
and imported_item.name.attr.value == "models"
and imported_item.asname is not None
and isinstance(imported_item.asname.name, cst.Name)
and imported_item.asname.name.value == "models"
):
position = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)
replacement = cst.ImportFrom(
module=cst.Name("torchvision"),
names=[cst.ImportAlias(name=cst.Name("models"))],
)
self.violations.append(
LintViolation(
error_code=self.ERROR_CODE,
message=(
"Consider replacing 'import torchvision.models as"
" models' with 'from torchvision import models'."
),
line=position.start.line,
column=position.start.column,
node=node,
replacement=replacement
)
)

0 comments on commit 8846f5c

Please sign in to comment.