Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
DN6 committed Dec 17, 2024
1 parent 391b5a9 commit e67c25a
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 15 deletions.
14 changes: 5 additions & 9 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
"quantizers.quantization_config": ["BitsAndBytesConfig", "TorchAoConfig", "GGUFQuantizationConfig"],
"quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"],
"schedulers": [],
"utils": [
"OptionalDependencyNotAvailable",
Expand Down Expand Up @@ -428,8 +428,7 @@
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import \
dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403

_import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [
name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_")
Expand All @@ -442,8 +441,7 @@
if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import \
dummy_torch_and_transformers_and_sentencepiece_objects # noqa F403
from .utils import dummy_torch_and_transformers_and_sentencepiece_objects # noqa F403

_import_structure["utils.dummy_torch_and_transformers_and_sentencepiece_objects"] = [
name for name in dir(dummy_torch_and_transformers_and_sentencepiece_objects) if not name.startswith("_")
Expand All @@ -456,8 +454,7 @@
if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import \
dummy_torch_and_transformers_and_onnx_objects # noqa F403
from .utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403

_import_structure["utils.dummy_torch_and_transformers_and_onnx_objects"] = [
name for name in dir(dummy_torch_and_transformers_and_onnx_objects) if not name.startswith("_")
Expand Down Expand Up @@ -492,8 +489,7 @@
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import \
dummy_transformers_and_torch_and_note_seq_objects # noqa F403
from .utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403

_import_structure["utils.dummy_transformers_and_torch_and_note_seq_objects"] = [
name for name in dir(dummy_transformers_and_torch_and_note_seq_objects) if not name.startswith("_")
Expand Down
20 changes: 15 additions & 5 deletions src/diffusers/quantizers/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,19 +395,29 @@ def to_diff_dict(self) -> Dict[str, Any]:
return serializable_config_dict


@dataclass
class GGUFQuantizationConfig(QuantizationConfigMixin):
def __init__(self, compute_dtype=None, quant_storage=None, modules_to_not_convert=None):
"""This is a config class for GGUF Quantization techniques.
Args:
compute_dtype: (`torch.dtype`, defaults to `torch.float32`):
This sets the computational type which might be different than the input type. For example, inputs might be
fp32, but computation can be set to bf16 for speedups.
"""

def __init__(self, compute_dtype: torch.dtype = None):
self.quant_method = QuantizationMethod.GGUF
self.compute_dtype = compute_dtype
self.quant_storage = quant_storage
self.pre_quantized = True
self.modules_to_not_convert = modules_to_not_convert

# TODO: (Dhruv) Add this as an init argument when we can support loading unquantized checkpoints.
self.modules_to_not_convert = []

if self.compute_dtype is None:
self.compute_dtype = torch.float32

if self.quant_storage is None:
self.quant_storage = torch.uint8

@dataclass
class TorchAoConfig(QuantizationConfigMixin):
"""This is a config class for torchao quantization/sparsity techniques.
Expand Down
5 changes: 4 additions & 1 deletion src/diffusers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,8 @@ def is_imageio_available():

def is_gguf_available():
return _is_gguf_available


def is_torchao_available():
return _is_torchao_available

Expand Down Expand Up @@ -622,7 +624,8 @@ def is_torchao_available():
"""

TORCHAO_IMPORT_ERROR = """
{0} requires the torchao library but it was not found in your environment. You can install it with pip: `pip install torchao`
{0} requires the torchao library but it was not found in your environment. You can install it with pip: `pip install
torchao`
"""

BACKENDS_MAPPING = OrderedDict(
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ def decorator(test_case):
correct_gguf_version, f"Test requires gguf with the version greater than {gguf_version}."
)(test_case)


def require_torchao_version_greater(torchao_version):
def decorator(test_case):
correct_torchao_version = is_torchao_available() and version.parse(
Expand Down

0 comments on commit e67c25a

Please sign in to comment.