From 3915c59ef92cfe04c5bdbeb25febb7de970221de Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Fri, 2 Feb 2024 17:07:04 -0800 Subject: [PATCH] TorchScopedLibraryVisitor --- .../internal/checker/scoped_library.py | 3 ++ .../internal/checker/scoped_library.txt | 1 + torchfix/torchfix.py | 5 ++- torchfix/visitors/internal/__init__.py | 33 +++++++++++++++++++ 4 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 tests/fixtures/internal/checker/scoped_library.py create mode 100644 tests/fixtures/internal/checker/scoped_library.txt create mode 100644 torchfix/visitors/internal/__init__.py diff --git a/tests/fixtures/internal/checker/scoped_library.py b/tests/fixtures/internal/checker/scoped_library.py new file mode 100644 index 0000000..54d5316 --- /dev/null +++ b/tests/fixtures/internal/checker/scoped_library.py @@ -0,0 +1,3 @@ +import torch +from torch.library import Library, impl, fallthrough_kernel +my_lib1 = Library("aten", "IMPL") diff --git a/tests/fixtures/internal/checker/scoped_library.txt b/tests/fixtures/internal/checker/scoped_library.txt new file mode 100644 index 0000000..1f1e7f8 --- /dev/null +++ b/tests/fixtures/internal/checker/scoped_library.txt @@ -0,0 +1 @@ +3:11 TOR901 Use `torch.library._scoped_library` instead of `torch.library.Library` in PyTorch tests files. See https://github.com/pytorch/pytorch/pull/118318 for details. diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index d1d648d..e6d01d1 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -11,6 +11,8 @@ _UpdateFunctorchImports, ) +from .visitors.internal import TorchScopedLibraryVisitor + from .visitors.performance import TorchSynchronizedDataLoaderVisitor from .visitors.misc import (TorchRequireGradVisitor, TorchReentrantCheckpointVisitor) @@ -24,11 +26,12 @@ DEPRECATED_CONFIG_PATH = Path(__file__).absolute().parent / "deprecated_symbols.yaml" -DISABLED_BY_DEFAULT = ["TOR3", "TOR4"] +DISABLED_BY_DEFAULT = ["TOR3", "TOR4", "TOR9"] ALL_VISITOR_CLS = [ TorchDeprecatedSymbolsVisitor, TorchRequireGradVisitor, + TorchScopedLibraryVisitor, TorchSynchronizedDataLoaderVisitor, TorchVisionDeprecatedPretrainedVisitor, TorchVisionDeprecatedToTensorVisitor, diff --git a/torchfix/visitors/internal/__init__.py b/torchfix/visitors/internal/__init__.py new file mode 100644 index 0000000..424e1f2 --- /dev/null +++ b/torchfix/visitors/internal/__init__.py @@ -0,0 +1,33 @@ +import libcst as cst +from ...common import TorchVisitor, LintViolation + + +class TorchScopedLibraryVisitor(TorchVisitor): + """ + Suggest `torch.library._scoped_library` for PyTorch tests. + """ + + ERROR_CODE = "TOR901" + MESSAGE = ( + "Use `torch.library._scoped_library` instead of `torch.library.Library` " + "in PyTorch tests files. See https://github.com/pytorch/pytorch/pull/118318 " + "for details." + ) + + def visit_Call(self, node): + qualified_name = self.get_qualified_name_for_call(node) + if qualified_name == "torch.library.Library": + position_metadata = self.get_metadata( + cst.metadata.WhitespaceInclusivePositionProvider, node + ) + + self.violations.append( + LintViolation( + error_code=self.ERROR_CODE, + message=self.MESSAGE, + line=position_metadata.start.line, + column=position_metadata.start.column, + node=node, + replacement=None, + ) + )