Skip to content

Commit

Permalink
fuse transposeop+matmulop->transpose_matmul_bop
Browse files Browse the repository at this point in the history
  • Loading branch information
“username” committed Oct 23, 2024
1 parent 41784eb commit ef35634
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 0 deletions.
42 changes: 42 additions & 0 deletions frontend/Python/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,48 @@ def add_node(self, node: Op):
self._body.append(node)
self.node_table[node.name] = node

def check_deletenode(self, node : Op) -> bool:
if (not(node.name in self.node_table) ):
raise KeyError("node{0} not in graph".format(node.name))

if (len(node._children)==0):
return True
return False;

def delete_node(self, node: Op,parents : List[Op]):
for i in parents:
i._children.remove(node.name)
node.args.clear()
node.kwargs.clear()
node._children.clear()
self._body.remove(node)
self.node_table.pop(node.name)

def displace_node(self,node: Op,newnode:Op):
newnode._arguments = node.args
newnode._keyword_arguments = node.kwargs
newnode._tensor_meta = node.tensor_meta
newnode._op_type = node._op_type

for i in node._children:
newnode.add_children(i)
users = [self.node_table[i] for i in node._children]
for user in users:
user._parents[user._parents.index(node.name)]=newnode.name
user.args[user.args.index(node.name)]=newnode.name
node._children.clear()
#deal with parents+args
for i in node._parents:
newnode.add_parent(i)
parents = [self.node_table[i] for i in node._parents]
for parent in parents:
parent._children[parent._children.index(node.name)]=newnode.name
node._parents.clear()
#update node table
self._body[self._body.index(node)] = newnode
self.node_table.pop(node.name)
self.node_table[newnode.name] = newnode

def init_op_group(self):
"""
Initializes operation groups within the graph.
Expand Down
6 changes: 6 additions & 0 deletions frontend/Python/graph/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,12 @@ def __init__(self) -> None:
self._op_type = OpType.ReduceType


class TransposeMatmulFusedOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ReduceType


class GetItemOp(Op):
def __init__(self) -> None:
super().__init__()
Expand Down
56 changes: 56 additions & 0 deletions frontend/Python/graph/transform/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,67 @@
from .. import Graph
from ..operation import *
from .. import DeviceType
from torch.fx.immutable_collections import immutable_list

classicfuse_register = {
"transpose+mamtmul2D": TransposeMatmulFusedOp
}

# TODO: classify op type for op fusion
# OP_TYPE_FUSABLE = [OpType.BroadcastType, OpType.ElementwiseType, OpType.ReshapeType]
# OP_TYPE_UNFUSABLE = [OpType.Unfusable, OpType.ConcatType]
# OP_TYPE_FUSABLE_BY_SPECIFIC_PASS = []
# ANCHOR_OP_TYPE = []

def check_classicfusetype(graph : Graph,op : Op):
pattern = None
if isinstance(op,MatmulOp):
parentop = [ graph.node_table[str(i)] for i in op._parents]
for target in parentop:
if (isinstance(target,PermuteOp) and target.args[1]==immutable_list([1, 0])):
pattern = target,parentop,"transpose+mamtmul2D"
#TODO:other classic fusion pattern
return pattern

def classic_fuse_check(graph : Graph):
for op in graph.body:
pattern = check_classicfusetype(graph,op)
if (pattern):
do_classicfusion(graph,op,pattern[0],pattern[1],pattern[2])
else:
continue

def do_classicfusion(graph : Graph,node,target : Op,parents : List[Op],pattern : str):
"""
Function to fuse some typical operations into one operation.
Such as transpose + matmul
Args:
- graph (Graph): The input graph to be simplified.
- node (Op): The operation to be fused.
- target (Op): The target operation to be fused.
- parents (List[Op]): The parents of the node to be fused.
- pattern (str): The pattern of the fusion.
Returns:
- None: Modifies the input graph in place.
"""
fusedop = classicfuse_register.get(pattern)()
#matmulop -> fusedmatmulopnode
fusedop.name = "fused"+node.name
graph.displace_node(node,fusedop)
fusedop.args.pop(fusedop.args.index(target.name))
fusedop._parents.pop(fusedop._parents.index(target.name))
fusedop.args.extend(target.args)

fusedop._parents.extend(target._parents)
targets_parent = [graph.node_table[i] for i in target._parents]
for i in targets_parent:
i.add_children(fusedop.name)
target._children.pop(target._children.index(fusedop.name))

if(graph.check_deletenode(target)):
graph.delete_node(target,targets_parent)

def simply_fuse(graph: Graph):
"""
Function to fuse all operations into one graph.
Expand All @@ -40,6 +94,8 @@ def simply_fuse(graph: Graph):
"""
new_op_group = []
device = DeviceType.UNKNOW
#Run the first round of op fusion
classic_fuse_check(graph)
for op in graph.body:
if isinstance(op, PlaceholderOp):
continue
Expand Down
21 changes: 21 additions & 0 deletions frontend/Python/ops/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,6 +1151,26 @@ def matmul_op(
return op


def matmul_transpose_b_op(
node: TransposeMatmulFusedOp,
symbol_table:Dict[Tuple[str, int], ir.Operation]
):
input1 = symbol_table.get((str(node.args[0]),0))
input2 = symbol_table.get((str(node.args[1]),0))

if input1 is None or input2 is None:
return
output_shape = list(node.tensor_meta["shape"])
dtype = node.tensor_meta["dtype"]
mlir_dtype = mlir_element_type_get(dtype)
tensor_type = ir.RankedTensorType.get(output_shape, mlir_dtype)
element = mlir_element_attr_get(dtype, 0.0)
attr = ir.DenseElementsAttr.get_splat(tensor_type, element)
result_buffer = arith.ConstantOp(tensor_type, attr).result
op = linalg.matmul_transpose_b(input1, input2, outs=[result_buffer])
return op


def transpose_op(
node: TransposeOp,
symbol_table: Dict[Tuple[str, int], ir.Operation],
Expand Down Expand Up @@ -1968,6 +1988,7 @@ def gt_op(node: GtOp, symbol_table):

ops_registry = {
"MatmulOp": matmul_op,
"TransposeMatmulFusedOp": matmul_transpose_b_op,
"ArangeOp": arange_op,
"UnsqueezeOp": unsqueeze_op,
"ViewOp": view_op,
Expand Down

0 comments on commit ef35634

Please sign in to comment.