From 7994bd97d2d2dd97775172f74d9aa58c3acdc347 Mon Sep 17 00:00:00 2001 From: Sergii Dymchenko Date: Mon, 16 Sep 2024 15:35:52 -0700 Subject: [PATCH] Add TorchLog1pVisitor --- tests/fixtures/misc/checker/log1p.py | 6 ++++ tests/fixtures/misc/checker/log1p.txt | 4 +++ tests/test_torchfix.py | 5 ++- torchfix/torchfix.py | 8 +++-- torchfix/visitors/__init__.py | 13 +++++--- torchfix/visitors/misc/__init__.py | 44 +++++++++++++++++++++++++++ 6 files changed, 72 insertions(+), 8 deletions(-) create mode 100644 tests/fixtures/misc/checker/log1p.py create mode 100644 tests/fixtures/misc/checker/log1p.txt diff --git a/tests/fixtures/misc/checker/log1p.py b/tests/fixtures/misc/checker/log1p.py new file mode 100644 index 0000000..6ba8df6 --- /dev/null +++ b/tests/fixtures/misc/checker/log1p.py @@ -0,0 +1,6 @@ +import torch +a = torch.randn(5) +b = torch.log(1 + a) +c = torch.log(a + 1) +b = torch.log(1.0 + a) +c = torch.log(a + 1.0) diff --git a/tests/fixtures/misc/checker/log1p.txt b/tests/fixtures/misc/checker/log1p.txt new file mode 100644 index 0000000..3bcbeac --- /dev/null +++ b/tests/fixtures/misc/checker/log1p.txt @@ -0,0 +1,4 @@ +3:5 TOR106 Use `torch.log1p(x)` instead of `torch.log(1 + x)`. It is more accurate for small values of `x`. +4:5 TOR106 Use `torch.log1p(x)` instead of `torch.log(1 + x)`. It is more accurate for small values of `x`. +5:5 TOR106 Use `torch.log1p(x)` instead of `torch.log(1 + x)`. It is more accurate for small values of `x`. +6:5 TOR106 Use `torch.log1p(x)` instead of `torch.log(1 + x)`. It is more accurate for small values of `x`. diff --git a/tests/test_torchfix.py b/tests/test_torchfix.py index 56fd05c..29f6cf9 100644 --- a/tests/test_torchfix.py +++ b/tests/test_torchfix.py @@ -35,7 +35,10 @@ def pytest_generate_tests(metafunc): ("ALL,TOR102", GET_ALL_ERROR_CODES()), ("TOR102", {"TOR102"}), ("TOR102,TOR101", {"TOR102", "TOR101"}), - ("TOR1,TOR102", {"TOR102", "TOR101", "TOR103", "TOR104", "TOR105"}), + ( + "TOR1,TOR102", + {"TOR102", "TOR101", "TOR103", "TOR104", "TOR105", "TOR106"}, + ), (None, set(GET_ALL_ERROR_CODES()) - exclude_set), ] metafunc.parametrize("case,expected", cases, ids=[case for case, _ in cases]) diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 24e88b5..80acb1c 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -9,6 +9,7 @@ from .visitors import ( TorchDeprecatedSymbolsVisitor, + TorchLog1pVisitor, TorchNonPublicAliasVisitor, TorchReentrantCheckpointVisitor, TorchRequireGradVisitor, @@ -28,15 +29,16 @@ ALL_VISITOR_CLS = [ TorchDeprecatedSymbolsVisitor, + TorchLog1pVisitor, + TorchNonPublicAliasVisitor, TorchRequireGradVisitor, + TorchReentrantCheckpointVisitor, TorchScopedLibraryVisitor, TorchSynchronizedDataLoaderVisitor, + TorchUnsafeLoadVisitor, TorchVisionDeprecatedPretrainedVisitor, TorchVisionDeprecatedToTensorVisitor, TorchVisionSingletonImportVisitor, - TorchUnsafeLoadVisitor, - TorchReentrantCheckpointVisitor, - TorchNonPublicAliasVisitor, ] diff --git a/torchfix/visitors/__init__.py b/torchfix/visitors/__init__.py index af2b62b..f63e405 100644 --- a/torchfix/visitors/__init__.py +++ b/torchfix/visitors/__init__.py @@ -1,6 +1,10 @@ from .deprecated_symbols import TorchDeprecatedSymbolsVisitor from .internal import TorchScopedLibraryVisitor -from .misc import TorchReentrantCheckpointVisitor, TorchRequireGradVisitor +from .misc import ( + TorchReentrantCheckpointVisitor, + TorchRequireGradVisitor, + TorchLog1pVisitor, +) from .nonpublic import TorchNonPublicAliasVisitor from .performance import TorchSynchronizedDataLoaderVisitor from .security import TorchUnsafeLoadVisitor @@ -12,13 +16,14 @@ __all__ = [ "TorchDeprecatedSymbolsVisitor", + "TorchLog1pVisitor", + "TorchNonPublicAliasVisitor", + "TorchReentrantCheckpointVisitor", "TorchRequireGradVisitor", "TorchScopedLibraryVisitor", "TorchSynchronizedDataLoaderVisitor", + "TorchUnsafeLoadVisitor", "TorchVisionDeprecatedPretrainedVisitor", "TorchVisionDeprecatedToTensorVisitor", "TorchVisionSingletonImportVisitor", - "TorchUnsafeLoadVisitor", - "TorchReentrantCheckpointVisitor", - "TorchNonPublicAliasVisitor", ] diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index ea8c7be..edc6809 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -77,3 +77,47 @@ def visit_Call(self, node): message=self.ERRORS[0].message(), replacement=replacement, ) + + +class TorchLog1pVisitor(TorchVisitor): + """ + Suggest using `torch.log1p(x)` instead of `torch.log(1 + x)`. + """ + + ERRORS = [ + TorchError( + "TOR106", + ( + "Use `torch.log1p(x)` instead of `torch.log(1 + x)`. " + "It is more accurate for small values of `x`." + ), + ) + ] + + def visit_Call(self, node): + if self.get_qualified_name_for_call(node) == "torch.log": + + if m.matches( + node, + m.Call( + args=[ + m.Arg( + value=m.BinaryOperation( + left=m.Integer(value="1") | m.Float(value="1.0"), + operator=m.Add(), + ) + | m.BinaryOperation( + operator=m.Add(), + right=m.Integer(value="1") | m.Float(value="1.0"), + ), + ), + ], + ), + ): + + self.add_violation( + node, + error_code=self.ERRORS[0].error_code, + message=self.ERRORS[0].message(), + replacement=None, + )