diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index e74f0bf53..d802dc64a 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -191,6 +191,16 @@ def get_instance(cls): FIRST_CUDA_DEVICE = torch.device("cuda", index=0) +if torch.cuda.device_count() > 1: + + def _cuda_device_of(a: torch.Tensor): + return torch.cuda.device_of(a) +else: + import contextlib + + def _cuda_device_of(a: torch.Tensor): + return contextlib.nullcontext() + def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE): num_bytes = dtype2bytes[dtype] * prod(shape) @@ -881,7 +891,7 @@ def quantize_blockwise( is_on_gpu([A, out, absmax]) - with torch.cuda.device_of(A): + with _cuda_device_of(A): args = ( get_ptr(code), get_ptr(A), @@ -992,7 +1002,7 @@ def dequantize_blockwise( is_on_gpu([A, absmax, out]) - with torch.cuda.device_of(A): + with _cuda_device_of(A): args = ( get_ptr(quant_state.code), get_ptr(A), @@ -1183,7 +1193,7 @@ def quantize_4bit( is_on_gpu([A, out, absmax]) - with torch.cuda.device_of(A): + with _cuda_device_of(A): args = ( get_ptr(None), get_ptr(A), @@ -1330,7 +1340,7 @@ def dequantize_4bit( is_on_gpu([A, absmax, out]) stream = _get_tensor_stream(A) - with torch.cuda.device_of(A): + with _cuda_device_of(A): args = ( get_ptr(None), get_ptr(A), @@ -1547,28 +1557,28 @@ def optimizer_update_32bit( ) is_on_gpu([g, p, state1, state2, unorm_vec]) - prev_device = pre_call(g.device) - optim_func( - get_ptr(g), - get_ptr(p), - get_ptr(state1), - get_ptr(state2), - get_ptr(unorm_vec), - ct.c_float(max_unorm), - ct.c_float(param_norm), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(beta3), - ct.c_float(alpha), - ct.c_float(eps), - ct.c_float(weight_decay), - ct.c_int32(step), - ct.c_float(lr), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) - post_call(prev_device) + + with _cuda_device_of(g): + optim_func( + get_ptr(g), + get_ptr(p), + get_ptr(state1), + get_ptr(state2), + get_ptr(unorm_vec), + ct.c_float(max_unorm), + ct.c_float(param_norm), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(beta3), + ct.c_float(alpha), + ct.c_float(eps), + ct.c_float(weight_decay), + ct.c_int32(step), + ct.c_float(lr), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) @deprecated( @@ -1731,8 +1741,7 @@ def optimizer_update_8bit_blockwise( skip_zeros=False, ) -> None: optim_func = None - prev_device = pre_call(g.device) - is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2]) + if g.dtype == torch.float32 and state1.dtype == torch.uint8: optim_func = str2optimizer8bit_blockwise[optimizer_name][0] elif g.dtype == torch.float16 and state1.dtype == torch.uint8: @@ -1747,33 +1756,31 @@ def optimizer_update_8bit_blockwise( raise ValueError( f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", ) - post_call(prev_device) is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2]) - prev_device = pre_call(g.device) - optim_func( - get_ptr(p), - get_ptr(g), - get_ptr(state1), - get_ptr(state2), - ct.c_float(beta1), - ct.c_float(beta2), - ct.c_float(beta3), - ct.c_float(alpha), - ct.c_float(eps), - ct.c_int32(step), - ct.c_float(lr), - get_ptr(qmap1), - get_ptr(qmap2), - get_ptr(absmax1), - get_ptr(absmax2), - ct.c_float(weight_decay), - ct.c_float(gnorm_scale), - ct.c_bool(skip_zeros), - ct.c_int32(g.numel()), - ) - post_call(prev_device) + with _cuda_device_of(g): + optim_func( + get_ptr(p), + get_ptr(g), + get_ptr(state1), + get_ptr(state2), + ct.c_float(beta1), + ct.c_float(beta2), + ct.c_float(beta3), + ct.c_float(alpha), + ct.c_float(eps), + ct.c_int32(step), + ct.c_float(lr), + get_ptr(qmap1), + get_ptr(qmap2), + get_ptr(absmax1), + get_ptr(absmax2), + ct.c_float(weight_decay), + ct.c_float(gnorm_scale), + ct.c_bool(skip_zeros), + ct.c_int32(g.numel()), + ) @deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning) @@ -1966,7 +1973,7 @@ def gemv_4bit( ldc = ct.c_int32(ldc) stream = _get_tensor_stream(A) - with torch.cuda.device_of(A): + with _cuda_device_of(A): if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]: if A.dtype == torch.float16: lib.cgemm_4bit_inference_naive_fp16( @@ -2285,7 +2292,7 @@ def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Ten is_on_gpu([A, B, out]) - with torch.cuda.device_of(A): + with _cuda_device_of(A): ctx = CUBLAS_Context.get_instance().get_context(A.device) ptrA = get_ptr(A) ptrB = get_ptr(B) @@ -2343,7 +2350,7 @@ def int8_mm_dequant( is_on_gpu([A, row_stats, col_stats, out, bias]) - with torch.cuda.device_of(A): + with _cuda_device_of(A): lib.cdequant_mm_int32_fp16( ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A) ) @@ -2407,7 +2414,7 @@ def get_row_absmax(A: torch.Tensor, threshold=0.0): is_on_gpu([A]) - with torch.cuda.device_of(A): + with _cuda_device_of(A): lib.cget_row_stats( get_ptr(A), get_ptr(row_stats), @@ -2550,7 +2557,7 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0): if outliers.any(): outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1) - with torch.cuda.device_of(A): + with _cuda_device_of(A): lib.cint8_vector_quant( get_ptr(A), get_ptr(out_row),