-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Issue 7] Update import torchvision.models as models (#26)
* [Issue 7] Update import torchvision.models as models * Move torchvision.models visitor to vision dir * Move torchvision.models visitor to vision dir
- Loading branch information
Showing
5 changed files
with
49 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
import torchvision.models as models | ||
import torchvision.models as cnn | ||
from torchvision.models import resnet50, resnet101 | ||
import torchvision.models | ||
from torchvision.models import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
1:1 TOR203 Consider replacing 'import torchvision.models as models' with 'from torchvision import models'. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .pretrained import TorchVisionDeprecatedPretrainedVisitor # noqa: F401 | ||
from .to_tensor import TorchVisionDeprecatedToTensorVisitor # noqa: F401 | ||
from .models_import import TorchVisionModelsImportVisitor # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import libcst as cst | ||
|
||
from ...common import LintViolation, TorchVisitor | ||
|
||
|
||
class TorchVisionModelsImportVisitor(TorchVisitor): | ||
ERROR_CODE = "TOR203" | ||
|
||
def visit_Import(self, node: cst.Import) -> None: | ||
for imported_item in node.names: | ||
if isinstance(imported_item.name, cst.Attribute): | ||
if ( | ||
isinstance(imported_item.name.value, cst.Name) | ||
and imported_item.name.value.value == "torchvision" | ||
and isinstance(imported_item.name.attr, cst.Name) | ||
and imported_item.name.attr.value == "models" | ||
and imported_item.asname is not None | ||
and isinstance(imported_item.asname.name, cst.Name) | ||
and imported_item.asname.name.value == "models" | ||
): | ||
position = self.get_metadata( | ||
cst.metadata.WhitespaceInclusivePositionProvider, node | ||
) | ||
replacement = cst.ImportFrom( | ||
module=cst.Name("torchvision"), | ||
names=[cst.ImportAlias(name=cst.Name("models"))], | ||
) | ||
self.violations.append( | ||
LintViolation( | ||
error_code=self.ERROR_CODE, | ||
message=( | ||
"Consider replacing 'import torchvision.models as" | ||
" models' with 'from torchvision import models'." | ||
), | ||
line=position.start.line, | ||
column=position.start.column, | ||
node=node, | ||
replacement=replacement | ||
) | ||
) |