From a0ec0aa0601ec5a01a979cd054391be014c60fae Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Thu, 14 Mar 2024 17:40:09 -0700 Subject: [PATCH] Add TorchNonPublicAliasVisitor --- .../checker/default_collate_convert.py | 11 ++++ .../checker/default_collate_convert.txt | 8 +++ tests/test_torchfix.py | 2 +- torchfix/common.py | 21 ++++++++ torchfix/torchfix.py | 2 + torchfix/visitors/nonpublic/__init__.py | 51 +++++++++++++++++++ 6 files changed, 94 insertions(+), 1 deletion(-) create mode 100644 tests/fixtures/nonpublic/checker/default_collate_convert.py create mode 100644 tests/fixtures/nonpublic/checker/default_collate_convert.txt create mode 100644 torchfix/visitors/nonpublic/__init__.py diff --git a/tests/fixtures/nonpublic/checker/default_collate_convert.py b/tests/fixtures/nonpublic/checker/default_collate_convert.py new file mode 100644 index 0000000..c7b7c65 --- /dev/null +++ b/tests/fixtures/nonpublic/checker/default_collate_convert.py @@ -0,0 +1,11 @@ +from torch.utils.data import _utils +batch = _utils.collate.default_collate(batch) + +from torch.utils.data._utils.collate import default_collate +inputs, labels, video_idx = default_collate(inputs), default_collate(labels), default_collate(video_idx) + +from torch.utils.data._utils.collate import default_convert +values = default_convert(values) + +import torch +values = torch.utils.data._utils.collate.default_convert(values) diff --git a/tests/fixtures/nonpublic/checker/default_collate_convert.txt b/tests/fixtures/nonpublic/checker/default_collate_convert.txt new file mode 100644 index 0000000..edfc9a9 --- /dev/null +++ b/tests/fixtures/nonpublic/checker/default_collate_convert.txt @@ -0,0 +1,8 @@ +2:9 TOR104 Use of non-public function `torch.utils.data._utils.collate.default_collate`, please use `torch.utils.data.dataloader.default_collate` instead +4:1 TOR105 Import of non-public function `torch.utils.data._utils.collate.default_collate`, please use `torch.utils.data.dataloader.default_collate` instead +5:29 TOR104 Use of non-public function `torch.utils.data._utils.collate.default_collate`, please use `torch.utils.data.dataloader.default_collate` instead +5:54 TOR104 Use of non-public function `torch.utils.data._utils.collate.default_collate`, please use `torch.utils.data.dataloader.default_collate` instead +5:79 TOR104 Use of non-public function `torch.utils.data._utils.collate.default_collate`, please use `torch.utils.data.dataloader.default_collate` instead +7:1 TOR105 Import of non-public function `torch.utils.data._utils.collate.default_convert`, please use `torch.utils.data.dataloader.default_convert` instead +8:10 TOR104 Use of non-public function `torch.utils.data._utils.collate.default_convert`, please use `torch.utils.data.dataloader.default_convert` instead +11:10 TOR104 Use of non-public function `torch.utils.data._utils.collate.default_convert`, please use `torch.utils.data.dataloader.default_convert` instead diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index 6e7e0c6..7b9a051 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -75,7 +75,7 @@ def test_parse_error_code_str(): ("ALL,TOR102", GET_ALL_ERROR_CODES()), ("TOR102", {"TOR102"}), ("TOR102,TOR101", {"TOR102", "TOR101"}), - ("TOR1,TOR102", {"TOR102", "TOR101"}), + ("TOR1,TOR102", {"TOR102", "TOR101", "TOR104", "TOR105"}), (None, GET_ALL_ERROR_CODES() - exclude_set), ] for case, expected in cases: diff --git a/torchfix/common.py b/torchfix/common.py index b302346..7fdd00a 100644 --- a/torchfix/common.py +++ b/torchfix/common.py @@ -61,6 +61,27 @@ def get_specific_arg( return arg return None + def add_violation( + self, + node: cst.CSTNode, + error_code: str, + message: str, + replacement: Optional[cst.CSTNode] = None, + ) -> None: + position_metadata = self.get_metadata( + cst.metadata.WhitespaceInclusivePositionProvider, node + ) + self.violations.append( + LintViolation( + error_code=error_code, + message=message, + line=position_metadata.start.line, + column=position_metadata.start.column, + node=node, + replacement=replacement, + ) + ) + def get_qualified_name_for_call(self, node: cst.Call) -> Optional[str]: # Guard against situations like `vmap(a)(b)`: # diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index a38d81d..21a8994 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -15,6 +15,7 @@ from .visitors.performance import TorchSynchronizedDataLoaderVisitor from .visitors.misc import (TorchRequireGradVisitor, TorchReentrantCheckpointVisitor) +from .visitors.nonpublic import TorchNonPublicAliasVisitor from .visitors.vision import ( TorchVisionDeprecatedPretrainedVisitor, @@ -39,6 +40,7 @@ TorchVisionModelsImportVisitor, TorchUnsafeLoadVisitor, TorchReentrantCheckpointVisitor, + TorchNonPublicAliasVisitor, ] diff --git a/torchfix/visitors/nonpublic/__init__.py b/torchfix/visitors/nonpublic/__init__.py new file mode 100644 index 0000000..0b8318f --- /dev/null +++ b/torchfix/visitors/nonpublic/__init__.py @@ -0,0 +1,51 @@ +from typing import Sequence + +import libcst as cst +from ...common import TorchVisitor + + +class TorchNonPublicAliasVisitor(TorchVisitor): + """ + Suggest to use public APIs instead of non-public aliases. + + Currently implemented for + torch.utils.data._utils.collate.default_collate and + torch.utils.data._utils.collate.default_convert, + see https://github.com/pytorch/pytorch/pull/69862/files + """ + + ERROR_CODE = ["TOR104", "TOR105"] + + # fmt: off + ALIASES = { + "torch.utils.data._utils.collate.default_collate": "torch.utils.data.dataloader.default_collate", # noqa: E501 + "torch.utils.data._utils.collate.default_convert": "torch.utils.data.dataloader.default_convert", # noqa: E501 + } + # fmt: on + + def visit_Call(self, node): + qualified_name = self.get_qualified_name_for_call(node) + if qualified_name is None: + return + + if qualified_name in self.ALIASES: + public_name = self.ALIASES[qualified_name] + error_code = self.ERROR_CODE[0] + message = f"Use of non-public function `{qualified_name}`, please use `{public_name}` instead" # noqa: E501 + self.add_violation(node, error_code=error_code, message=message) + + def visit_ImportFrom(self, node: cst.ImportFrom) -> None: + if node.module is None: + return + + module = cst.helpers.get_full_name_for_node(node.module) + if not isinstance(node.names, Sequence): + return + + for name in node.names: + qualified_name = f"{module}.{name.name.value}" + if qualified_name in self.ALIASES: + public_name = self.ALIASES[qualified_name] + error_code = self.ERROR_CODE[1] + message = f"Import of non-public function `{qualified_name}`, please use `{public_name}` instead" # noqa: E501 + self.add_violation(node, error_code=error_code, message=message)