Skip to content

Commit

Permalink
Refactored each import item from node using libcst.matchers.matches (#37
Browse files Browse the repository at this point in the history
)

* refactored each import item from node using libcst.matchers.matches

* add proper import statement

* remove comment

* fix code style with prettier indent

* prettify the code by removing more indent

* prettify the code

---------

Co-authored-by: Francesca Wang <[email protected]>
  • Loading branch information
Francescaaa and Francesca Wang authored Apr 3, 2024
1 parent c909236 commit 3387070
Showing 1 changed file with 18 additions and 23 deletions.
41 changes: 18 additions & 23 deletions torchfix/visitors/vision/models_import.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import libcst as cst
import libcst.matchers as m

from ...common import TorchVisitor

Expand All @@ -13,27 +14,21 @@ class TorchVisionModelsImportVisitor(TorchVisitor):
def visit_Import(self, node: cst.Import) -> None:
replacement = None
for imported_item in node.names:
if isinstance(imported_item.name, cst.Attribute):
# TODO refactor using libcst.matchers.matches
if (
isinstance(imported_item.name.value, cst.Name)
and imported_item.name.value.value == "torchvision"
and isinstance(imported_item.name.attr, cst.Name)
and imported_item.name.attr.value == "models"
and imported_item.asname is not None
and isinstance(imported_item.asname.name, cst.Name)
and imported_item.asname.name.value == "models"
):
# Replace only if the import statement has no other names
if len(node.names) == 1:
replacement = cst.ImportFrom(
module=cst.Name("torchvision"),
names=[cst.ImportAlias(name=cst.Name("models"))],
)
self.add_violation(
node,
error_code=self.ERROR_CODE,
message=self.MESSAGE,
replacement=replacement,
if m.matches(imported_item, m.ImportAlias(
name=m.Attribute(value=m.Name("torchvision"),
attr=m.Name("models")),
asname=m.AsName(name=m.Name("models"))
)):
# Replace only if the import statement has no other names
if len(node.names) == 1:
replacement = cst.ImportFrom(
module=cst.Name("torchvision"),
names=[cst.ImportAlias(name=cst.Name("models"))],
)
break
self.add_violation(
node,
error_code=self.ERROR_CODE,
message=self.MESSAGE,
replacement=replacement,
)
break

0 comments on commit 3387070

Please sign in to comment.