From a72c463c7f4cc73b21dce227c1b5d8954dd02a4d Mon Sep 17 00:00:00 2001 From: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Date: Tue, 5 Nov 2024 15:04:41 -0500 Subject: [PATCH] cleanup --- bitsandbytes/functional.py | 264 +++++++++++++------------------------ 1 file changed, 92 insertions(+), 172 deletions(-) diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 60c7a1931..05495fe5b 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -828,13 +828,13 @@ def __eq__(self, other): def quantize_blockwise( - A: Tensor, + A: torch.Tensor, code: Optional[torch.Tensor] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=4096, nested=False, -) -> Tuple[Tensor, QuantState]: +) -> Tuple[torch.Tensor, QuantState]: """ Quantize tensor A in blocks of size 4096 values. @@ -878,21 +878,11 @@ def quantize_blockwise( assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] code = code.to(A.device) - is_on_gpu([code, A, out, absmax]) - - fn_map = { - torch.float32: "cquantize_blockwise_fp32", - torch.bfloat16: "cquantize_blockwise_bf16", - torch.float16: "cquantize_blockwise_fp16", - } - if A.dtype not in fn_map.keys(): - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") - - fn = fn_map[A.dtype] + is_on_gpu([A, out, absmax]) with torch.cuda.device_of(A): - lib[fn]( + args = ( get_ptr(code), get_ptr(A), get_ptr(absmax), @@ -901,6 +891,15 @@ def quantize_blockwise( ct.c_int(A.numel()), ) + if A.dtype == torch.float16: + lib.cquantize_blockwise_fp16(*args) + elif A.dtype == torch.bfloat16: + lib.cquantize_blockwise_bf16(*args) + elif A.dtype == torch.float32: + lib.cquantize_blockwise_fp32(*args) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + else: # cpu code = code.cpu() @@ -932,14 +931,14 @@ def quantize_blockwise( def dequantize_blockwise( - A: Tensor, + A: torch.Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, code: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 4096, nested=False, -) -> Tensor: +) -> torch.Tensor: """ Dequantizes blockwise quantized values. @@ -986,25 +985,15 @@ def dequantize_blockwise( if A.device.type != "cpu": code = quant_state.code.to(A.device) - if quant_state.blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: + if quant_state.blocksize not in [4096, 2048, 1024, 512, 256, 128, 64]: raise ValueError( - f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]", + f"The blocksize of {quant_state.blocksize} is not supported. Supported values: [4096, 2048, 1024, 512, 256, 128, 64]", ) - is_on_gpu([A, absmax, out]) - fn_map = { - torch.float32: "cdequantize_blockwise_fp32", - torch.bfloat16: "cdequantize_blockwise_bf16", - torch.float16: "cdequantize_blockwise_fp16", - } - - if out.dtype not in fn_map.keys(): - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") - - fn = fn_map[out.dtype] + is_on_gpu([A, absmax, out]) with torch.cuda.device_of(A): - lib[fn]( + args = ( get_ptr(quant_state.code), get_ptr(A), get_ptr(absmax), @@ -1013,6 +1002,15 @@ def dequantize_blockwise( ct.c_int(A.numel()), _get_tensor_stream(A), ) + + if out.dtype == torch.float16: + lib.cdequantize_blockwise_fp16(*args) + elif out.dtype == torch.bfloat16: + lib.cdequantize_blockwise_bf16(*args) + elif out.dtype == torch.float32: + lib.cdequantize_blockwise_fp32(*args) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") else: code = quant_state.code.cpu() lib.cdequantize_blockwise_cpu_fp32( @@ -1110,7 +1108,7 @@ def get_4bit_type(typename, device=None, blocksize=64): def quantize_fp4( - A: Tensor, + A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, @@ -1121,7 +1119,7 @@ def quantize_fp4( def quantize_nf4( - A: Tensor, + A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, @@ -1132,14 +1130,14 @@ def quantize_nf4( def quantize_4bit( - A: Tensor, + A: torch.Tensor, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize=64, compress_statistics=False, quant_type="fp4", quant_storage=torch.uint8, -) -> Tuple[Tensor, QuantState]: +) -> Tuple[torch.Tensor, QuantState]: """ Quantize tensor A in blocks of 4-bit values. @@ -1184,71 +1182,34 @@ def quantize_4bit( assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] is_on_gpu([A, out, absmax]) - if A.dtype == torch.float32: - if quant_type == "fp4": - with torch.cuda.device_of(A): - lib.cquantize_blockwise_fp32_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - else: - with torch.cuda.device_of(A): - lib.cquantize_blockwise_fp32_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - elif A.dtype == torch.float16: - if quant_type == "fp4": - with torch.cuda.device_of(A): - lib.cquantize_blockwise_fp16_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - else: - with torch.cuda.device_of(A): - lib.cquantize_blockwise_fp16_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - elif A.dtype == torch.bfloat16: - if quant_type == "fp4": - with torch.cuda.device_of(A): - lib.cquantize_blockwise_bf16_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) + + with torch.cuda.device_of(A): + args = ( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int32(blocksize), + ct.c_int(n), + ) + + if A.dtype == torch.bfloat16: + if quant_type == "fp4": + lib.cquantize_blockwise_bf16_fp4(*args) + else: + lib.cquantize_blockwise_bf16_nf4(*args) + elif A.dtype == torch.float16: + if quant_type == "fp4": + lib.cquantize_blockwise_fp16_fp4(*args) + else: + lib.cquantize_blockwise_fp16_nf4(*args) + elif A.dtype == torch.float32: + if quant_type == "fp4": + lib.cquantize_blockwise_fp32_fp4(*args) + else: + lib.cquantize_blockwise_fp32_nf4(*args) else: - with torch.cuda.device_of(A): - lib.cquantize_blockwise_bf16_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int32(blocksize), - ct.c_int(n), - ) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") code = get_4bit_type(quant_type, device=A.device) @@ -1281,33 +1242,33 @@ def quantize_4bit( def dequantize_fp4( - A: Tensor, + A: torch.Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, -) -> Tensor: +) -> torch.Tensor: return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") def dequantize_nf4( - A: Tensor, + A: torch.Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, -) -> Tensor: +) -> torch.Tensor: return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") def dequantize_4bit( - A: Tensor, + A: torch.Tensor, quant_state: Optional[QuantState] = None, absmax: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, blocksize: int = 64, quant_type="fp4", -) -> Tensor: +) -> torch.Tensor: """ Dequantizes FP4 blockwise quantized values. @@ -1368,76 +1329,35 @@ def dequantize_4bit( is_on_gpu([A, absmax, out]) stream = _get_tensor_stream(A) - if out.dtype == torch.float32: - if quant_state.quant_type == "fp4": - with torch.cuda.device_of(A): - lib.cdequantize_blockwise_fp32_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - stream, - ) - else: - with torch.cuda.device_of(A): - lib.cdequantize_blockwise_fp32_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - stream, - ) - elif out.dtype == torch.float16: - if quant_state.quant_type == "fp4": - with torch.cuda.device_of(A): - lib.cdequantize_blockwise_fp16_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - stream, - ) - else: - with torch.cuda.device_of(A): - lib.cdequantize_blockwise_fp16_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - stream, - ) - elif out.dtype == torch.bfloat16: - with torch.cuda.device_of(A): + + with torch.cuda.device_of(A): + args = ( + get_ptr(None), + get_ptr(A), + get_ptr(absmax), + get_ptr(out), + ct.c_int(quant_state.blocksize), + ct.c_int(n), + stream, + ) + + if out.dtype == torch.bfloat16: if quant_state.quant_type == "fp4": - lib.cdequantize_blockwise_bf16_fp4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - stream, - ) + lib.cdequantize_blockwise_bf16_fp4(*args) else: - lib.cdequantize_blockwise_bf16_nf4( - get_ptr(None), - get_ptr(A), - get_ptr(absmax), - get_ptr(out), - ct.c_int(quant_state.blocksize), - ct.c_int(n), - stream, - ) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") + lib.cdequantize_blockwise_bf16_nf4(*args) + elif out.dtype == torch.float16: + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_fp16_fp4(*args) + else: + lib.cdequantize_blockwise_fp16_nf4(*args) + elif out.dtype == torch.float32: + if quant_state.quant_type == "fp4": + lib.cdequantize_blockwise_fp32_fp4(*args) + else: + lib.cdequantize_blockwise_fp32_nf4(*args) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") if A.shape[0] == 1: # is transposed, transpose back return out.t()