diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 133f9e066..41555a450 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -308,7 +308,14 @@ def forward( # 1. Quantize A if len(A.shape) == 3: A = A.reshape(-1, A.shape[-1]) - CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold) + + if ctx.needs_input_grad[1]: + # Slower path + CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold) + else: + # Fast path + CA, SCA, coo_tensorA = F.int8_vectorwise_quant(A.to(torch.float16), threshold=state.threshold) + CAt = SCAt = None has_grad = False @@ -322,20 +329,24 @@ def forward( state.reset_grads() # 2. Quantize B - ( - state.CB, - state.CBt, - state.SCB, - state.SCBt, - _, - ) = F.double_quant(B.to(torch.float16)) + state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16)) + + # ( + # state.CB, + # state.CBt, + # state.SCB, + # state.SCBt, + # _, + # ) = F.double_quant(B.to(torch.float16)) if state.threshold > 0.0 and coo_tensorA is not None: state.idx = torch.unique(coo_tensorA._indices()[1]).long() # Zero out the outliers in the int8 inputs CA[:, state.idx] = 0 - # CAt[:, state.idx] = 0 + + if CAt is not None: + CAt[:, state.idx] = 0 # Extract the input outliers in original precision subA = A[:, state.idx] @@ -372,7 +383,7 @@ def forward( ctx.tensors = (CAt, subA, A) ctx.tensor_states = (SCAt, state.idx) else: - ctx.tensors = [None, None, None] # A] + ctx.tensors = [None, None, None] ctx.tensor_states = (None, None) ctx.save_for_backward(None, None) @@ -403,17 +414,16 @@ def backward(ctx, grad_output): if len(grad_output.shape) == 3: grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() - Cgrad, Cgradt, SCgrad, SCgradt, _ = F.double_quant(grad_output.to(torch.float16)) - # if req_gradB: - - # grad_B = torch.matmul(grad_output.t(), A) - # if state.threshold > 0.0 and subA is not None: - # grad_B[:, idx] += torch.matmul(grad_output.t(), subA) if req_gradB: - gradB32, SgradB32 = F.igemmlt(Cgrad.t().contiguous(), CAt.t()) - grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) + grad_B = torch.matmul(grad_output.t(), A) if state.threshold > 0.0 and subA is not None: grad_B[:, idx] += torch.matmul(grad_output.t(), subA) + # Cgrad, Cgradt, SCgrad, SCgradt, _ = F.double_quant(grad_output.to(torch.float16)) + # if req_gradB: + # gradB32, SgradB32 = F.igemmlt(Cgrad.t().contiguous(), CAt.t()) + # grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) + # if state.threshold > 0.0 and subA is not None: + # grad_B[:, idx] += torch.matmul(grad_output.t(), subA) if req_gradA: # grad_output @ B.T diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 6c8ffe3d1..f4ff3eafa 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -2409,6 +2409,7 @@ def get_colrow_absmax( if row_stats is None: # shape [rows]; unsqueeze(-1) gives [rows,1] + # We have a CUDA kernel for row max, but not yet for cols. row_stats = get_row_absmax(A, threshold) if col_stats is None: @@ -2521,29 +2522,42 @@ def extract_outliers_new(A: torch.Tensor, threshold: float): def double_quant(A: torch.Tensor, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): + # TODO: Optimize/write CUDA kernel for this? + # Note: for inference, use the new int8_vectorwise_quant. + + # Use CUDA kernel for rowwise and COO tensor + quant_row, row_stats, coo_tensor = int8_vectorwise_quant(A, threshold=threshold) + + # PyTorch impl for colwise + _, col_stats, outlier_mask = get_colrow_absmax(A, threshold=threshold) + if threshold > 0.0 and outlier_mask is not None: + A = A.masked_fill(outlier_mask, 0.0) + quant_col = torch.round(A.mul(C) / col_stats.unsqueeze(0)).to(torch.int8) + + if out_row is not None: + quant_row = out_row.copy_(quant_row) + if out_col is not None: + quant_col = out_col.copy_(quant_col) + + return quant_row, quant_col, row_stats, col_stats.flatten().float(), coo_tensor + + +def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0): assert A.dtype == torch.half + is_on_gpu([A]) rows = prod(A.shape[:-1]) cols = A.shape[-1] row_stats = torch.empty((rows,), device=A.device, dtype=torch.float32) - out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8) if threshold > 0.0: - # Extract outliers to COO tensor: - # 1. Zero out all of the non-outliers, convert to COO. - # 2. Zero out the outliers in the dense tensor. # TODO we could improve perf of this - # outlier_mask = A.abs() >= threshold - # coo_tensor = A.masked_fill(outlier_mask == False, 0.0).to_sparse_coo() - # A = A.masked_fill(outlier_mask, 0.0) coo_tensor = extract_outliers_new(A, threshold) else: coo_tensor = None - is_on_gpu([A, row_stats]) - with torch.cuda.device_of(A): lib.cint8_vector_quant( get_ptr(A), @@ -2554,9 +2568,7 @@ def double_quant(A: torch.Tensor, col_stats=None, row_stats=None, out_col=None, ct.c_int32(cols), ) - # TODO: col_stats - - return out_row, None, row_stats, None, coo_tensor # coo_tensor #col_stats.flatten().float(), coo_tensor + return out_row, row_stats, coo_tensor # coo_tensor #col_stats.flatten().float(), coo_tensor def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None): diff --git a/csrc/kernels.cu b/csrc/kernels.cu index d0ea0d270..45ee0a3ed 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -3612,7 +3612,7 @@ template __global__ void kgemm_4bit_inferenc #pragma unroll for(int k = 0; k < num_values_8bit/4; k++) { - #if __CUDA_ARCH__ >= 800 + #if BNB_BF16_AVAILABLE local_B[k*2] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*local_absmax; local_B[k*2 + 1] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*local_absmax; #else @@ -3649,7 +3649,7 @@ template __global__ void kgemm_4bit_inferenc #pragma unroll for(int k = 0; k < num_values_4bit/4; k++) { - #if __CUDA_ARCH__ >= 800 + #if BNB_BF16_AVAILABLE local_C += (float)(local_A[k]*local_B[k]); #else // bf16 multipliation not supported diff --git a/tests/test_autograd.py b/tests/test_autograd.py index 89dce644b..3717a9572 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -253,13 +253,16 @@ def test_matmullt(dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, dec if not has_fp16_weights: if not transpose[0] and not transpose[1]: B2 = B2.t().contiguous() - ( - state.CB, - CBt, - state.SCB, - SCBt, - coo_tensorB, - ) = bnb.functional.double_quant(B2.to(torch.float16)) + + state.CB, state.SCB, _ = bnb.functional.int8_vectorwise_quant(B2.to(torch.float16)) + + # ( + # state.CB, + # CBt, + # state.SCB, + # SCBt, + # coo_tensorB, + # ) = bnb.functional.double_quant(B2.to(torch.float16)) B2 = state.CB if not transpose[0] and transpose[1]: diff --git a/tests/test_functional.py b/tests/test_functional.py index 9b7004946..34dbf56fd 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1132,17 +1132,37 @@ def test_overflow(): c2 = torch.matmul(a.float(), b.float().t()) +@pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) +@pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) +def test_coo_double_quant(dim1, dim2): + threshold = 2.00 + for i in range(k): + A = torch.randn(dim1, dim2, device="cuda").half() + + idx = torch.abs(A) >= threshold + CA, _, statsA, _, coo_tensor = F.double_quant(A, threshold=threshold) + + if coo_tensor is not None: + A1 = A * idx + A2 = coo_tensor.to_dense() + torch.testing.assert_close(A1, A2) + + A1 = A * (idx == 0) + A2 = (CA.float() * statsA.unsqueeze(1) / 127).half() + torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) + + # @pytest.mark.parametrize("dim1", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim1")) # @pytest.mark.parametrize("dim2", get_test_dims(1, 4 * 1024, n=2), ids=id_formatter("dim2")) @pytest.mark.parametrize("dim1", [512, 2048], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [1024, 4096], ids=id_formatter("dim2")) -def test_coo_double_quant(dim1, dim2): +def test_coo_int8_vectorwise_quant(dim1, dim2): threshold = 3.00 for i in range(k): A = torch.randn(dim1, dim2, device="cuda").half() idx = torch.abs(A) >= threshold - CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) + CA, statsA, coo_tensor = F.int8_vectorwise_quant(A, threshold=threshold) if coo_tensor is not None: A1 = A * idx @@ -1239,13 +1259,13 @@ def test_integrated_sparse_decomp(dim1, dim2): w1 = torch.randn(dim1, dim2).cuda().half() out1 = torch.matmul(A, w1.t()) - Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1) - CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A) + Cw1, statsw1, coo_tensor = F.int8_vectorwise_quant(w1) + CA, statsA, coo_tensor = F.int8_vectorwise_quant(A) out1_32, Sout1_32 = F.igemmlt(CA, Cw1) out2 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) - CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=threshold) + CA, statsA, coo_tensor = F.int8_vectorwise_quant(A, threshold=threshold) out1_32, Sout1_32 = F.igemmlt(CA, Cw1) out3 = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1) diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py index 48c3a9ea8..3f80beacf 100644 --- a/tests/test_linear8bitlt.py +++ b/tests/test_linear8bitlt.py @@ -72,10 +72,11 @@ def test_linear_no_igemmlt(): assert linear_custom.state.CB is not None assert not linear_custom.state.has_fp16_weights - assert torch.allclose(fx_ref, fx_ours, atol=0.02) - assert torch.allclose(x_ref.grad, x_ours.grad, atol=0.01) - # assert linear_custom.state.CxB is None + idx = torch.isclose(fx_ref, fx_ours, atol=0.02, rtol=1e-5) + assert (idx == 0).sum().item() < fx_ref.numel() * 2.5e-4 + torch.testing.assert_close(fx_ref, fx_ours, atol=0.03, rtol=1e-5) + torch.testing.assert_close(x_ref.grad, x_ours.grad, atol=0.01, rtol=1e-5) @pytest.mark.parametrize("has_fp16_weights", TRUE_FALSE, ids=id_formatter("has_fp16_weights"))