From de70a1fff0b270e3e0e2a3a9bd4f394de1f2de21 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 31 Oct 2023 00:13:05 +0000 Subject: [PATCH] Enable split rotary fusion for batched model --- mlc_llm/core.py | 3 +- .../transform/fuse_split_rotary_embedding.py | 286 ++++++++++++++++-- 2 files changed, 269 insertions(+), 20 deletions(-) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 6b993c07b5..a82bfd393a 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -453,8 +453,7 @@ def mod_transform_before_build( mod = mlc_llm.transform.FuseDecodeTranspose(skip_gemm=not use_ft_quant)(mod) if ( - not args.enable_batching - and hasattr(config, "num_attention_heads") + hasattr(config, "num_attention_heads") and hasattr(config, "hidden_size") and hasattr(config, "position_embedding_base") and getattr(config, "dtype", "float16") == "float16" diff --git a/mlc_llm/transform/fuse_split_rotary_embedding.py b/mlc_llm/transform/fuse_split_rotary_embedding.py index ed19a7095c..8e7e84959e 100644 --- a/mlc_llm/transform/fuse_split_rotary_embedding.py +++ b/mlc_llm/transform/fuse_split_rotary_embedding.py @@ -8,11 +8,30 @@ is_tuple_get_item, GlobalVarPattern, TuplePattern, - is_shape, ) from tvm.script import relax as R, tir as T +def update_param_sinfo(f): + param_sinfo = [] + for param in f.params: + if param in f.buffer_map: + buf = f.buffer_map[param] + sinfo = relax.TensorStructInfo(shape=buf.shape, dtype=buf.dtype) + else: + sinfo = relax.PrimStructInfo(param.dtype) + param_sinfo.append(sinfo) + + relax.expr._update_struct_info( + f, + tvm.relax.FuncStructInfo( + params=param_sinfo, + ret=relax.TupleStructInfo([]), + purity=False, + ), + ) + + def get_dynamic_split_rotary(): """Implementation of R.split(rotary_embedding(fused_qkv)) @@ -88,23 +107,91 @@ def split_rotary( batch_i, seq_i, head_num - num_query_heads - num_kv_heads, head_i ] = input_value - param_sinfo = [] - for param in split_rotary.params: - if param in split_rotary.buffer_map: - buf = split_rotary.buffer_map[param] - sinfo = relax.TensorStructInfo(shape=buf.shape, dtype=buf.dtype) - else: - sinfo = relax.PrimStructInfo(param.dtype) - param_sinfo.append(sinfo) + update_param_sinfo(split_rotary) - relax.expr._update_struct_info( - split_rotary, - tvm.relax.FuncStructInfo( - params=param_sinfo, - ret=relax.TupleStructInfo([]), - purity=False, - ), - ) + return split_rotary + + +def get_dynamic_split_rotary_batched(): + """Implementation of R.split(rotary_embedding(fused_qkv)) + + Implementation is generic over the number of query heads, + key/value heads, sequence length, head dimension, and position + embedding base. These parameters can be replaced with static + values using `PrimFunc.specialize`. + """ + + @T.prim_func(private=True) + def split_rotary( + fused_qkv_handle: T.handle, + position_handle: T.handle, + embedded_query_handle: T.handle, + embedded_key_handle: T.handle, + value_handle: T.handle, + num_token: T.int64, + num_query_heads: T.int64, + num_kv_heads: T.int64, + head_dim: T.int64, + position_embedding_base: T.float32, + ): + Fused_QKV = T.match_buffer( + fused_qkv_handle, + [num_token, num_query_heads + num_kv_heads * 2, head_dim], + dtype="float16", + ) + EmbeddedQuery = T.match_buffer( + embedded_query_handle, + [num_token, num_query_heads, head_dim], + dtype="float16", + ) + EmbeddedKey = T.match_buffer( + embedded_key_handle, + [num_token, num_kv_heads, head_dim], + dtype="float16", + ) + Value = T.match_buffer( + value_handle, + [num_token, num_kv_heads, head_dim], + dtype="float16", + ) + Position = T.match_buffer( + position_handle, + [num_token], + dtype="int32", + ) + + T.func_attr({"op_pattern": 2, "tir.noalias": T.bool(True)}) + + for iters in T.grid(num_token, num_query_heads + num_kv_heads * 2, head_dim): + with T.block("FusedRotaryEmbeddingAndSplitQKV"): + tok_i, head_num, head_i = T.axis.remap("SSS", iters) + pos: T.float32 = T.Cast("float32", Position[tok_i]) + + inv_freq: T.float32 = T.float32(1) / T.pow( + position_embedding_base, + T.Cast("float32", (head_i * 2) % head_dim) / T.float32(head_dim), + ) + freq: T.float32 = pos * inv_freq + cos_value: T.float16 = T.Cast("float16", T.cos(freq)) + sin_value: T.float16 = T.Cast("float16", T.sin(freq)) + + input_value = Fused_QKV[tok_i, head_num, head_i] + embedded_value = cos_value * input_value + sin_value * T.Select( + head_i < T.int64(head_dim // 2), + Fused_QKV[tok_i, head_num, head_i + T.int64(head_dim // 2)] + * T.float16(-1), + Fused_QKV[tok_i, head_num, head_i - T.int64(head_dim // 2)], + ) + if head_num < num_query_heads: + EmbeddedQuery[tok_i, head_num, head_i] = embedded_value + elif head_num < num_query_heads + num_kv_heads: + EmbeddedKey[tok_i, head_num - num_query_heads, head_i] = embedded_value + else: + Value[ + tok_i, head_num - num_query_heads - num_kv_heads, head_i + ] = input_value + + update_param_sinfo(split_rotary) return split_rotary @@ -139,6 +226,8 @@ def ir_module_pass(mod: tvm.IRModule, _pass_context) -> tvm.IRModule: } ) + update_param_sinfo(split_rotary) + mod["split_rotary"] = split_rotary split_rotary_gvar = mod.get_global_var("split_rotary") @@ -281,4 +370,165 @@ def rewriter(matchings, bindings): new_mod = tvm.IRModule(new_mod, mod.type_definitions, mod.attrs, mod.global_infos) return new_mod - return ir_module_pass + @tvm.ir.transform.module_pass(opt_level=0, name="fuse_split_rotary_embedding") + def ir_module_pass_batched(mod: tvm.IRModule, _pass_context) -> tvm.IRModule: + head_dim = hidden_size // num_query_heads + split_rotary = get_dynamic_split_rotary_batched() + + ( + dyn_num_token, + dyn_num_query_heads, + dyn_num_kv_heads, + dyn_head_dim, + dyn_position_embedding_base, + ) = split_rotary.params[-5:] + + split_rotary = split_rotary.specialize( + { + # Static model parameters + dyn_num_query_heads: T.int64(num_query_heads), + dyn_num_kv_heads: T.int64(num_kv_heads), + dyn_head_dim: T.int64(head_dim), + dyn_position_embedding_base: T.float32(position_embedding_base), + # Dynamic parameters, to be inferred from TIR Buffer shapes + dyn_num_token: tvm.tir.Var("num_token", "int64"), + } + ) + + update_param_sinfo(split_rotary) + + mod["split_rotary"] = split_rotary + + split_rotary_gvar = mod.get_global_var("split_rotary") + relax.expr._update_struct_info(split_rotary_gvar, mod["split_rotary"].struct_info) + + with PatternContext() as ctx: + # flat_qkv_tuple: R.Tuple( + # R.Tensor((num_token, 4096), dtype="float16"), + # R.Tensor((num_token, 4096), dtype="float16"), + # R.Tensor((num_token, 4096), dtype="float16"), + # ) = R.split(flat_fused_qkv, indices_or_sections=[4096, 8192], axis=2) + # + # flat_query: R.Tensor((num_token, 4096), dtype="float16") = flat_qkv_tuple[0] + # query: R.Tensor((num_token, 32, 128), dtype="float16") = R.reshape( + # flat_query, R.shape([num_token, 32, 128]) + # ) + # flat_key: R.Tensor((num_token, 4096), dtype="float16") = flat_qkv_tuple[1] + # key: R.Tensor((num_token, 32, 128), dtype="float16") = R.reshape( + # flat_key, R.shape([num_token, 32, 128]) + # ) + # flat_value: R.Tensor((num_token, 4096), dtype="float16") = flat_qkv_tuple[2] + # value: R.Tensor((num_token, 32, 128), dtype="float16") = R.reshape( + # flat_value, R.shape([ num_token, 32, 128]) + # ) + # embedded_query = R.call_tir( + # cls.rotary_embedding1, + # [query, positions], + # out_sinfo=R.Tensor((num_token, 32, 128), dtype="float16"), + # ) + # ) + # embedded_key = R.call_tir( + # cls.rotary_embedding1, + # [key, positions], + # out_sinfo=R.Tensor((num_token, 32, 128), dtype="float16"), + # ) + + pat_rotary_embedding_gvar = GlobalVarPattern() + + pat_flat_fused_qkv = wildcard() + pat_position = wildcard() + + # query_shape = is_shape([num_token, num_query_heads, head_dim]) + pat_query_shape = wildcard() + # value_shape = is_shape([num_token, num_kv_heads, head_dim]) + pat_key_shape = wildcard() + # value_shape = is_shape([num_token, num_kv_heads, head_dim]) + pat_value_shape = wildcard() + + pat_flat_qkv_tuple = is_op("relax.split")(pat_flat_fused_qkv) + pat_flat_query = is_tuple_get_item(pat_flat_qkv_tuple, 0) + pat_query = is_op("relax.reshape")( + pat_flat_query, pat_query_shape, add_constraint=False + ) + pat_flat_query.used_by(pat_query) + pat_flat_key = is_tuple_get_item(pat_flat_qkv_tuple, 1) + pat_key = is_op("relax.reshape")(pat_flat_key, pat_key_shape, add_constraint=False) + pat_flat_key.used_by(pat_key) + pat_flat_value = is_tuple_get_item(pat_flat_qkv_tuple, 2) + pat_value = is_op("relax.reshape")( + pat_flat_value, pat_value_shape, add_constraint=False + ) + pat_flat_value.used_by(pat_value) + + pat_embedded_query = is_op("relax.call_tir")( + pat_rotary_embedding_gvar, + TuplePattern([pat_query, pat_position]), + add_constraint=False, + ) + pat_embedded_key = is_op("relax.call_tir")( + pat_rotary_embedding_gvar, + TuplePattern([pat_key, pat_position]), + add_constraint=False, + ) + + pat_flat_qkv_tuple.used_by(pat_flat_query) + pat_flat_qkv_tuple.used_by(pat_flat_key) + pat_flat_qkv_tuple.used_by(pat_flat_value) + pat_query.used_by(pat_embedded_query) + pat_key.used_by(pat_embedded_key) + pat_position.used_by(pat_embedded_query) + pat_position.used_by(pat_embedded_key) + + def rewriter(matchings, bindings): + # Extracting all the relax and TIR variables that we'll need + flat_fused_qkv = matchings[pat_flat_fused_qkv] + + query = matchings[pat_query] + key = matchings[pat_key] + value = matchings[pat_value] + + position = matchings[pat_position] + + embedded_query = matchings[pat_embedded_query] + embedded_key = matchings[pat_embedded_key] + + num_token, num_query_heads, head_dim = query.struct_info.shape + num_token, num_kv_heads, _head_dim = key.struct_info.shape + + # Rewriting along the new path + + fused_qkv = relax.op.reshape( + flat_fused_qkv, [num_token, num_query_heads + 2 * num_kv_heads, head_dim] + ) + + split_rotary_sinfo = [ + R.Tensor((num_token, num_query_heads, head_dim), dtype="float16"), + R.Tensor((num_token, num_kv_heads, head_dim), dtype="float16"), + R.Tensor((num_token, num_kv_heads, head_dim), dtype="float16"), + ] + qkv_tuple_new = R.call_tir( + split_rotary_gvar, + (fused_qkv, position), + out_sinfo=split_rotary_sinfo, + ) + + embedded_query_new = qkv_tuple_new[0] + embedded_key_new = qkv_tuple_new[1] + value_new = qkv_tuple_new[2] + + return { + value: value_new, + embedded_query: embedded_query_new, + embedded_key: embedded_key_new, + } + + new_mod = {} + for gvar, func in mod.functions.items(): + if isinstance(func, relax.Function): + func = rewrite_bindings(ctx, rewriter, func) + new_mod[gvar] = func + + new_mod = tvm.IRModule(new_mod, mod.type_definitions, mod.attrs, mod.global_infos) + return new_mod + + return tvm.transform.Sequential([ir_module_pass, ir_module_pass_batched])