Skip to content

Commit

Permalink
Add TorchLog1pVisitor
Browse files Browse the repository at this point in the history
  • Loading branch information
kit1980 committed Sep 16, 2024
1 parent aef3ea1 commit 7994bd9
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 8 deletions.
6 changes: 6 additions & 0 deletions tests/fixtures/misc/checker/log1p.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import torch
a = torch.randn(5)
b = torch.log(1 + a)
c = torch.log(a + 1)
b = torch.log(1.0 + a)
c = torch.log(a + 1.0)
4 changes: 4 additions & 0 deletions tests/fixtures/misc/checker/log1p.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
3:5 TOR106 Use `torch.log1p(x)` instead of `torch.log(1 + x)`. It is more accurate for small values of `x`.
4:5 TOR106 Use `torch.log1p(x)` instead of `torch.log(1 + x)`. It is more accurate for small values of `x`.
5:5 TOR106 Use `torch.log1p(x)` instead of `torch.log(1 + x)`. It is more accurate for small values of `x`.
6:5 TOR106 Use `torch.log1p(x)` instead of `torch.log(1 + x)`. It is more accurate for small values of `x`.
5 changes: 4 additions & 1 deletion tests/test_torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def pytest_generate_tests(metafunc):
("ALL,TOR102", GET_ALL_ERROR_CODES()),
("TOR102", {"TOR102"}),
("TOR102,TOR101", {"TOR102", "TOR101"}),
("TOR1,TOR102", {"TOR102", "TOR101", "TOR103", "TOR104", "TOR105"}),
(
"TOR1,TOR102",
{"TOR102", "TOR101", "TOR103", "TOR104", "TOR105", "TOR106"},
),
(None, set(GET_ALL_ERROR_CODES()) - exclude_set),
]
metafunc.parametrize("case,expected", cases, ids=[case for case, _ in cases])
Expand Down
8 changes: 5 additions & 3 deletions torchfix/torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from .visitors import (
TorchDeprecatedSymbolsVisitor,
TorchLog1pVisitor,
TorchNonPublicAliasVisitor,
TorchReentrantCheckpointVisitor,
TorchRequireGradVisitor,
Expand All @@ -28,15 +29,16 @@

ALL_VISITOR_CLS = [
TorchDeprecatedSymbolsVisitor,
TorchLog1pVisitor,
TorchNonPublicAliasVisitor,
TorchRequireGradVisitor,
TorchReentrantCheckpointVisitor,
TorchScopedLibraryVisitor,
TorchSynchronizedDataLoaderVisitor,
TorchUnsafeLoadVisitor,
TorchVisionDeprecatedPretrainedVisitor,
TorchVisionDeprecatedToTensorVisitor,
TorchVisionSingletonImportVisitor,
TorchUnsafeLoadVisitor,
TorchReentrantCheckpointVisitor,
TorchNonPublicAliasVisitor,
]


Expand Down
13 changes: 9 additions & 4 deletions torchfix/visitors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from .deprecated_symbols import TorchDeprecatedSymbolsVisitor
from .internal import TorchScopedLibraryVisitor
from .misc import TorchReentrantCheckpointVisitor, TorchRequireGradVisitor
from .misc import (
TorchReentrantCheckpointVisitor,
TorchRequireGradVisitor,
TorchLog1pVisitor,
)
from .nonpublic import TorchNonPublicAliasVisitor
from .performance import TorchSynchronizedDataLoaderVisitor
from .security import TorchUnsafeLoadVisitor
Expand All @@ -12,13 +16,14 @@

__all__ = [
"TorchDeprecatedSymbolsVisitor",
"TorchLog1pVisitor",
"TorchNonPublicAliasVisitor",
"TorchReentrantCheckpointVisitor",
"TorchRequireGradVisitor",
"TorchScopedLibraryVisitor",
"TorchSynchronizedDataLoaderVisitor",
"TorchUnsafeLoadVisitor",
"TorchVisionDeprecatedPretrainedVisitor",
"TorchVisionDeprecatedToTensorVisitor",
"TorchVisionSingletonImportVisitor",
"TorchUnsafeLoadVisitor",
"TorchReentrantCheckpointVisitor",
"TorchNonPublicAliasVisitor",
]
44 changes: 44 additions & 0 deletions torchfix/visitors/misc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,47 @@ def visit_Call(self, node):
message=self.ERRORS[0].message(),
replacement=replacement,
)


class TorchLog1pVisitor(TorchVisitor):
"""
Suggest using `torch.log1p(x)` instead of `torch.log(1 + x)`.
"""

ERRORS = [
TorchError(
"TOR106",
(
"Use `torch.log1p(x)` instead of `torch.log(1 + x)`. "
"It is more accurate for small values of `x`."
),
)
]

def visit_Call(self, node):
if self.get_qualified_name_for_call(node) == "torch.log":

if m.matches(
node,
m.Call(
args=[
m.Arg(
value=m.BinaryOperation(
left=m.Integer(value="1") | m.Float(value="1.0"),
operator=m.Add(),
)
| m.BinaryOperation(
operator=m.Add(),
right=m.Integer(value="1") | m.Float(value="1.0"),
),
),
],
),
):

self.add_violation(
node,
error_code=self.ERRORS[0].error_code,
message=self.ERRORS[0].message(),
replacement=None,
)

0 comments on commit 7994bd9

Please sign in to comment.