Skip to content

Commit

Permalink
missed out biases in fused-lora
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed Oct 11, 2024
1 parent 73f8a58 commit 87ad314
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ def forward(ctx, X : torch.Tensor,

e = matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS, dropout=dropout_gate)
g = matmul_lora(X, upW, upW_quant, upA, upB, upS, dropout=dropout_up)
e += gate_bias
g += up_bias
if gate_bias is not None: e += gate_bias
if up_bias is not None: g += up_bias
h = _forward_function(e, g)
i = matmul_lora(h, downW, downW_quant, downA, downB, downS, dropout=dropout_down)
i += down_bias
if down_bias is not None: i += down_bias

# Extract post-dropout X for use in backward computation
_dropped_X = []
Expand Down Expand Up @@ -261,9 +261,9 @@ def forward(ctx, X : torch.Tensor,
K = matmul_lora(X, KW, KW_quant, KA, KB, KS, dropout=dropout_K)
V = matmul_lora(X, VW, VW_quant, VA, VB, VS, dropout=dropout_V)

Q += Q_bias
K += K_bias
V += V_bias
if Q_bias is not None: Q += Q_bias
if K_bias is not None: K += K_bias
if V_bias is not None: V += V_bias

# Extract post-dropout X for use in backward computation
_dropped_X = []
Expand Down Expand Up @@ -406,7 +406,7 @@ def forward(ctx, X : torch.Tensor,
W, W_quant, bias, A, B, S, dropout_O):
dtype = X.dtype
XW = matmul_lora(X, W, W_quant, A, B, S, dropout=dropout_O)
XW += bias
if bias is not None: XW += bias

# Extract post-dropout X for use in backward computation
if dropout_O is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ def forward(
e = matmul_lora(X, gateW, gateA, gateB, gateS, dropout=dropout_gate)
upW = dequant248(up_qweight, up_scales, up_qzeros, up_g_idx, up_bits)
g = matmul_lora(X, upW, upA, upB, upS, dropout=dropout_up)
e += gate_bias
g += up_bias
if gate_bias is not None: e += gate_bias
if up_bias is not None: g += up_bias
# f = torch.nn.functional.silu(e)
# h = f * g
h = swiglu_fg_kernel(e, g)
Expand All @@ -257,7 +257,7 @@ def forward(
down_qweight, down_scales, down_qzeros, down_g_idx, down_bits
)
i = matmul_lora(h, downW, downA, downB, downS, dropout=dropout_down)
i += down_bias
if down_bias is not None: i += down_bias

ctx.custom_saved_tensors = (
gate_qweight,
Expand Down Expand Up @@ -529,9 +529,9 @@ def forward(
K = matmul_lora(X, KW, KA, KB, KS, dropout=dropout_K)
V = matmul_lora(X, VW, VA, VB, VS, dropout=dropout_V)

Q += Q_bias
K += K_bias
V += V_bias
if Q_bias is not None: Q += Q_bias
if K_bias is not None: K += K_bias
if V_bias is not None: V += V_bias

ctx.custom_saved_tensors = (
Q_qweight,
Expand Down Expand Up @@ -774,7 +774,7 @@ def forward(
):
W = dequant248(O_qweight, O_scales, O_qzeros, O_g_idx, O_bits)
XW = matmul_lora(X, W, A, B, S, dropout=dropout_O)
XW += O_bias
if O_bias is not None: XW += O_bias
del W
ctx.custom_saved_tensors = (
O_qweight,
Expand Down

0 comments on commit 87ad314

Please sign in to comment.