From 80ed420206feb5ef12e98ec76cfaf6afa6eba37d Mon Sep 17 00:00:00 2001 From: Yu Chin Fabian Lim Date: Wed, 6 Nov 2024 07:57:47 +0000 Subject: [PATCH] fix dropout in fused_lora Signed-off-by: Yu Chin Fabian Lim --- .../fused_ops/unsloth_lora/bnb/fast_lora.py | 17 +++++---- .../fused_ops/unsloth_lora/gptq/fast_lora.py | 37 +++++++++++++------ .../fused_ops/unsloth_lora/utils.py | 14 +++++-- .../tests/test_fused_ops.py | 1 - 4 files changed, 45 insertions(+), 24 deletions(-) diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py index 63d7dcb7..55e35ca6 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py @@ -121,7 +121,10 @@ def backward(ctx, dY : torch.Tensor): g = g .view(-1, g .shape[-1]) dtype = X.dtype - DW = matmul_lora(dY, downW.t(), downW_quant, downB, downA, downS) + DW = matmul_lora( + dY, downW.t(), downW_quant, downB, downA, downS, + dropout=(downX !=0) + ) DW, e, g = _backward_function(DW, e, g) h, df, de = DW, e, g @@ -148,12 +151,12 @@ def backward(ctx, dY : torch.Tensor): upW = fast_dequantize(upW.t(), upW_quant) dX = torch.matmul(df, upW.t(), out = X) del upW - dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t()) + dX += (upX != 0) * (df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())) gateW = fast_dequantize(gateW.t(), gateW_quant) dX += de @ gateW.t() del gateW - dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t()) + dX += (gateX != 0) * (de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())) # gateW, gateW_quant, gate_bias, gateA, gateB, gateS, # upW, upW_quant, up_bias, upA, upB, upS, @@ -333,19 +336,19 @@ def backward(ctx, dQ, dK, dV): QW = fast_dequantize(QW.t(), QW_quant) dX = torch.matmul(dQ, QW.t(), out = X) del QW - dX += (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t())) + dX += (QX != 0) * (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t())) # dK KW = fast_dequantize(KW.t(), KW_quant) dX += dK @ KW.t() del KW - dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t()) + dX += (KX != 0) * (dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())) # dV VW = fast_dequantize(VW.t(), VW_quant) dX += dV @ VW.t() del VW - dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t()) + dX += (VX != 0) * (dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())) # QW, QW_quant, Q_bias, QA, QB, QS, # KW, KW_quant, K_bias, KA, KB, KS, @@ -443,7 +446,7 @@ def backward(ctx, dY : torch.Tensor): W = fast_dequantize(W.t(), W_quant) dX = dY @ W.t() del W - dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t()) + dX += (OX != 0) * (dY @ B.to(dtype).t() @ (S * A.to(dtype).t())) # W, W_quant, A, B, S return ( diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py index 2f0fc89c..31aa5d5e 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py @@ -129,8 +129,13 @@ def matmul_lora_canonicalized(X, W, A, B, s, dropout=None): out = torch.matmul(X, W) if dropout is not None: - X = dropout(X) - dropout.X = X + if isinstance(dropout, torch.Tensor): + X *= dropout + elif isinstance(dropout, torch.nn.Module): + X = dropout(X) + dropout.X = X + else: + raise NotImplementedError("dropout must be a tensor or module.") A, B = A.t(), B.t() out += (X @ A) @ (s * B) @@ -152,9 +157,14 @@ def matmul_lora(X, W, A, B, s, out=None, dropout=None): if A is not None: # LoRA is enabled if dropout is not None: - # save post-dropout X for backward computation - X = dropout(X) - dropout.X = X + if isinstance(dropout, torch.Tensor): + X *= dropout + elif isinstance(dropout, torch.nn.Module): + # save post-dropout X for backward computation + X = dropout(X) + dropout.X = X + else: + raise NotImplementedError("dropout must be a tensor or module.") A, B = A.t(), B.t() out += (X @ A.to(dtype)) @ (s * B.to(dtype)) @@ -343,7 +353,10 @@ def backward(ctx, dY: torch.Tensor): downW = dequant248( down_qweight, down_scales, down_qzeros, down_g_idx, down_bits ) - DW = matmul_lora(dY, downW.t(), downB, downA, downS) + DW = matmul_lora( + dY, downW.t(), downB, downA, downS, + dropout=(downX !=0) + ) # e = e.float() # se = 1.0 / (1.0 + torch.exp(-e)) # f = (se * e).to(dtype) @@ -377,14 +390,14 @@ def backward(ctx, dY: torch.Tensor): upW = dequant248(up_qweight, up_scales, up_qzeros, up_g_idx, up_bits) dX = torch.matmul(df, upW.t()) # , out=X) del upW - dX += df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t()) + dX += (upX != 0) * (df @ upB.to(dtype).t() @ (upS * upA.to(dtype).t())) gateW = dequant248( gate_qweight, gate_scales, gate_qzeros, gate_g_idx, gate_bits ) dX += de @ gateW.t() del gateW - dX += de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t()) + dX += (gateX != 0) * (de @ gateB.to(dtype).t() @ (gateS * gateA.to(dtype).t())) # qweight, scales, qzeros, g_idx, bits # upW, upW_quant, upA, upB, upS, @@ -648,19 +661,19 @@ def backward(ctx, dQ, dK, dV): QW = dequant248(Q_qweight, Q_scales, Q_qzeros, Q_g_idx, Q_bits) dX = torch.matmul(dQ, QW.t()) # , out=X) del QW - dX += dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t()) + dX += (QX != 0) * (dQ @ QB.to(dtype).t() @ (QS * QA.to(dtype).t())) # dK KW = dequant248(K_qweight, K_scales, K_qzeros, K_g_idx, K_bits) dX += dK @ KW.t() del KW - dX += dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t()) + dX += (KX != 0) * (dK @ KB.to(dtype).t() @ (KS * KA.to(dtype).t())) # dV VW = dequant248(V_qweight, V_scales, V_qzeros, V_g_idx, V_bits) dX += dV @ VW.t() del VW - dX += dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t()) + dX += (VX != 0) * (dV @ VB.to(dtype).t() @ (VS * VA.to(dtype).t())) # Q_qweight, Q_scales, Q_qzeros, Q_wf, Q_g_idx, Q_bits, Q_bias, QA, QB, QS, # K_qweight, K_scales, K_qzeros, K_wf, K_g_idx, K_bits, K_bias, KA, KB, KS, @@ -817,7 +830,7 @@ def backward(ctx, dY: torch.Tensor): W = dequant248(O_qweight, O_scales, O_qzeros, O_g_idx, O_bits) dX = dY @ W.t() del W - dX += dY @ B.to(dtype).t() @ (S * A.to(dtype).t()) + dX += (OX !=0) * (dY @ B.to(dtype).t() @ (S * A.to(dtype).t())) # O_qweight, O_scales, O_qzeros, O_wf, O_g_idx, O_bits, O_bias, A, B, S return ( diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/utils.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/utils.py index 4316aa40..300cfbf3 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/utils.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/utils.py @@ -260,10 +260,16 @@ def matmul_lora(X, W, W_quant, A, B, s, out = None, dropout=None): if A is not None: # LoRA is enabled if dropout is not None: - # in order to return the dropped out X to the - # top level, we save it on the dropout module - X = dropout(X) - dropout.X = X + if isinstance(dropout, torch.Tensor): + X *= dropout + elif isinstance(dropout, torch.nn.Module): + # in order to return the dropped out X to the + # top level, we save it on the dropout module + X = dropout(X) + dropout.X = X + else: + raise NotImplementedError("dropout must be a tensor or module.") + A, B = A.t(), B.t() out += (X @ A.to(dtype)) @ (s * B.to(dtype)) pass diff --git a/plugins/fused-ops-and-kernels/tests/test_fused_ops.py b/plugins/fused-ops-and-kernels/tests/test_fused_ops.py index cd75bbac..1ff44ffa 100644 --- a/plugins/fused-ops-and-kernels/tests/test_fused_ops.py +++ b/plugins/fused-ops-and-kernels/tests/test_fused_ops.py @@ -412,7 +412,6 @@ def test_adapter_gradients_match_with_attention_layer( assert ( loss_unpatched - loss_patched ).abs() < LOSS_TOL, "Loss after foak patch do not match" - import pdb; pdb.set_trace() # check input gradients torch.allclose(