diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index cc32958da..a86ed1953 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -407,8 +407,8 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): A_wo_outliers = A.clone() if state.idx is not None: A_wo_outliers[:, state.idx.long()] = 0 - output = torch.nn.functional.linear(A_wo_outliers, state.CB.to(A.dtype)) - output = output.mul_(state.SCB.unsqueeze(0).mul(1.0 / 127.0)) + CB = state.CB.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) + output = torch.nn.functional.linear(A_wo_outliers, CB) if bias is not None: output = output.add_(bias)