From 94bae36a1f62b1e8c9ccba7f08e6681b68ec8fd3 Mon Sep 17 00:00:00 2001 From: lantiankaikai Date: Thu, 30 Mar 2023 17:46:10 +0000 Subject: [PATCH] Fix strip_function_call in GuardBuilder (#97810) repo: from #92670 this address one of the bug for TorchDynamo pytest ./generated/test_PeterouZh_CIPS_3D.py -k test_003 Issue: In GuardBuilder, when parsing argnames with "getattr(a.layers[slice(2)][0]._abc, '0')" it returns "getattr(a", where it suppose to return "a", and thus causing SyntaxError. This PR fix the regex and add couple test cases. Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/97810 Approved by: https://github.com/yanboliang --- test/dynamo/test_misc.py | 15 +++++++++++++++ torch/_dynamo/guards.py | 19 ++++++++++++++++--- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index a1ad5fd00c3eb..130f56fc43d27 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -5231,6 +5231,21 @@ def forward(self, input): prof.report() ) + def test_guards_strip_function_call(self): + from torch._dynamo.guards import strip_function_call + + test_case = [ + ("___odict_getitem(a, 1)", "a"), + ("a.layers[slice(2)][0]._xyz", "a"), + ("getattr(a.layers[slice(2)][0]._abc, '0')", "a"), + ("getattr(getattr(a.x[3], '0'), '3')", "a"), + ("a.layers[slice(None, -1, None)][0]._xyz", "a"), + ("a.layers[func('offset', -1, None)][0]._xyz", "a"), + ] + # strip_function_call should extract the object from the string. + for name, expect_obj in test_case: + self.assertEqual(strip_function_call(name), expect_obj) + class CustomFunc1(torch.autograd.Function): @staticmethod diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index c16c48515857a..42b653e88cc10 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -67,10 +67,23 @@ def strip_function_call(name): """ "___odict_getitem(a, 1)" => "a" + "a.layers[slice(2)][0]._xyz" ==> "a" + "getattr(a.layers[slice(2)][0]._abc, '0')" ==> "a" + "getattr(getattr(a.x[3], '0'), '3')" ==> "a" + "a.layers[slice(None, -1, None)][0]._xyz" ==> "a" """ - m = re.search(r"([a-z0-9_]+)\(([^(),]+)[^()]*\)", name) - if m and m.group(1) != "slice": - return strip_function_call(m.group(2)) + # recursively find valid object name in fuction + valid_name = re.compile("[A-Za-z_].*") + curr = "" + for char in name: + if char in " (": + curr = "" + elif char in "),[]": + if curr and curr != "None" and valid_name.match(curr): + return strip_function_call(curr) + else: + curr += char + return strip_getattr_getitem(name)