Skip to content

Commit

Permalink
Add TorchNonPublicAliasVisitor
Browse files Browse the repository at this point in the history
  • Loading branch information
kit1980 committed Mar 15, 2024
1 parent 40e021a commit a0ec0aa
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 1 deletion.
11 changes: 11 additions & 0 deletions tests/fixtures/nonpublic/checker/default_collate_convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from torch.utils.data import _utils
batch = _utils.collate.default_collate(batch)

from torch.utils.data._utils.collate import default_collate
inputs, labels, video_idx = default_collate(inputs), default_collate(labels), default_collate(video_idx)

from torch.utils.data._utils.collate import default_convert
values = default_convert(values)

import torch
values = torch.utils.data._utils.collate.default_convert(values)
8 changes: 8 additions & 0 deletions tests/fixtures/nonpublic/checker/default_collate_convert.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
2:9 TOR104 Use of non-public function `torch.utils.data._utils.collate.default_collate`, please use `torch.utils.data.dataloader.default_collate` instead
4:1 TOR105 Import of non-public function `torch.utils.data._utils.collate.default_collate`, please use `torch.utils.data.dataloader.default_collate` instead
5:29 TOR104 Use of non-public function `torch.utils.data._utils.collate.default_collate`, please use `torch.utils.data.dataloader.default_collate` instead
5:54 TOR104 Use of non-public function `torch.utils.data._utils.collate.default_collate`, please use `torch.utils.data.dataloader.default_collate` instead
5:79 TOR104 Use of non-public function `torch.utils.data._utils.collate.default_collate`, please use `torch.utils.data.dataloader.default_collate` instead
7:1 TOR105 Import of non-public function `torch.utils.data._utils.collate.default_convert`, please use `torch.utils.data.dataloader.default_convert` instead
8:10 TOR104 Use of non-public function `torch.utils.data._utils.collate.default_convert`, please use `torch.utils.data.dataloader.default_convert` instead
11:10 TOR104 Use of non-public function `torch.utils.data._utils.collate.default_convert`, please use `torch.utils.data.dataloader.default_convert` instead
2 changes: 1 addition & 1 deletion tests/test_torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_parse_error_code_str():
("ALL,TOR102", GET_ALL_ERROR_CODES()),
("TOR102", {"TOR102"}),
("TOR102,TOR101", {"TOR102", "TOR101"}),
("TOR1,TOR102", {"TOR102", "TOR101"}),
("TOR1,TOR102", {"TOR102", "TOR101", "TOR104", "TOR105"}),
(None, GET_ALL_ERROR_CODES() - exclude_set),
]
for case, expected in cases:
Expand Down
21 changes: 21 additions & 0 deletions torchfix/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,27 @@ def get_specific_arg(
return arg
return None

def add_violation(
self,
node: cst.CSTNode,
error_code: str,
message: str,
replacement: Optional[cst.CSTNode] = None,
) -> None:
position_metadata = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)
self.violations.append(
LintViolation(
error_code=error_code,
message=message,
line=position_metadata.start.line,
column=position_metadata.start.column,
node=node,
replacement=replacement,
)
)

def get_qualified_name_for_call(self, node: cst.Call) -> Optional[str]:
# Guard against situations like `vmap(a)(b)`:
#
Expand Down
2 changes: 2 additions & 0 deletions torchfix/torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from .visitors.performance import TorchSynchronizedDataLoaderVisitor
from .visitors.misc import (TorchRequireGradVisitor, TorchReentrantCheckpointVisitor)
from .visitors.nonpublic import TorchNonPublicAliasVisitor

from .visitors.vision import (
TorchVisionDeprecatedPretrainedVisitor,
Expand All @@ -39,6 +40,7 @@
TorchVisionModelsImportVisitor,
TorchUnsafeLoadVisitor,
TorchReentrantCheckpointVisitor,
TorchNonPublicAliasVisitor,
]


Expand Down
51 changes: 51 additions & 0 deletions torchfix/visitors/nonpublic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Sequence

import libcst as cst
from ...common import TorchVisitor


class TorchNonPublicAliasVisitor(TorchVisitor):
"""
Suggest to use public APIs instead of non-public aliases.
Currently implemented for
torch.utils.data._utils.collate.default_collate and
torch.utils.data._utils.collate.default_convert,
see https://github.com/pytorch/pytorch/pull/69862/files
"""

ERROR_CODE = ["TOR104", "TOR105"]

# fmt: off
ALIASES = {
"torch.utils.data._utils.collate.default_collate": "torch.utils.data.dataloader.default_collate", # noqa: E501
"torch.utils.data._utils.collate.default_convert": "torch.utils.data.dataloader.default_convert", # noqa: E501
}
# fmt: on

def visit_Call(self, node):
qualified_name = self.get_qualified_name_for_call(node)
if qualified_name is None:
return

if qualified_name in self.ALIASES:
public_name = self.ALIASES[qualified_name]
error_code = self.ERROR_CODE[0]
message = f"Use of non-public function `{qualified_name}`, please use `{public_name}` instead" # noqa: E501
self.add_violation(node, error_code=error_code, message=message)

def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
if node.module is None:
return

module = cst.helpers.get_full_name_for_node(node.module)
if not isinstance(node.names, Sequence):
return

for name in node.names:
qualified_name = f"{module}.{name.name.value}"
if qualified_name in self.ALIASES:
public_name = self.ALIASES[qualified_name]
error_code = self.ERROR_CODE[1]
message = f"Import of non-public function `{qualified_name}`, please use `{public_name}` instead" # noqa: E501
self.add_violation(node, error_code=error_code, message=message)

0 comments on commit a0ec0aa

Please sign in to comment.