Skip to content

Commit

Permalink
Add a rule for use_reentrant with checkpoint (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
kit1980 authored Dec 26, 2023
1 parent 4748cfb commit 03ea18c
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 1 deletion.
13 changes: 13 additions & 0 deletions tests/fixtures/misc/checker/reentrant_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))

import torch.utils.checkpoint
def fn(x, y):
return checkpoint(gn, torch.sin(x), y)
return checkpoint(gn, torch.sin(x), y, use_reentrant=False)

from torch.utils.checkpoint import checkpoint
def fn(x, y):
return checkpoint(gn, torch.sin(x), y)
return checkpoint(gn, torch.sin(x), y, use_reentrant=True)
2 changes: 2 additions & 0 deletions tests/fixtures/misc/checker/reentrant_checkpoint.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
7:12 TOR003 Please pass `use_reentrant` explicitly to `checkpoint`. To maintain old behavior, pass `use_reentrant=True`. It is recommended to use `use_reentrant=False`.
12:12 TOR003 Please pass `use_reentrant` explicitly to `checkpoint`. To maintain old behavior, pass `use_reentrant=True`. It is recommended to use `use_reentrant=False`.
6 changes: 6 additions & 0 deletions tests/fixtures/misc/codemod/reentrant_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import torch
from torch.utils.checkpoint import checkpoint
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
def fn(x, y):
return checkpoint(gn, torch.sin(x), y)
6 changes: 6 additions & 0 deletions tests/fixtures/misc/codemod/reentrant_checkpoint.py.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import torch
from torch.utils.checkpoint import checkpoint
def gn(x, y):
return torch.sigmoid(torch.matmul(x, y))
def fn(x, y):
return checkpoint(gn, torch.sin(x), y, use_reentrant=False)
4 changes: 3 additions & 1 deletion torchfix/torchfix.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
)

from .visitors.performance import TorchSynchronizedDataLoaderVisitor
from .visitors.misc import TorchRequireGradVisitor
from .visitors.misc import (TorchRequireGradVisitor, TorchReentrantCheckpointVisitor)

from .visitors.vision import (
TorchVisionDeprecatedPretrainedVisitor,
TorchVisionDeprecatedToTensorVisitor,
Expand All @@ -33,6 +34,7 @@ def GET_ALL_VISITORS():
TorchVisionDeprecatedPretrainedVisitor(),
TorchVisionDeprecatedToTensorVisitor(),
TorchUnsafeLoadVisitor(),
TorchReentrantCheckpointVisitor(),
]


Expand Down
41 changes: 41 additions & 0 deletions torchfix/visitors/misc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,44 @@ def visit_Assign(self, node):
replacement=replacement,
)
)


class TorchReentrantCheckpointVisitor(TorchVisitor):
"""
Find and fix common misuse of reentrant checkpoints.
"""

ERROR_CODE = "TOR003"
MESSAGE = (
"Please pass `use_reentrant` explicitly to `checkpoint`. "
"To maintain old behavior, pass `use_reentrant=True`. "
"It is recommended to use `use_reentrant=False`."
)

def visit_Call(self, node):
qualified_name = self.get_qualified_name_for_call(node)
if qualified_name == "torch.utils.checkpoint.checkpoint":
use_reentrant_arg = self.get_specific_arg(node, "use_reentrant", -1)
if use_reentrant_arg is None:
position_metadata = self.get_metadata(
cst.metadata.WhitespaceInclusivePositionProvider, node
)

# This codemod maybe unsafe correctness-wise
# if reentrant behavior is actually needed,
# so the changes need to be verified/tested.
use_reentrant_arg = cst.ensure_type(
cst.parse_expression("f(use_reentrant=False)"), cst.Call
).args[0]
replacement = node.with_changes(args=node.args + (use_reentrant_arg,))

self.violations.append(
LintViolation(
error_code=self.ERROR_CODE,
message=self.MESSAGE,
line=position_metadata.start.line,
column=position_metadata.start.column,
node=node,
replacement=replacement,
)
)

0 comments on commit 03ea18c

Please sign in to comment.