Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
it works!
Browse files Browse the repository at this point in the history
  • Loading branch information
ElizaWszola committed Jul 12, 2024
1 parent cda9a0f commit c469b74
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 48 deletions.
18 changes: 10 additions & 8 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,14 +737,16 @@ def fused_marlin_moe(hidden_states: torch.Tensor,
E = w1.shape[0]
N = w2.shape[1] * 16

print("hidden_states shape:", hidden_states.shape)
print("w1 shape:", w1.shape)
print("w2 shape:", w2.shape)
print("gating_output shape:", gating_output.shape)
print("g_idx1 shape:", g_idx1.shape)
print("g_idx2 shape:", g_idx2.shape)
print("w1_scale shape:", w1_scale.shape)
print("w2_scale shape:", w2_scale.shape)
# print("hidden_states shape:", hidden_states)
# print("w1 shape:", w1)
# print("w2 shape:", w2)
# print("gating_output shape:", gating_output)
# print("g_idx1 shape:", g_idx1)
# print("g_idx2 shape:", g_idx2)
# print("w1_scale shape:", w1_scale)
# print("w2_scale shape:", w2_scale)

# raise ValueError("stop")

topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
Expand Down
46 changes: 21 additions & 25 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx)
w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx)
for e in range(num_experts):
w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e])#.to(torch.int)
w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e])#.to(torch.int)
w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_g_idx[e]).to(torch.int32)
w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to(torch.int32)
w13_sorted_g_idx[e] = layer.w13_g_idx[e][w13_g_idx_sort_indices[e]]
w2_sorted_g_idx[e] = layer.w2_g_idx[e][w2_g_idx_sort_indices[e]]

Expand Down Expand Up @@ -246,16 +246,16 @@ def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, size_n: int,

# print("3", layer.w13_scales)

print("*")
print("hidden:", x.shape)
print("w13 before:", layer.w13_qweight.shape)
print("w2 before:", layer.w2_qweight.shape)
print("w13 args:", layer.w13_qweight.shape[1]
* self.quant_config.pack_factor,
layer.w13_qweight.shape[2])
print("w2 args:", layer.w2_qweight.shape[1]
* self.quant_config.pack_factor,
layer.w2_qweight.shape[2])
# print("*")
# print("hidden:", x.shape)
# print("w13 before:", layer.w13_qweight.shape)
# print("w2 before:", layer.w2_qweight.shape)
# print("w13 args:", layer.w13_qweight.shape[1]
# * self.quant_config.pack_factor,
# layer.w13_qweight.shape[2])
# print("w2 args:", layer.w2_qweight.shape[1]
# * self.quant_config.pack_factor,
# layer.w2_qweight.shape[2])

# print("weight type:", layer.w13_qweight.dtype)

Expand All @@ -277,14 +277,14 @@ def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
)
replace_tensor("w2_qweight", marlin_w2_qweight)

print("w13 after:", marlin_w13_qweight.shape)
print("w2 after:", marlin_w2_qweight.shape)
# print("w13 after:", marlin_w13_qweight.shape)
# print("w2 after:", marlin_w2_qweight.shape)

print("w13 scales before:", layer.w13_scales.shape)
print("w2 scales before:", layer.w2_scales.shape)
print("w13 args:", x.shape[1], layer.w13_scales.shape[2])
print("w2 args:", layer.w2_scales.shape[1] * self.quant_config.pack_factor,
x.shape[1])
# print("w13 scales before:", layer.w13_scales.shape)
# print("w2 scales before:", layer.w2_scales.shape)
# print("w13 args:", x.shape[1], layer.w13_scales.shape[2])
# print("w2 args:", layer.w2_scales.shape[1] * self.quant_config.pack_factor,
# x.shape[1])

# Repack scales
marlin_w13_scales = marlin_moe_permute_scales(
Expand All @@ -305,12 +305,8 @@ def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
)
replace_tensor("w2_scales", marlin_w2_scales)

print("w13 scales after:", marlin_w13_scales.shape)
print("w2 scales after:", marlin_w2_scales.shape)

print(x.shape)
print(layer.w13_qweight.shape)
print(layer.w2_qweight.shape)
# print("w13 scales after:", marlin_w13_scales.shape)
# print("w2 scales after:", marlin_w2_scales.shape)

return fused_marlin_moe(x,
layer.w13_qweight,
Expand Down
28 changes: 14 additions & 14 deletions vllm/model_executor/layers/quantization/gptq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,12 +837,12 @@ def replace_tensor(name, new_t):

layer_qweight13 = torch.cat((layer.qweight1, layer.qweight3), 1)

print("*")
print("hidden:", x.shape)
print("w13 before:", layer_qweight13.shape)
print("w2 before:", layer.qweight2.shape)
print("w13 args:", part_size_k, layer_qweight13.shape[1])
print("w2 args:", part_size_n, part_size_k)
# print("*")
# print("hidden:", x.shape)
# print("w13 before:", layer_qweight13.shape)
# print("w2 before:", layer.qweight2.shape)
# print("w13 args:", part_size_k, layer_qweight13.shape[1])
# print("w2 args:", part_size_n, part_size_k)

# Repack weights
# marlin_qweight1 = ops.gptq_marlin_repack(
Expand Down Expand Up @@ -880,8 +880,8 @@ def replace_tensor(name, new_t):
)
replace_tensor("qweight13", marlin_qweight13)

print("w13 after:", marlin_qweight13.shape)
print("w2 after:", marlin_qweight2.shape)
# print("w13 after:", marlin_qweight13.shape)
# print("w2 after:", marlin_qweight2.shape)

# print("done repack", layer.get_parameter("qweight1").shape,
# layer.get_parameter("qweight2").shape,
Expand All @@ -896,10 +896,10 @@ def replace_tensor(name, new_t):

layer_scales13 = torch.cat((layer.scales1, layer.scales3), 1)

print("w13 scales before:", layer_scales13.shape)
print("w2 scales before:", layer.scales2.shape)
print("w13 args:", part_size_k, layer_qweight13.shape[1])
print("w2 args:", layer.scales2.shape[0] * 8, layer.scales2.shape[1])
# print("w13 scales before:", layer_scales13.shape)
# print("w2 scales before:", layer.scales2.shape)
# print("w13 args:", part_size_k, layer_qweight13.shape[1])
# print("w2 args:", layer.scales2.shape[0] * 8, layer.scales2.shape[1])

# marlin_scales1 = marlin_permute_scales(
# layer.scales1,
Expand Down Expand Up @@ -934,8 +934,8 @@ def replace_tensor(name, new_t):
)
replace_tensor("scales13", marlin_scales13)

print("w13 scales after:", marlin_scales13.shape)
print("w2 scales after:", marlin_scales2.shape)
# print("w13 scales after:", marlin_scales13.shape)
# print("w2 scales after:", marlin_scales2.shape)

# raise ValueError("stop")

Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/mixtral_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if self.experimental_fused_moe:

if not self.old_code:
return self.experts(hidden_states, router_logits)
return self.experts(hidden_states.half(), router_logits).bfloat16()

qweight13_l = []
scales13_l = []
Expand Down

0 comments on commit c469b74

Please sign in to comment.