Skip to content

Commit

Permalink
[Mixtral] Use col major for MoE gemm and use small batch specializati…
Browse files Browse the repository at this point in the history
…on (#171)
  • Loading branch information
vinx13 authored Jan 24, 2024
1 parent 274ac99 commit f1bc68f
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 59 deletions.
34 changes: 19 additions & 15 deletions mlc_llm/relax_model/commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,33 +82,32 @@ def shard_gate_up_weight_scale(weight: relax.TensorStructInfo):
return func

def moe_shard_k_weight_scale(weight: relax.TensorStructInfo):
(num_experts, red, spatial), dtype = weight.shape, weight.dtype
(num_experts, spatial, red), dtype = weight.shape, weight.dtype
spatial, red = int(spatial), int(red)
if param_shape_is_already_sharded:
red *= num_shards
a = te.placeholder((num_experts, red, spatial), dtype=dtype)
w = topi.reshape(a, (num_experts, num_shards, red // num_shards, spatial))
w = topi.transpose(w, (1, 0, 2, 3))
a = te.placeholder((num_experts, spatial, red), dtype=dtype)
w = topi.reshape(a, (num_experts, spatial, num_shards, red // num_shards))
w = topi.transpose(w, (2, 0, 1, 3))
func = te.create_prim_func([a, w])
return func

def moe_shard_gate_up_weight_scale(weight: relax.TensorStructInfo):
(num_experts, red, spatial), dtype = weight.shape, weight.dtype
(num_experts, spatial, red), dtype = weight.shape, weight.dtype
spatial, red = int(spatial), int(red)
if param_shape_is_already_sharded:
spatial *= num_shards
a = te.placeholder((num_experts, red, spatial), dtype=dtype)
g = te.compute((num_experts, red, spatial // 2), lambda e, i, j: a[e, i, j])
u = te.compute((num_experts, red, spatial // 2), lambda e, i, j: a[e, i, spatial // 2 + j])
g = topi.reshape(g, (num_experts, red, num_shards, spatial // 2 // num_shards))
u = topi.reshape(u, (num_experts, red, num_shards, spatial // 2 // num_shards))
w = topi.concatenate((g, u), axis=3)
w = topi.reshape(w, (num_experts, red, num_shards, spatial // num_shards))
w = topi.transpose(w, (2, 0, 1, 3))
a = te.placeholder((num_experts, spatial, red), dtype=dtype)
g = te.compute((num_experts, spatial // 2, red), lambda e, i, j: a[e, i, j])
u = te.compute((num_experts, spatial // 2, red), lambda e, i, j: a[e, spatial // 2 + i, j])
g = topi.reshape(g, (num_experts, num_shards, spatial // 2 // num_shards, red))
u = topi.reshape(u, (num_experts, num_shards, spatial // 2 // num_shards, red))
w = topi.concatenate((g, u), axis=2)
w = topi.reshape(w, (num_experts, num_shards, spatial // num_shards, red))
w = topi.transpose(w, (1, 0, 2, 3))
func = te.create_prim_func([a, w])
return func


# pylint: enable=invalid-name

return {
Expand Down Expand Up @@ -233,7 +232,12 @@ def add_to_shard_info(param_name: str, func_name: Optional[str]):


def create_shard_transformation_func(param_manager, args, model_config) -> tvm.IRModule:
use_ft_quant = args.quantization.name in ["q4f16_ft", "q8f16_ft", "q4f16_ft_group", "q8f16_ft_group"]
use_ft_quant = args.quantization.name in [
"q4f16_ft",
"q8f16_ft",
"q4f16_ft_group",
"q8f16_ft_group",
]

if use_ft_quant:
shard_strategy_to_func = _get_shard_strategies_ft(
Expand Down
57 changes: 14 additions & 43 deletions mlc_llm/relax_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,9 @@ def forward(


class LlamaModelForSingleSequence(LlamaModelBase):
def __init__(self, config: LlamaConfig, vocab_size_var: tvm.tir.SizeVar, sep_embed: bool = False):
def __init__(
self, config: LlamaConfig, vocab_size_var: tvm.tir.SizeVar, sep_embed: bool = False
):
super().__init__(config, vocab_size_var, sep_embed, enable_batching=False)

def _prepare_decoder_attention_mask(self, input_shape, src_len, dtype):
Expand Down Expand Up @@ -1251,7 +1253,6 @@ def f_convert_pname_fwd(pname: str) -> List[str]:
if isinstance(config, MixtralConfig):
for k, v in mappings:
pname = pname.replace(k, v)
# pname = pname.replace("model.", "")
if config.quantization_scheme.name == "q4f16_ft":
if pname.endswith("scales"):
# TODO: remove after quantization integarted
Expand Down Expand Up @@ -1315,7 +1316,7 @@ def f_convert_param_bkwd(torch_pname: str, torch_param):
def quantize(experts, relax_pname):
print("quantizing experts", relax_pname)
func = tvm.get_global_func("cutlass.symmetric_quantize")
nd_experts = tvm.nd.array(experts)
nd_experts = tvm.nd.array(experts.transpose(0, 2, 1))
qweight, qscale = func(nd_experts, True)
if relax_pname.endswith("weight"):
return qweight
Expand All @@ -1332,50 +1333,20 @@ def f_compute_relax_param(relax_pname: str, torch_params: List[Any]):
# combine along out_features dimension and then experts dimension
experts = []
assert len(torch_params) == 2 * config.num_local_experts

use_pytorch = True
if use_pytorch and dtype == "float16":
import torch

torch_params = [torch.from_numpy(param).cuda() for param in torch_params]
for i in range(config.num_local_experts):
gate, up = (
torch_params[i],
torch_params[i + config.num_local_experts],
) # torch weight in col major
gate_up = torch.concatenate([gate, up], axis=0).type(torch.float16)
experts.append(gate_up.transpose(1, 0))
result = torch.stack(experts)
result = result.cpu().numpy()
else:
for i in range(config.num_local_experts):
gate, up = (
torch_params[i],
torch_params[i + config.num_local_experts],
) # torch weight in col major
gate_up = np.concatenate([gate, up], axis=0).astype(dtype)
experts.append(gate_up.transpose())
result = np.stack(experts)
# print(config.quantization_scheme.name)
for i in range(config.num_local_experts):
gate, up = (
torch_params[i],
torch_params[i + config.num_local_experts],
) # torch weight in col major
gate_up = np.concatenate([gate, up], axis=0).astype(dtype)
experts.append(gate_up)
result = np.stack(experts)
if config.quantization_scheme.name == "q4f16_ft" and "experts" in relax_pname:
result = quantize(result, relax_pname)
return result
if "experts" in relax_pname:
use_pytorch = True
if use_pytorch and dtype == "float16":
import torch

torch_params = [torch.from_numpy(param).cuda() for param in torch_params]
experts = torch.stack(
[expert.type(torch.float16).transpose(1, 0) for expert in torch_params]
)
result = experts.cpu().numpy()
else:
experts = [expert.astype(dtype).transpose() for expert in torch_params]
result = np.stack(experts)
# torch_params = [torch.from_numpy(param).cuda() for param in torch_params]
# experts = [expert.type(dtype).transpose(1, 0) for expert in torch_params]
# result = torch.stack(experts).detach().numpy()
experts = [expert.astype(dtype) for expert in torch_params]
result = np.stack(experts)
if config.quantization_scheme.name == "q4f16_ft" and "experts" in relax_pname:
result = quantize(result, relax_pname)
return result
Expand Down
2 changes: 1 addition & 1 deletion mlc_llm/relax_model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, config: MixtralConfig, num_experts, in_features, out_features
if config.quantization_scheme.name == "q0f16":
# weight is row major
self.weight = nn.Parameter(
(num_experts, in_features, out_features),
(num_experts, out_features, in_features),
dtype="float16",
)
elif config.quantization_scheme.name == "q4f16_ft":
Expand Down

0 comments on commit f1bc68f

Please sign in to comment.