Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Quantization] enable multi-backend bitsandbytes #10574

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,6 @@ def __init__(self, quantization_config, **kwargs):
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules

def validate_environment(self, *args, **kwargs):
if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
raise ImportError(
"Using `bitsandbytes` 4-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`"
Expand All @@ -72,6 +70,12 @@ def validate_environment(self, *args, **kwargs):
"Using `bitsandbytes` 4-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`"
)

from ...utils import is_bitsandbytes_multi_backend_available
from .utils import validate_bnb_backend_availability

bnb_multibackend_is_enabled = is_bitsandbytes_multi_backend_available()
validate_bnb_backend_availability(raise_exception=True)

if kwargs.get("from_flax", False):
raise ValueError(
"Converting into 4-bit weights from flax weights is currently not supported, please make"
Expand All @@ -87,7 +91,9 @@ def validate_environment(self, *args, **kwargs):
device_map_without_no_convert = {
key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert
}
if "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values():
if set(device_map.values()) == {"cpu"} and bnb_multibackend_is_enabled:
pass
Comment on lines +94 to +95
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because bnb is supported on intel CPUs?

elif "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values():
raise ValueError(
"Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the "
"quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules "
Expand Down Expand Up @@ -240,10 +246,15 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
# Commenting this for discussions on the PR.
# def update_device_map(self, device_map):
# if device_map is None:
# device_map = {"": torch.cuda.current_device()}
# if torch.cuda.is_available():
# device_map = {"": torch.cuda.current_device()}
# elif is_torch_xpu_available():
# device_map = {"": f"xpu:{torch.xpu.current_device()}"}
# else:
# device_map = {"": "cpu"}
# logger.info(
# "The device_map was not initialized. "
# "Setting device_map to {'':torch.cuda.current_device()}. "
# f"Setting device_map to {device_map}. "
# "If you want to use the model for inference, please set device_map ='auto' "
# )
# return device_map
Expand Down Expand Up @@ -344,8 +355,6 @@ def __init__(self, quantization_config, **kwargs):
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules

def validate_environment(self, *args, **kwargs):
if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
raise ImportError(
"Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install 'accelerate>=0.26.0'`"
Expand All @@ -355,6 +364,12 @@ def validate_environment(self, *args, **kwargs):
"Using `bitsandbytes` 8-bit quantization requires the latest version of bitsandbytes: `pip install -U bitsandbytes`"
)

from ...utils import is_bitsandbytes_multi_backend_available
from .utils import validate_bnb_backend_availability

bnb_multibackend_is_enabled = is_bitsandbytes_multi_backend_available()
validate_bnb_backend_availability(raise_exception=True)

if kwargs.get("from_flax", False):
raise ValueError(
"Converting into 8-bit weights from flax weights is currently not supported, please make"
Expand All @@ -370,7 +385,9 @@ def validate_environment(self, *args, **kwargs):
device_map_without_no_convert = {
key: device_map[key] for key in device_map.keys() if key not in self.modules_to_not_convert
}
if "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values():
if set(device_map.values()) == {"cpu"} and bnb_multibackend_is_enabled:
pass
elif "cpu" in device_map_without_no_convert.values() or "disk" in device_map_without_no_convert.values():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The common piece of code between the two utilities could be clubbed into a small function and reused?

Previously we didn't do because it was relatively small and was better off in-line.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The common piece of code between the two utilities could be clubbed into a small function and reused?

Previously we didn't do because it was relatively small and was better off in-line.

raise ValueError(
"Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit the "
"quantized model. If you want to dispatch the model on the CPU or the disk while keeping these modules "
Expand Down Expand Up @@ -403,10 +420,15 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
# # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map
# def update_device_map(self, device_map):
# if device_map is None:
# device_map = {"": torch.cuda.current_device()}
# if torch.cuda.is_available():
# device_map = {"": torch.cuda.current_device()}
# elif is_torch_xpu_available():
# device_map = {"": f"xpu:{torch.xpu.current_device()}"}
# else:
# device_map = {"": "cpu"}
# logger.info(
# "The device_map was not initialized. "
# "Setting device_map to {'':torch.cuda.current_device()}. "
# f"Setting device_map to {device_map}. "
# "If you want to use the model for inference, please set device_map ='auto' "
# )
# return device_map
Expand Down
101 changes: 96 additions & 5 deletions src/diffusers/quantizers/bitsandbytes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,22 @@
https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/integrations/bitsandbytes.py
"""

import importlib
import inspect
from inspect import signature
from typing import Union

from ...utils import is_accelerate_available, is_bitsandbytes_available, is_torch_available, logging
from packaging import version

from ...utils import (
get_available_devices,
is_accelerate_available,
is_bitsandbytes_available,
is_bitsandbytes_multi_backend_available,
is_ipex_available,
is_torch_available,
logging,
)
from ..quantization_config import QuantizationMethod


Expand Down Expand Up @@ -154,7 +165,7 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name


# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
def dequantize_bnb_weight(weight: "torch.nn.Parameter", dtype: "torch.dtype", state=None):
"""
Helper function to dequantize 4bit or 8bit bnb weights.

Expand All @@ -172,7 +183,7 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
logger.warning_once(
f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
)
return output_tensor
return output_tensor.to(dtype)

if state.SCB is None:
state.SCB = weight.SCB
Expand All @@ -183,7 +194,7 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
if state.CxB is None:
state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()
return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t().to(dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: #10401

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will rebase after that PR has merged.



def _create_accelerate_new_hook(old_hook):
Expand All @@ -205,6 +216,7 @@ def _create_accelerate_new_hook(old_hook):

def _dequantize_and_replace(
model,
dtype,
modules_to_not_convert=None,
current_key_name=None,
quantization_config=None,
Expand Down Expand Up @@ -244,7 +256,7 @@ def _dequantize_and_replace(
else:
state = None

new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state))
new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, dtype, state))

if bias is not None:
new_module.bias = bias
Expand All @@ -263,6 +275,7 @@ def _dequantize_and_replace(
if len(list(module.children())) > 0:
_, has_been_replaced = _dequantize_and_replace(
module,
dtype,
modules_to_not_convert,
current_key_name,
quantization_config,
Expand All @@ -280,6 +293,7 @@ def dequantize_and_replace(
):
model, has_been_replaced = _dequantize_and_replace(
model,
model.dtype,
modules_to_not_convert=modules_to_not_convert,
quantization_config=quantization_config,
)
Expand All @@ -304,3 +318,80 @@ def _check_bnb_status(module) -> Union[bool, bool]:
and getattr(module, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
)
return is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb


def _validate_bnb_multi_backend_availability(raise_exception):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthewdouglas I wonder if it makes sense have these as utility functions in bitsandbytes so that they can be reused in transformers and diffusers (and any other libraries)?

import bitsandbytes as bnb

bnb_supported_devices = getattr(bnb, "supported_torch_devices", set())
available_devices = get_available_devices()

if available_devices == {"cpu"} and not is_ipex_available():
from importlib.util import find_spec

if find_spec("intel_extension_for_pytorch"):
logger.warning(
"You have Intel IPEX installed but if you're intending to use it for CPU, it might not have the right version. Be sure to double check that your PyTorch and IPEX installs are compatible."
)

available_devices.discard("cpu") # Only Intel CPU is supported by BNB at the moment

if not available_devices.intersection(bnb_supported_devices):
if raise_exception:
bnb_supported_devices_with_info = set( # noqa: C401
'"cpu" (needs an Intel CPU and intel_extension_for_pytorch installed and compatible with the PyTorch version)'
if device == "cpu"
else device
for device in bnb_supported_devices
)
err_msg = (
f"None of the available devices `available_devices = {available_devices or None}` are supported by the bitsandbytes version you have installed: `bnb_supported_devices = {bnb_supported_devices_with_info}`. "
"Please check the docs to see if the backend you intend to use is available and how to install it: https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend"
)

logger.error(err_msg)
raise RuntimeError(err_msg)

logger.warning("No supported devices found for bitsandbytes multi-backend.")
return False

logger.debug("Multi-backend validation successful.")
return True


def _validate_bnb_cuda_backend_availability(raise_exception):
if not is_torch_available():
return False

import torch

if not torch.cuda.is_available():
log_msg = (
"CUDA is required but not available for bitsandbytes. Please consider installing the multi-platform enabled version of bitsandbytes, which is currently a work in progress. "
"Please check currently supported platforms and installation instructions at https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend"
)
if raise_exception:
logger.error(log_msg)
raise RuntimeError(log_msg)

logger.warning(log_msg)
return False

logger.debug("CUDA backend validation successful.")
return True


def validate_bnb_backend_availability(raise_exception=False):
"""
Validates if the available devices are supported by bitsandbytes, optionally raising an exception if not.
"""
if not is_bitsandbytes_available():
if importlib.util.find_spec("bitsandbytes") and version.parse(
importlib.metadata.version("bitsandbytes")
) < version.parse("0.43.1"):
return _validate_bnb_cuda_backend_availability(raise_exception)
return False

if is_bitsandbytes_multi_backend_available():
return _validate_bnb_multi_backend_availability(raise_exception)
return _validate_bnb_cuda_backend_availability(raise_exception)
37 changes: 37 additions & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@


import os
from functools import lru_cache
from typing import FrozenSet

from packaging import version

Expand Down Expand Up @@ -63,6 +65,7 @@
is_accelerate_available,
is_accelerate_version,
is_bitsandbytes_available,
is_bitsandbytes_multi_backend_available,
is_bitsandbytes_version,
is_bs4_available,
is_flax_available,
Expand All @@ -73,6 +76,7 @@
is_hf_hub_version,
is_inflect_available,
is_invisible_watermark_available,
is_ipex_available,
is_k_diffusion_available,
is_k_diffusion_version,
is_librosa_available,
Expand All @@ -87,10 +91,15 @@
is_tensorboard_available,
is_timm_available,
is_torch_available,
is_torch_cuda_available,
is_torch_mlu_available,
is_torch_mps_available,
is_torch_musa_available,
is_torch_npu_available,
is_torch_version,
is_torch_xla_available,
is_torch_xla_version,
is_torch_xpu_available,
is_torchao_available,
is_torchsde_available,
is_torchvision_available,
Expand Down Expand Up @@ -139,3 +148,31 @@ def check_min_version(min_version):
error_message = f"This example requires a minimum version of {min_version},"
error_message += f" but the version found is {__version__}.\n"
raise ImportError(error_message)


@lru_cache()
def get_available_devices() -> FrozenSet[str]:
"""
Returns a frozenset of devices available for the current PyTorch installation.
"""
devices = {"cpu"} # `cpu` is always supported as a device in PyTorch

if is_torch_cuda_available():
devices.add("cuda")

if is_torch_mps_available():
devices.add("mps")

if is_torch_xpu_available():
devices.add("xpu")

if is_torch_npu_available():
devices.add("npu")

if is_torch_mlu_available():
devices.add("mlu")

if is_torch_musa_available():
devices.add("musa")

return frozenset(devices)
Loading