Skip to content

Commit

Permalink
small perf optimization for single-GPU systems
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Nov 5, 2024
1 parent ed922b8 commit a93b91f
Showing 1 changed file with 64 additions and 57 deletions.
121 changes: 64 additions & 57 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,16 @@ def get_instance(cls):

FIRST_CUDA_DEVICE = torch.device("cuda", index=0)

if torch.cuda.device_count() > 1:

def _cuda_device_of(a: torch.Tensor):
return torch.cuda.device_of(a)
else:
import contextlib

def _cuda_device_of(a: torch.Tensor):
return contextlib.nullcontext()


def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE):
num_bytes = dtype2bytes[dtype] * prod(shape)
Expand Down Expand Up @@ -881,7 +891,7 @@ def quantize_blockwise(

is_on_gpu([A, out, absmax])

with torch.cuda.device_of(A):
with _cuda_device_of(A):
args = (
get_ptr(code),
get_ptr(A),
Expand Down Expand Up @@ -992,7 +1002,7 @@ def dequantize_blockwise(

is_on_gpu([A, absmax, out])

with torch.cuda.device_of(A):
with _cuda_device_of(A):
args = (
get_ptr(quant_state.code),
get_ptr(A),
Expand Down Expand Up @@ -1183,7 +1193,7 @@ def quantize_4bit(

is_on_gpu([A, out, absmax])

with torch.cuda.device_of(A):
with _cuda_device_of(A):
args = (
get_ptr(None),
get_ptr(A),
Expand Down Expand Up @@ -1330,7 +1340,7 @@ def dequantize_4bit(
is_on_gpu([A, absmax, out])
stream = _get_tensor_stream(A)

with torch.cuda.device_of(A):
with _cuda_device_of(A):
args = (
get_ptr(None),
get_ptr(A),
Expand Down Expand Up @@ -1547,28 +1557,28 @@ def optimizer_update_32bit(
)

is_on_gpu([g, p, state1, state2, unorm_vec])
prev_device = pre_call(g.device)
optim_func(
get_ptr(g),
get_ptr(p),
get_ptr(state1),
get_ptr(state2),
get_ptr(unorm_vec),
ct.c_float(max_unorm),
ct.c_float(param_norm),
ct.c_float(beta1),
ct.c_float(beta2),
ct.c_float(beta3),
ct.c_float(alpha),
ct.c_float(eps),
ct.c_float(weight_decay),
ct.c_int32(step),
ct.c_float(lr),
ct.c_float(gnorm_scale),
ct.c_bool(skip_zeros),
ct.c_int32(g.numel()),
)
post_call(prev_device)

with _cuda_device_of(g):
optim_func(
get_ptr(g),
get_ptr(p),
get_ptr(state1),
get_ptr(state2),
get_ptr(unorm_vec),
ct.c_float(max_unorm),
ct.c_float(param_norm),
ct.c_float(beta1),
ct.c_float(beta2),
ct.c_float(beta3),
ct.c_float(alpha),
ct.c_float(eps),
ct.c_float(weight_decay),
ct.c_int32(step),
ct.c_float(lr),
ct.c_float(gnorm_scale),
ct.c_bool(skip_zeros),
ct.c_int32(g.numel()),
)


@deprecated(
Expand Down Expand Up @@ -1731,8 +1741,7 @@ def optimizer_update_8bit_blockwise(
skip_zeros=False,
) -> None:
optim_func = None
prev_device = pre_call(g.device)
is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2])

if g.dtype == torch.float32 and state1.dtype == torch.uint8:
optim_func = str2optimizer8bit_blockwise[optimizer_name][0]
elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
Expand All @@ -1747,33 +1756,31 @@ def optimizer_update_8bit_blockwise(
raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
)
post_call(prev_device)

is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2])

prev_device = pre_call(g.device)
optim_func(
get_ptr(p),
get_ptr(g),
get_ptr(state1),
get_ptr(state2),
ct.c_float(beta1),
ct.c_float(beta2),
ct.c_float(beta3),
ct.c_float(alpha),
ct.c_float(eps),
ct.c_int32(step),
ct.c_float(lr),
get_ptr(qmap1),
get_ptr(qmap2),
get_ptr(absmax1),
get_ptr(absmax2),
ct.c_float(weight_decay),
ct.c_float(gnorm_scale),
ct.c_bool(skip_zeros),
ct.c_int32(g.numel()),
)
post_call(prev_device)
with _cuda_device_of(g):
optim_func(
get_ptr(p),
get_ptr(g),
get_ptr(state1),
get_ptr(state2),
ct.c_float(beta1),
ct.c_float(beta2),
ct.c_float(beta3),
ct.c_float(alpha),
ct.c_float(eps),
ct.c_int32(step),
ct.c_float(lr),
get_ptr(qmap1),
get_ptr(qmap2),
get_ptr(absmax1),
get_ptr(absmax2),
ct.c_float(weight_decay),
ct.c_float(gnorm_scale),
ct.c_bool(skip_zeros),
ct.c_int32(g.numel()),
)


@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
Expand Down Expand Up @@ -1966,7 +1973,7 @@ def gemv_4bit(
ldc = ct.c_int32(ldc)
stream = _get_tensor_stream(A)

with torch.cuda.device_of(A):
with _cuda_device_of(A):
if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]:
if A.dtype == torch.float16:
lib.cgemm_4bit_inference_naive_fp16(
Expand Down Expand Up @@ -2285,7 +2292,7 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten

is_on_gpu([A, B, out])

with torch.cuda.device_of(A):
with _cuda_device_of(A):
ctx = CUBLAS_Context.get_instance().get_context(A.device)
ptrA = get_ptr(A)
ptrB = get_ptr(B)
Expand Down Expand Up @@ -2343,7 +2350,7 @@ def int8_mm_dequant(

is_on_gpu([A, row_stats, col_stats, out, bias])

with torch.cuda.device_of(A):
with _cuda_device_of(A):
lib.cdequant_mm_int32_fp16(
ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A)
)
Expand Down Expand Up @@ -2407,7 +2414,7 @@ def get_row_absmax(A: torch.Tensor, threshold=0.0):

is_on_gpu([A])

with torch.cuda.device_of(A):
with _cuda_device_of(A):
lib.cget_row_stats(
get_ptr(A),
get_ptr(row_stats),
Expand Down Expand Up @@ -2550,7 +2557,7 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
if outliers.any():
outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1)

with torch.cuda.device_of(A):
with _cuda_device_of(A):
lib.cint8_vector_quant(
get_ptr(A),
get_ptr(out_row),
Expand Down

0 comments on commit a93b91f

Please sign in to comment.