Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Nov 5, 2024
1 parent 35dbb2e commit a72c463
Showing 1 changed file with 92 additions and 172 deletions.
264 changes: 92 additions & 172 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,13 +828,13 @@ def __eq__(self, other):


def quantize_blockwise(
A: Tensor,
A: torch.Tensor,
code: Optional[torch.Tensor] = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize=4096,
nested=False,
) -> Tuple[Tensor, QuantState]:
) -> Tuple[torch.Tensor, QuantState]:
"""
Quantize tensor A in blocks of size 4096 values.
Expand Down Expand Up @@ -878,21 +878,11 @@ def quantize_blockwise(
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]

code = code.to(A.device)
is_on_gpu([code, A, out, absmax])

fn_map = {
torch.float32: "cquantize_blockwise_fp32",
torch.bfloat16: "cquantize_blockwise_bf16",
torch.float16: "cquantize_blockwise_fp16",
}

if A.dtype not in fn_map.keys():
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")

fn = fn_map[A.dtype]
is_on_gpu([A, out, absmax])

with torch.cuda.device_of(A):
lib[fn](
args = (
get_ptr(code),
get_ptr(A),
get_ptr(absmax),
Expand All @@ -901,6 +891,15 @@ def quantize_blockwise(
ct.c_int(A.numel()),
)

if A.dtype == torch.float16:
lib.cquantize_blockwise_fp16(*args)
elif A.dtype == torch.bfloat16:
lib.cquantize_blockwise_bf16(*args)
elif A.dtype == torch.float32:
lib.cquantize_blockwise_fp32(*args)
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")

else:
# cpu
code = code.cpu()
Expand Down Expand Up @@ -932,14 +931,14 @@ def quantize_blockwise(


def dequantize_blockwise(
A: Tensor,
A: torch.Tensor,
quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None,
code: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: int = 4096,
nested=False,
) -> Tensor:
) -> torch.Tensor:
"""
Dequantizes blockwise quantized values.
Expand Down Expand Up @@ -986,25 +985,15 @@ def dequantize_blockwise(

if A.device.type != "cpu":
code = quant_state.code.to(A.device)
if quant_state.blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
if quant_state.blocksize not in [4096, 2048, 1024, 512, 256, 128, 64]:
raise ValueError(
f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]",
f"The blocksize of {quant_state.blocksize} is not supported. Supported values: [4096, 2048, 1024, 512, 256, 128, 64]",
)
is_on_gpu([A, absmax, out])

fn_map = {
torch.float32: "cdequantize_blockwise_fp32",
torch.bfloat16: "cdequantize_blockwise_bf16",
torch.float16: "cdequantize_blockwise_fp16",
}

if out.dtype not in fn_map.keys():
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}")

fn = fn_map[out.dtype]
is_on_gpu([A, absmax, out])

