Skip to content

Commit

Permalink
[Single File] Add GGUF support (#9964)
Browse files Browse the repository at this point in the history
* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update src/diffusers/quantizers/gguf/utils.py

Co-authored-by: Sayak Paul <[email protected]>

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* Update docs/source/en/quantization/gguf.md

Co-authored-by: Steven Liu <[email protected]>

* update

* update

* update

* update

---------

Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Steven Liu <[email protected]>
  • Loading branch information
3 people authored Dec 17, 2024
1 parent f9d5a93 commit e24941b
Show file tree
Hide file tree
Showing 22 changed files with 1,321 additions and 21 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/nightly_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ jobs:
config:
- backend: "bitsandbytes"
test_location: "bnb"
- backend: "gguf"
test_location: "gguf"
runs-on:
group: aws-g6e-xlarge-plus
container:
Expand Down
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@
title: Getting Started
- local: quantization/bitsandbytes
title: bitsandbytes
- local: quantization/gguf
title: gguf
- local: quantization/torchao
title: torchao
title: Quantization Methods
Expand Down
3 changes: 3 additions & 0 deletions docs/source/en/api/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ Learn how to quantize models in the [Quantization](../quantization/overview) gui

[[autodoc]] BitsAndBytesConfig

## GGUFQuantizationConfig

[[autodoc]] GGUFQuantizationConfig
## TorchAoConfig

[[autodoc]] TorchAoConfig
Expand Down
70 changes: 70 additions & 0 deletions docs/source/en/quantization/gguf.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# GGUF

The GGUF file format is typically used to store models for inference with [GGML](https://github.com/ggerganov/ggml) and supports a variety of block wise quantization options. Diffusers supports loading checkpoints prequantized and saved in the GGUF format via `from_single_file` loading with Model classes. Loading GGUF checkpoints via Pipelines is currently not supported.

The following example will load the [FLUX.1 DEV](https://huggingface.co/black-forest-labs/FLUX.1-dev) transformer model using the GGUF Q2_K quantization variant.

Before starting please install gguf in your environment

```shell
pip install -U gguf
```

Since GGUF is a single file format, use [`~FromSingleFileMixin.from_single_file`] to load the model and pass in the [`GGUFQuantizationConfig`].

When using GGUF checkpoints, the quantized weights remain in a low memory `dtype`(typically `torch.unint8`) and are dynamically dequantized and cast to the configured `compute_dtype` during each module's forward pass through the model. The `GGUFQuantizationConfig` allows you to set the `compute_dtype`.

The functions used for dynamic dequantizatation are based on the great work done by [city96](https://github.com/city96/ComfyUI-GGUF), who created the Pytorch ports of the original (`numpy`)[https://github.com/ggerganov/llama.cpp/blob/master/gguf-py/gguf/quants.py] implementation by [compilade](https://github.com/compilade).

```python
import torch

from diffusers import FluxPipeline, FluxTransformer2DModel, GGUFQuantizationConfig

ckpt_path = (
"https://huggingface.co/city96/FLUX.1-dev-gguf/blob/main/flux1-dev-Q2_K.gguf"
)
transformer = FluxTransformer2DModel.from_single_file(
ckpt_path,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
torch_dtype=torch.bfloat16,
)
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
generator=torch.manual_seed(0),
torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
prompt = "A cat holding a sign that says hello world"
image = pipe(prompt).images[0]
image.save("flux-gguf.png")
```

## Supported Quantization Types

- BF16
- Q4_0
- Q4_1
- Q5_0
- Q5_1
- Q8_0
- Q2_K
- Q3_K
- Q4_K
- Q5_K
- Q6_K

9 changes: 7 additions & 2 deletions docs/source/en/quantization/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Quantization techniques focus on representing data with less information while a

<Tip>

Interested in adding a new quantization method to Transformers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method.
Interested in adding a new quantization method to Diffusers? Refer to the [Contribute new quantization method guide](https://huggingface.co/docs/transformers/main/en/quantization/contribute) to learn more about adding a new quantization method.

</Tip>

Expand All @@ -32,4 +32,9 @@ If you are new to the quantization field, we recommend you to check out these be

## When to use what?

Diffusers supports [bitsandbytes](https://huggingface.co/docs/bitsandbytes/main/en/index) and [torchao](https://github.com/pytorch/ao). Refer to this [table](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) to help you determine which quantization backend to use.
Diffusers currently supports the following quantization methods.
- [BitsandBytes]()
- [TorchAO]()
- [GGUF]()

[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
4 changes: 2 additions & 2 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
"quantizers.quantization_config": ["BitsAndBytesConfig", "TorchAoConfig"],
"quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"],
"schedulers": [],
"utils": [
"OptionalDependencyNotAvailable",
Expand Down Expand Up @@ -569,7 +569,7 @@

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .configuration_utils import ConfigMixin
from .quantizers.quantization_config import BitsAndBytesConfig, TorchAoConfig
from .quantizers.quantization_config import BitsAndBytesConfig, GGUFQuantizationConfig, TorchAoConfig

try:
if not is_onnx_available():
Expand Down
46 changes: 44 additions & 2 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from contextlib import nullcontext
from typing import Optional

import torch
from huggingface_hub.utils import validate_hf_hub_args

from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging
from .single_file_utils import (
SingleFileComponentError,
Expand Down Expand Up @@ -214,6 +216,8 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
subfolder = kwargs.pop("subfolder", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
quantization_config = kwargs.pop("quantization_config", None)
device = kwargs.pop("device", None)

if isinstance(pretrained_model_link_or_path_or_dict, dict):
checkpoint = pretrained_model_link_or_path_or_dict
Expand All @@ -227,6 +231,12 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
local_files_only=local_files_only,
revision=revision,
)
if quantization_config is not None:
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
hf_quantizer.validate_environment()

else:
hf_quantizer = None

mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]

Expand Down Expand Up @@ -309,8 +319,36 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
with ctx():
model = cls.from_config(diffusers_model_config)

# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
)
if use_keep_in_fp32_modules:
keep_in_fp32_modules = cls._keep_in_fp32_modules
if not isinstance(keep_in_fp32_modules, list):
keep_in_fp32_modules = [keep_in_fp32_modules]

else:
keep_in_fp32_modules = []

if hf_quantizer is not None:
hf_quantizer.preprocess_model(
model=model,
device_map=None,
state_dict=diffusers_format_checkpoint,
keep_in_fp32_modules=keep_in_fp32_modules,
)

if is_accelerate_available():
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
param_device = torch.device(device) if device else torch.device("cpu")
unexpected_keys = load_model_dict_into_meta(
model,
diffusers_format_checkpoint,
dtype=torch_dtype,
device=param_device,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
)

else:
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
Expand All @@ -324,7 +362,11 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
)

if torch_dtype is not None:
if hf_quantizer is not None:
hf_quantizer.postprocess_model(model)
model.hf_quantizer = hf_quantizer

if torch_dtype is not None and hf_quantizer is None:
model.to(torch_dtype)

model.eval()
Expand Down
25 changes: 19 additions & 6 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,14 @@
"open_clip_sd3": "text_encoders.clip_g.transformer.text_model.embeddings.position_embedding.weight",
"stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight",
"stable_cascade_stage_c": "clip_txt_mapper.weight",
"sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
"sd35_large": "model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
"sd3": [
"joint_blocks.0.context_block.adaLN_modulation.1.bias",
"model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias",
],
"sd35_large": [
"joint_blocks.37.x_block.mlp.fc1.weight",
"model.diffusion_model.joint_blocks.37.x_block.mlp.fc1.weight",
],
"animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe",
"animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias",
"animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight",
Expand Down Expand Up @@ -542,13 +548,20 @@ def infer_diffusers_model_type(checkpoint):
):
model_type = "stable_cascade_stage_b"

elif CHECKPOINT_KEY_NAMES["sd3"] in checkpoint and checkpoint[CHECKPOINT_KEY_NAMES["sd3"]].shape[-1] == 9216:
if checkpoint["model.diffusion_model.pos_embed"].shape[1] == 36864:
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd3"]) and any(
checkpoint[key].shape[-1] == 9216 if key in checkpoint else False for key in CHECKPOINT_KEY_NAMES["sd3"]
):
if "model.diffusion_model.pos_embed" in checkpoint:
key = "model.diffusion_model.pos_embed"
else:
key = "pos_embed"

if checkpoint[key].shape[1] == 36864:
model_type = "sd3"
elif checkpoint["model.diffusion_model.pos_embed"].shape[1] == 147456:
elif checkpoint[key].shape[1] == 147456:
model_type = "sd35_medium"

elif CHECKPOINT_KEY_NAMES["sd35_large"] in checkpoint:
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["sd35_large"]):
model_type = "sd35_large"

elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint:
Expand Down
84 changes: 83 additions & 1 deletion src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import importlib
import inspect
import os
from array import array
from collections import OrderedDict
from pathlib import Path
from typing import List, Optional, Union
Expand All @@ -26,13 +27,16 @@
from huggingface_hub.utils import EntryNotFoundError

from ..utils import (
GGUF_FILE_EXTENSION,
SAFE_WEIGHTS_INDEX_NAME,
SAFETENSORS_FILE_EXTENSION,
WEIGHTS_INDEX_NAME,
_add_variant,
_get_model_file,
deprecate,
is_accelerate_available,
is_gguf_available,
is_torch_available,
is_torch_version,
logging,
)
Expand Down Expand Up @@ -139,6 +143,8 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
if file_extension == SAFETENSORS_FILE_EXTENSION:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
elif file_extension == GGUF_FILE_EXTENSION:
return load_gguf_checkpoint(checkpoint_file)
else:
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
return torch.load(
Expand Down Expand Up @@ -211,13 +217,14 @@ def load_model_dict_into_meta(
set_module_kwargs["dtype"] = dtype

# bnb params are flattened.
# gguf quants have a different shape based on the type of quantization applied
if empty_state_dict[param_name].shape != param.shape:
if (
is_quantized
and hf_quantizer.pre_quantized
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
):
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name].shape, param.shape)
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param)
else:
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
raise ValueError(
Expand Down Expand Up @@ -396,3 +403,78 @@ def _fetch_index_file_legacy(
index_file = None

return index_file


def _gguf_parse_value(_value, data_type):
if not isinstance(data_type, list):
data_type = [data_type]
if len(data_type) == 1:
data_type = data_type[0]
array_data_type = None
else:
if data_type[0] != 9:
raise ValueError("Received multiple types, therefore expected the first type to indicate an array.")
data_type, array_data_type = data_type

if data_type in [0, 1, 2, 3, 4, 5, 10, 11]:
_value = int(_value[0])
elif data_type in [6, 12]:
_value = float(_value[0])
elif data_type in [7]:
_value = bool(_value[0])
elif data_type in [8]:
_value = array("B", list(_value)).tobytes().decode()
elif data_type in [9]:
_value = _gguf_parse_value(_value, array_data_type)
return _value


def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
"""
Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed tokenizer and config
attributes.
Args:
gguf_checkpoint_path (`str`):
The path the to GGUF file to load
return_tensors (`bool`, defaults to `True`):
Whether to read the tensors from the file and return them. Not doing so is faster and only loads the
metadata in memory.
"""

if is_gguf_available() and is_torch_available():
import gguf
from gguf import GGUFReader

from ..quantizers.gguf.utils import SUPPORTED_GGUF_QUANT_TYPES, GGUFParameter
else:
logger.error(
"Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see "
"https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions."
)
raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.")

reader = GGUFReader(gguf_checkpoint_path)

parsed_parameters = {}
for tensor in reader.tensors:
name = tensor.name
quant_type = tensor.tensor_type

# if the tensor is a torch supported dtype do not use GGUFParameter
is_gguf_quant = quant_type not in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16]
if is_gguf_quant and quant_type not in SUPPORTED_GGUF_QUANT_TYPES:
_supported_quants_str = "\n".join([str(type) for type in SUPPORTED_GGUF_QUANT_TYPES])
raise ValueError(
(
f"{name} has a quantization type: {str(quant_type)} which is unsupported."
"\n\nCurrently the following quantization types are supported: \n\n"
f"{_supported_quants_str}"
"\n\nTo request support for this quantization type please open an issue here: https://github.com/huggingface/diffusers"
)
)

weights = torch.from_numpy(tensor.data.copy())
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights

return parsed_parameters
Loading

0 comments on commit e24941b

Please sign in to comment.