diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index c654a0254..5d9983545 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -319,10 +319,10 @@ def forward( if ctx.needs_input_grad[1]: # Slower path - CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold) + CA, CAt, SCA, SCAt, outlier_cols = F.double_quant(A.to(torch.float16), threshold=state.threshold) else: # Fast path - CA, SCA, coo_tensorA = F.int8_vectorwise_quant(A.to(torch.float16), threshold=state.threshold) + CA, SCA, outlier_cols = F.int8_vectorwise_quant(A.to(torch.float16), threshold=state.threshold) CAt = SCAt = None has_grad = False @@ -339,8 +339,8 @@ def forward( # 2. Quantize B state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16)) - if state.threshold > 0.0 and coo_tensorA is not None: - state.idx = torch.unique(coo_tensorA._indices()[1]).long() + if state.threshold > 0.0 and outlier_cols is not None: + state.idx = outlier_cols # Zero out the outliers in the transposed 8bit inputs. if CAt is not None: diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 33a5d72eb..15402d7d4 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2546,7 +2546,7 @@ def double_quant(A: torch.Tensor, col_stats=None, row_stats=None, out_col=None, # Note: for inference, use the new int8_vectorwise_quant. # Use CUDA kernel for rowwise and COO tensor - quant_row, row_stats, coo_tensor = int8_vectorwise_quant(A, threshold=threshold) + quant_row, row_stats, outlier_cols = int8_vectorwise_quant(A, threshold=threshold) # PyTorch impl for colwise _, col_stats, outlier_mask = get_colrow_absmax(A, threshold=threshold) @@ -2559,7 +2559,7 @@ def double_quant(A: torch.Tensor, col_stats=None, row_stats=None, out_col=None, if out_col is not None: quant_col = out_col.copy_(quant_col) - return quant_row, quant_col, row_stats, col_stats.flatten().float(), coo_tensor + return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0): @@ -2574,13 +2574,9 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0): if threshold > 0.0: # TODO we could improve perf of this - - # A.masked_fill(A.abs() < threshold, 0.0).to_sparse_coo() - # coo_tensor = extract_outliers_new(A, threshold) - coo_tensor = torch.masked_fill(A, A.abs() < threshold, 0.0).to_sparse_coo() - + outlier_cols = torch.argwhere((A.abs() >= threshold).any(dim=0)).view(-1) else: - coo_tensor = None + outlier_cols = None with torch.cuda.device_of(A): lib.cint8_vector_quant( @@ -2593,7 +2589,7 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0): get_tensor_stream(A), ) - return out_row, row_stats, coo_tensor + return out_row, row_stats, outlier_cols @deprecated( diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py index dd4de5df6..f8fe14c48 100644 --- a/bitsandbytes/research/autograd/_functions.py +++ b/bitsandbytes/research/autograd/_functions.py @@ -217,12 +217,11 @@ def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None): # 1. Quantize A if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous() - CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold) + CA, CAt, SCA, SCAt, outlier_cols = F.double_quant(A.to(torch.float16), threshold=state.threshold) - if state.threshold > 0.0 and coo_tensorA is not None: + if state.threshold > 0.0 and outlier_cols is not None: if state.has_fp16_weights: - # idx = torch.unique(coo_tensorA.colidx).long() - idx = torch.unique(coo_tensorA._indices()[1]).long() + idx = outlier_cols CA[:, idx] = 0 # CAt[:, idx] = 0 subA = A[:, idx] @@ -257,9 +256,9 @@ def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None): else: has_grad = False - if coo_tensorA is not None and not state.has_fp16_weights: + if outlier_cols is not None and not state.has_fp16_weights: # extract outliers - state.idx = torch.unique(coo_tensorA._indices()[1]).long() + state.idx = outlier_cols # outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int()) outliers = state.CB[:, state.idx.long()].clone() @@ -287,7 +286,7 @@ def forward(ctx, A, B, out=None, bias=None, state: MatmulLtState = None): output = output.to(A.dtype).add_(bias) # 4. Mixed-precision decomposition matmul - if coo_tensorA is not None and subA is not None: + if outlier_cols is not None and subA is not None: output += torch.matmul(subA, state.subB) # 5. Save state @@ -327,7 +326,7 @@ def backward(ctx, grad_output): if len(grad_output.shape) == 3: grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() - Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) + Cgrad, Cgradt, SCgrad, SCgradt, outlier_cols = F.double_quant(grad_output.to(torch.float16)) if req_gradB: # print('back A shape', A.shape)