diff --git a/bitsandbytes/triton/quantize_columnwise_and_transpose.py b/bitsandbytes/triton/quantize_columnwise_and_transpose.py index b8eeffd0c..6ac02e719 100644 --- a/bitsandbytes/triton/quantize_columnwise_and_transpose.py +++ b/bitsandbytes/triton/quantize_columnwise_and_transpose.py @@ -54,7 +54,7 @@ def _quantize_columnwise_and_transpose( x = tl.load(x_ptr + offsets, mask=p2_arange_mask) abs_x = tl.abs(x) max_val = tl.max(tl.where(p2_arange_mask, abs_x, 0), axis=0) - output = tl.libdevice.llrint(127.0 * (x / max_val)) + output = tl.extra.cuda.libdevice.rint(127.0 * (x / max_val)) new_start = pid * M new_offsets = new_start + p2_arange diff --git a/bitsandbytes/triton/quantize_global.py b/bitsandbytes/triton/quantize_global.py index f35bdd304..bdd9e727a 100644 --- a/bitsandbytes/triton/quantize_global.py +++ b/bitsandbytes/triton/quantize_global.py @@ -35,7 +35,7 @@ def _quantize_global( mask = offsets < n_elements x = tl.load(x_ptr + offsets, mask=mask) absmax_inv = tl.load(absmax_inv_ptr) - output = tl.libdevice.llrint(127.0 * (x * absmax_inv)) + output = tl.extra.cuda.libdevice.rint(127.0 * (x * absmax_inv)) tl.store(output_ptr + offsets, output, mask=mask) def quantize_global(x: torch.Tensor): @@ -95,7 +95,7 @@ def _quantize_global_transpose( B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn) mask = (rm < M)[:, None] & (rn < N)[None, :] - output = tl.libdevice.llrint(127.0 * (a * absmax_inv)) + output = tl.extra.cuda.libdevice.rint(127.0 * (a * absmax_inv)) tl.store(B, output, mask=mask) diff --git a/bitsandbytes/triton/quantize_rowwise.py b/bitsandbytes/triton/quantize_rowwise.py index f92ace02c..7d0c500a1 100644 --- a/bitsandbytes/triton/quantize_rowwise.py +++ b/bitsandbytes/triton/quantize_rowwise.py @@ -50,7 +50,7 @@ def _quantize_rowwise( abs_x = tl.abs(x) max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) - output = tl.libdevice.llrint(127.0 * (x / max_val)) + output = tl.extra.cuda.libdevice.rint(127.0 * (x / max_val)) tl.store(output_ptr + offsets, output, mask=row_mask) tl.store(output_maxs + pid, max_val)