Skip to content

Commit

Permalink
Merge branch 'revert-45647076' into 'main'
Browse files Browse the repository at this point in the history
Revert "Merge branch 'ppadjin/mistral-7b-bringup' into 'main'"

See merge request tenstorrent/tvm!60
  • Loading branch information
nvukobratTT committed Apr 29, 2024
2 parents 4564707 + 9e6047c commit 0561ee8
Showing 1 changed file with 0 additions and 71 deletions.
71 changes: 0 additions & 71 deletions python/tvm/relay/op/contrib/buda/buda_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3727,76 +3727,6 @@ def callback(self, pre, post, node_map):
return unpadded_bmm


class GQABroadcastReshape(DFPatternCallback):
"""
Callback for Grouped Query Attention Pattern. When parsing a standard GQA,
A subpattern appears that is in the form:
(bs, n_kv_heads, seq_len, head_dim) ->[reshape0]-> (bs, n_kv_heads, 1, seq_len, head_dim) ->[bc0]->
(bs, n_kv_heads, 1, seq_len, head_dim) ->[bc1]-> (bs, n_kv_heads, n_kv_blocks, seq_len, head_dim) ->[reshape1]->
(n_query_heads, bs*seq_len, head_dim) ->[transpose]-> (n_query_heads, head_dim, bs*seq_len)
Where n_query_heads == n_kv_heads * n_kv_blocks. The problem with this subpattern is this broadcast that is
performed (bc1) which generates a 5D tensor with 4 dimensions that are not equal to 1.
That bc output is then input to reshape1 and pybuda compiler has no way to decompose a reshape that is performed
on such input tensor. That is why we change this pattern so that this doesn't occur.
Modification:
(bs, n_kv_heads, seq_len, head_dim) ->[transpose] -> (bs, seq_len, n_kv_heads, head_dim)
->[reshape]-> (bs, n_kv_heads*seq_len, 1, head_dim) ->[bc]-> (bs, seq_len*n_kv_heads, brcst_val, head_dim)
->[reshape]-> (bs*seqlen, n_kv_heads*brcst_val, head_dim) ->[transpose]-> (n_kv_heads*brcst_val, bs*seqlen, head_dim)
->[transpose]-> (n_kv_heads*brcst_val, head_dim, bs*seqlen)
"""
def __init__(self, require_type=False, rewrite_once=False):
super().__init__(require_type, rewrite_once)
self.act = wildcard()
self.reshape0 = is_op('reshape')(self.act)
self.bc0 = is_op('broadcast_to')(self.reshape0)
self.bc1 = is_op('broadcast_to')(self.bc0)
self.reshape1 = is_op('reshape')(self.bc1)
self.pattern = is_op('transpose')(self.reshape1)

def callback(self, pre, post, node_map):
act = node_map[self.act][0] # [bs, n_kv_heads, seq_len, head_dim]
orig_shape = act.checked_type.shape

# idea is to catch only reshapes [bs, n_kv_heads, seq_len, head_dim] -> [bs, n_kv_heads, 1, seq_len, head_dim]
if len(orig_shape) != 4:
return post

if len(node_map[self.reshape0][0].attrs.newshape) != 5:
return post

transpose0 = tvm.relay.transpose(act, axes=[0,2,1,3]) # [bs, seq_len, n_kv_heads, head_dim]
prev_shape = (orig_shape[-4], orig_shape[-2], orig_shape[-3], orig_shape[-1]) # a.k.a. transpose0 shape

new_shape = (prev_shape[-4], int(prev_shape[-3] * prev_shape[-2]), 1, prev_shape[-1])

reshape0 = tvm.relay.reshape(transpose0, newshape=new_shape) # (bs, seq_len*n_kv_heads, 1, head_dim)

if new_shape[-3] != prev_shape[-3] * prev_shape[-2]:
return post

bc1 = node_map[self.bc1][0]
pre_broadcast_shape = list(bc1.type_args[0].shape)
post_broadcast_shape = list(bc1.attrs.shape)

# get the value of dimension that is different after applying broadcast
broadcasted_value = [el for idx, el in enumerate(post_broadcast_shape) if el != pre_broadcast_shape[idx]][0]

new_broadcast_shape = [el for el in new_shape]
new_broadcast_shape[-2] = broadcasted_value

bc = tvm.relay.broadcast_to(reshape0, new_broadcast_shape) # (bs, seq_len*n_kv_heads, 1, head_dim) -> (bs, seq_len*n_kv_heads, brcst_val, head_dim)

new_shape = [prev_shape[-4]*prev_shape[-3], new_broadcast_shape[-3]*new_broadcast_shape[-2] // prev_shape[-3], prev_shape[-1]] # (bs*seqlen, n_kv_heads*brcst_val, head_dim)
reshape1 = tvm.relay.reshape(bc, new_shape)

transpose1 = tvm.relay.transpose(reshape1, axes=[1,0,2]) # (n_kv_heads*brcst_val, bs*seqlen, head_dim)
transpose2 = tvm.relay.transpose(transpose1, axes=[0,2,1]) # (n_kv_heads*brcst_val, head_dim, bs*seqlen)
return transpose2


def _get_callback_name(callback):
Expand Down Expand Up @@ -3935,7 +3865,6 @@ def run_buda_compile_passes(relay_module, params=None, inputs=None, target=None,
# LowerSplitToStridedSlice(),
PadSpecificBatchMatmulShapes(),
SimplifyVITOnnxAttention(),
GQABroadcastReshape(),
],
params=params,
inputs=inputs,
Expand Down

0 comments on commit 0561ee8

Please sign in to comment.