Skip to content

Commit

Permalink
add test for fusion--permute+matmul->linalg.matmul_transpose_b
Browse files Browse the repository at this point in the history
  • Loading branch information
“username” committed Nov 6, 2024
1 parent ef35634 commit 3284a8a
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 1 deletion.
3 changes: 2 additions & 1 deletion frontend/Python/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ def displace_node(self,node: Op,newnode:Op):
newnode.add_children(i)
users = [self.node_table[i] for i in node._children]
for user in users:
user._parents[user._parents.index(node.name)]=newnode.name
if node.name in user._parents:
user._parents[user._parents.index(node.name)]=newnode.name
user.args[user.args.index(node.name)]=newnode.name
node._children.clear()
#deal with parents+args
Expand Down
1 change: 1 addition & 0 deletions frontend/Python/graph/transform/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def simply_fuse(graph: Graph):
"""
new_op_group = []
device = DeviceType.UNKNOW

#Run the first round of op fusion
classic_fuse_check(graph)
for op in graph.body:
Expand Down
40 changes: 40 additions & 0 deletions tests/Python/test_permute+matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# RUN: %PYTHON %s 2>&1 | FileCheck %s

import torch
import torch._dynamo as dynamo
from torch._inductor.decomposition import decompositions as inductor_decomp
from torch._functorch.aot_autograd import aot_autograd_decompositions

from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.ops import linalg
from buddy.compiler.graph.transform import simply_fuse

def foo(m1, m2,map):
tmp = torch.ops.aten.permute(m2,map)
return torch.matmul(m1,tmp)

m1 = torch.ones([3, 4], dtype=torch.float32)
m2 = torch.ones([3, 4], dtype=torch.float32)
map = (1,0)
# Initialize the dynamo compiler.
dynamo_compiler = DynamoCompiler(
primary_registry=linalg.ops_registry,
aot_autograd_decomposition=aot_autograd_decompositions,
)

graphs = dynamo_compiler.importer(foo, m1,m2,map)
assert len(graphs) == 1
graph = graphs[0]
pattern_list = [simply_fuse]
graphs[0].fuse_ops(pattern_list)

graph.lower_to_top_level_ir()
print(graph._imported_module)

# CHECK: module {
# CHECK-LABEL: func.func @forward
# CHECK: %{{.*}} = arith.constant
# CHECK: %{{.*}} = linalg.matmul_transpose_b
# CHECK: return %{{.*}}
# CHECK: }
# CHECK: }

0 comments on commit 3284a8a

Please sign in to comment.