diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8fca380c8255..e2351a0c53b8 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -31,7 +31,7 @@ "loaders": ["FromOriginalModelMixin"], "models": [], "pipelines": [], - "quantizers.quantization_config": ["BitsAndBytesConfig", "TorchAoConfig", "GGUFQuantizationConfig"], + "quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"], "schedulers": [], "utils": [ "OptionalDependencyNotAvailable", @@ -428,8 +428,7 @@ if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import \ - dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 + from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 _import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [ name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_") @@ -442,8 +441,7 @@ if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import \ - dummy_torch_and_transformers_and_sentencepiece_objects # noqa F403 + from .utils import dummy_torch_and_transformers_and_sentencepiece_objects # noqa F403 _import_structure["utils.dummy_torch_and_transformers_and_sentencepiece_objects"] = [ name for name in dir(dummy_torch_and_transformers_and_sentencepiece_objects) if not name.startswith("_") @@ -456,8 +454,7 @@ if not (is_torch_available() and is_transformers_available() and is_onnx_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import \ - dummy_torch_and_transformers_and_onnx_objects # noqa F403 + from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403 _import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [ name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_") @@ -492,8 +489,7 @@ if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: - from .utils import \ - dummy_transformers_and_torch_and_note_seq_objects # noqa F403 + from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 _import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [ name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_") diff --git a/src/diffusers/quantizers/quantization_config.py b/src/diffusers/quantizers/quantization_config.py index 5d74eb7008cd..fc22e4e65a9a 100644 --- a/src/diffusers/quantizers/quantization_config.py +++ b/src/diffusers/quantizers/quantization_config.py @@ -395,19 +395,29 @@ def to_diff_dict(self) -> Dict[str, Any]: return serializable_config_dict +@dataclass class GGUFQuantizationConfig(QuantizationConfigMixin): - def __init__(self, compute_dtype=None, quant_storage=None, modules_to_not_convert=None): + """This is a config class for GGUF Quantization techniques. + + Args: + compute_dtype: (`torch.dtype`, defaults to `torch.float32`): + This sets the computational type which might be different than the input type. For example, inputs might be + fp32, but computation can be set to bf16 for speedups. + + """ + + def __init__(self, compute_dtype: torch.dtype = None): self.quant_method = QuantizationMethod.GGUF self.compute_dtype = compute_dtype - self.quant_storage = quant_storage self.pre_quantized = True - self.modules_to_not_convert = modules_to_not_convert + + # TODO: (Dhruv) Add this as an init argument when we can support loading unquantized checkpoints. + self.modules_to_not_convert = [] if self.compute_dtype is None: self.compute_dtype = torch.float32 - if self.quant_storage is None: - self.quant_storage = torch.uint8 + @dataclass class TorchAoConfig(QuantizationConfigMixin): """This is a config class for torchao quantization/sparsity techniques. diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 40983fe8cae2..3014efebc82e 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -479,6 +479,8 @@ def is_imageio_available(): def is_gguf_available(): return _is_gguf_available + + def is_torchao_available(): return _is_torchao_available @@ -622,7 +624,8 @@ def is_torchao_available(): """ TORCHAO_IMPORT_ERROR = """ -{0} requires the torchao library but it was not found in your environment. You can install it with pip: `pip install torchao` +{0} requires the torchao library but it was not found in your environment. You can install it with pip: `pip install +torchao` """ BACKENDS_MAPPING = OrderedDict( diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 4753bc4785b5..e5eac05ac4cd 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -487,6 +487,7 @@ def decorator(test_case): correct_gguf_version, f"Test requires gguf with the version greater than {gguf_version}." )(test_case) + def require_torchao_version_greater(torchao_version): def decorator(test_case): correct_torchao_version = is_torchao_available() and version.parse(