Skip to content

Commit

Permalink
add single file to pipelines
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-r-o-w committed Dec 10, 2024
1 parent 9f9e016 commit db16983
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 11 deletions.
14 changes: 5 additions & 9 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2212,10 +2212,9 @@ def swap_scale_shift(weight):


def convert_ltx_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}

def remove_keys_(key: str, state_dict):
state_dict.pop(key)
converted_state_dict = {
key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "model.diffusion_model." in key
}

TRANSFORMER_KEYS_RENAME_DICT = {
"model.diffusion_model.": "",
Expand All @@ -2225,9 +2224,7 @@ def remove_keys_(key: str, state_dict):
"k_norm": "norm_k",
}

TRANSFORMER_SPECIAL_KEYS_REMAP = {
"vae": remove_keys_,
}
TRANSFORMER_SPECIAL_KEYS_REMAP = {}

for key in list(converted_state_dict.keys()):
new_key = key
Expand All @@ -2245,7 +2242,7 @@ def remove_keys_(key: str, state_dict):


def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys()) if "vae." in key}

def remove_keys_(key: str, state_dict):
state_dict.pop(key)
Expand Down Expand Up @@ -2287,7 +2284,6 @@ def remove_keys_(key: str, state_dict):
"per_channel_statistics.channel": remove_keys_,
"per_channel_statistics.mean-of-means": remove_keys_,
"per_channel_statistics.mean-of-stds": remove_keys_,
"model.diffusion_model": remove_keys_,
}

for key in list(converted_state_dict.keys()):
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/pipelines/ltx/pipeline_ltx.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from transformers import T5EncoderModel, T5TokenizerFast

from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import FromSingleFileMixin
from ...models.autoencoders import AutoencoderKLLTX
from ...models.transformers import LTXTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
Expand Down Expand Up @@ -139,7 +140,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps


class LTXPipeline(DiffusionPipeline):
class LTXPipeline(DiffusionPipeline, FromSingleFileMixin):
r"""
Pipeline for text-to-video generation.
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/pipelines/ltx/pipeline_ltx_image2video.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
from ...loaders import FromSingleFileMixin
from ...models.autoencoders import AutoencoderKLLTX
from ...models.transformers import LTXTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
Expand Down Expand Up @@ -158,7 +159,7 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")


class LTXImageToVideoPipeline(DiffusionPipeline):
class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin):
r"""
Pipeline for image-to-video generation.
Expand Down

0 comments on commit db16983

Please sign in to comment.