-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Quantization] enable multi-backend bitsandbytes #10574
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,8 +61,6 @@ def __init__(self, quantization_config, **kwargs): | |
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules | ||
|
||
def validate_environment(self, *args, **kwargs): | ||
if not torch.cuda.is_available(): | ||
raise RuntimeError("No GPU found. A GPU is needed for quantization.") | ||
if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"): | ||
raise ImportError( | ||
"Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`" | ||
|
@@ -72,6 +70,12 @@ def validate_environment(self, *args, **kwargs): | |
"Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" | ||
) | ||
|
||
from ...utils import is_bitsandbytes_multi_backend_available | ||
from .utils import validate_bnb_backend_availability | ||
|
||
bnb_multibackend_is_enabled = is_bitsandbytes_multi_backend_available() | ||
validate_bnb_backend_availability(raise_exception=True) | ||
|
||
if kwargs.get("from_flax", False): | ||
raise ValueError( | ||
"Converting into 4-bit weights from flax weights is currently not supported, please make" | ||
|
@@ -87,7 +91,9 @@ def validate_environment(self, *args, **kwargs): | |
device_map_without_no_convert = { | ||
key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert | ||
} | ||
if "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values(): | ||
if set(device_map.values()) == {"cpu"} and bnb_multibackend_is_enabled: | ||
pass | ||
elif "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values(): | ||
raise ValueError( | ||
"Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the " | ||
"quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules " | ||
|
@@ -240,10 +246,15 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": | |
# Commenting this for discussions on the PR. | ||
# def update_device_map(self, device_map): | ||
# if device_map is None: | ||
# device_map = {"": torch.cuda.current_device()} | ||
# if torch.cuda.is_available(): | ||
# device_map = {"": torch.cuda.current_device()} | ||
# elif is_torch_xpu_available(): | ||
# device_map = {"": f"xpu:{torch.xpu.current_device()}"} | ||
# else: | ||
# device_map = {"": "cpu"} | ||
# logger.info( | ||
# "The device_map was not initialized. " | ||
# "Setting device_map to {'':torch.cuda.current_device()}. " | ||
# f"Setting device_map to {device_map}. " | ||
# "If you want to use the model for inference, please set device_map ='auto' " | ||
# ) | ||
# return device_map | ||
|
@@ -344,8 +355,6 @@ def __init__(self, quantization_config, **kwargs): | |
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules | ||
|
||
def validate_environment(self, *args, **kwargs): | ||
if not torch.cuda.is_available(): | ||
raise RuntimeError("No GPU found. A GPU is needed for quantization.") | ||
if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"): | ||
raise ImportError( | ||
"Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`" | ||
|
@@ -355,6 +364,12 @@ def validate_environment(self, *args, **kwargs): | |
"Using `bitsandbytes` 8-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`" | ||
) | ||
|
||
from ...utils import is_bitsandbytes_multi_backend_available | ||
from .utils import validate_bnb_backend_availability | ||
|
||
bnb_multibackend_is_enabled = is_bitsandbytes_multi_backend_available() | ||
validate_bnb_backend_availability(raise_exception=True) | ||
|
||
if kwargs.get("from_flax", False): | ||
raise ValueError( | ||
"Converting into 8-bit weights from flax weights is currently not supported, please make" | ||
|
@@ -370,7 +385,9 @@ def validate_environment(self, *args, **kwargs): | |
device_map_without_no_convert = { | ||
key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert | ||
} | ||
if "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values(): | ||
if set(device_map.values()) == {"cpu"} and bnb_multibackend_is_enabled: | ||
pass | ||
elif "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The common piece of code between the two utilities could be clubbed into a small function and reused? Previously we didn't do because it was relatively small and was better off in-line. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The common piece of code between the two utilities could be clubbed into a small function and reused? Previously we didn't do because it was relatively small and was better off in-line. |
||
raise ValueError( | ||
"Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the " | ||
"quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules " | ||
|
@@ -403,10 +420,15 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": | |
# # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map | ||
# def update_device_map(self, device_map): | ||
# if device_map is None: | ||
# device_map = {"": torch.cuda.current_device()} | ||
# if torch.cuda.is_available(): | ||
# device_map = {"": torch.cuda.current_device()} | ||
# elif is_torch_xpu_available(): | ||
# device_map = {"": f"xpu:{torch.xpu.current_device()}"} | ||
# else: | ||
# device_map = {"": "cpu"} | ||
# logger.info( | ||
# "The device_map was not initialized. " | ||
# "Setting device_map to {'':torch.cuda.current_device()}. " | ||
# f"Setting device_map to {device_map}. " | ||
# "If you want to use the model for inference, please set device_map ='auto' " | ||
# ) | ||
# return device_map | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,11 +16,22 @@ | |
https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/integrations/bitsandbytes.py | ||
""" | ||
|
||
import importlib | ||
import inspect | ||
from inspect import signature | ||
from typing import Union | ||
|
||
from ...utils import is_accelerate_available, is_bitsandbytes_available, is_torch_available, logging | ||
from packaging import version | ||
|
||
from ...utils import ( | ||
get_available_devices, | ||
is_accelerate_available, | ||
is_bitsandbytes_available, | ||
is_bitsandbytes_multi_backend_available, | ||
is_ipex_available, | ||
is_torch_available, | ||
logging, | ||
) | ||
from ..quantization_config import QuantizationMethod | ||
|
||
|
||
|
@@ -154,7 +165,7 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name | |
|
||
|
||
# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41 | ||
def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None): | ||
def dequantize_bnb_weight(weight: "torch.nn.Parameter", dtype: "torch.dtype", state=None): | ||
""" | ||
Helper function to dequantize 4bit or 8bit bnb weights. | ||
|
||
|
@@ -172,7 +183,7 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None): | |
logger.warning_once( | ||
f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`" | ||
) | ||
return output_tensor | ||
return output_tensor.to(dtype) | ||
|
||
if state.SCB is None: | ||
state.SCB = weight.SCB | ||
|
@@ -183,7 +194,7 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None): | |
if state.CxB is None: | ||
state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB) | ||
out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB) | ||
return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t() | ||
return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t().to(dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: #10401 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will rebase after that PR has merged. |
||
|
||
|
||
def _create_accelerate_new_hook(old_hook): | ||
|
@@ -205,6 +216,7 @@ def _create_accelerate_new_hook(old_hook): | |
|
||
def _dequantize_and_replace( | ||
model, | ||
dtype, | ||
modules_to_not_convert=None, | ||
current_key_name=None, | ||
quantization_config=None, | ||
|
@@ -244,7 +256,7 @@ def _dequantize_and_replace( | |
else: | ||
state = None | ||
|
||
new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state)) | ||
new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, dtype, state)) | ||
|
||
if bias is not None: | ||
new_module.bias = bias | ||
|
@@ -263,6 +275,7 @@ def _dequantize_and_replace( | |
if len(list(module.children())) > 0: | ||
_, has_been_replaced = _dequantize_and_replace( | ||
module, | ||
dtype, | ||
modules_to_not_convert, | ||
current_key_name, | ||
quantization_config, | ||
|
@@ -280,6 +293,7 @@ def dequantize_and_replace( | |
): | ||
model, has_been_replaced = _dequantize_and_replace( | ||
model, | ||
model.dtype, | ||
modules_to_not_convert=modules_to_not_convert, | ||
quantization_config=quantization_config, | ||
) | ||
|
@@ -304,3 +318,80 @@ def _check_bnb_status(module) -> Union[bool, bool]: | |
and getattr(module, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES | ||
) | ||
return is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb | ||
|
||
|
||
def _validate_bnb_multi_backend_availability(raise_exception): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @matthewdouglas I wonder if it makes sense have these as utility functions in |
||
import bitsandbytes as bnb | ||
|
||
bnb_supported_devices = getattr(bnb, "supported_torch_devices", set()) | ||
available_devices = get_available_devices() | ||
|
||
if available_devices == {"cpu"} and not is_ipex_available(): | ||
from importlib.util import find_spec | ||
|
||
if find_spec("intel_extension_for_pytorch"): | ||
logger.warning( | ||
"You have Intel IPEX installed but if you're intending to use it for CPU, it might not have the right version. Be sure to double check that your PyTorch and IPEX installs are compatible." | ||
) | ||
|
||
available_devices.discard("cpu") # Only Intel CPU is supported by BNB at the moment | ||
|
||
if not available_devices.intersection(bnb_supported_devices): | ||
if raise_exception: | ||
bnb_supported_devices_with_info = set( # noqa: C401 | ||
'"cpu" (needs an Intel CPU and intel_extension_for_pytorch installed and compatible with the PyTorch version)' | ||
if device == "cpu" | ||
else device | ||
for device in bnb_supported_devices | ||
) | ||
err_msg = ( | ||
f"None of the available devices `available_devices = {available_devices or None}` are supported by the bitsandbytes version you have installed: `bnb_supported_devices = {bnb_supported_devices_with_info}`. " | ||
"Please check the docs to see if the backend you intend to use is available and how to install it: https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend" | ||
) | ||
|
||
logger.error(err_msg) | ||
raise RuntimeError(err_msg) | ||
|
||
logger.warning("No supported devices found for bitsandbytes multi-backend.") | ||
return False | ||
|
||
logger.debug("Multi-backend validation successful.") | ||
return True | ||
|
||
|
||
def _validate_bnb_cuda_backend_availability(raise_exception): | ||
if not is_torch_available(): | ||
return False | ||
|
||
import torch | ||
|
||
if not torch.cuda.is_available(): | ||
log_msg = ( | ||
"CUDA is required but not available for bitsandbytes. Please consider installing the multi-platform enabled version of bitsandbytes, which is currently a work in progress. " | ||
"Please check currently supported platforms and installation instructions at https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend" | ||
) | ||
if raise_exception: | ||
logger.error(log_msg) | ||
raise RuntimeError(log_msg) | ||
|
||
logger.warning(log_msg) | ||
return False | ||
|
||
logger.debug("CUDA backend validation successful.") | ||
return True | ||
|
||
|
||
def validate_bnb_backend_availability(raise_exception=False): | ||
""" | ||
Validates if the available devices are supported by bitsandbytes, optionally raising an exception if not. | ||
""" | ||
if not is_bitsandbytes_available(): | ||
if importlib.util.find_spec("bitsandbytes") and version.parse( | ||
importlib.metadata.version("bitsandbytes") | ||
) < version.parse("0.43.1"): | ||
return _validate_bnb_cuda_backend_availability(raise_exception) | ||
return False | ||
|
||
if is_bitsandbytes_multi_backend_available(): | ||
return _validate_bnb_multi_backend_availability(raise_exception) | ||
return _validate_bnb_cuda_backend_availability(raise_exception) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because bnb is supported on intel CPUs?