From 89a531ac4ef9951e279e8d10da8a169a96d195b7 Mon Sep 17 00:00:00 2001 From: Forkoz <59298527+Ph0rk0z@users.noreply.github.com> Date: Wed, 17 May 2023 13:14:08 +0000 Subject: [PATCH 1/2] Fix for Pascal NaN redux Force push over-rode but it isn't fixed. --- bitsandbytes/autograd/_functions.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index cc32958da..ffda6a64f 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -407,8 +407,8 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): A_wo_outliers = A.clone() if state.idx is not None: A_wo_outliers[:, state.idx.long()] = 0 - output = torch.nn.functional.linear(A_wo_outliers, state.CB.to(A.dtype)) - output = output.mul_(state.SCB.unsqueeze(0).mul(1.0 / 127.0)) + CB = state.CB.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) + output = torch.nn.functional.linear(A_wo_outliers, CB) if bias is not None: output = output.add_(bias) @@ -469,7 +469,8 @@ def backward(ctx, grad_output): if state.CxBt is None: state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True) gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) - grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) + CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) + grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) elif state.CB is not None: CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) From 7f291f708c3f3fb4d6edb9a84b7226258cce7d4e Mon Sep 17 00:00:00 2001 From: Forkoz <59298527+Ph0rk0z@users.noreply.github.com> Date: Mon, 4 Mar 2024 05:29:25 -0600 Subject: [PATCH 2/2] Remove backwards pass. --- bitsandbytes/autograd/_functions.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index ffda6a64f..a86ed1953 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -469,8 +469,7 @@ def backward(ctx, grad_output): if state.CxBt is None: state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True) gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt) - CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) - grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) + grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A) elif state.CB is not None: CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))