diff --git a/benchmarking/switchback/speed_benchmark.py b/benchmarking/switchback/speed_benchmark.py index 9ad991194..b0983d0b8 100644 --- a/benchmarking/switchback/speed_benchmark.py +++ b/benchmarking/switchback/speed_benchmark.py @@ -8,7 +8,7 @@ from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose -from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze +from bitsandbytes.triton.int8_matmul_mixed_dequantize import int8_matmul_mixed_dequantize # KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large. @@ -72,8 +72,8 @@ def get_time(k, fn, info_dict): get_time('standard_gx', lambda : g.matmul(w), info) get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info) get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info) - get_time('global_fwd', lambda : int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info) - get_time('global_bwd', lambda : int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info) + get_time('global_fwd', lambda : int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info) + get_time('global_bwd', lambda : int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info) get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info) get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info) get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info) diff --git a/bitsandbytes/nn/triton_based_modules.py b/bitsandbytes/nn/triton_based_modules.py index 6fbf583b9..de07ac647 100644 --- a/bitsandbytes/nn/triton_based_modules.py +++ b/bitsandbytes/nn/triton_based_modules.py @@ -10,7 +10,7 @@ from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose -from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze +from bitsandbytes.triton.int8_matmul_mixed_dequantize import int8_matmul_mixed_dequantize class _switchback_global(torch.autograd.Function): @@ -29,7 +29,7 @@ def forward(ctx, X_3D, W, bias): # matmult, fused dequant and add bias # call "mixed" because we are mixing rowwise quantized and global quantized - return int8_matmul_mixed_dequanitze( + return int8_matmul_mixed_dequantize( X_int8, W_int8.t(), state_X, state_W, bias ).view(*X_3D.size()[:-1], -1) @@ -47,7 +47,7 @@ def backward(ctx, G_3D): # so we transpose once then call .t() in the matmul G_int8, state_G = quantize_rowwise(G) W_int8, state_W = quantize_global_transpose(W) - grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view( + grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view( *G_3D.size()[:-1], -1 ) if ctx.needs_input_grad[1]: @@ -119,7 +119,7 @@ def forward(ctx, X_3D, W, bias): # matmult, fused dequant and add bias # call "mixed" because we are mixing rowwise quantized and global quantized - return int8_matmul_mixed_dequanitze( + return int8_matmul_mixed_dequantize( X_int8, W_int8.t(), state_X, state_W, bias ).view(*X_3D_sz[:-1], -1) @@ -143,7 +143,7 @@ def backward(ctx, G_3D): G_int8, state_G = quantize_rowwise(G) del G W_int8 = W_int8.t().contiguous() - grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view( + grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view( *G_3D_sz[:-1], -1 ) @@ -215,7 +215,7 @@ def forward(self, x): X_int8, self.W_int8.t(), state_X, self.state_W, self.bias ).view(*x.size()[:-1], -1) else: - return int8_matmul_mixed_dequanitze( + return int8_matmul_mixed_dequantize( X_int8, self.W_int8.t(), state_X, self.state_W, self.bias ).view(*x.size()[:-1], -1) diff --git a/bitsandbytes/triton/int8_matmul_mixed_dequanitze.py b/bitsandbytes/triton/int8_matmul_mixed_dequantize.py similarity index 98% rename from bitsandbytes/triton/int8_matmul_mixed_dequanitze.py rename to bitsandbytes/triton/int8_matmul_mixed_dequantize.py index 60a56e698..b0961f558 100644 --- a/bitsandbytes/triton/int8_matmul_mixed_dequanitze.py +++ b/bitsandbytes/triton/int8_matmul_mixed_dequantize.py @@ -2,7 +2,7 @@ from bitsandbytes.triton.triton_utils import is_triton_available if not is_triton_available(): - def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias): return None + def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): return None else: import triton @@ -136,7 +136,7 @@ def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N, tl.atomic_add(C, acc, mask=mask) - def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias): + def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): device = a.device divfactor = 1. / (127. * 127.) has_bias = 0 if bias is None else 1