Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Single File] LTX support for loading original weights #10135

Merged
merged 3 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
convert_flux_transformer_checkpoint_to_diffusers,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
convert_ltx_transformer_checkpoint_to_diffusers,
convert_ltx_vae_checkpoint_to_diffusers,
convert_sd3_transformer_checkpoint_to_diffusers,
convert_stable_cascade_unet_single_file_to_diffusers,
create_controlnet_diffusers_config_from_ldm,
Expand Down Expand Up @@ -82,6 +84,14 @@
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"LTXTransformer3DModel": {
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"AutoencoderKLLTX": {
"checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers,
"default_subfolder": "vae",
},
}


Expand Down Expand Up @@ -270,6 +280,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
subfolder=subfolder,
local_files_only=local_files_only,
token=token,
revision=revision,
)
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)

Expand Down
105 changes: 105 additions & 0 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@
"double_blocks.0.img_attn.norm.key_norm.scale",
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
],
"ltx-video": [
(
"model.diffusion_model.patchify_proj.weight",
"model.diffusion_model.transformer_blocks.27.scale_shift_table",
),
],
}

DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
Expand Down Expand Up @@ -138,6 +144,7 @@
"animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"},
"flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"},
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
"ltx-video": {"pretrained_model_name_or_path": "Lightricks/LTX-Video"},
}

# Use to configure model sample size when original config is provided
Expand Down Expand Up @@ -564,6 +571,10 @@ def infer_diffusers_model_type(checkpoint):
model_type = "flux-dev"
else:
model_type = "flux-schnell"

elif any(all(key in checkpoint for key in key_list) for key_list in CHECKPOINT_KEY_NAMES["ltx-video"]):
model_type = "ltx-video"

else:
model_type = "v1"

Expand Down Expand Up @@ -2198,3 +2209,97 @@ def swap_scale_shift(weight):
)

return converted_state_dict


def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}

def remove_keys_(key: str, state_dict):
state_dict.pop(key)

TRANSFORMER_KEYS_RENAME_DICT = {
"model.diffusion_model.": "",
"patchify_proj": "proj_in",
"adaln_single": "time_embed",
"q_norm": "norm_q",
"k_norm": "norm_k",
}

TRANSFORMER_SPECIAL_KEYS_REMAP = {
"vae": remove_keys_,
}

for key in list(converted_state_dict.keys()):
new_key = key
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
converted_state_dict[new_key] = converted_state_dict.pop(key)

for key in list(converted_state_dict.keys()):
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)

return converted_state_dict


def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}

def remove_keys_(key: str, state_dict):
state_dict.pop(key)

VAE_KEYS_RENAME_DICT = {
# common
"vae.": "",
# decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0",
"up_blocks.2": "up_blocks.1.upsamplers.0",
"up_blocks.3": "up_blocks.1",
"up_blocks.4": "up_blocks.2.conv_in",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
"up_blocks.7": "up_blocks.3.conv_in",
"up_blocks.8": "up_blocks.3.upsamplers.0",
"up_blocks.9": "up_blocks.3",
# encoder
"down_blocks.0": "down_blocks.0",
"down_blocks.1": "down_blocks.0.downsamplers.0",
"down_blocks.2": "down_blocks.0.conv_out",
"down_blocks.3": "down_blocks.1",
"down_blocks.4": "down_blocks.1.downsamplers.0",
"down_blocks.5": "down_blocks.1.conv_out",
"down_blocks.6": "down_blocks.2",
"down_blocks.7": "down_blocks.2.downsamplers.0",
"down_blocks.8": "down_blocks.3",
"down_blocks.9": "mid_block",
# common
"conv_shortcut": "conv_shortcut.conv",
"res_blocks": "resnets",
"norm3.norm": "norm3",
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}

VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_,
"per_channel_statistics.mean-of-stds": remove_keys_,
"model.diffusion_model": remove_keys_,
}

for key in list(converted_state_dict.keys()):
new_key = key
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
converted_state_dict[new_key] = converted_state_dict.pop(key)

for key in list(converted_state_dict.keys()):
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)

return converted_state_dict
3 changes: 2 additions & 1 deletion src/diffusers/models/autoencoders/autoencoder_kl_ltx.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch.nn as nn

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
Expand Down Expand Up @@ -718,7 +719,7 @@ def create_forward(*inputs):
return hidden_states


class AutoencoderKLLTX(ModelMixin, ConfigMixin):
class AutoencoderKLLTX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
[LTX](https://huggingface.co/Lightricks/LTX-Video).
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/models/transformers/transformer_ltx.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch.nn.functional as F

from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin
from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
Expand Down Expand Up @@ -266,7 +267,7 @@ def forward(


@maybe_allow_in_graph
class LTXTransformer3DModel(ModelMixin, ConfigMixin):
class LTXTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A Transformer model for video-like data used in [LTX](https://huggingface.co/Lightricks/LTX-Video).

Expand Down