diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index a461d1749..3e90364bb 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -890,7 +890,16 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize def quantize_nf4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_storage=torch.uint8): return quantize_4bit(A, absmax, out, blocksize, compress_statistics, 'nf4', quant_storage) -def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize=64, compress_statistics=False, quant_type='fp4', quant_storage=torch.uint8) -> (Tensor, QuantState): + +def quantize_4bit( + A: Tensor, + absmax: Tensor = None, + out: Tensor = None, + blocksize=64, + compress_statistics=False, + quant_type='fp4', + quant_storage=torch.uint8, +) -> Tuple[Tensor, QuantState]: """ Quantize tensor A in blocks of 4-bit values.