with torch.cuda.device_of(A):
lib[fn](
args = (
get_ptr(quant_state.code),
get_ptr(A),
get_ptr(absmax),
Expand All @@ -1013,6 +1002,15 @@ def dequantize_blockwise(
ct.c_int(A.numel()),
_get_tensor_stream(A),
)

if out.dtype == torch.float16:
lib.cdequantize_blockwise_fp16(*args)
elif out.dtype == torch.bfloat16:
lib.cdequantize_blockwise_bf16(*args)
elif out.dtype == torch.float32:
lib.cdequantize_blockwise_fp32(*args)
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}")
else:
code = quant_state.code.cpu()
lib.cdequantize_blockwise_cpu_fp32(
Expand Down Expand Up @@ -1110,7 +1108,7 @@ def get_4bit_type(typename, device=None, blocksize=64):


def quantize_fp4(
A: Tensor,
A: torch.Tensor,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize=64,
Expand All @@ -1121,7 +1119,7 @@ def quantize_fp4(


def quantize_nf4(
A: Tensor,
A: torch.Tensor,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize=64,
Expand All @@ -1132,14 +1130,14 @@ def quantize_nf4(


def quantize_4bit(
A: Tensor,
A: torch.Tensor,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize=64,
compress_statistics=False,
quant_type="fp4",
quant_storage=torch.uint8,
) -> Tuple[Tensor, QuantState]:
) -> Tuple[torch.Tensor, QuantState]:
"""
Quantize tensor A in blocks of 4-bit values.
Expand Down Expand Up @@ -1184,71 +1182,34 @@ def quantize_4bit(
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]

is_on_gpu([A, out, absmax])
if A.dtype == torch.float32:
if quant_type == "fp4":
with torch.cuda.device_of(A):
lib.cquantize_blockwise_fp32_fp4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
)
else:
with torch.cuda.device_of(A):
lib.cquantize_blockwise_fp32_nf4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
)
elif A.dtype == torch.float16:
if quant_type == "fp4":
with torch.cuda.device_of(A):
lib.cquantize_blockwise_fp16_fp4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
)
else:
with torch.cuda.device_of(A):
lib.cquantize_blockwise_fp16_nf4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
)
elif A.dtype == torch.bfloat16:
if quant_type == "fp4":
with torch.cuda.device_of(A):
lib.cquantize_blockwise_bf16_fp4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
)

with torch.cuda.device_of(A):
args = (
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
)

if A.dtype == torch.bfloat16:
if quant_type == "fp4":
lib.cquantize_blockwise_bf16_fp4(*args)
else:
lib.cquantize_blockwise_bf16_nf4(*args)
elif A.dtype == torch.float16:
if quant_type == "fp4":
lib.cquantize_blockwise_fp16_fp4(*args)
else:
lib.cquantize_blockwise_fp16_nf4(*args)
elif A.dtype == torch.float32:
if quant_type == "fp4":
lib.cquantize_blockwise_fp32_fp4(*args)
else:
lib.cquantize_blockwise_fp32_nf4(*args)
else:
with torch.cuda.device_of(A):
lib.cquantize_blockwise_bf16_nf4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
)
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")

code = get_4bit_type(quant_type, device=A.device)

Expand Down Expand Up @@ -1281,33 +1242,33 @@ def quantize_4bit(


def dequantize_fp4(
A: Tensor,
A: torch.Tensor,
quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: int = 64,
) -> Tensor:
) -> torch.Tensor:
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4")


def dequantize_nf4(
A: Tensor,
A: torch.Tensor,
quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: int = 64,
) -> Tensor:
) -> torch.Tensor:
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4")


def dequantize_4bit(
A: Tensor,
A: torch.Tensor,
quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None,
blocksize: int = 64,
quant_type="fp4",
) -> Tensor:
) -> torch.Tensor:
"""
Dequantizes FP4 blockwise quantized values.
Expand Down Expand Up @@ -1368,76 +1329,35 @@ def dequantize_4bit(

is_on_gpu([A, absmax, out])
stream = _get_tensor_stream(A)
if out.dtype == torch.float32:
if quant_state.quant_type == "fp4":
with torch.cuda.device_of(A):
lib.cdequantize_blockwise_fp32_fp4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
stream,
)
else:
with torch.cuda.device_of(A):
lib.cdequantize_blockwise_fp32_nf4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
stream,
)
elif out.dtype == torch.float16:
if quant_state.quant_type == "fp4":
with torch.cuda.device_of(A):
lib.cdequantize_blockwise_fp16_fp4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
stream,
)
else:
with torch.cuda.device_of(A):
lib.cdequantize_blockwise_fp16_nf4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
stream,
)
elif out.dtype == torch.bfloat16:
with torch.cuda.device_of(A):

with torch.cuda.device_of(A):
args = (
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
stream,
)

if out.dtype == torch.bfloat16:
if quant_state.quant_type == "fp4":
lib.cdequantize_blockwise_bf16_fp4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
stream,
)
lib.cdequantize_blockwise_bf16_fp4(*args)
else:
lib.cdequantize_blockwise_bf16_nf4(
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
stream,
)
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
lib.cdequantize_blockwise_bf16_nf4(*args)
elif out.dtype == torch.float16:
if quant_state.quant_type == "fp4":
lib.cdequantize_blockwise_fp16_fp4(*args)
else:
lib.cdequantize_blockwise_fp16_nf4(*args)
elif out.dtype == torch.float32:
if quant_state.quant_type == "fp4":
lib.cdequantize_blockwise_fp32_fp4(*args)
else:
lib.cdequantize_blockwise_fp32_nf4(*args)
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}")

if A.shape[0] == 1: # is transposed, transpose back
return out.t()
Expand Down

0 comments on commit a72c463

Please sign in to comment.