Skip to content

Commit

Permalink
fix dropout in fused_lora
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim committed Nov 6, 2024
1 parent 2cdd799 commit 80ed420
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ def backward(ctx, dY : torch.Tensor):
g = g .view(-1, g .shape[-1])
dtype = X.dtype

DW = matmul_lora(dY, downW.t(), downW_quant, downB, downA, downS)
DW = matmul_lora(
dY, downW.t(), downW_quant, downB, downA, downS,
dropout=(downX !=0)
)
DW, e, g = _backward_function(DW, e, g)
h, df, de = DW, e, g

Expand All @@ -148,12 +151,12 @@ def backward(ctx, dY : torch.Tensor):
upW = fast_dequantize(upW.t(), upW_quant)
dX = torch.matmul(df, upW.t(), out = X)
del upW
dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())
dX += (upX != 0) * (df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t()))

gateW = fast_dequantize(gateW.t(), gateW_quant)
dX += de @ gateW.t()
del gateW
dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())
dX += (gateX != 0) * (de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t()))

# gateW, gateW_quant, gate_bias, gateA, gateB, gateS,
# upW, upW_quant, up_bias, upA, upB, upS,
Expand Down Expand Up @@ -333,19 +336,19 @@ def backward(ctx, dQ, dK, dV):
QW = fast_dequantize(QW.t(), QW_quant)
dX = torch.matmul(dQ, QW.t(), out = X)
del QW
dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()))
dX += (QX != 0) * (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()))

# dK
KW = fast_dequantize(KW.t(), KW_quant)
dX += dK @ KW.t()
del KW
dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())
dX += (KX != 0) * (dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t()))

# dV
VW = fast_dequantize(VW.t(), VW_quant)
dX += dV @ VW.t()
del VW
dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())
dX += (VX != 0) * (dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t()))

# QW, QW_quant, Q_bias, QA, QB, QS,
# KW, KW_quant, K_bias, KA, KB, KS,
Expand Down Expand Up @@ -443,7 +446,7 @@ def backward(ctx, dY : torch.Tensor):
W = fast_dequantize(W.t(), W_quant)
dX = dY @ W.t()
del W
dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t())
dX += (OX != 0) * (dY @ B.to(dtype).t() @ (S * A.to(dtype).t()))

# W, W_quant, A, B, S
return (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,13 @@ def matmul_lora_canonicalized(X, W, A, B, s, dropout=None):

out = torch.matmul(X, W)
if dropout is not None:
X = dropout(X)
dropout.X = X
if isinstance(dropout, torch.Tensor):
X *= dropout
elif isinstance(dropout, torch.nn.Module):
X = dropout(X)
dropout.X = X
else:
raise NotImplementedError("dropout must be a tensor or module.")
A, B = A.t(), B.t()
out += (X @ A) @ (s * B)

Expand All @@ -152,9 +157,14 @@ def matmul_lora(X, W, A, B, s, out=None, dropout=None):
if A is not None:
# LoRA is enabled
if dropout is not None:
# save post-dropout X for backward computation
X = dropout(X)
dropout.X = X
if isinstance(dropout, torch.Tensor):
X *= dropout
elif isinstance(dropout, torch.nn.Module):
# save post-dropout X for backward computation
X = dropout(X)
dropout.X = X
else:
raise NotImplementedError("dropout must be a tensor or module.")
A, B = A.t(), B.t()
out += (X @ A.to(dtype)) @ (s * B.to(dtype))

Expand Down Expand Up @@ -343,7 +353,10 @@ def backward(ctx, dY: torch.Tensor):
downW = dequant248(
down_qweight, down_scales, down_qzeros, down_g_idx, down_bits
)
DW = matmul_lora(dY, downW.t(), downB, downA, downS)
DW = matmul_lora(
dY, downW.t(), downB, downA, downS,
dropout=(downX !=0)
)
# e = e.float()
# se = 1.0 / (1.0 + torch.exp(-e))
# f = (se * e).to(dtype)
Expand Down Expand Up @@ -377,14 +390,14 @@ def backward(ctx, dY: torch.Tensor):
upW = dequant248(up_qweight, up_scales, up_qzeros, up_g_idx, up_bits)
dX = torch.matmul(df, upW.t()) # , out=X)
del upW
dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())
dX += (upX != 0) * (df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t()))

gateW = dequant248(
gate_qweight, gate_scales, gate_qzeros, gate_g_idx, gate_bits
)
dX += de @ gateW.t()
del gateW
dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())
dX += (gateX != 0) * (de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t()))

# qweight, scales, qzeros, g_idx, bits
# upW, upW_quant, upA, upB, upS,
Expand Down Expand Up @@ -648,19 +661,19 @@ def backward(ctx, dQ, dK, dV):
QW = dequant248(Q_qweight, Q_scales, Q_qzeros, Q_g_idx, Q_bits)
dX = torch.matmul(dQ, QW.t()) # , out=X)
del QW
dX += dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t())
dX += (QX != 0) * (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()))

# dK
KW = dequant248(K_qweight, K_scales, K_qzeros, K_g_idx, K_bits)
dX += dK @ KW.t()
del KW
dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())
dX += (KX != 0) * (dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t()))

# dV
VW = dequant248(V_qweight, V_scales, V_qzeros, V_g_idx, V_bits)
dX += dV @ VW.t()
del VW
dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())
dX += (VX != 0) * (dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t()))

# Q_qweight, Q_scales, Q_qzeros, Q_wf, Q_g_idx, Q_bits, Q_bias, QA, QB, QS,
# K_qweight, K_scales, K_qzeros, K_wf, K_g_idx, K_bits, K_bias, KA, KB, KS,
Expand Down Expand Up @@ -817,7 +830,7 @@ def backward(ctx, dY: torch.Tensor):
W = dequant248(O_qweight, O_scales, O_qzeros, O_g_idx, O_bits)
dX = dY @ W.t()
del W
dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t())
dX += (OX !=0) * (dY @ B.to(dtype).t() @ (S * A.to(dtype).t()))

# O_qweight, O_scales, O_qzeros, O_wf, O_g_idx, O_bits, O_bias, A, B, S
return (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,16 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None, dropout=None):
if A is not None:
# LoRA is enabled
if dropout is not None:
# in order to return the dropped out X to the
# top level, we save it on the dropout module
X = dropout(X)
dropout.X = X
if isinstance(dropout, torch.Tensor):
X *= dropout
elif isinstance(dropout, torch.nn.Module):
# in order to return the dropped out X to the
# top level, we save it on the dropout module
X = dropout(X)
dropout.X = X
else:
raise NotImplementedError("dropout must be a tensor or module.")

A, B = A.t(), B.t()
out += (X @ A.to(dtype)) @ (s * B.to(dtype))
pass
Expand Down
1 change: 0 additions & 1 deletion plugins/fused-ops-and-kernels/tests/test_fused_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,6 @@ def test_adapter_gradients_match_with_attention_layer(
assert (
loss_unpatched - loss_patched
).abs() < LOSS_TOL, "Loss after foak patch do not match"
import pdb; pdb.set_trace()

# check input gradients
torch.allclose(
Expand Down

0 comments on commit 80ed420

Please sign in to comment.