Skip to content

Commit

Permalink
Fix strip_function_call in GuardBuilder (pytorch#97810)
Browse files Browse the repository at this point in the history
repo:
from pytorch#92670 this address one of the bug for TorchDynamo

pytest ./generated/test_PeterouZh_CIPS_3D.py -k test_003

Issue:
In GuardBuilder, when parsing argnames with "getattr(a.layers[slice(2)][0]._abc, '0')" it returns "getattr(a", where it suppose to return "a", and thus causing SyntaxError.

This PR fix the regex and add couple test cases.

Fixes #ISSUE_NUMBER

Pull Request resolved: pytorch#97810
Approved by: https://github.com/yanboliang
  • Loading branch information
lantiankaikai authored and pytorchmergebot committed Mar 30, 2023
1 parent ffd76d1 commit 94bae36
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
15 changes: 15 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5231,6 +5231,21 @@ def forward(self, input):
prof.report()
)

def test_guards_strip_function_call(self):
from torch._dynamo.guards import strip_function_call

test_case = [
("___odict_getitem(a, 1)", "a"),
("a.layers[slice(2)][0]._xyz", "a"),
("getattr(a.layers[slice(2)][0]._abc, '0')", "a"),
("getattr(getattr(a.x[3], '0'), '3')", "a"),
("a.layers[slice(None, -1, None)][0]._xyz", "a"),
("a.layers[func('offset', -1, None)][0]._xyz", "a"),
]
# strip_function_call should extract the object from the string.
for name, expect_obj in test_case:
self.assertEqual(strip_function_call(name), expect_obj)


class CustomFunc1(torch.autograd.Function):
@staticmethod
Expand Down
19 changes: 16 additions & 3 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,23 @@
def strip_function_call(name):
"""
"___odict_getitem(a, 1)" => "a"
"a.layers[slice(2)][0]._xyz" ==> "a"
"getattr(a.layers[slice(2)][0]._abc, '0')" ==> "a"
"getattr(getattr(a.x[3], '0'), '3')" ==> "a"
"a.layers[slice(None, -1, None)][0]._xyz" ==> "a"
"""
m = re.search(r"([a-z0-9_]+)\(([^(),]+)[^()]*\)", name)
if m and m.group(1) != "slice":
return strip_function_call(m.group(2))
# recursively find valid object name in fuction
valid_name = re.compile("[A-Za-z_].*")
curr = ""
for char in name:
if char in " (":
curr = ""
elif char in "),[]":
if curr and curr != "None" and valid_name.match(curr):
return strip_function_call(curr)
else:
curr += char

return strip_getattr_getitem(name)


Expand Down

0 comments on commit 94bae36

Please sign in to comment.