diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py index 5ca4f8bf..63d7dcb7 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/bnb/fast_lora.py @@ -69,11 +69,11 @@ def forward(ctx, X : torch.Tensor, e = matmul_lora(X, gateW, gateW_quant, gateA, gateB, gateS, dropout=dropout_gate) g = matmul_lora(X, upW, upW_quant, upA, upB, upS, dropout=dropout_up) - e += gate_bias - g += up_bias + if gate_bias is not None: e += gate_bias + if up_bias is not None: g += up_bias h = _forward_function(e, g) i = matmul_lora(h, downW, downW_quant, downA, downB, downS, dropout=dropout_down) - i += down_bias + if down_bias is not None: i += down_bias # Extract post-dropout X for use in backward computation _dropped_X = [] @@ -261,9 +261,9 @@ def forward(ctx, X : torch.Tensor, K = matmul_lora(X, KW, KW_quant, KA, KB, KS, dropout=dropout_K) V = matmul_lora(X, VW, VW_quant, VA, VB, VS, dropout=dropout_V) - Q += Q_bias - K += K_bias - V += V_bias + if Q_bias is not None: Q += Q_bias + if K_bias is not None: K += K_bias + if V_bias is not None: V += V_bias # Extract post-dropout X for use in backward computation _dropped_X = [] @@ -406,7 +406,7 @@ def forward(ctx, X : torch.Tensor, W, W_quant, bias, A, B, S, dropout_O): dtype = X.dtype XW = matmul_lora(X, W, W_quant, A, B, S, dropout=dropout_O) - XW += bias + if bias is not None: XW += bias # Extract post-dropout X for use in backward computation if dropout_O is not None: diff --git a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py index b8e1cfcb..c372c1dd 100644 --- a/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py +++ b/plugins/fused-ops-and-kernels/src/fms_acceleration_foak/fused_ops/unsloth_lora/gptq/fast_lora.py @@ -247,8 +247,8 @@ def forward( e = matmul_lora(X, gateW, gateA, gateB, gateS, dropout=dropout_gate) upW = dequant248(up_qweight, up_scales, up_qzeros, up_g_idx, up_bits) g = matmul_lora(X, upW, upA, upB, upS, dropout=dropout_up) - e += gate_bias - g += up_bias + if gate_bias is not None: e += gate_bias + if up_bias is not None: g += up_bias # f = torch.nn.functional.silu(e) # h = f * g h = swiglu_fg_kernel(e, g) @@ -257,7 +257,7 @@ def forward( down_qweight, down_scales, down_qzeros, down_g_idx, down_bits ) i = matmul_lora(h, downW, downA, downB, downS, dropout=dropout_down) - i += down_bias + if down_bias is not None: i += down_bias ctx.custom_saved_tensors = ( gate_qweight, @@ -529,9 +529,9 @@ def forward( K = matmul_lora(X, KW, KA, KB, KS, dropout=dropout_K) V = matmul_lora(X, VW, VA, VB, VS, dropout=dropout_V) - Q += Q_bias - K += K_bias - V += V_bias + if Q_bias is not None: Q += Q_bias + if K_bias is not None: K += K_bias + if V_bias is not None: V += V_bias ctx.custom_saved_tensors = ( Q_qweight, @@ -774,7 +774,7 @@ def forward( ): W = dequant248(O_qweight, O_scales, O_qzeros, O_g_idx, O_bits) XW = matmul_lora(X, W, A, B, S, dropout=dropout_O) - XW += O_bias + if O_bias is not None: XW += O_bias del W ctx.custom_saved_tensors = ( O_qweight,