diff --git a/frontend/Python/graph/graph.py b/frontend/Python/graph/graph.py index d26e354697..a3fc3c5b6c 100644 --- a/frontend/Python/graph/graph.py +++ b/frontend/Python/graph/graph.py @@ -191,7 +191,8 @@ def displace_node(self,node: Op,newnode:Op): newnode.add_children(i) users = [self.node_table[i] for i in node._children] for user in users: - user._parents[user._parents.index(node.name)]=newnode.name + if node.name in user._parents: + user._parents[user._parents.index(node.name)]=newnode.name user.args[user.args.index(node.name)]=newnode.name node._children.clear() #deal with parents+args diff --git a/frontend/Python/graph/transform/fuse_ops.py b/frontend/Python/graph/transform/fuse_ops.py index e4b3f14bcd..f232b79242 100644 --- a/frontend/Python/graph/transform/fuse_ops.py +++ b/frontend/Python/graph/transform/fuse_ops.py @@ -94,6 +94,7 @@ def simply_fuse(graph: Graph): """ new_op_group = [] device = DeviceType.UNKNOW + #Run the first round of op fusion classic_fuse_check(graph) for op in graph.body: diff --git a/tests/Python/test_permute+matmul.py b/tests/Python/test_permute+matmul.py new file mode 100644 index 0000000000..644bcb8558 --- /dev/null +++ b/tests/Python/test_permute+matmul.py @@ -0,0 +1,40 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +import torch +import torch._dynamo as dynamo +from torch._inductor.decomposition import decompositions as inductor_decomp +from torch._functorch.aot_autograd import aot_autograd_decompositions + +from buddy.compiler.frontend import DynamoCompiler +from buddy.compiler.ops import linalg +from buddy.compiler.graph.transform import simply_fuse + +def foo(m1, m2,map): + tmp = torch.ops.aten.permute(m2,map) + return torch.matmul(m1,tmp) + +m1 = torch.ones([3, 4], dtype=torch.float32) +m2 = torch.ones([3, 4], dtype=torch.float32) +map = (1,0) +# Initialize the dynamo compiler. +dynamo_compiler = DynamoCompiler( + primary_registry=linalg.ops_registry, + aot_autograd_decomposition=aot_autograd_decompositions, +) + +graphs = dynamo_compiler.importer(foo, m1,m2,map) +assert len(graphs) == 1 +graph = graphs[0] +pattern_list = [simply_fuse] +graphs[0].fuse_ops(pattern_list) + +graph.lower_to_top_level_ir() +print(graph._imported_module) + +# CHECK: module { +# CHECK-LABEL: func.func @forward +# CHECK: %{{.*}} = arith.constant +# CHECK: %{{.*}} = linalg.matmul_transpose_b +# CHECK: return %{{.*}} +# CHECK: } +# CHECK: }