diff --git a/docs/source/en/using-diffusers/other-formats.md b/docs/source/en/using-diffusers/other-formats.md index 24ac9ced84ce..e662e3940a38 100644 --- a/docs/source/en/using-diffusers/other-formats.md +++ b/docs/source/en/using-diffusers/other-formats.md @@ -240,6 +240,46 @@ Benefits of using a single-file layout include: 1. Easy compatibility with diffusion interfaces such as [ComfyUI](https://github.com/comfyanonymous/ComfyUI) or [Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) which commonly use a single-file layout. 2. Easier to manage (download and share) a single file. +### DDUF + +> [!WARNING] +> DDUF is an experimental file format and APIs related to it can change in the future. + +DDUF (**D**DUF **D**iffusion **U**nified **F**ormat) is a file format designed to make storing, distributing, and using diffusion models much easier. Built on the ZIP file format, DDUF offers a standardized, efficient, and flexible way to package all parts of a diffusion model into a single, easy-to-manage file. It provides a balance between Diffusers multi-folder format and the widely popular single-file format. + +Learn more details about DDUF on the Hugging Face Hub [documentation](https://huggingface.co/docs/hub/dduf). + +Pass a checkpoint to the `dduf_file` parameter to load it in [`DiffusionPipeline`]. + +```py +from diffusers import DiffusionPipeline +import torch + +pipe = DiffusionPipeline.from_pretrained( + "DDUF/FLUX.1-dev-DDUF", dduf_file="FLUX.1-dev.dduf", torch_dtype=torch.bfloat16 +).to("cuda") +image = pipe( + "photo a cat holding a sign that says Diffusers", num_inference_steps=50, guidance_scale=3.5 +).images[0] +image.save("cat.png") +``` + +To save a pipeline as a `.dduf` checkpoint, use the [`~huggingface_hub.export_folder_as_dduf`] utility, which takes care of all the necessary file-level validations. + +```py +from huggingface_hub import export_folder_as_dduf +from diffusers import DiffusionPipeline +import torch + +pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16) + +save_folder = "flux-dev" +pipe.save_pretrained("flux-dev") +export_folder_as_dduf("flux-dev.dduf", folder_path=save_folder) + +> [!TIP] +> Packaging and loading quantized checkpoints in the DDUF format is supported as long as they respect the multi-folder structure. + ## Convert layout and files Diffusers provides many scripts and methods to convert storage layouts and file formats to enable broader support across the diffusion ecosystem. diff --git a/setup.py b/setup.py index d696c14ca842..0acdcbbb9c52 100644 --- a/setup.py +++ b/setup.py @@ -101,7 +101,7 @@ "filelock", "flax>=0.4.1", "hf-doc-builder>=0.3.0", - "huggingface-hub>=0.23.2", + "huggingface-hub>=0.27.0", "requests-mock==1.10.0", "importlib_metadata", "invisible-watermark>=0.2.0", diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index d21ada6fbe60..9dd4f0121a44 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -24,10 +24,10 @@ import re from collections import OrderedDict from pathlib import Path -from typing import Any, Dict, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import numpy as np -from huggingface_hub import create_repo, hf_hub_download +from huggingface_hub import DDUFEntry, create_repo, hf_hub_download from huggingface_hub.utils import ( EntryNotFoundError, RepositoryNotFoundError, @@ -347,6 +347,7 @@ def load_config( _ = kwargs.pop("mirror", None) subfolder = kwargs.pop("subfolder", None) user_agent = kwargs.pop("user_agent", {}) + dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None) user_agent = {**user_agent, "file_type": "config"} user_agent = http_user_agent(user_agent) @@ -358,8 +359,15 @@ def load_config( "`self.config_name` is not defined. Note that one should not load a config from " "`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`" ) - - if os.path.isfile(pretrained_model_name_or_path): + # Custom path for now + if dduf_entries: + if subfolder is not None: + raise ValueError( + "DDUF file only allow for 1 level of directory (e.g transformer/model1/model.safetentors is not allowed). " + "Please check the DDUF structure" + ) + config_file = cls._get_config_file_from_dduf(pretrained_model_name_or_path, dduf_entries) + elif os.path.isfile(pretrained_model_name_or_path): config_file = pretrained_model_name_or_path elif os.path.isdir(pretrained_model_name_or_path): if subfolder is not None and os.path.isfile( @@ -426,10 +434,8 @@ def load_config( f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " f"containing a {cls.config_name} file" ) - try: - # Load config dict - config_dict = cls._dict_from_json_file(config_file) + config_dict = cls._dict_from_json_file(config_file, dduf_entries=dduf_entries) commit_hash = extract_commit_hash(config_file) except (json.JSONDecodeError, UnicodeDecodeError): @@ -552,9 +558,14 @@ def extract_init_dict(cls, config_dict, **kwargs): return init_dict, unused_kwargs, hidden_config_dict @classmethod - def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): - with open(json_file, "r", encoding="utf-8") as reader: - text = reader.read() + def _dict_from_json_file( + cls, json_file: Union[str, os.PathLike], dduf_entries: Optional[Dict[str, DDUFEntry]] = None + ): + if dduf_entries: + text = dduf_entries[json_file].read_text() + else: + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() return json.loads(text) def __repr__(self): @@ -616,6 +627,20 @@ def to_json_file(self, json_file_path: Union[str, os.PathLike]): with open(json_file_path, "w", encoding="utf-8") as writer: writer.write(self.to_json_string()) + @classmethod + def _get_config_file_from_dduf(cls, pretrained_model_name_or_path: str, dduf_entries: Dict[str, DDUFEntry]): + # paths inside a DDUF file must always be "/" + config_file = ( + cls.config_name + if pretrained_model_name_or_path == "" + else "/".join([pretrained_model_name_or_path, cls.config_name]) + ) + if config_file not in dduf_entries: + raise ValueError( + f"We did not manage to find the file {config_file} in the dduf file. We only have the following files {dduf_entries.keys()}" + ) + return config_file + def register_to_config(init): r""" diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index bb5a54f73419..7999368f1417 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -9,7 +9,7 @@ "filelock": "filelock", "flax": "flax>=0.4.1", "hf-doc-builder": "hf-doc-builder>=0.3.0", - "huggingface-hub": "huggingface-hub>=0.23.2", + "huggingface-hub": "huggingface-hub>=0.27.0", "requests-mock": "requests-mock==1.10.0", "importlib_metadata": "importlib_metadata", "invisible-watermark": "invisible-watermark>=0.2.0", diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index a3d006f18994..386c07e8747c 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -20,10 +20,11 @@ from array import array from collections import OrderedDict from pathlib import Path -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union import safetensors import torch +from huggingface_hub import DDUFEntry from huggingface_hub.utils import EntryNotFoundError from ..utils import ( @@ -132,7 +133,10 @@ def _fetch_remapped_cls_from_config(config, old_class): def load_state_dict( - checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None, disable_mmap: bool = False + checkpoint_file: Union[str, os.PathLike], + variant: Optional[str] = None, + dduf_entries: Optional[Dict[str, DDUFEntry]] = None, + disable_mmap: bool = False, ): """ Reads a checkpoint file, returning properly formatted errors if they arise. @@ -144,6 +148,10 @@ def load_state_dict( try: file_extension = os.path.basename(checkpoint_file).split(".")[-1] if file_extension == SAFETENSORS_FILE_EXTENSION: + if dduf_entries: + # tensors are loaded on cpu + with dduf_entries[checkpoint_file].as_mmap() as mm: + return safetensors.torch.load(mm) if disable_mmap: return safetensors.torch.load(open(checkpoint_file, "rb").read()) else: @@ -284,6 +292,7 @@ def _fetch_index_file( revision, user_agent, commit_hash, + dduf_entries: Optional[Dict[str, DDUFEntry]] = None, ): if is_local: index_file = Path( @@ -309,8 +318,10 @@ def _fetch_index_file( subfolder=None, user_agent=user_agent, commit_hash=commit_hash, + dduf_entries=dduf_entries, ) - index_file = Path(index_file) + if not dduf_entries: + index_file = Path(index_file) except (EntryNotFoundError, EnvironmentError): index_file = None @@ -319,7 +330,9 @@ def _fetch_index_file( # Adapted from # https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64 -def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata): +def _merge_sharded_checkpoints( + sharded_ckpt_cached_folder, sharded_metadata, dduf_entries: Optional[Dict[str, DDUFEntry]] = None +): weight_map = sharded_metadata.get("weight_map", None) if weight_map is None: raise KeyError("'weight_map' key not found in the shard index file.") @@ -332,14 +345,23 @@ def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata): # Load tensors from each unique file for file_name in files_to_load: part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name) - if not os.path.exists(part_file_path): - raise FileNotFoundError(f"Part file {file_name} not found.") + if dduf_entries: + if part_file_path not in dduf_entries: + raise FileNotFoundError(f"Part file {file_name} not found.") + else: + if not os.path.exists(part_file_path): + raise FileNotFoundError(f"Part file {file_name} not found.") if is_safetensors: - with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f: - for tensor_key in f.keys(): - if tensor_key in weight_map: - merged_state_dict[tensor_key] = f.get_tensor(tensor_key) + if dduf_entries: + with dduf_entries[part_file_path].as_mmap() as mm: + tensors = safetensors.torch.load(mm) + merged_state_dict.update(tensors) + else: + with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f: + for tensor_key in f.keys(): + if tensor_key in weight_map: + merged_state_dict[tensor_key] = f.get_tensor(tensor_key) else: merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu")) @@ -360,6 +382,7 @@ def _fetch_index_file_legacy( revision, user_agent, commit_hash, + dduf_entries: Optional[Dict[str, DDUFEntry]] = None, ): if is_local: index_file = Path( @@ -400,6 +423,7 @@ def _fetch_index_file_legacy( subfolder=None, user_agent=user_agent, commit_hash=commit_hash, + dduf_entries=dduf_entries, ) index_file = Path(index_file) deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`." diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 17e9d2043150..fcd7775fb608 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -23,11 +23,11 @@ from collections import OrderedDict from functools import partial, wraps from pathlib import Path -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import safetensors import torch -from huggingface_hub import create_repo, split_torch_state_dict_into_shards +from huggingface_hub import DDUFEntry, create_repo, split_torch_state_dict_into_shards from huggingface_hub.utils import validate_hf_hub_args from torch import Tensor, nn @@ -607,6 +607,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P variant = kwargs.pop("variant", None) use_safetensors = kwargs.pop("use_safetensors", None) quantization_config = kwargs.pop("quantization_config", None) + dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None) disable_mmap = kwargs.pop("disable_mmap", False) allow_pickle = False @@ -700,6 +701,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P revision=revision, subfolder=subfolder, user_agent=user_agent, + dduf_entries=dduf_entries, **kwargs, ) # no in-place modification of the original config. @@ -776,13 +778,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P "revision": revision, "user_agent": user_agent, "commit_hash": commit_hash, + "dduf_entries": dduf_entries, } index_file = _fetch_index_file(**index_file_kwargs) # In case the index file was not found we still have to consider the legacy format. # this becomes applicable when the variant is not None. if variant is not None and (index_file is None or not os.path.exists(index_file)): index_file = _fetch_index_file_legacy(**index_file_kwargs) - if index_file is not None and index_file.is_file(): + if index_file is not None and (dduf_entries or index_file.is_file()): is_sharded = True if is_sharded and from_flax: @@ -811,6 +814,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P model = load_flax_checkpoint_in_pytorch_model(model, model_file) else: + # in the case it is sharded, we have already the index if is_sharded: sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files( pretrained_model_name_or_path, @@ -822,10 +826,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P user_agent=user_agent, revision=revision, subfolder=subfolder or "", + dduf_entries=dduf_entries, ) # TODO: https://github.com/huggingface/diffusers/issues/10013 - if hf_quantizer is not None: - model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata) + if hf_quantizer is not None or dduf_entries: + model_file = _merge_sharded_checkpoints( + sharded_ckpt_cached_folder, sharded_metadata, dduf_entries=dduf_entries + ) logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.") is_sharded = False @@ -843,6 +850,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P subfolder=subfolder, user_agent=user_agent, commit_hash=commit_hash, + dduf_entries=dduf_entries, ) except IOError as e: @@ -866,6 +874,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P subfolder=subfolder, user_agent=user_agent, commit_hash=commit_hash, + dduf_entries=dduf_entries, ) if low_cpu_mem_usage: @@ -887,7 +896,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # TODO (sayakpaul, SunMarc): remove this after model loading refactor else: param_device = torch.device(torch.cuda.current_device()) - state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap) + state_dict = load_state_dict( + model_file, variant=variant, dduf_entries=dduf_entries, disable_mmap=disable_mmap + ) model._convert_deprecated_attention_blocks(state_dict) # move the params from meta device to cpu @@ -983,7 +994,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P else: model = cls.from_config(config, **unused_kwargs) - state_dict = load_state_dict(model_file, variant=variant, disable_mmap=disable_mmap) + state_dict = load_state_dict( + model_file, variant=variant, dduf_entries=dduf_entries, disable_mmap=disable_mmap + ) model._convert_deprecated_attention_blocks(state_dict) model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 23f1279e203d..a100dfe77bdf 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -12,19 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - import importlib import os import re import warnings from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union +import requests import torch -from huggingface_hub import ModelCard, model_info -from huggingface_hub.utils import validate_hf_hub_args +from huggingface_hub import DDUFEntry, ModelCard, model_info, snapshot_download +from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args from packaging import version +from requests.exceptions import HTTPError from .. import __version__ from ..utils import ( @@ -38,14 +38,16 @@ is_accelerate_available, is_peft_available, is_transformers_available, + is_transformers_version, logging, ) from ..utils.torch_utils import is_compiled_module +from .transformers_loading_utils import _load_tokenizer_from_dduf, _load_transformers_model_from_dduf if is_transformers_available(): import transformers - from transformers import PreTrainedModel + from transformers import PreTrainedModel, PreTrainedTokenizerBase from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME @@ -627,6 +629,7 @@ def load_sub_model( low_cpu_mem_usage: bool, cached_folder: Union[str, os.PathLike], use_safetensors: bool, + dduf_entries: Optional[Dict[str, DDUFEntry]], ): """Helper method to load the module `name` from `library_name` and `class_name`""" @@ -663,7 +666,7 @@ def load_sub_model( f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}." ) - load_method = getattr(class_obj, load_method_name) + load_method = _get_load_method(class_obj, load_method_name, is_dduf=dduf_entries is not None) # add kwargs to loading method diffusers_module = importlib.import_module(__name__.split(".")[0]) @@ -721,7 +724,10 @@ def load_sub_model( loading_kwargs["low_cpu_mem_usage"] = False # check if the module is in a subdirectory - if os.path.isdir(os.path.join(cached_folder, name)): + if dduf_entries: + loading_kwargs["dduf_entries"] = dduf_entries + loaded_sub_model = load_method(name, **loading_kwargs) + elif os.path.isdir(os.path.join(cached_folder, name)): loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) else: # else load from the root directory @@ -746,6 +752,22 @@ def load_sub_model( return loaded_sub_model +def _get_load_method(class_obj: object, load_method_name: str, is_dduf: bool) -> Callable: + """ + Return the method to load the sub model. + + In practice, this method will return the `"from_pretrained"` (or `load_method_name`) method of the class object + except if loading from a DDUF checkpoint. In that case, transformers models and tokenizers have a specific loading + method that we need to use. + """ + if is_dduf: + if issubclass(class_obj, PreTrainedTokenizerBase): + return lambda *args, **kwargs: _load_tokenizer_from_dduf(class_obj, *args, **kwargs) + if issubclass(class_obj, PreTrainedModel): + return lambda *args, **kwargs: _load_transformers_model_from_dduf(class_obj, *args, **kwargs) + return getattr(class_obj, load_method_name) + + def _fetch_class_library_tuple(module): # import it here to avoid circular import diffusers_module = importlib.import_module(__name__.split(".")[0]) @@ -968,3 +990,70 @@ def _get_ignore_patterns( ) return ignore_patterns + + +def _download_dduf_file( + pretrained_model_name: str, + dduf_file: str, + pipeline_class_name: str, + cache_dir: str, + proxies: str, + local_files_only: bool, + token: str, + revision: str, +): + model_info_call_error = None + if not local_files_only: + try: + info = model_info(pretrained_model_name, token=token, revision=revision) + except (HTTPError, OfflineModeIsEnabled, requests.ConnectionError) as e: + logger.warning(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.") + local_files_only = True + model_info_call_error = e # save error to reraise it if model is not cached locally + + if ( + not local_files_only + and dduf_file is not None + and dduf_file not in (sibling.rfilename for sibling in info.siblings) + ): + raise ValueError(f"Requested {dduf_file} file is not available in {pretrained_model_name}.") + + try: + user_agent = {"pipeline_class": pipeline_class_name, "dduf": True} + cached_folder = snapshot_download( + pretrained_model_name, + cache_dir=cache_dir, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + allow_patterns=[dduf_file], + user_agent=user_agent, + ) + return cached_folder + except FileNotFoundError: + # Means we tried to load pipeline with `local_files_only=True` but the files have not been found in local cache. + # This can happen in two cases: + # 1. If the user passed `local_files_only=True` => we raise the error directly + # 2. If we forced `local_files_only=True` when `model_info` failed => we raise the initial error + if model_info_call_error is None: + # 1. user passed `local_files_only=True` + raise + else: + # 2. we forced `local_files_only=True` when `model_info` failed + raise EnvironmentError( + f"Cannot load model {pretrained_model_name}: model is not cached locally and an error occurred" + " while trying to fetch metadata from the Hub. Please check out the root cause in the stacktrace" + " above." + ) from model_info_call_error + + +def _maybe_raise_error_for_incorrect_transformers(config_dict): + has_transformers_component = False + for k in config_dict: + if isinstance(config_dict[k], list): + has_transformers_component = config_dict[k][0] == "transformers" + if has_transformers_component: + break + if has_transformers_component and not is_transformers_version(">", "4.47.1"): + raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.") diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 527724d1de1a..3cafb77e5d63 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -29,10 +29,12 @@ import requests import torch from huggingface_hub import ( + DDUFEntry, ModelCard, create_repo, hf_hub_download, model_info, + read_dduf_file, snapshot_download, ) from huggingface_hub.utils import OfflineModeIsEnabled, validate_hf_hub_args @@ -72,6 +74,7 @@ CONNECTED_PIPES_KEYS, CUSTOM_PIPELINE_FILE_NAME, LOADABLE_CLASSES, + _download_dduf_file, _fetch_class_library_tuple, _get_custom_components_and_folders, _get_custom_pipeline_class, @@ -79,6 +82,7 @@ _get_ignore_patterns, _get_pipeline_class, _identify_model_variants, + _maybe_raise_error_for_incorrect_transformers, _maybe_raise_warning_for_inpainting, _resolve_custom_pipeline_and_cls, _unwrap_model, @@ -218,6 +222,7 @@ class implements both a save and loading method. The pipeline is easily reloaded Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). + kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ @@ -531,6 +536,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P - A path to a *directory* (for example `./my_pipeline_directory/`) containing pipeline weights saved using [`~DiffusionPipeline.save_pretrained`]. + - A path to a *directory* (for example `./my_pipeline_directory/`) containing a dduf file torch_dtype (`str` or `torch.dtype`, *optional*): Override the default `torch.dtype` and load the model with another dtype. If "auto" is passed, the dtype is automatically derived from the model's weights. @@ -625,6 +631,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P variant (`str`, *optional*): Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when loading `from_flax`. + dduf_file(`str`, *optional*): + Load weights from the specified dduf file. @@ -674,6 +682,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P offload_state_dict = kwargs.pop("offload_state_dict", False) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) variant = kwargs.pop("variant", None) + dduf_file = kwargs.pop("dduf_file", None) use_safetensors = kwargs.pop("use_safetensors", None) use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) @@ -722,6 +731,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P " dispatching. Please make sure to set `low_cpu_mem_usage=True`." ) + if dduf_file: + if custom_pipeline: + raise NotImplementedError("Custom pipelines are not supported with DDUF at the moment.") + if load_connected_pipeline: + raise NotImplementedError("Connected pipelines are not supported with DDUF at the moment.") + # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained if not os.path.isdir(pretrained_model_name_or_path): @@ -744,6 +759,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P custom_pipeline=custom_pipeline, custom_revision=custom_revision, variant=variant, + dduf_file=dduf_file, load_connected_pipeline=load_connected_pipeline, **kwargs, ) @@ -765,7 +781,17 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) logger.warning(warn_msg) - config_dict = cls.load_config(cached_folder) + dduf_entries = None + if dduf_file: + dduf_file_path = os.path.join(cached_folder, dduf_file) + dduf_entries = read_dduf_file(dduf_file_path) + # The reader contains already all the files needed, no need to check it again + cached_folder = "" + + config_dict = cls.load_config(cached_folder, dduf_entries=dduf_entries) + + if dduf_file: + _maybe_raise_error_for_incorrect_transformers(config_dict) # pop out "_ignore_files" as it is only needed for download config_dict.pop("_ignore_files", None) @@ -943,6 +969,7 @@ def load_module(name, value): low_cpu_mem_usage=low_cpu_mem_usage, cached_folder=cached_folder, use_safetensors=use_safetensors, + dduf_entries=dduf_entries, ) logger.info( f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}." @@ -1256,6 +1283,8 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: variant (`str`, *optional*): Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when loading `from_flax`. + dduf_file(`str`, *optional*): + Load weights from the specified DDUF file. use_safetensors (`bool`, *optional*, defaults to `None`): If set to `None`, the safetensors weights are downloaded if they're available **and** if the safetensors library is installed. If set to `True`, the model is forcibly loaded from safetensors @@ -1296,6 +1325,23 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) trust_remote_code = kwargs.pop("trust_remote_code", False) + dduf_file: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_file", None) + + if dduf_file: + if custom_pipeline: + raise NotImplementedError("Custom pipelines are not supported with DDUF at the moment.") + if load_connected_pipeline: + raise NotImplementedError("Connected pipelines are not supported with DDUF at the moment.") + return _download_dduf_file( + pretrained_model_name=pretrained_model_name, + dduf_file=dduf_file, + pipeline_class_name=cls.__name__, + cache_dir=cache_dir, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + ) allow_pickle = False if use_safetensors is None: @@ -1375,7 +1421,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: allow_patterns += [f"{custom_pipeline}.py"] if f"{custom_pipeline}.py" in filenames else [] # also allow downloading config.json files with the model allow_patterns += [os.path.join(k, "config.json") for k in model_folder_names] - allow_patterns += [ SCHEDULER_CONFIG_NAME, CONFIG_NAME, @@ -1471,7 +1516,6 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: user_agent=user_agent, ) - # retrieve pipeline class from local file cls_name = cls.load_config(os.path.join(cached_folder, "model_index.json")).get("_class_name", None) cls_name = cls_name[4:] if isinstance(cls_name, str) and cls_name.startswith("Flax") else cls_name diff --git a/src/diffusers/pipelines/transformers_loading_utils.py b/src/diffusers/pipelines/transformers_loading_utils.py new file mode 100644 index 000000000000..f080adb23deb --- /dev/null +++ b/src/diffusers/pipelines/transformers_loading_utils.py @@ -0,0 +1,121 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib +import os +import tempfile +from typing import TYPE_CHECKING, Dict + +from huggingface_hub import DDUFEntry +from tqdm import tqdm + +from ..utils import is_safetensors_available, is_transformers_available, is_transformers_version + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizer + +if is_transformers_available(): + from transformers import PreTrainedModel, PreTrainedTokenizer + +if is_safetensors_available(): + import safetensors.torch + + +def _load_tokenizer_from_dduf( + cls: "PreTrainedTokenizer", name: str, dduf_entries: Dict[str, DDUFEntry], **kwargs +) -> "PreTrainedTokenizer": + """ + Load a tokenizer from a DDUF archive. + + In practice, `transformers` do not provide a way to load a tokenizer from a DDUF archive. This function is a + workaround by extracting the tokenizer files from the DDUF archive and loading the tokenizer from the extracted + files. There is an extra cost of extracting the files, but of limited impact as the tokenizer files are usually + small-ish. + """ + with tempfile.TemporaryDirectory() as tmp_dir: + for entry_name, entry in dduf_entries.items(): + if entry_name.startswith(name + "/"): + tmp_entry_path = os.path.join(tmp_dir, *entry_name.split("/")) + # need to create intermediary directory if they don't exist + os.makedirs(os.path.dirname(tmp_entry_path), exist_ok=True) + with open(tmp_entry_path, "wb") as f: + with entry.as_mmap() as mm: + f.write(mm) + return cls.from_pretrained(os.path.dirname(tmp_entry_path), **kwargs) + + +def _load_transformers_model_from_dduf( + cls: "PreTrainedModel", name: str, dduf_entries: Dict[str, DDUFEntry], **kwargs +) -> "PreTrainedModel": + """ + Load a transformers model from a DDUF archive. + + In practice, `transformers` do not provide a way to load a model from a DDUF archive. This function is a workaround + by instantiating a model from the config file and loading the weights from the DDUF archive directly. + """ + config_file = dduf_entries.get(f"{name}/config.json") + if config_file is None: + raise EnvironmentError( + f"Could not find a config.json file for component {name} in DDUF file (contains {dduf_entries.keys()})." + ) + generation_config = dduf_entries.get(f"{name}/generation_config.json", None) + + weight_files = [ + entry + for entry_name, entry in dduf_entries.items() + if entry_name.startswith(f"{name}/") and entry_name.endswith(".safetensors") + ] + if not weight_files: + raise EnvironmentError( + f"Could not find any weight file for component {name} in DDUF file (contains {dduf_entries.keys()})." + ) + if not is_safetensors_available(): + raise EnvironmentError( + "Safetensors is not available, cannot load model from DDUF. Please `pip install safetensors`." + ) + if is_transformers_version("<", "4.47.0"): + raise ImportError( + "You need to install `transformers>4.47.0` in order to load a transformers model from a DDUF file. " + "You can install it with: `pip install --upgrade transformers`" + ) + + with tempfile.TemporaryDirectory() as tmp_dir: + from transformers import AutoConfig, GenerationConfig + + tmp_config_file = os.path.join(tmp_dir, "config.json") + with open(tmp_config_file, "w") as f: + f.write(config_file.read_text()) + config = AutoConfig.from_pretrained(tmp_config_file) + if generation_config is not None: + tmp_generation_config_file = os.path.join(tmp_dir, "generation_config.json") + with open(tmp_generation_config_file, "w") as f: + f.write(generation_config.read_text()) + generation_config = GenerationConfig.from_pretrained(tmp_generation_config_file) + state_dict = {} + with contextlib.ExitStack() as stack: + for entry in tqdm(weight_files, desc="Loading state_dict"): # Loop over safetensors files + # Memory-map the safetensors file + mmap = stack.enter_context(entry.as_mmap()) + # Load tensors from the memory-mapped file + tensors = safetensors.torch.load(mmap) + # Update the state dictionary with tensors + state_dict.update(tensors) + return cls.from_pretrained( + pretrained_model_name_or_path=None, + config=config, + generation_config=generation_config, + state_dict=state_dict, + **kwargs, + ) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index f8de48ecfc78..5a171d078ce3 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -70,6 +70,7 @@ is_gguf_available, is_gguf_version, is_google_colab, + is_hf_hub_version, is_inflect_available, is_invisible_watermark_available, is_k_diffusion_available, diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index a6dfe18433e3..839e696c0ce9 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -26,6 +26,7 @@ from uuid import uuid4 from huggingface_hub import ( + DDUFEntry, ModelCard, ModelCardData, create_repo, @@ -291,9 +292,26 @@ def _get_model_file( user_agent: Optional[Union[Dict, str]] = None, revision: Optional[str] = None, commit_hash: Optional[str] = None, + dduf_entries: Optional[Dict[str, DDUFEntry]] = None, ): pretrained_model_name_or_path = str(pretrained_model_name_or_path) - if os.path.isfile(pretrained_model_name_or_path): + + if dduf_entries: + if subfolder is not None: + raise ValueError( + "DDUF file only allow for 1 level of directory (e.g transformer/model1/model.safetentors is not allowed). " + "Please check the DDUF structure" + ) + model_file = ( + weights_name + if pretrained_model_name_or_path == "" + else "/".join([pretrained_model_name_or_path, weights_name]) + ) + if model_file in dduf_entries: + return model_file + else: + raise EnvironmentError(f"Error no file named {weights_name} found in archive {dduf_entries.keys()}.") + elif os.path.isfile(pretrained_model_name_or_path): return pretrained_model_name_or_path elif os.path.isdir(pretrained_model_name_or_path): if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): @@ -419,6 +437,7 @@ def _get_checkpoint_shard_files( user_agent=None, revision=None, subfolder="", + dduf_entries: Optional[Dict[str, DDUFEntry]] = None, ): """ For a given model: @@ -430,11 +449,18 @@ def _get_checkpoint_shard_files( For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub). """ - if not os.path.isfile(index_filename): - raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.") + if dduf_entries: + if index_filename not in dduf_entries: + raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.") + else: + if not os.path.isfile(index_filename): + raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.") - with open(index_filename, "r") as f: - index = json.loads(f.read()) + if dduf_entries: + index = json.loads(dduf_entries[index_filename].read_text()) + else: + with open(index_filename, "r") as f: + index = json.loads(f.read()) original_shard_filenames = sorted(set(index["weight_map"].values())) sharded_metadata = index["metadata"] @@ -448,6 +474,8 @@ def _get_checkpoint_shard_files( pretrained_model_name_or_path, subfolder=subfolder, original_shard_filenames=original_shard_filenames ) return shards_path, sharded_metadata + elif dduf_entries: + return shards_path, sharded_metadata # At this stage pretrained_model_name_or_path is a model identifier on the Hub allow_patterns = original_shard_filenames diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 3014efebc82e..c7d002651f3a 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -115,6 +115,13 @@ except importlib_metadata.PackageNotFoundError: _transformers_available = False +_hf_hub_available = importlib.util.find_spec("huggingface_hub") is not None +try: + _hf_hub_version = importlib_metadata.version("huggingface_hub") + logger.debug(f"Successfully imported huggingface_hub version {_hf_hub_version}") +except importlib_metadata.PackageNotFoundError: + _hf_hub_available = False + _inflect_available = importlib.util.find_spec("inflect") is not None try: @@ -767,6 +774,21 @@ def is_transformers_version(operation: str, version: str): return compare_versions(parse(_transformers_version), operation, version) +def is_hf_hub_version(operation: str, version: str): + """ + Compares the current Hugging Face Hub version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A version string + """ + if not _hf_hub_available: + return False + return compare_versions(parse(_hf_hub_version), operation, version) + + def is_accelerate_version(operation: str, version: str): """ Compares the current Accelerate version to a given reference with an operation. diff --git a/src/diffusers/utils/testing_utils.py b/src/diffusers/utils/testing_utils.py index 3ae74cddcbbf..62156786c6c8 100644 --- a/src/diffusers/utils/testing_utils.py +++ b/src/diffusers/utils/testing_utils.py @@ -478,6 +478,18 @@ def decorator(test_case): return decorator +def require_hf_hub_version_greater(hf_hub_version): + def decorator(test_case): + correct_hf_hub_version = version.parse( + version.parse(importlib.metadata.version("huggingface_hub")).base_version + ) > version.parse(hf_hub_version) + return unittest.skipUnless( + correct_hf_hub_version, f"Test requires huggingface_hub with the version greater than {hf_hub_version}." + )(test_case) + + return decorator + + def require_gguf_version_greater_or_equal(gguf_version): def decorator(test_case): correct_gguf_version = is_gguf_available() and version.parse( diff --git a/tests/pipelines/allegro/test_allegro.py b/tests/pipelines/allegro/test_allegro.py index d09fc0488378..6ca96b19b8ab 100644 --- a/tests/pipelines/allegro/test_allegro.py +++ b/tests/pipelines/allegro/test_allegro.py @@ -14,6 +14,8 @@ import gc import inspect +import os +import tempfile import unittest import numpy as np @@ -24,7 +26,9 @@ from diffusers.utils.testing_utils import ( enable_full_determinism, numpy_cosine_similarity_distance, + require_hf_hub_version_greater, require_torch_gpu, + require_transformers_version_greater, slow, torch_device, ) @@ -297,6 +301,35 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2): "VAE tiling should not affect the inference results", ) + @require_hf_hub_version_greater("0.26.5") + @require_transformers_version_greater("4.47.1") + def test_save_load_dduf(self): + # reimplement because it needs `enable_tiling()` on the loaded pipe. + from huggingface_hub import export_folder_as_dduf + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device="cpu") + inputs.pop("generator") + inputs["generator"] = torch.manual_seed(0) + + pipeline_out = pipe(**inputs)[0].cpu() + + with tempfile.TemporaryDirectory() as tmpdir: + dduf_filename = os.path.join(tmpdir, f"{pipe.__class__.__name__.lower()}.dduf") + pipe.save_pretrained(tmpdir, safe_serialization=True) + export_folder_as_dduf(dduf_filename, folder_path=tmpdir) + loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, dduf_file=dduf_filename).to(torch_device) + + loaded_pipe.vae.enable_tiling() + inputs["generator"] = torch.manual_seed(0) + loaded_pipeline_out = loaded_pipe(**inputs)[0].cpu() + + assert np.allclose(pipeline_out, loaded_pipeline_out) + @slow @require_torch_gpu diff --git a/tests/pipelines/audioldm/test_audioldm.py b/tests/pipelines/audioldm/test_audioldm.py index eddab54a3c03..aaf44985aafd 100644 --- a/tests/pipelines/audioldm/test_audioldm.py +++ b/tests/pipelines/audioldm/test_audioldm.py @@ -63,6 +63,8 @@ class AudioLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] ) + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/audioldm2/test_audioldm2.py b/tests/pipelines/audioldm2/test_audioldm2.py index bf3ce2542d4e..95aaa370ef8b 100644 --- a/tests/pipelines/audioldm2/test_audioldm2.py +++ b/tests/pipelines/audioldm2/test_audioldm2.py @@ -70,6 +70,8 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] ) + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = AudioLDM2UNet2DConditionModel( diff --git a/tests/pipelines/blipdiffusion/test_blipdiffusion.py b/tests/pipelines/blipdiffusion/test_blipdiffusion.py index 7e85cef65129..6d422745ce5a 100644 --- a/tests/pipelines/blipdiffusion/test_blipdiffusion.py +++ b/tests/pipelines/blipdiffusion/test_blipdiffusion.py @@ -60,6 +60,8 @@ class BlipDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase): "prompt_reps", ] + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) text_encoder_config = CLIPTextConfig( diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py index b12655d989d4..fc8ea5284ccc 100644 --- a/tests/pipelines/controlnet/test_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -291,6 +291,8 @@ class StableDiffusionMultiControlNetPipelineFastTests( batch_params = TEXT_TO_IMAGE_BATCH_PARAMS image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( @@ -523,6 +525,8 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests( batch_params = TEXT_TO_IMAGE_BATCH_PARAMS image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py b/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py index 99a238caf53a..b4d3e3aaa8ed 100644 --- a/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py +++ b/tests/pipelines/controlnet/test_controlnet_blip_diffusion.py @@ -68,6 +68,8 @@ class BlipDiffusionControlNetPipelineFastTests(PipelineTesterMixin, unittest.Tes "prompt_reps", ] + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) text_encoder_config = CLIPTextConfig( diff --git a/tests/pipelines/controlnet/test_controlnet_img2img.py b/tests/pipelines/controlnet/test_controlnet_img2img.py index 7c4ae716b37d..516fcc513b99 100644 --- a/tests/pipelines/controlnet/test_controlnet_img2img.py +++ b/tests/pipelines/controlnet/test_controlnet_img2img.py @@ -198,6 +198,8 @@ class StableDiffusionMultiControlNetPipelineFastTests( batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py index e49106334c2e..0e4dba4265e2 100644 --- a/tests/pipelines/controlnet/test_controlnet_inpaint.py +++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py @@ -257,6 +257,8 @@ class MultiControlNetInpaintPipelineFastTests( params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py b/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py index d2c63137c99e..6e752804e2e0 100644 --- a/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_inpaint_sdxl.py @@ -78,6 +78,8 @@ class ControlNetPipelineSDXLFastTests( } ) + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index ea7fff5537a5..fc15973faeaf 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -487,6 +487,8 @@ class StableDiffusionXLMultiControlNetPipelineFastTests( batch_params = TEXT_TO_IMAGE_BATCH_PARAMS image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( @@ -692,6 +694,8 @@ class StableDiffusionXLMultiControlNetOneModelPipelineFastTests( batch_params = TEXT_TO_IMAGE_BATCH_PARAMS image_params = frozenset([]) # TO_DO: add image_params once refactored VaeImageProcessor.preprocess + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/deepfloyd_if/test_if.py b/tests/pipelines/deepfloyd_if/test_if.py index 13a05855f145..2231821fbc4a 100644 --- a/tests/pipelines/deepfloyd_if/test_if.py +++ b/tests/pipelines/deepfloyd_if/test_if.py @@ -26,7 +26,9 @@ from diffusers.utils.testing_utils import ( load_numpy, require_accelerator, + require_hf_hub_version_greater, require_torch_gpu, + require_transformers_version_greater, skip_mps, slow, torch_device, @@ -89,6 +91,11 @@ def test_inference_batch_single_identical(self): def test_xformers_attention_forwardGenerator_pass(self): self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=1e-3) + @require_hf_hub_version_greater("0.26.5") + @require_transformers_version_greater("4.47.1") + def test_save_load_dduf(self): + super().test_save_load_dduf(atol=1e-2, rtol=1e-2) + @slow @require_torch_gpu diff --git a/tests/pipelines/deepfloyd_if/test_if_img2img.py b/tests/pipelines/deepfloyd_if/test_if_img2img.py index 26ac42831b8b..c6d5384e2467 100644 --- a/tests/pipelines/deepfloyd_if/test_if_img2img.py +++ b/tests/pipelines/deepfloyd_if/test_if_img2img.py @@ -26,7 +26,9 @@ floats_tensor, load_numpy, require_accelerator, + require_hf_hub_version_greater, require_torch_gpu, + require_transformers_version_greater, skip_mps, slow, torch_device, @@ -100,6 +102,11 @@ def test_inference_batch_single_identical(self): expected_max_diff=1e-2, ) + @require_hf_hub_version_greater("0.26.5") + @require_transformers_version_greater("4.47.1") + def test_save_load_dduf(self): + super().test_save_load_dduf(atol=1e-2, rtol=1e-2) + @slow @require_torch_gpu diff --git a/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py b/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py index 1d1244c96c33..7cdd8cd147f8 100644 --- a/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py +++ b/tests/pipelines/deepfloyd_if/test_if_img2img_superresolution.py @@ -26,7 +26,9 @@ floats_tensor, load_numpy, require_accelerator, + require_hf_hub_version_greater, require_torch_gpu, + require_transformers_version_greater, skip_mps, slow, torch_device, @@ -97,6 +99,11 @@ def test_inference_batch_single_identical(self): expected_max_diff=1e-2, ) + @require_hf_hub_version_greater("0.26.5") + @require_transformers_version_greater("4.47.1") + def test_save_load_dduf(self): + super().test_save_load_dduf(atol=1e-2, rtol=1e-2) + @slow @require_torch_gpu diff --git a/tests/pipelines/deepfloyd_if/test_if_inpainting.py b/tests/pipelines/deepfloyd_if/test_if_inpainting.py index 1c4f27403332..9f151190251f 100644 --- a/tests/pipelines/deepfloyd_if/test_if_inpainting.py +++ b/tests/pipelines/deepfloyd_if/test_if_inpainting.py @@ -26,7 +26,9 @@ floats_tensor, load_numpy, require_accelerator, + require_hf_hub_version_greater, require_torch_gpu, + require_transformers_version_greater, skip_mps, slow, torch_device, @@ -97,6 +99,11 @@ def test_inference_batch_single_identical(self): expected_max_diff=1e-2, ) + @require_hf_hub_version_greater("0.26.5") + @require_transformers_version_greater("4.47.1") + def test_save_load_dduf(self): + super().test_save_load_dduf(atol=1e-2, rtol=1e-2) + @slow @require_torch_gpu diff --git a/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py b/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py index fc1b04aacb9b..c2b48bfd6d77 100644 --- a/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py +++ b/tests/pipelines/deepfloyd_if/test_if_inpainting_superresolution.py @@ -26,7 +26,9 @@ floats_tensor, load_numpy, require_accelerator, + require_hf_hub_version_greater, require_torch_gpu, + require_transformers_version_greater, skip_mps, slow, torch_device, @@ -99,6 +101,11 @@ def test_inference_batch_single_identical(self): expected_max_diff=1e-2, ) + @require_hf_hub_version_greater("0.26.5") + @require_transformers_version_greater("4.47.1") + def test_save_load_dduf(self): + super().test_save_load_dduf(atol=1e-2, rtol=1e-2) + @slow @require_torch_gpu diff --git a/tests/pipelines/deepfloyd_if/test_if_superresolution.py b/tests/pipelines/deepfloyd_if/test_if_superresolution.py index bdb9f8a76d8a..57e12899e4fd 100644 --- a/tests/pipelines/deepfloyd_if/test_if_superresolution.py +++ b/tests/pipelines/deepfloyd_if/test_if_superresolution.py @@ -26,7 +26,9 @@ floats_tensor, load_numpy, require_accelerator, + require_hf_hub_version_greater, require_torch_gpu, + require_transformers_version_greater, skip_mps, slow, torch_device, @@ -92,6 +94,11 @@ def test_inference_batch_single_identical(self): expected_max_diff=1e-2, ) + @require_hf_hub_version_greater("0.26.5") + @require_transformers_version_greater("4.47.1") + def test_save_load_dduf(self): + super().test_save_load_dduf(atol=1e-2, rtol=1e-2) + @slow @require_torch_gpu diff --git a/tests/pipelines/i2vgen_xl/test_i2vgenxl.py b/tests/pipelines/i2vgen_xl/test_i2vgenxl.py index 592ebd35f4a9..f4d6165f9010 100644 --- a/tests/pipelines/i2vgen_xl/test_i2vgenxl.py +++ b/tests/pipelines/i2vgen_xl/test_i2vgenxl.py @@ -59,6 +59,8 @@ class I2VGenXLPipelineFastTests(SDFunctionTesterMixin, PipelineTesterMixin, unit # No `output_type`. required_optional_params = frozenset(["num_inference_steps", "generator", "latents", "return_dict"]) + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) scheduler = DDIMScheduler( diff --git a/tests/pipelines/kandinsky/test_kandinsky.py b/tests/pipelines/kandinsky/test_kandinsky.py index 8553ed96e9e1..1a13ec75d082 100644 --- a/tests/pipelines/kandinsky/test_kandinsky.py +++ b/tests/pipelines/kandinsky/test_kandinsky.py @@ -204,6 +204,8 @@ class KandinskyPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] test_xformers_attention = False + supports_dduf = False + def get_dummy_components(self): dummy = Dummies() return dummy.get_dummy_components() diff --git a/tests/pipelines/kandinsky/test_kandinsky_combined.py b/tests/pipelines/kandinsky/test_kandinsky_combined.py index a7f861565cc9..3c8767a708d4 100644 --- a/tests/pipelines/kandinsky/test_kandinsky_combined.py +++ b/tests/pipelines/kandinsky/test_kandinsky_combined.py @@ -52,6 +52,8 @@ class KandinskyPipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCase) ] test_xformers_attention = True + supports_dduf = False + def get_dummy_components(self): dummy = Dummies() prior_dummy = PriorDummies() @@ -160,6 +162,8 @@ class KandinskyPipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.Te ] test_xformers_attention = False + supports_dduf = False + def get_dummy_components(self): dummy = Img2ImgDummies() prior_dummy = PriorDummies() @@ -269,6 +273,8 @@ class KandinskyPipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.Te ] test_xformers_attention = False + supports_dduf = False + def get_dummy_components(self): dummy = InpaintDummies() prior_dummy = PriorDummies() diff --git a/tests/pipelines/kandinsky/test_kandinsky_img2img.py b/tests/pipelines/kandinsky/test_kandinsky_img2img.py index ea289c5ccd71..23f13ffee223 100644 --- a/tests/pipelines/kandinsky/test_kandinsky_img2img.py +++ b/tests/pipelines/kandinsky/test_kandinsky_img2img.py @@ -226,6 +226,8 @@ class KandinskyImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] test_xformers_attention = False + supports_dduf = False + def get_dummy_components(self): dummies = Dummies() return dummies.get_dummy_components() diff --git a/tests/pipelines/kandinsky/test_kandinsky_inpaint.py b/tests/pipelines/kandinsky/test_kandinsky_inpaint.py index 740046678744..ebb1a4d88739 100644 --- a/tests/pipelines/kandinsky/test_kandinsky_inpaint.py +++ b/tests/pipelines/kandinsky/test_kandinsky_inpaint.py @@ -220,6 +220,8 @@ class KandinskyInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] test_xformers_attention = False + supports_dduf = False + def get_dummy_components(self): dummies = Dummies() return dummies.get_dummy_components() diff --git a/tests/pipelines/kandinsky/test_kandinsky_prior.py b/tests/pipelines/kandinsky/test_kandinsky_prior.py index 5f42447bd9d5..abb53bfb792f 100644 --- a/tests/pipelines/kandinsky/test_kandinsky_prior.py +++ b/tests/pipelines/kandinsky/test_kandinsky_prior.py @@ -184,6 +184,8 @@ class KandinskyPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] test_xformers_attention = False + supports_dduf = False + def get_dummy_components(self): dummy = Dummies() return dummy.get_dummy_components() diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py index dbba0831397b..bbf2f08a7b08 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_combined.py @@ -57,6 +57,8 @@ class KandinskyV22PipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCa test_xformers_attention = True callback_cfg_params = ["image_embds"] + supports_dduf = False + def get_dummy_components(self): dummy = Dummies() prior_dummy = PriorDummies() @@ -181,6 +183,8 @@ class KandinskyV22PipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest test_xformers_attention = False callback_cfg_params = ["image_embds"] + supports_dduf = False + def get_dummy_components(self): dummy = Img2ImgDummies() prior_dummy = PriorDummies() @@ -302,6 +306,8 @@ class KandinskyV22PipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest ] test_xformers_attention = False + supports_dduf = False + def get_dummy_components(self): dummy = InpaintDummies() prior_dummy = PriorDummies() diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py b/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py index be0bc238d4da..bdec6c132f80 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_prior.py @@ -186,6 +186,8 @@ class KandinskyV22PriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase) callback_cfg_params = ["prompt_embeds", "text_encoder_hidden_states", "text_mask"] test_xformers_attention = False + supports_dduf = False + def get_dummy_components(self): dummies = Dummies() return dummies.get_dummy_components() diff --git a/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py b/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py index e898824e2d17..0ea32981d518 100644 --- a/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py +++ b/tests/pipelines/kandinsky2_2/test_kandinsky_prior_emb2emb.py @@ -59,6 +59,8 @@ class KandinskyV22PriorEmb2EmbPipelineFastTests(PipelineTesterMixin, unittest.Te ] test_xformers_attention = False + supports_dduf = False + @property def text_embedder_hidden_size(self): return 32 diff --git a/tests/pipelines/kolors/test_kolors.py b/tests/pipelines/kolors/test_kolors.py index de44af6d5908..e88ba0282096 100644 --- a/tests/pipelines/kolors/test_kolors.py +++ b/tests/pipelines/kolors/test_kolors.py @@ -47,6 +47,8 @@ class KolorsPipelineFastTests(PipelineTesterMixin, unittest.TestCase): image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"}) + supports_dduf = False + def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/kolors/test_kolors_img2img.py b/tests/pipelines/kolors/test_kolors_img2img.py index 2010dbd7055a..9f1ca43a081f 100644 --- a/tests/pipelines/kolors/test_kolors_img2img.py +++ b/tests/pipelines/kolors/test_kolors_img2img.py @@ -51,6 +51,8 @@ class KolorsPipelineImg2ImgFastTests(PipelineTesterMixin, unittest.TestCase): image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"}) + supports_dduf = False + # Copied from tests.pipelines.kolors.test_kolors.KolorsPipelineFastTests.get_dummy_components def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) diff --git a/tests/pipelines/lumina/test_lumina_nextdit.py b/tests/pipelines/lumina/test_lumina_nextdit.py index 5fd0dbf06050..e0fd06847b77 100644 --- a/tests/pipelines/lumina/test_lumina_nextdit.py +++ b/tests/pipelines/lumina/test_lumina_nextdit.py @@ -31,6 +31,8 @@ class LuminaText2ImgPipelinePipelineFastTests(unittest.TestCase, PipelineTesterM ) batch_params = frozenset(["prompt", "negative_prompt"]) + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) transformer = LuminaNextDiT2DModel( diff --git a/tests/pipelines/musicldm/test_musicldm.py b/tests/pipelines/musicldm/test_musicldm.py index e51f5103933a..bdd536b6ff86 100644 --- a/tests/pipelines/musicldm/test_musicldm.py +++ b/tests/pipelines/musicldm/test_musicldm.py @@ -65,6 +65,8 @@ class MusicLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] ) + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/pag/test_pag_kolors.py b/tests/pipelines/pag/test_pag_kolors.py index 8cfb2c3fd16a..cf9466988d85 100644 --- a/tests/pipelines/pag/test_pag_kolors.py +++ b/tests/pipelines/pag/test_pag_kolors.py @@ -56,6 +56,8 @@ class KolorsPAGPipelineFastTests( image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"add_text_embeds", "add_time_ids"}) + supports_dduf = False + # Copied from tests.pipelines.kolors.test_kolors.KolorsPipelineFastTests.get_dummy_components def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) diff --git a/tests/pipelines/pag/test_pag_sana.py b/tests/pipelines/pag/test_pag_sana.py index 12addabeb0a8..a2c657297860 100644 --- a/tests/pipelines/pag/test_pag_sana.py +++ b/tests/pipelines/pag/test_pag_sana.py @@ -53,6 +53,8 @@ class SanaPAGPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) test_xformers_attention = False + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) transformer = SanaTransformer2DModel( diff --git a/tests/pipelines/pag/test_pag_sdxl_img2img.py b/tests/pipelines/pag/test_pag_sdxl_img2img.py index 7e5fc5fa28b9..33bd47bfee10 100644 --- a/tests/pipelines/pag/test_pag_sdxl_img2img.py +++ b/tests/pipelines/pag/test_pag_sdxl_img2img.py @@ -82,6 +82,8 @@ class StableDiffusionXLPAGImg2ImgPipelineFastTests( {"add_text_embeds", "add_time_ids", "add_neg_time_ids"} ) + supports_dduf = False + # based on tests.pipelines.stable_diffusion_xl.test_stable_diffusion_xl_img2img_pipeline.get_dummy_components def get_dummy_components( self, skip_first_text_encoder=False, time_cond_proj_dim=None, requires_aesthetics_score=False diff --git a/tests/pipelines/pag/test_pag_sdxl_inpaint.py b/tests/pipelines/pag/test_pag_sdxl_inpaint.py index efc37abd0682..8378b07e9f74 100644 --- a/tests/pipelines/pag/test_pag_sdxl_inpaint.py +++ b/tests/pipelines/pag/test_pag_sdxl_inpaint.py @@ -82,6 +82,8 @@ class StableDiffusionXLPAGInpaintPipelineFastTests( {"add_text_embeds", "add_time_ids", "mask", "masked_image_latents"} ) + supports_dduf = False + # based on tests.pipelines.stable_diffusion_xl.test_stable_diffusion_xl_inpaint.StableDiffusionXLInpaintPipelineFastTests.get_dummy_components def get_dummy_components( self, skip_first_text_encoder=False, time_cond_proj_dim=None, requires_aesthetics_score=False diff --git a/tests/pipelines/paint_by_example/test_paint_by_example.py b/tests/pipelines/paint_by_example/test_paint_by_example.py index c71e2d4761c2..6b668de2762a 100644 --- a/tests/pipelines/paint_by_example/test_paint_by_example.py +++ b/tests/pipelines/paint_by_example/test_paint_by_example.py @@ -46,6 +46,8 @@ class PaintByExamplePipelineFastTests(PipelineTesterMixin, unittest.TestCase): batch_params = IMAGE_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS image_params = frozenset([]) # TO_DO: update the image_prams once refactored VaeImageProcessor.preprocess + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/shap_e/test_shap_e_img2img.py b/tests/pipelines/shap_e/test_shap_e_img2img.py index f3661355e9dd..ac7096874b31 100644 --- a/tests/pipelines/shap_e/test_shap_e_img2img.py +++ b/tests/pipelines/shap_e/test_shap_e_img2img.py @@ -50,6 +50,8 @@ class ShapEImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ] test_xformers_attention = False + supports_dduf = False + @property def text_embedder_hidden_size(self): return 16 diff --git a/tests/pipelines/stable_audio/test_stable_audio.py b/tests/pipelines/stable_audio/test_stable_audio.py index 41ac94891c6f..b2ca3ddd0e84 100644 --- a/tests/pipelines/stable_audio/test_stable_audio.py +++ b/tests/pipelines/stable_audio/test_stable_audio.py @@ -70,6 +70,7 @@ class StableAudioPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ) # There is not xformers version of the StableAudioPipeline custom attention processor test_xformers_attention = False + supports_dduf = False def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py index 01a0a3abe4ee..430d99781a25 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_depth.py @@ -76,6 +76,8 @@ class StableDiffusionDepth2ImgPipelineFastTests( image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS.union({"depth_mask"}) + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py b/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py index 2a1e691e9e8f..15f298c67e11 100644 --- a/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py +++ b/tests/pipelines/stable_diffusion_adapter/test_stable_diffusion_adapter.py @@ -389,6 +389,8 @@ def test_stable_diffusion_adapter_default_case(self): class StableDiffusionMultiAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase): + supports_dduf = False + def get_dummy_components(self, time_cond_proj_dim=None): return super().get_dummy_components("multi_adapter", time_cond_proj_dim=time_cond_proj_dim) diff --git a/tests/pipelines/stable_diffusion_gligen_text_image/test_stable_diffusion_gligen_text_image.py b/tests/pipelines/stable_diffusion_gligen_text_image/test_stable_diffusion_gligen_text_image.py index 748702541b1e..15e4c60db82d 100644 --- a/tests/pipelines/stable_diffusion_gligen_text_image/test_stable_diffusion_gligen_text_image.py +++ b/tests/pipelines/stable_diffusion_gligen_text_image/test_stable_diffusion_gligen_text_image.py @@ -66,6 +66,8 @@ class GligenTextImagePipelineFastTests( image_params = TEXT_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/stable_diffusion_image_variation/test_stable_diffusion_image_variation.py b/tests/pipelines/stable_diffusion_image_variation/test_stable_diffusion_image_variation.py index 7a3b0f70ccb1..d7567afdee1f 100644 --- a/tests/pipelines/stable_diffusion_image_variation/test_stable_diffusion_image_variation.py +++ b/tests/pipelines/stable_diffusion_image_variation/test_stable_diffusion_image_variation.py @@ -58,6 +58,8 @@ class StableDiffusionImageVariationPipelineFastTests( # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess image_latents_params = frozenset([]) + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py index 7c7b03786563..23291b0407aa 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py @@ -422,6 +422,8 @@ def test_adapter_sdxl_lcm_custom_timesteps(self): class StableDiffusionXLMultiAdapterPipelineFastTests( StableDiffusionXLAdapterPipelineFastTests, PipelineTesterMixin, unittest.TestCase ): + supports_dduf = False + def get_dummy_components(self, time_cond_proj_dim=None): return super().get_dummy_components("multi_adapter", time_cond_proj_dim=time_cond_proj_dim) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py index db0905a48310..ceec86a811c0 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py @@ -77,6 +77,8 @@ class StableDiffusionXLImg2ImgPipelineFastTests( {"add_text_embeds", "add_time_ids", "add_neg_time_ids"} ) + supports_dduf = False + def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim=None): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py index 964c7123dd32..c759f4c112d9 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py @@ -72,6 +72,8 @@ class StableDiffusionXLInpaintPipelineFastTests( } ) + supports_dduf = False + def get_dummy_components(self, skip_first_text_encoder=False, time_cond_proj_dim=None): torch.manual_seed(0) unet = UNet2DConditionModel( diff --git a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py index a5cbf7761501..34f2553a9184 100644 --- a/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py +++ b/tests/pipelines/stable_unclip/test_stable_unclip_img2img.py @@ -51,6 +51,8 @@ class StableUnCLIPImg2ImgPipelineFastTests( ) # TO-DO: update image_params once pipeline is refactored with VaeImageProcessor.preprocess image_latents_params = frozenset([]) + supports_dduf = False + def get_dummy_components(self): embedder_hidden_size = 32 embedder_projection_dim = embedder_hidden_size diff --git a/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py index ac9acb26afd3..352477ecec56 100644 --- a/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py +++ b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py @@ -58,6 +58,8 @@ class StableVideoDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCa ] ) + supports_dduf = False + def get_dummy_components(self): torch.manual_seed(0) unet = UNetSpatioTemporalConditionModel( diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 423c82e0602e..6665a005ba96 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -75,9 +75,11 @@ nightly, require_compel, require_flax, + require_hf_hub_version_greater, require_onnxruntime, require_torch_2, require_torch_gpu, + require_transformers_version_greater, run_test_in_subprocess, slow, torch_device, @@ -981,6 +983,18 @@ def test_download_ignore_files(self): assert not any(f in ["vae/diffusion_pytorch_model.bin", "text_encoder/config.json"] for f in files) assert len(files) == 14 + def test_download_dduf_with_custom_pipeline_raises_error(self): + with self.assertRaises(NotImplementedError): + _ = DiffusionPipeline.download( + "DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", custom_pipeline="my_pipeline" + ) + + def test_download_dduf_with_connected_pipeline_raises_error(self): + with self.assertRaises(NotImplementedError): + _ = DiffusionPipeline.download( + "DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", load_connected_pipeline=True + ) + def test_get_pipeline_class_from_flax(self): flax_config = {"_class_name": "FlaxStableDiffusionPipeline"} config = {"_class_name": "StableDiffusionPipeline"} @@ -1802,6 +1816,55 @@ def test_pipe_same_device_id_offload(self): sd.maybe_free_model_hooks() assert sd._offload_gpu_id == 5 + @parameterized.expand([torch.float32, torch.float16]) + @require_hf_hub_version_greater("0.26.5") + @require_transformers_version_greater("4.47.1") + def test_load_dduf_from_hub(self, dtype): + with tempfile.TemporaryDirectory() as tmpdir: + pipe = DiffusionPipeline.from_pretrained( + "DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", cache_dir=tmpdir, torch_dtype=dtype + ).to(torch_device) + out_1 = pipe(prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np").images + + pipe.save_pretrained(tmpdir) + loaded_pipe = DiffusionPipeline.from_pretrained(tmpdir, torch_dtype=dtype).to(torch_device) + + out_2 = loaded_pipe( + prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np" + ).images + + self.assertTrue(np.allclose(out_1, out_2, atol=1e-4, rtol=1e-4)) + + @require_hf_hub_version_greater("0.26.5") + @require_transformers_version_greater("4.47.1") + def test_load_dduf_from_hub_local_files_only(self): + with tempfile.TemporaryDirectory() as tmpdir: + pipe = DiffusionPipeline.from_pretrained( + "DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", cache_dir=tmpdir + ).to(torch_device) + out_1 = pipe(prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np").images + + local_files_pipe = DiffusionPipeline.from_pretrained( + "DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", cache_dir=tmpdir, local_files_only=True + ).to(torch_device) + out_2 = local_files_pipe( + prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np" + ).images + + self.assertTrue(np.allclose(out_1, out_2, atol=1e-4, rtol=1e-4)) + + def test_dduf_raises_error_with_custom_pipeline(self): + with self.assertRaises(NotImplementedError): + _ = DiffusionPipeline.from_pretrained( + "DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", custom_pipeline="my_pipeline" + ) + + def test_dduf_raises_error_with_connected_pipeline(self): + with self.assertRaises(NotImplementedError): + _ = DiffusionPipeline.from_pretrained( + "DDUF/tiny-flux-dev-pipe-dduf", dduf_file="fluxpipeline.dduf", load_connected_pipeline=True + ) + def test_wrong_model(self): tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") with self.assertRaises(ValueError) as error_context: @@ -1812,6 +1875,27 @@ def test_wrong_model(self): assert "is of type" in str(error_context.exception) assert "but should be" in str(error_context.exception) + @require_hf_hub_version_greater("0.26.5") + @require_transformers_version_greater("4.47.1") + def test_dduf_load_sharded_checkpoint_diffusion_model(self): + with tempfile.TemporaryDirectory() as tmpdir: + pipe = DiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-flux-dev-pipe-sharded-checkpoint-DDUF", + dduf_file="tiny-flux-dev-pipe-sharded-checkpoint.dduf", + cache_dir=tmpdir, + ).to(torch_device) + + out_1 = pipe(prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np").images + + pipe.save_pretrained(tmpdir) + loaded_pipe = DiffusionPipeline.from_pretrained(tmpdir).to(torch_device) + + out_2 = loaded_pipe( + prompt="dog", num_inference_steps=5, generator=torch.manual_seed(0), output_type="np" + ).images + + self.assertTrue(np.allclose(out_1, out_2, atol=1e-4, rtol=1e-4)) + @slow @require_torch_gpu diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index f5494fbade2e..83b628e09f88 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -43,7 +43,9 @@ CaptureLogger, require_accelerate_version_greater, require_accelerator, + require_hf_hub_version_greater, require_torch, + require_transformers_version_greater, skip_mps, torch_device, ) @@ -986,6 +988,8 @@ class PipelineTesterMixin: test_xformers_attention = True + supports_dduf = True + def get_generator(self, seed): device = torch_device if torch_device != "mps" else "cpu" generator = torch.Generator(device).manual_seed(seed) @@ -1990,6 +1994,39 @@ def test_StableDiffusionMixin_component(self): ) ) + @require_hf_hub_version_greater("0.26.5") + @require_transformers_version_greater("4.47.1") + def test_save_load_dduf(self, atol=1e-4, rtol=1e-4): + if not self.supports_dduf: + return + + from huggingface_hub import export_folder_as_dduf + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe = pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device="cpu") + inputs.pop("generator") + inputs["generator"] = torch.manual_seed(0) + + pipeline_out = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + dduf_filename = os.path.join(tmpdir, f"{pipe.__class__.__name__.lower()}.dduf") + pipe.save_pretrained(tmpdir, safe_serialization=True) + export_folder_as_dduf(dduf_filename, folder_path=tmpdir) + loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, dduf_file=dduf_filename).to(torch_device) + + inputs["generator"] = torch.manual_seed(0) + loaded_pipeline_out = loaded_pipe(**inputs)[0] + + if isinstance(pipeline_out, np.ndarray) and isinstance(loaded_pipeline_out, np.ndarray): + assert np.allclose(pipeline_out, loaded_pipeline_out, atol=atol, rtol=rtol) + elif isinstance(pipeline_out, torch.Tensor) and isinstance(loaded_pipeline_out, torch.Tensor): + assert torch.allclose(pipeline_out, loaded_pipeline_out, atol=atol, rtol=rtol) + @is_staging_test class PipelinePushToHubTester(unittest.TestCase): diff --git a/tests/pipelines/unclip/test_unclip_image_variation.py b/tests/pipelines/unclip/test_unclip_image_variation.py index dfc3acc0c0f2..23a6cd6663b7 100644 --- a/tests/pipelines/unclip/test_unclip_image_variation.py +++ b/tests/pipelines/unclip/test_unclip_image_variation.py @@ -66,6 +66,7 @@ class UnCLIPImageVariationPipelineFastTests(PipelineTesterMixin, unittest.TestCa "super_res_num_inference_steps", ] test_xformers_attention = False + supports_dduf = False @property def text_embedder_hidden_size(self): diff --git a/tests/pipelines/unidiffuser/test_unidiffuser.py b/tests/pipelines/unidiffuser/test_unidiffuser.py index 2e0ba1cfb8eb..310e46a2e8c6 100644 --- a/tests/pipelines/unidiffuser/test_unidiffuser.py +++ b/tests/pipelines/unidiffuser/test_unidiffuser.py @@ -86,6 +86,8 @@ class UniDiffuserPipelineFastTests( # vae_latents, not latents, is the argument that corresponds to VAE latent inputs image_latents_params = frozenset(["vae_latents"]) + supports_dduf = False + def get_dummy_components(self): unet = UniDiffuserModel.from_pretrained( "hf-internal-testing/unidiffuser-diffusers-test",