Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] DDUF format #10037

Open
wants to merge 87 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
1fb86e3
load and save dduf archive
SunMarc Nov 27, 2024
0389333
style
SunMarc Nov 27, 2024
2eeda25
switch to zip uncompressed
SunMarc Nov 28, 2024
cbee7cb
Merge branch 'main' into dduf
sayakpaul Nov 30, 2024
d840867
updates
sayakpaul Nov 29, 2024
78135f1
Merge branch 'main' into dduf
sayakpaul Dec 4, 2024
7d2c7d5
Update src/diffusers/pipelines/pipeline_utils.py
SunMarc Dec 4, 2024
e66c4d0
Update src/diffusers/pipelines/pipeline_utils.py
SunMarc Dec 4, 2024
b14bffe
first draft
SunMarc Dec 4, 2024
1cd5155
remove print
SunMarc Dec 4, 2024
b8a43e7
switch to dduf_file for consistency
SunMarc Dec 5, 2024
cac988a
Merge remote-tracking branch 'origin/main' into dduf
SunMarc Dec 5, 2024
977baa3
switch to huggingface hub api
SunMarc Dec 5, 2024
81bd097
fix log
SunMarc Dec 6, 2024
c4df147
Merge branch 'main' into dduf
sayakpaul Dec 8, 2024
d0a861c
add a basic test
sayakpaul Dec 8, 2024
1ec988f
Update src/diffusers/configuration_utils.py
SunMarc Dec 9, 2024
5217712
Update src/diffusers/pipelines/pipeline_utils.py
SunMarc Dec 9, 2024
6922226
Update src/diffusers/pipelines/pipeline_utils.py
SunMarc Dec 9, 2024
3b0d84d
fix
SunMarc Dec 9, 2024
1bc953b
Merge remote-tracking branch 'origin/dduf' into dduf
SunMarc Dec 9, 2024
04ecf0e
fix variant
SunMarc Dec 9, 2024
9fff68a
change saving logic
SunMarc Dec 10, 2024
ed6c727
DDUF - Load transformers components manually (#10171)
Wauplin Dec 11, 2024
17d50d1
working version with transformers and tokenizer !
SunMarc Dec 11, 2024
59929a5
add generation_config case
SunMarc Dec 12, 2024
aa0d497
fix tests
sayakpaul Dec 12, 2024
a793066
Merge branch 'main' into dduf
sayakpaul Dec 12, 2024
7602952
remove saving for now
SunMarc Dec 12, 2024
660d7c8
typing
SunMarc Dec 12, 2024
8358ef6
need next version from transformers
SunMarc Dec 12, 2024
4e7d15a
Update src/diffusers/configuration_utils.py
SunMarc Dec 12, 2024
cc75db3
check path corectly
SunMarc Dec 12, 2024
63575af
Merge remote-tracking branch 'origin/dduf' into dduf
SunMarc Dec 12, 2024
1e5ebf5
Apply suggestions from code review
SunMarc Dec 12, 2024
53e100b
udapte
SunMarc Dec 12, 2024
54991ac
Merge remote-tracking branch 'origin/dduf' into dduf
SunMarc Dec 12, 2024
1eb25dc
typing
SunMarc Dec 12, 2024
0cb1b98
remove check for subfolder
SunMarc Dec 12, 2024
5ec3951
quality
SunMarc Dec 12, 2024
1785eaa
revert setup changes
SunMarc Dec 12, 2024
7486016
oups
SunMarc Dec 12, 2024
73e81a5
more readable condition
SunMarc Dec 12, 2024
ea0126d
add loading from the hub test
sayakpaul Dec 13, 2024
3ebdcff
add basic docs.
sayakpaul Dec 13, 2024
5943a60
Merge branch 'main' into dduf
sayakpaul Dec 13, 2024
021abf0
Apply suggestions from code review
SunMarc Dec 13, 2024
af2ca07
add example
SunMarc Dec 13, 2024
9d70b6c
Merge branch 'main' into dduf
sayakpaul Dec 13, 2024
47cb92c
Merge remote-tracking branch 'origin/dduf' into dduf
SunMarc Dec 13, 2024
c9734ab
add
SunMarc Dec 13, 2024
27ebf9e
make functions private
SunMarc Dec 13, 2024
d5dbb5c
Apply suggestions from code review
SunMarc Dec 13, 2024
627aec0
resolve conflicts.
sayakpaul Dec 18, 2024
f0e21a9
Merge branch 'main' into dduf
sayakpaul Dec 18, 2024
e9b7429
minor.
sayakpaul Dec 18, 2024
b8b699a
fixes
sayakpaul Dec 18, 2024
6003176
Merge branch 'main' into dduf
sayakpaul Dec 18, 2024
0e54b06
fix
sayakpaul Dec 18, 2024
a026055
change the precdence of parameterized.
sayakpaul Dec 18, 2024
03e30b4
Merge remote-tracking branch 'upstream/main' into dduf
SunMarc Dec 23, 2024
da48dcb
Merge branch 'main' into dduf
SunMarc Dec 23, 2024
0fbea9a
Merge branch 'main' into dduf
sayakpaul Dec 30, 2024
6a163c7
error out when custom pipeline is passed with dduf_file.
sayakpaul Dec 30, 2024
67b617e
updates
sayakpaul Dec 31, 2024
b40272e
fix
sayakpaul Dec 31, 2024
ce237f3
updates
sayakpaul Dec 31, 2024
454b9b9
fixes
sayakpaul Dec 31, 2024
366aa2f
updates
sayakpaul Dec 31, 2024
6648995
fix xfail condition.
sayakpaul Dec 31, 2024
21ae7ee
fix xfail
sayakpaul Dec 31, 2024
15d4569
fixes
sayakpaul Dec 31, 2024
f3a4ddc
sharded checkpoint compat
SunMarc Jan 3, 2025
a032025
Merge branch 'dduf' of github.com:huggingface/diffusers into dduf
SunMarc Jan 3, 2025
faa0cac
Merge branch 'main' into dduf
sayakpaul Jan 3, 2025
0205cc8
add test for sharded checkpoint
SunMarc Jan 3, 2025
9cda4c1
Merge branch 'dduf' of github.com:huggingface/diffusers into dduf
SunMarc Jan 3, 2025
cd0734e
Merge branch 'main' into dduf
sayakpaul Jan 4, 2025
5037d39
add suggestions
SunMarc Jan 6, 2025
c9e08da
Merge branch 'dduf' of github.com:huggingface/diffusers into dduf
SunMarc Jan 6, 2025
02a368b
Update src/diffusers/models/model_loading_utils.py
SunMarc Jan 7, 2025
7bc9347
from suggestions
SunMarc Jan 7, 2025
fff5954
add class attributes to flag dduf tests
SunMarc Jan 7, 2025
da402da
last one
SunMarc Jan 7, 2025
9ebbf84
Merge branch 'main' into dduf
sayakpaul Jan 8, 2025
aaaa947
Merge branch 'main' into dduf
sayakpaul Jan 8, 2025
290b88d
resolve conflicts.
sayakpaul Jan 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions docs/source/en/using-diffusers/other-formats.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,50 @@ Benefits of using a single-file layout include:
1. Easy compatibility with diffusion interfaces such as [ComfyUI](https://github.com/comfyanonymous/ComfyUI) or [Automatic1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) which commonly use a single-file layout.
2. Easier to manage (download and share) a single file.

### DDUF

<Tip warning={true}>

DDUF is an experimental file format and APIs related to it can change in the future.

</Tip>
SunMarc marked this conversation as resolved.
Show resolved Hide resolved

DDUF, aka (**D**DUF’s **D**iffusion **U**nified **F**ormat) is a file format designed to make storing, distributing, and using diffusion models much easier. Built on the ZIP file format, DDUF offers a standardized, efficient, and flexible way to package all parts of a diffusion model into a single, easy-to-manage file. It tries to provide a sweet spot between our multi-folder format and widely popular single-file format. To learn more about it, please check out the documentation [here](https://huggingface.co/docs/hub/dduf).
SunMarc marked this conversation as resolved.
Show resolved Hide resolved

Below we show, how to load a DDUF checkpoint in a [`DiffusionPipeline`]:
SunMarc marked this conversation as resolved.
Show resolved Hide resolved

```py
from diffusers import DiffusionPipeline
import torch

pipe = DiffusionPipeline.from_pretrained(
"DDUF/FLUX.1-dev-DDUF", dduf_file="FLUX.1-dev.dduf", torch_dtype=torch.bfloat16
).to("cuda")
image = pipe(
"photo a cat holding a sign that says Diffusers", num_inference_steps=50, guidance_scale=3.5
).images[0]
image.save("cat.png")
```

To save a pipeline as a `.dduf` checkpoint, we rely on `huggingface_hub`'s `export_folder_as_dduf()` utility, which takes care of all the necessary file-level validations:
SunMarc marked this conversation as resolved.
Show resolved Hide resolved

```py
from huggingface_hub import export_folder_as_dduf
from diffusers import DiffusionPipeline
import torch

pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)

save_folder = "flux-dev"
pipe.save_pretrained("flux-dev")
export_folder_as_dduf("flux-dev.dduf", folder_path=save_folder)

<Tip>

We support packaging and loading quantized checkpoints in the DDUF format as long as they respect the multi-folder structure.

</Tip>
SunMarc marked this conversation as resolved.
Show resolved Hide resolved

## Convert layout and files

Diffusers provides many scripts and methods to convert storage layouts and file formats to enable broader support across the diffusion ecosystem.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
"filelock",
"flax>=0.4.1",
"hf-doc-builder>=0.3.0",
"huggingface-hub>=0.23.2",
"huggingface-hub>=0.27.0",
"requests-mock==1.10.0",
"importlib_metadata",
"invisible-watermark>=0.2.0",
Expand Down
40 changes: 30 additions & 10 deletions src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
import re
from collections import OrderedDict
from pathlib import Path
from typing import Any, Dict, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union

import numpy as np
from huggingface_hub import create_repo, hf_hub_download
from huggingface_hub import DDUFEntry, create_repo, hf_hub_download
from huggingface_hub.utils import (
EntryNotFoundError,
RepositoryNotFoundError,
Expand Down Expand Up @@ -347,6 +347,7 @@ def load_config(
_ = kwargs.pop("mirror", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = kwargs.pop("user_agent", {})
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)

user_agent = {**user_agent, "file_type": "config"}
user_agent = http_user_agent(user_agent)
Expand All @@ -358,8 +359,24 @@ def load_config(
"`self.config_name` is not defined. Note that one should not load a config from "
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
)

if os.path.isfile(pretrained_model_name_or_path):
# Custom path for now
if dduf_entries:
if subfolder is not None:
raise ValueError(
"DDUF file only allow for 1 level of directory (e.g transformer/model1/model.safetentors is not allowed). "
"Please check the DDUF structure"
)
# paths inside a DDUF file must always be "/"
config_file = (
cls.config_name
if pretrained_model_name_or_path == ""
else "/".join([pretrained_model_name_or_path, cls.config_name])
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved
)
if config_file not in dduf_entries:
raise ValueError(
f"We did not manage to find the file {config_file} in the dduf file. We only have the following files {dduf_entries.keys()}"
)
elif os.path.isfile(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path
elif os.path.isdir(pretrained_model_name_or_path):
if subfolder is not None and os.path.isfile(
Expand Down Expand Up @@ -426,10 +443,8 @@ def load_config(
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a {cls.config_name} file"
)

try:
# Load config dict
config_dict = cls._dict_from_json_file(config_file)
config_dict = cls._dict_from_json_file(config_file, dduf_entries=dduf_entries)

commit_hash = extract_commit_hash(config_file)
except (json.JSONDecodeError, UnicodeDecodeError):
Expand Down Expand Up @@ -552,9 +567,14 @@ def extract_init_dict(cls, config_dict, **kwargs):
return init_dict, unused_kwargs, hidden_config_dict

@classmethod
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
def _dict_from_json_file(
cls, json_file: Union[str, os.PathLike], dduf_entries: Optional[Dict[str, DDUFEntry]] = None
):
if dduf_entries:
text = dduf_entries[json_file].read_text()
else:
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
return json.loads(text)

def __repr__(self):
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"filelock": "filelock",
"flax": "flax>=0.4.1",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"huggingface-hub": "huggingface-hub>=0.23.2",
"huggingface-hub": "huggingface-hub>=0.27.0",
"requests-mock": "requests-mock==1.10.0",
"importlib_metadata": "importlib_metadata",
"invisible-watermark": "invisible-watermark>=0.2.0",
Expand Down
21 changes: 18 additions & 3 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
import os
from collections import OrderedDict
from pathlib import Path
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union

import safetensors
import torch
from huggingface_hub import DDUFEntry
from huggingface_hub.utils import EntryNotFoundError

from ..quantizers.quantization_config import QuantizationMethod
Expand Down Expand Up @@ -128,7 +129,11 @@ def _fetch_remapped_cls_from_config(config, old_class):
return old_class


def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
def load_state_dict(
checkpoint_file: Union[str, os.PathLike],
variant: Optional[str] = None,
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
):
"""
Reads a checkpoint file, returning properly formatted errors if they arise.
"""
Expand All @@ -139,7 +144,13 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
try:
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
if file_extension == SAFETENSORS_FILE_EXTENSION:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
if dduf_entries:
# tensors are loaded on cpu
with dduf_entries[checkpoint_file].as_mmap() as mm:
return safetensors.torch.load(mm)
else:
return safetensors.torch.load_file(checkpoint_file, device="cpu")

else:
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
return torch.load(
Expand Down Expand Up @@ -274,6 +285,7 @@ def _fetch_index_file(
revision,
user_agent,
commit_hash,
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
):
if is_local:
index_file = Path(
Expand All @@ -299,6 +311,7 @@ def _fetch_index_file(
subfolder=None,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)
index_file = Path(index_file)
except (EntryNotFoundError, EnvironmentError):
Expand Down Expand Up @@ -350,6 +363,7 @@ def _fetch_index_file_legacy(
revision,
user_agent,
commit_hash,
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
):
if is_local:
index_file = Path(
Expand Down Expand Up @@ -390,6 +404,7 @@ def _fetch_index_file_legacy(
subfolder=None,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)
index_file = Path(index_file)
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
Expand Down
16 changes: 11 additions & 5 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
from collections import OrderedDict
from functools import partial, wraps
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import safetensors
import torch
from huggingface_hub import create_repo, split_torch_state_dict_into_shards
from huggingface_hub import DDUFEntry, create_repo, split_torch_state_dict_into_shards
from huggingface_hub.utils import validate_hf_hub_args
from torch import Tensor, nn

Expand Down Expand Up @@ -586,6 +586,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
variant = kwargs.pop("variant", None)
use_safetensors = kwargs.pop("use_safetensors", None)
quantization_config = kwargs.pop("quantization_config", None)
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)

allow_pickle = False
if use_safetensors is None:
Expand Down Expand Up @@ -678,6 +679,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
dduf_entries=dduf_entries,
**kwargs,
)
# no in-place modification of the original config.
Expand Down Expand Up @@ -753,6 +755,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
"revision": revision,
"user_agent": user_agent,
"commit_hash": commit_hash,
"dduf_entries": dduf_entries,
}
index_file = _fetch_index_file(**index_file_kwargs)
# In case the index file was not found we still have to consider the legacy format.
Expand Down Expand Up @@ -788,7 +791,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

model = load_flax_checkpoint_in_pytorch_model(model, model_file)
else:
if is_sharded:
# in the case it is sharded, we have already the index
if is_sharded and not dduf_entries:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we support dduf_entries for sharded checkpoints?
theis_sharded logic (on line 786) depends on whether the index_file exists in the filesystem, which is just never going to be True for dduf because all the files are contained in the dduf_entires - it is unclear to me whether this is intended or not

Copy link
Member Author

@SunMarc SunMarc Dec 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dduf_entries do not support sharded checkpoint indeed. Thanks for noticing this. Maybe it is better to refactor the loading as we discussed a long time ago sayak and do the same as transformers where we just pass the state_dict and the config ? WDYT @sayakpaul @yiyixuxu ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, the loading refactoring (#10013) is turning out to be more and more urgent which @DN6 wants to tackle. If we can make dduf_entries distinct from is_sharded in this PR, I think that might be a better option to explore for the minimal DDUF PoC?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, I used the _merge_sharded_checkpoints path just like how we do for quantized model cc @sayakpaul . But after the refactoring in a follow-up PR, we will be able to make it a lot simpler.

sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
pretrained_model_name_or_path,
index_file,
Expand Down Expand Up @@ -819,6 +823,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)

except IOError as e:
Expand All @@ -842,6 +847,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
subfolder=subfolder,
user_agent=user_agent,
commit_hash=commit_hash,
dduf_entries=dduf_entries,
)

if low_cpu_mem_usage:
Expand All @@ -866,7 +872,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
elif is_quant_method_bnb:
param_device = torch.device(torch.cuda.current_device())
state_dict = load_state_dict(model_file, variant=variant)
state_dict = load_state_dict(model_file, variant=variant, dduf_entries=dduf_entries)
model._convert_deprecated_attention_blocks(state_dict)

# move the params from meta device to cpu
Expand Down Expand Up @@ -966,7 +972,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
else:
model = cls.from_config(config, **unused_kwargs)

state_dict = load_state_dict(model_file, variant=variant)
state_dict = load_state_dict(model_file, variant=variant, dduf_entries=dduf_entries)
model._convert_deprecated_attention_blocks(state_dict)

model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
Expand Down
33 changes: 26 additions & 7 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import importlib
import os
import re
import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from huggingface_hub import ModelCard, model_info
from huggingface_hub import DDUFEntry, ModelCard, model_info
from huggingface_hub.utils import validate_hf_hub_args
from packaging import version

Expand All @@ -41,11 +39,12 @@
logging,
)
from ..utils.torch_utils import is_compiled_module
from .transformers_loading_utils import _load_tokenizer_from_dduf, _load_transformers_model_from_dduf


if is_transformers_available():
import transformers
from transformers import PreTrainedModel
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME
from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME
from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME
Expand Down Expand Up @@ -627,6 +626,7 @@ def load_sub_model(
low_cpu_mem_usage: bool,
cached_folder: Union[str, os.PathLike],
use_safetensors: bool,
dduf_entries: Optional[Dict[str, DDUFEntry]],
):
"""Helper method to load the module `name` from `library_name` and `class_name`"""

Expand Down Expand Up @@ -663,7 +663,7 @@ def load_sub_model(
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
)

load_method = getattr(class_obj, load_method_name)
load_method = _get_load_method(class_obj, load_method_name, is_dduf=dduf_entries is not None)

# add kwargs to loading method
diffusers_module = importlib.import_module(__name__.split(".")[0])
Expand Down Expand Up @@ -721,7 +721,10 @@ def load_sub_model(
loading_kwargs["low_cpu_mem_usage"] = False

# check if the module is in a subdirectory
if os.path.isdir(os.path.join(cached_folder, name)):
if dduf_entries:
loading_kwargs["dduf_entries"] = dduf_entries
loaded_sub_model = load_method(name, **loading_kwargs)
elif os.path.isdir(os.path.join(cached_folder, name)):
SunMarc marked this conversation as resolved.
Show resolved Hide resolved
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
else:
# else load from the root directory
Expand All @@ -746,6 +749,22 @@ def load_sub_model(
return loaded_sub_model


def _get_load_method(class_obj: object, load_method_name: str, is_dduf: bool) -> Callable:
"""
Return the method to load the sub model.

In practice, this method will return the `"from_pretrained"` (or `load_method_name`) method of the class object
except if loading from a DDUF checkpoint. In that case, transformers models and tokenizers have a specific loading
method that we need to use.
"""
if is_dduf:
if issubclass(class_obj, PreTrainedTokenizerBase):
return lambda *args, **kwargs: _load_tokenizer_from_dduf(class_obj, *args, **kwargs)
if issubclass(class_obj, PreTrainedModel):
return lambda *args, **kwargs: _load_transformers_model_from_dduf(class_obj, *args, **kwargs)
return getattr(class_obj, load_method_name)


def _fetch_class_library_tuple(module):
# import it here to avoid circular import
diffusers_module = importlib.import_module(__name__.split(".")[0])
Expand Down
Loading
Loading