diff --git a/tests/fixtures/misc/checker/reentrant_checkpoint.py b/tests/fixtures/misc/checker/reentrant_checkpoint.py new file mode 100644 index 0000000..938a41f --- /dev/null +++ b/tests/fixtures/misc/checker/reentrant_checkpoint.py @@ -0,0 +1,13 @@ +import torch +def gn(x, y): + return torch.sigmoid(torch.matmul(x, y)) + +import torch.utils.checkpoint +def fn(x, y): + return checkpoint(gn, torch.sin(x), y) + return checkpoint(gn, torch.sin(x), y, use_reentrant=False) + +from torch.utils.checkpoint import checkpoint +def fn(x, y): + return checkpoint(gn, torch.sin(x), y) + return checkpoint(gn, torch.sin(x), y, use_reentrant=True) diff --git a/tests/fixtures/misc/checker/reentrant_checkpoint.txt b/tests/fixtures/misc/checker/reentrant_checkpoint.txt new file mode 100644 index 0000000..af867d6 --- /dev/null +++ b/tests/fixtures/misc/checker/reentrant_checkpoint.txt @@ -0,0 +1,2 @@ +7:12 TOR003 Please pass `use_reentrant` explicitly to `checkpoint`. To maintain old behavior, pass `use_reentrant=True`. It is recommended to use `use_reentrant=False`. +12:12 TOR003 Please pass `use_reentrant` explicitly to `checkpoint`. To maintain old behavior, pass `use_reentrant=True`. It is recommended to use `use_reentrant=False`. diff --git a/tests/fixtures/misc/codemod/reentrant_checkpoint.py b/tests/fixtures/misc/codemod/reentrant_checkpoint.py new file mode 100644 index 0000000..3d0051d --- /dev/null +++ b/tests/fixtures/misc/codemod/reentrant_checkpoint.py @@ -0,0 +1,6 @@ +import torch +from torch.utils.checkpoint import checkpoint +def gn(x, y): + return torch.sigmoid(torch.matmul(x, y)) +def fn(x, y): + return checkpoint(gn, torch.sin(x), y) diff --git a/tests/fixtures/misc/codemod/reentrant_checkpoint.py.out b/tests/fixtures/misc/codemod/reentrant_checkpoint.py.out new file mode 100644 index 0000000..57c69b7 --- /dev/null +++ b/tests/fixtures/misc/codemod/reentrant_checkpoint.py.out @@ -0,0 +1,6 @@ +import torch +from torch.utils.checkpoint import checkpoint +def gn(x, y): + return torch.sigmoid(torch.matmul(x, y)) +def fn(x, y): + return checkpoint(gn, torch.sin(x), y, use_reentrant=False) diff --git a/torchfix/torchfix.py b/torchfix/torchfix.py index 1a47e20..c0e7da9 100644 --- a/torchfix/torchfix.py +++ b/torchfix/torchfix.py @@ -11,7 +11,8 @@ ) from .visitors.performance import TorchSynchronizedDataLoaderVisitor -from .visitors.misc import TorchRequireGradVisitor +from .visitors.misc import (TorchRequireGradVisitor, TorchReentrantCheckpointVisitor) + from .visitors.vision import ( TorchVisionDeprecatedPretrainedVisitor, TorchVisionDeprecatedToTensorVisitor, @@ -33,6 +34,7 @@ def GET_ALL_VISITORS(): TorchVisionDeprecatedPretrainedVisitor(), TorchVisionDeprecatedToTensorVisitor(), TorchUnsafeLoadVisitor(), + TorchReentrantCheckpointVisitor(), ] diff --git a/torchfix/visitors/misc/__init__.py b/torchfix/visitors/misc/__init__.py index ef83d91..6ce7c84 100644 --- a/torchfix/visitors/misc/__init__.py +++ b/torchfix/visitors/misc/__init__.py @@ -46,3 +46,44 @@ def visit_Assign(self, node): replacement=replacement, ) ) + + +class TorchReentrantCheckpointVisitor(TorchVisitor): + """ + Find and fix common misuse of reentrant checkpoints. + """ + + ERROR_CODE = "TOR003" + MESSAGE = ( + "Please pass `use_reentrant` explicitly to `checkpoint`. " + "To maintain old behavior, pass `use_reentrant=True`. " + "It is recommended to use `use_reentrant=False`." + ) + + def visit_Call(self, node): + qualified_name = self.get_qualified_name_for_call(node) + if qualified_name == "torch.utils.checkpoint.checkpoint": + use_reentrant_arg = self.get_specific_arg(node, "use_reentrant", -1) + if use_reentrant_arg is None: + position_metadata = self.get_metadata( + cst.metadata.WhitespaceInclusivePositionProvider, node + ) + + # This codemod maybe unsafe correctness-wise + # if reentrant behavior is actually needed, + # so the changes need to be verified/tested. + use_reentrant_arg = cst.ensure_type( + cst.parse_expression("f(use_reentrant=False)"), cst.Call + ).args[0] + replacement = node.with_changes(args=node.args + (use_reentrant_arg,)) + + self.violations.append( + LintViolation( + error_code=self.ERROR_CODE, + message=self.MESSAGE, + line=position_metadata.start.line, + column=position_metadata.start.column, + node=node, + replacement=replacement, + ) + )