From f1bc68f9d868c34dba0675354d9b42b8e25e8622 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Tue, 23 Jan 2024 16:26:13 -0800 Subject: [PATCH] [Mixtral] Use col major for MoE gemm and use small batch specialization (#171) --- mlc_llm/relax_model/commons.py | 34 +++++++++++--------- mlc_llm/relax_model/llama.py | 57 +++++++++------------------------- mlc_llm/relax_model/mixtral.py | 2 +- 3 files changed, 34 insertions(+), 59 deletions(-) diff --git a/mlc_llm/relax_model/commons.py b/mlc_llm/relax_model/commons.py index 1e90a63434..676ff610a2 100644 --- a/mlc_llm/relax_model/commons.py +++ b/mlc_llm/relax_model/commons.py @@ -82,33 +82,32 @@ def shard_gate_up_weight_scale(weight: relax.TensorStructInfo): return func def moe_shard_k_weight_scale(weight: relax.TensorStructInfo): - (num_experts, red, spatial), dtype = weight.shape, weight.dtype + (num_experts, spatial, red), dtype = weight.shape, weight.dtype spatial, red = int(spatial), int(red) if param_shape_is_already_sharded: red *= num_shards - a = te.placeholder((num_experts, red, spatial), dtype=dtype) - w = topi.reshape(a, (num_experts, num_shards, red // num_shards, spatial)) - w = topi.transpose(w, (1, 0, 2, 3)) + a = te.placeholder((num_experts, spatial, red), dtype=dtype) + w = topi.reshape(a, (num_experts, spatial, num_shards, red // num_shards)) + w = topi.transpose(w, (2, 0, 1, 3)) func = te.create_prim_func([a, w]) return func def moe_shard_gate_up_weight_scale(weight: relax.TensorStructInfo): - (num_experts, red, spatial), dtype = weight.shape, weight.dtype + (num_experts, spatial, red), dtype = weight.shape, weight.dtype spatial, red = int(spatial), int(red) if param_shape_is_already_sharded: spatial *= num_shards - a = te.placeholder((num_experts, red, spatial), dtype=dtype) - g = te.compute((num_experts, red, spatial // 2), lambda e, i, j: a[e, i, j]) - u = te.compute((num_experts, red, spatial // 2), lambda e, i, j: a[e, i, spatial // 2 + j]) - g = topi.reshape(g, (num_experts, red, num_shards, spatial // 2 // num_shards)) - u = topi.reshape(u, (num_experts, red, num_shards, spatial // 2 // num_shards)) - w = topi.concatenate((g, u), axis=3) - w = topi.reshape(w, (num_experts, red, num_shards, spatial // num_shards)) - w = topi.transpose(w, (2, 0, 1, 3)) + a = te.placeholder((num_experts, spatial, red), dtype=dtype) + g = te.compute((num_experts, spatial // 2, red), lambda e, i, j: a[e, i, j]) + u = te.compute((num_experts, spatial // 2, red), lambda e, i, j: a[e, spatial // 2 + i, j]) + g = topi.reshape(g, (num_experts, num_shards, spatial // 2 // num_shards, red)) + u = topi.reshape(u, (num_experts, num_shards, spatial // 2 // num_shards, red)) + w = topi.concatenate((g, u), axis=2) + w = topi.reshape(w, (num_experts, num_shards, spatial // num_shards, red)) + w = topi.transpose(w, (1, 0, 2, 3)) func = te.create_prim_func([a, w]) return func - # pylint: enable=invalid-name return { @@ -233,7 +232,12 @@ def add_to_shard_info(param_name: str, func_name: Optional[str]): def create_shard_transformation_func(param_manager, args, model_config) -> tvm.IRModule: - use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft", "q4f16_ft_group", "q8f16_ft_group"] + use_ft_quant = args.quantization.name in [ + "q4f16_ft", + "q8f16_ft", + "q4f16_ft_group", + "q8f16_ft_group", + ] if use_ft_quant: shard_strategy_to_func = _get_shard_strategies_ft( diff --git a/mlc_llm/relax_model/llama.py b/mlc_llm/relax_model/llama.py index 95aebc7754..b21553abd1 100644 --- a/mlc_llm/relax_model/llama.py +++ b/mlc_llm/relax_model/llama.py @@ -726,7 +726,9 @@ def forward( class LlamaModelForSingleSequence(LlamaModelBase): - def __init__(self, config: LlamaConfig, vocab_size_var: tvm.tir.SizeVar, sep_embed: bool = False): + def __init__( + self, config: LlamaConfig, vocab_size_var: tvm.tir.SizeVar, sep_embed: bool = False + ): super().__init__(config, vocab_size_var, sep_embed, enable_batching=False) def _prepare_decoder_attention_mask(self, input_shape, src_len, dtype): @@ -1251,7 +1253,6 @@ def f_convert_pname_fwd(pname: str) -> List[str]: if isinstance(config, MixtralConfig): for k, v in mappings: pname = pname.replace(k, v) - # pname = pname.replace("model.", "") if config.quantization_scheme.name == "q4f16_ft": if pname.endswith("scales"): # TODO: remove after quantization integarted @@ -1315,7 +1316,7 @@ def f_convert_param_bkwd(torch_pname: str, torch_param): def quantize(experts, relax_pname): print("quantizing experts", relax_pname) func = tvm.get_global_func("cutlass.symmetric_quantize") - nd_experts = tvm.nd.array(experts) + nd_experts = tvm.nd.array(experts.transpose(0, 2, 1)) qweight, qscale = func(nd_experts, True) if relax_pname.endswith("weight"): return qweight @@ -1332,50 +1333,20 @@ def f_compute_relax_param(relax_pname: str, torch_params: List[Any]): # combine along out_features dimension and then experts dimension experts = [] assert len(torch_params) == 2 * config.num_local_experts - - use_pytorch = True - if use_pytorch and dtype == "float16": - import torch - - torch_params = [torch.from_numpy(param).cuda() for param in torch_params] - for i in range(config.num_local_experts): - gate, up = ( - torch_params[i], - torch_params[i + config.num_local_experts], - ) # torch weight in col major - gate_up = torch.concatenate([gate, up], axis=0).type(torch.float16) - experts.append(gate_up.transpose(1, 0)) - result = torch.stack(experts) - result = result.cpu().numpy() - else: - for i in range(config.num_local_experts): - gate, up = ( - torch_params[i], - torch_params[i + config.num_local_experts], - ) # torch weight in col major - gate_up = np.concatenate([gate, up], axis=0).astype(dtype) - experts.append(gate_up.transpose()) - result = np.stack(experts) - # print(config.quantization_scheme.name) + for i in range(config.num_local_experts): + gate, up = ( + torch_params[i], + torch_params[i + config.num_local_experts], + ) # torch weight in col major + gate_up = np.concatenate([gate, up], axis=0).astype(dtype) + experts.append(gate_up) + result = np.stack(experts) if config.quantization_scheme.name == "q4f16_ft" and "experts" in relax_pname: result = quantize(result, relax_pname) return result if "experts" in relax_pname: - use_pytorch = True - if use_pytorch and dtype == "float16": - import torch - - torch_params = [torch.from_numpy(param).cuda() for param in torch_params] - experts = torch.stack( - [expert.type(torch.float16).transpose(1, 0) for expert in torch_params] - ) - result = experts.cpu().numpy() - else: - experts = [expert.astype(dtype).transpose() for expert in torch_params] - result = np.stack(experts) - # torch_params = [torch.from_numpy(param).cuda() for param in torch_params] - # experts = [expert.type(dtype).transpose(1, 0) for expert in torch_params] - # result = torch.stack(experts).detach().numpy() + experts = [expert.astype(dtype) for expert in torch_params] + result = np.stack(experts) if config.quantization_scheme.name == "q4f16_ft" and "experts" in relax_pname: result = quantize(result, relax_pname) return result diff --git a/mlc_llm/relax_model/mixtral.py b/mlc_llm/relax_model/mixtral.py index ea4b112780..b65472aa2b 100644 --- a/mlc_llm/relax_model/mixtral.py +++ b/mlc_llm/relax_model/mixtral.py @@ -18,7 +18,7 @@ def __init__(self, config: MixtralConfig, num_experts, in_features, out_features if config.quantization_scheme.name == "q0f16": # weight is row major self.weight = nn.Parameter( - (num_experts, in_features, out_features), + (num_experts, out_features, in_features), dtype="float16", ) elif config.quantization_scheme.name == "q4f16_ft":