Skip to content

Commit

Permalink
Merge pull request #436 from akx/quanitze
Browse files Browse the repository at this point in the history
Fix typo "quanitze"
  • Loading branch information
TimDettmers authored Jan 2, 2024
2 parents 8c5c668 + 6b26402 commit 947db7c
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
6 changes: 3 additions & 3 deletions benchmarking/switchback/speed_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
from bitsandbytes.triton.int8_matmul_mixed_dequantize import int8_matmul_mixed_dequantize

# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.

Expand Down Expand Up @@ -72,8 +72,8 @@ def get_time(k, fn, info_dict):
get_time('standard_gx', lambda : g.matmul(w), info)
get_time('rowwise_fwd', lambda : int8_matmul_rowwise_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_columnwise, None), info)
get_time('rowwise_bwd', lambda : int8_matmul_rowwise_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_rowwise, None), info)
get_time('global_fwd', lambda : int8_matmul_mixed_dequanitze(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('global_bwd', lambda : int8_matmul_mixed_dequanitze(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('global_fwd', lambda : int8_matmul_mixed_dequantize(x_int8, w_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('global_bwd', lambda : int8_matmul_mixed_dequantize(g_int8, wt_int8.t(), state_x_rowwise, state_w_global, None), info)
get_time('x_quantize_rowwise', lambda : quantize_rowwise(x), info)
get_time('g_quantize_rowwise', lambda : quantize_rowwise(g), info)
get_time('w_quantize_rowwise', lambda : quantize_rowwise(w), info)
Expand Down
12 changes: 6 additions & 6 deletions bitsandbytes/nn/triton_based_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
from bitsandbytes.triton.int8_matmul_rowwise_dequantize import int8_matmul_rowwise_dequantize
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
from bitsandbytes.triton.int8_matmul_mixed_dequanitze import int8_matmul_mixed_dequanitze
from bitsandbytes.triton.int8_matmul_mixed_dequantize import int8_matmul_mixed_dequantize


class _switchback_global(torch.autograd.Function):
Expand All @@ -29,7 +29,7 @@ def forward(ctx, X_3D, W, bias):

# matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized
return int8_matmul_mixed_dequanitze(
return int8_matmul_mixed_dequantize(
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D.size()[:-1], -1)

Expand All @@ -47,7 +47,7 @@ def backward(ctx, G_3D):
# so we transpose once then call .t() in the matmul
G_int8, state_G = quantize_rowwise(G)
W_int8, state_W = quantize_global_transpose(W)
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view(
grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D.size()[:-1], -1
)
if ctx.needs_input_grad[1]:
Expand Down Expand Up @@ -119,7 +119,7 @@ def forward(ctx, X_3D, W, bias):

# matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized
return int8_matmul_mixed_dequanitze(
return int8_matmul_mixed_dequantize(
X_int8, W_int8.t(), state_X, state_W, bias
).view(*X_3D_sz[:-1], -1)

Expand All @@ -143,7 +143,7 @@ def backward(ctx, G_3D):
G_int8, state_G = quantize_rowwise(G)
del G
W_int8 = W_int8.t().contiguous()
grad_X = int8_matmul_mixed_dequanitze(G_int8, W_int8.t(), state_G, state_W, None).view(
grad_X = int8_matmul_mixed_dequantize(G_int8, W_int8.t(), state_G, state_W, None).view(
*G_3D_sz[:-1], -1
)

Expand Down Expand Up @@ -215,7 +215,7 @@ def forward(self, x):
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1)
else:
return int8_matmul_mixed_dequanitze(
return int8_matmul_mixed_dequantize(
X_int8, self.W_int8.t(), state_X, self.state_W, self.bias
).view(*x.size()[:-1], -1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from bitsandbytes.triton.triton_utils import is_triton_available

if not is_triton_available():
def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias): return None
def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias): return None
else:

import triton
Expand Down Expand Up @@ -136,7 +136,7 @@ def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N,
tl.atomic_add(C, acc, mask=mask)


def int8_matmul_mixed_dequanitze(a, b, state_x, state_w, bias):
def int8_matmul_mixed_dequantize(a, b, state_x, state_w, bias):
device = a.device
divfactor = 1. / (127. * 127.)
has_bias = 0 if bias is None else 1
Expand Down

0 comments on commit 947db7c

Please sign in to comment.