Skip to content

Commit

Permalink
TorchScopedLibraryVisitor
Browse files Browse the repository at this point in the history
  • Loading branch information
kit1980 committed Feb 3, 2024
1 parent f219838 commit 3915c59
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 1 deletion.
3 changes: 3 additions & 0 deletions tests/fixtures/internal/checker/scoped_library.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import torch
from torch.library import Library, impl, fallthrough_kernel
my_lib1 = Library("aten", "IMPL")
1 change: 1 addition & 0 deletions tests/fixtures/internal/checker/scoped_library.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3:11 TOR901 Use `torch.library._scoped_library` instead of `torch.library.Library` in PyTorch tests files. See https://github.com/pytorch/pytorch/pull/118318 for details.
5 changes: 4 additions & 1 deletion torchfix/torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
_UpdateFunctorchImports,
)

from .visitors.internal import TorchScopedLibraryVisitor

from .visitors.performance import TorchSynchronizedDataLoaderVisitor
from .visitors.misc import (TorchRequireGradVisitor, TorchReentrantCheckpointVisitor)

Expand All @@ -24,11 +26,12 @@

DEPRECATED_CONFIG_PATH = Path(__file__).absolute().parent / "deprecated_symbols.yaml"

DISABLED_BY_DEFAULT = ["TOR3", "TOR4"]
DISABLED_BY_DEFAULT = ["TOR3", "TOR4", "TOR9"]

ALL_VISITOR_CLS = [
TorchDeprecatedSymbolsVisitor,
TorchRequireGradVisitor,
TorchScopedLibraryVisitor,
TorchSynchronizedDataLoaderVisitor,
TorchVisionDeprecatedPretrainedVisitor,
TorchVisionDeprecatedToTensorVisitor,
Expand Down
33 changes: 33 additions & 0 deletions torchfix/visitors/internal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import libcst as cst
from ...common import TorchVisitor, LintViolation


class TorchScopedLibraryVisitor(TorchVisitor):
"""
Suggest `torch.library._scoped_library` for PyTorch tests.
"""

ERROR_CODE = "TOR901"
MESSAGE = (
"Use `torch.library._scoped_library` instead of `torch.library.Library` "
"in PyTorch tests files. See https://github.com/pytorch/pytorch/pull/118318 "
"for details."
)

def visit_Call(self, node):
qualified_name = self.get_qualified_name_for_call(node)
if qualified_name == "torch.library.Library":
position_metadata = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)

self.violations.append(
LintViolation(
error_code=self.ERROR_CODE,
message=self.MESSAGE,
line=position_metadata.start.line,
column=position_metadata.start.column,
node=node,
replacement=None,
)
)

0 comments on commit 3915c59

Please sign in to comment.