diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml
index 702b95458d78..b85459f71aae 100644
--- a/docs/source/en/_toctree.yml
+++ b/docs/source/en/_toctree.yml
@@ -310,6 +310,8 @@
title: Versatile Diffusion
- local: api/pipelines/vq_diffusion
title: VQ Diffusion
+ - local: api/pipelines/wuerstchen
+ title: Wuerstchen
title: Pipelines
- sections:
- local: api/schedulers/overview
diff --git a/docs/source/en/api/pipelines/wuerstchen.md b/docs/source/en/api/pipelines/wuerstchen.md
new file mode 100644
index 000000000000..4316bc739ca9
--- /dev/null
+++ b/docs/source/en/api/pipelines/wuerstchen.md
@@ -0,0 +1,140 @@
+# Würstchen
+
+
+
+[Würstchen: Efficient Pretraining of Text-to-Image Models](https://huggingface.co/papers/2306.00637) is by Pablo Pernias, Dominic Rampas, and Marc Aubreville.
+
+The abstract from the paper is:
+
+*We introduce Würstchen, a novel technique for text-to-image synthesis that unites competitive performance with unprecedented cost-effectiveness and ease of training on constrained hardware. Building on recent advancements in machine learning, our approach, which utilizes latent diffusion strategies at strong latent image compression rates, significantly reduces the computational burden, typically associated with state-of-the-art models, while preserving, if not enhancing, the quality of generated images. Wuerstchen achieves notable speed improvements at inference time, thereby rendering real-time applications more viable. One of the key advantages of our method lies in its modest training requirements of only 9,200 GPU hours, slashing the usual costs significantly without compromising the end performance. In a comparison against the state-of-the-art, we found the approach to yield strong competitiveness. This paper opens the door to a new line of research that prioritizes both performance and computational accessibility, hence democratizing the use of sophisticated AI technologies. Through Wuerstchen, we demonstrate a compelling stride forward in the realm of text-to-image synthesis, offering an innovative path to explore in future research.*
+
+## Würstchen v2 comes to Diffusers
+
+After the initial paper release, we have improved numerous things in the architecture, training and sampling, making Würstchen competetive to current state-of-the-art models in many ways. We are excited to release this new version together with Diffusers. Here is a list of the improvements.
+
+- Higher resolution (1024x1024 up to 2048x2048)
+- Faster inference
+- Multi Aspect Resolution Sampling
+- Better quality
+
+We are releasing 3 checkpoints for the text-conditional image generation model (Stage C). Those are:
+- v2-base
+- v2-aesthetic
+- v2-interpolated (50% interpolation between v2-base and v2-aesthetic)
+
+We recommend to use v2-interpolated, as it has a nice touch of both photorealism and aesthetic. Use v2-base for finetunings as it does not have a style bias and use v2-aesthetic for very artistic generations.
+A comparison can be seen here:
+
+
+
+## Text-to-Image Generation
+
+For the sake of usability Würstchen can be used with a single pipeline. This pipeline is called `WuerstchenCombinedPipeline` and can be used as follows:
+
+```python
+import torch
+from diffusers import AutoPipelineForText2Image
+
+device = "cuda"
+dtype = torch.float16
+num_images_per_prompt = 2
+
+pipeline = AutoPipelineForText2Image.from_pretrained(
+ "warp-diffusion/wuerstchen", torch_dtype=dtype
+).to(device)
+
+caption = "Anthropomorphic cat dressed as a fire fighter"
+negative_prompt = ""
+
+output = pipeline(
+ prompt=caption,
+ height=1024,
+ width=1024,
+ negative_prompt=negative_prompt,
+ prior_guidance_scale=4.0,
+ decoder_guidance_scale=0.0,
+ num_images_per_prompt=num_images_per_prompt,
+ output_type="pil",
+).images
+```
+
+For explanation purposes, we can also initialize the two main pipelines of Würstchen individually. Würstchen consists of 3 stages: Stage C, Stage B, Stage A. They all have different jobs and work only together. When generating text-conditional images, Stage C will first generate the latents in a very compressed latent space. This is what happens in the `prior_pipeline`. Afterwards, the generated latents will be passed to Stage B, which decompresses the latents into a bigger latent space of a VQGAN. These latents can then be decoded by Stage A, which is a VQGAN, into the pixel-space. Stage B & Stage A are both encapsulated in the `decoder_pipeline`. For more details, take a look the [paper](https://huggingface.co/papers/2306.00637).
+
+```python
+import torch
+from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline
+
+device = "cuda"
+dtype = torch.float16
+num_images_per_prompt = 2
+
+prior_pipeline = WuerstchenPriorPipeline.from_pretrained(
+ "warp-diffusion/wuerstchen-prior", torch_dtype=dtype
+).to(device)
+decoder_pipeline = WuerstchenDecoderPipeline.from_pretrained(
+ "warp-diffusion/wuerstchen", torch_dtype=dtype
+).to(device)
+
+caption = "A captivating artwork of a mysterious stone golem"
+negative_prompt = ""
+
+prior_output = prior_pipeline(
+ prompt=caption,
+ height=1024,
+ width=1024,
+ negative_prompt=negative_prompt,
+ guidance_scale=4.0,
+ num_images_per_prompt=num_images_per_prompt,
+)
+decoder_output = decoder_pipeline(
+ image_embeddings=prior_output.image_embeddings,
+ prompt=caption,
+ negative_prompt=negative_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ guidance_scale=0.0,
+ output_type="pil",
+).images
+```
+
+## Speed-Up Inference
+You can make use of ``torch.compile`` function and gain a speed-up of about 2-3x:
+
+```python
+pipeline.prior = torch.compile(pipeline.prior, mode="reduce-overhead", fullgraph=True)
+pipeline.decoder = torch.compile(pipeline.decoder, mode="reduce-overhead", fullgraph=True)
+```
+
+## Limitations
+- Due to the high compression employed by Würstchen, generations can lack a good amount
+of detail. To our human eye, this is especially noticeable in faces, hands etc.
+- **Images can only be generated in 128-pixel steps**, e.g. the next higher resolution
+after 1024x1024 is 1152x1152
+- The model lacks the ability to render correct text in images
+- The model often does not achieve photorealism
+- Difficult compositional prompts are hard for the model
+
+
+The original codebase, as well as experimental ideas, can be found at [dome272/Wuerstchen](https://github.com/dome272/Wuerstchen).
+
+## WuerschenPipeline
+
+[[autodoc]] WuerstchenCombinedPipeline
+ - all
+ - __call__
+
+## WuerstchenPriorPipeline
+
+[[autodoc]] WuerstchenDecoderPipeline
+
+ - all
+ - __call__
+
+## WuerstchenPriorPipelineOutput
+
+[[autodoc]] pipelines.wuerstchen.pipeline_wuerstchen_prior.WuerstchenPriorPipelineOutput
+
+## WuerstchenDecoderPipeline
+
+[[autodoc]] WuerstchenDecoderPipeline
+ - all
+ - __call__
diff --git a/scripts/convert_wuerstchen.py b/scripts/convert_wuerstchen.py
new file mode 100644
index 000000000000..91fd9b79b4ee
--- /dev/null
+++ b/scripts/convert_wuerstchen.py
@@ -0,0 +1,115 @@
+# Run inside root directory of official source code: https://github.com/dome272/wuerstchen/
+import os
+
+import torch
+from transformers import AutoTokenizer, CLIPTextModel
+from vqgan import VQModel
+
+from diffusers import (
+ DDPMWuerstchenScheduler,
+ WuerstchenCombinedPipeline,
+ WuerstchenDecoderPipeline,
+ WuerstchenPriorPipeline,
+)
+from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt, WuerstchenPrior
+
+
+model_path = "models/"
+device = "cpu"
+
+paella_vqmodel = VQModel()
+state_dict = torch.load(os.path.join(model_path, "vqgan_f4_v1_500k.pt"), map_location=device)["state_dict"]
+paella_vqmodel.load_state_dict(state_dict)
+
+state_dict["vquantizer.embedding.weight"] = state_dict["vquantizer.codebook.weight"]
+state_dict.pop("vquantizer.codebook.weight")
+vqmodel = PaellaVQModel(num_vq_embeddings=paella_vqmodel.codebook_size, latent_channels=paella_vqmodel.c_latent)
+vqmodel.load_state_dict(state_dict)
+
+# Clip Text encoder and tokenizer
+text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
+tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
+
+# Generator
+gen_text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu")
+gen_tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
+
+orig_state_dict = torch.load(os.path.join(model_path, "model_v2_stage_b.pt"), map_location=device)["state_dict"]
+state_dict = {}
+for key in orig_state_dict.keys():
+ if key.endswith("in_proj_weight"):
+ weights = orig_state_dict[key].chunk(3, 0)
+ state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
+ state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
+ state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
+ elif key.endswith("in_proj_bias"):
+ weights = orig_state_dict[key].chunk(3, 0)
+ state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
+ state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
+ state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
+ elif key.endswith("out_proj.weight"):
+ weights = orig_state_dict[key]
+ state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
+ elif key.endswith("out_proj.bias"):
+ weights = orig_state_dict[key]
+ state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
+ else:
+ state_dict[key] = orig_state_dict[key]
+deocder = WuerstchenDiffNeXt()
+deocder.load_state_dict(state_dict)
+
+# Prior
+orig_state_dict = torch.load(os.path.join(model_path, "model_v3_stage_c.pt"), map_location=device)["ema_state_dict"]
+state_dict = {}
+for key in orig_state_dict.keys():
+ if key.endswith("in_proj_weight"):
+ weights = orig_state_dict[key].chunk(3, 0)
+ state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
+ state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
+ state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
+ elif key.endswith("in_proj_bias"):
+ weights = orig_state_dict[key].chunk(3, 0)
+ state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
+ state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
+ state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
+ elif key.endswith("out_proj.weight"):
+ weights = orig_state_dict[key]
+ state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
+ elif key.endswith("out_proj.bias"):
+ weights = orig_state_dict[key]
+ state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
+ else:
+ state_dict[key] = orig_state_dict[key]
+prior_model = WuerstchenPrior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24).to(device)
+prior_model.load_state_dict(state_dict)
+
+# scheduler
+scheduler = DDPMWuerstchenScheduler()
+
+# Prior pipeline
+prior_pipeline = WuerstchenPriorPipeline(
+ prior=prior_model, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler
+)
+
+prior_pipeline.save_pretrained("warp-diffusion/wuerstchen-prior")
+
+decoder_pipeline = WuerstchenDecoderPipeline(
+ text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, vqgan=vqmodel, decoder=deocder, scheduler=scheduler
+)
+decoder_pipeline.save_pretrained("warp-diffusion/wuerstchen")
+
+# Wuerstchen pipeline
+wuerstchen_pipeline = WuerstchenCombinedPipeline(
+ # Decoder
+ text_encoder=gen_text_encoder,
+ tokenizer=gen_tokenizer,
+ decoder=deocder,
+ scheduler=scheduler,
+ vqgan=vqmodel,
+ # Prior
+ prior_tokenizer=tokenizer,
+ prior_text_encoder=text_encoder,
+ prior=prior_model,
+ prior_scheduler=scheduler,
+)
+wuerstchen_pipeline.save_pretrained("warp-diffusion/WuerstchenCombinedPipeline")
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index 0d562c9eaaf4..d72c671671c1 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -88,6 +88,7 @@
DDIMScheduler,
DDPMParallelScheduler,
DDPMScheduler,
+ DDPMWuerstchenScheduler,
DEISMultistepScheduler,
DPMSolverMultistepInverseScheduler,
DPMSolverMultistepScheduler,
@@ -216,6 +217,9 @@
VersatileDiffusionTextToImagePipeline,
VideoToVideoSDPipeline,
VQDiffusionPipeline,
+ WuerstchenCombinedPipeline,
+ WuerstchenDecoderPipeline,
+ WuerstchenPriorPipeline,
)
try:
diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py
index 1ec89a5ff68e..49fc2c638620 100644
--- a/src/diffusers/models/attention_processor.py
+++ b/src/diffusers/models/attention_processor.py
@@ -188,7 +188,7 @@ def set_use_memory_efficient_attention_xformers(
if use_memory_efficient_attention_xformers:
if is_added_kv_processor and (is_lora or is_custom_diffusion):
raise NotImplementedError(
- f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
+ f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}"
)
if not is_xformers_available():
raise ModuleNotFoundError(
diff --git a/src/diffusers/models/vae.py b/src/diffusers/models/vae.py
index d3c512b07390..220c0ce990c8 100644
--- a/src/diffusers/models/vae.py
+++ b/src/diffusers/models/vae.py
@@ -52,7 +52,7 @@ def __init__(
super().__init__()
self.layers_per_block = layers_per_block
- self.conv_in = torch.nn.Conv2d(
+ self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[0],
kernel_size=3,
diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py
index 687449e8c755..393a638d483b 100644
--- a/src/diffusers/models/vq_model.py
+++ b/src/diffusers/models/vq_model.py
@@ -132,7 +132,7 @@ def decode(
) -> Union[DecoderOutput, torch.FloatTensor]:
# also go through quantization layer
if not force_not_quantize:
- quant, emb_loss, info = self.quantize(h)
+ quant, _, _ = self.quantize(h)
else:
quant = h
quant2 = self.post_quant_conv(quant)
diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py
index 39d4b9375dbb..28f42ce9fae9 100644
--- a/src/diffusers/pipelines/__init__.py
+++ b/src/diffusers/pipelines/__init__.py
@@ -132,6 +132,7 @@
VersatileDiffusionTextToImagePipeline,
)
from .vq_diffusion import VQDiffusionPipeline
+ from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, WuerstchenPriorPipeline
try:
diff --git a/src/diffusers/pipelines/auto_pipeline.py b/src/diffusers/pipelines/auto_pipeline.py
index 9d9d49f700de..13f12e75fb31 100644
--- a/src/diffusers/pipelines/auto_pipeline.py
+++ b/src/diffusers/pipelines/auto_pipeline.py
@@ -52,6 +52,7 @@
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline,
)
+from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline
AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
@@ -63,6 +64,7 @@
("kandinsky22", KandinskyV22CombinedPipeline),
("stable-diffusion-controlnet", StableDiffusionControlNetPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline),
+ ("wuerstchen", WuerstchenCombinedPipeline),
]
)
@@ -93,6 +95,7 @@
[
("kandinsky", KandinskyPipeline),
("kandinsky22", KandinskyV22Pipeline),
+ ("wuerstchen", WuerstchenDecoderPipeline),
]
)
_AUTO_IMAGE2IMAGE_DECODER_PIPELINES_MAPPING = OrderedDict(
@@ -305,8 +308,6 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
- subfolder = kwargs.pop("subfolder", None)
- user_agent = kwargs.pop("user_agent", {})
load_config_kwargs = {
"cache_dir": cache_dir,
@@ -316,8 +317,6 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
"use_auth_token": use_auth_token,
"local_files_only": local_files_only,
"revision": revision,
- "subfolder": subfolder,
- "user_agent": user_agent,
}
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
@@ -580,8 +579,6 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
- subfolder = kwargs.pop("subfolder", None)
- user_agent = kwargs.pop("user_agent", {})
load_config_kwargs = {
"cache_dir": cache_dir,
@@ -591,8 +588,6 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
"use_auth_token": use_auth_token,
"local_files_only": local_files_only,
"revision": revision,
- "subfolder": subfolder,
- "user_agent": user_agent,
}
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
@@ -856,8 +851,6 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
- subfolder = kwargs.pop("subfolder", None)
- user_agent = kwargs.pop("user_agent", {})
load_config_kwargs = {
"cache_dir": cache_dir,
@@ -867,8 +860,6 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
"use_auth_token": use_auth_token,
"local_files_only": local_files_only,
"revision": revision,
- "subfolder": subfolder,
- "user_agent": user_agent,
}
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
diff --git a/src/diffusers/pipelines/wuerstchen/__init__.py b/src/diffusers/pipelines/wuerstchen/__init__.py
new file mode 100644
index 000000000000..a6f6321b048a
--- /dev/null
+++ b/src/diffusers/pipelines/wuerstchen/__init__.py
@@ -0,0 +1,10 @@
+from ...utils import is_torch_available, is_transformers_available
+
+
+if is_transformers_available() and is_torch_available():
+ from .modeling_paella_vq_model import PaellaVQModel
+ from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
+ from .modeling_wuerstchen_prior import WuerstchenPrior
+ from .pipeline_wuerstchen import WuerstchenDecoderPipeline
+ from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline
+ from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline
diff --git a/src/diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py b/src/diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py
new file mode 100644
index 000000000000..09bdd16592df
--- /dev/null
+++ b/src/diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py
@@ -0,0 +1,172 @@
+# Copyright (c) 2022 Dominic Rampas MIT License
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Union
+
+import torch
+import torch.nn as nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...models.modeling_utils import ModelMixin
+from ...models.vae import DecoderOutput, VectorQuantizer
+from ...models.vq_model import VQEncoderOutput
+from ...utils import apply_forward_hook
+
+
+class MixingResidualBlock(nn.Module):
+ """
+ Residual block with mixing used by Paella's VQ-VAE.
+ """
+
+ def __init__(self, inp_channels, embed_dim):
+ super().__init__()
+ # depthwise
+ self.norm1 = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6)
+ self.depthwise = nn.Sequential(
+ nn.ReplicationPad2d(1), nn.Conv2d(inp_channels, inp_channels, kernel_size=3, groups=inp_channels)
+ )
+
+ # channelwise
+ self.norm2 = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6)
+ self.channelwise = nn.Sequential(
+ nn.Linear(inp_channels, embed_dim), nn.GELU(), nn.Linear(embed_dim, inp_channels)
+ )
+
+ self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
+
+ def forward(self, x):
+ mods = self.gammas
+ x_temp = self.norm1(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * (1 + mods[0]) + mods[1]
+ x = x + self.depthwise(x_temp) * mods[2]
+ x_temp = self.norm2(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * (1 + mods[3]) + mods[4]
+ x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
+ return x
+
+
+class PaellaVQModel(ModelMixin, ConfigMixin):
+ r"""VQ-VAE model from Paella model.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the model (such as downloading or saving, etc.)
+
+ Parameters:
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
+ up_down_scale_factor (int, *optional*, defaults to 2): Up and Downscale factor of the input image.
+ levels (int, *optional*, defaults to 2): Number of levels in the model.
+ bottleneck_blocks (int, *optional*, defaults to 12): Number of bottleneck blocks in the model.
+ embed_dim (int, *optional*, defaults to 384): Number of hidden channels in the model.
+ latent_channels (int, *optional*, defaults to 4): Number of latent channels in the VQ-VAE model.
+ num_vq_embeddings (int, *optional*, defaults to 8192): Number of codebook vectors in the VQ-VAE.
+ scale_factor (float, *optional*, defaults to 0.3764): Scaling factor of the latent space.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ up_down_scale_factor: int = 2,
+ levels: int = 2,
+ bottleneck_blocks: int = 12,
+ embed_dim: int = 384,
+ latent_channels: int = 4,
+ num_vq_embeddings: int = 8192,
+ scale_factor: float = 0.3764,
+ ):
+ super().__init__()
+
+ c_levels = [embed_dim // (2**i) for i in reversed(range(levels))]
+ # Encoder blocks
+ self.in_block = nn.Sequential(
+ nn.PixelUnshuffle(up_down_scale_factor),
+ nn.Conv2d(in_channels * up_down_scale_factor**2, c_levels[0], kernel_size=1),
+ )
+ down_blocks = []
+ for i in range(levels):
+ if i > 0:
+ down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
+ block = MixingResidualBlock(c_levels[i], c_levels[i] * 4)
+ down_blocks.append(block)
+ down_blocks.append(
+ nn.Sequential(
+ nn.Conv2d(c_levels[-1], latent_channels, kernel_size=1, bias=False),
+ nn.BatchNorm2d(latent_channels), # then normalize them to have mean 0 and std 1
+ )
+ )
+ self.down_blocks = nn.Sequential(*down_blocks)
+
+ # Vector Quantizer
+ self.vquantizer = VectorQuantizer(num_vq_embeddings, vq_embed_dim=latent_channels, legacy=False, beta=0.25)
+
+ # Decoder blocks
+ up_blocks = [nn.Sequential(nn.Conv2d(latent_channels, c_levels[-1], kernel_size=1))]
+ for i in range(levels):
+ for j in range(bottleneck_blocks if i == 0 else 1):
+ block = MixingResidualBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
+ up_blocks.append(block)
+ if i < levels - 1:
+ up_blocks.append(
+ nn.ConvTranspose2d(
+ c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, padding=1
+ )
+ )
+ self.up_blocks = nn.Sequential(*up_blocks)
+ self.out_block = nn.Sequential(
+ nn.Conv2d(c_levels[0], out_channels * up_down_scale_factor**2, kernel_size=1),
+ nn.PixelShuffle(up_down_scale_factor),
+ )
+
+ @apply_forward_hook
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
+ h = self.in_block(x)
+ h = self.down_blocks(h)
+
+ if not return_dict:
+ return (h,)
+
+ return VQEncoderOutput(latents=h)
+
+ @apply_forward_hook
+ def decode(
+ self, h: torch.FloatTensor, force_not_quantize: bool = True, return_dict: bool = True
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
+ if not force_not_quantize:
+ quant, _, _ = self.vquantizer(h)
+ else:
+ quant = h
+
+ x = self.up_blocks(quant)
+ dec = self.out_block(x)
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): Input sample.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ h = self.encode(x).latents
+ dec = self.decode(h).sample
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py
new file mode 100644
index 000000000000..b3aac39386bc
--- /dev/null
+++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py
@@ -0,0 +1,88 @@
+# Copyright (c) 2023 Dominic Rampas MIT License
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+import torch.nn as nn
+
+from ...models.attention_processor import Attention
+
+
+class WuerstchenLayerNorm(nn.LayerNorm):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, x):
+ x = x.permute(0, 2, 3, 1)
+ x = super().forward(x)
+ return x.permute(0, 3, 1, 2)
+
+
+class TimestepBlock(nn.Module):
+ def __init__(self, c, c_timestep):
+ super().__init__()
+ self.mapper = nn.Linear(c_timestep, c * 2)
+
+ def forward(self, x, t):
+ a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1)
+ return x * (1 + a) + b
+
+
+class ResBlock(nn.Module):
+ def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
+ super().__init__()
+ self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
+ self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
+ self.channelwise = nn.Sequential(
+ nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c)
+ )
+
+ def forward(self, x, x_skip=None):
+ x_res = x
+ if x_skip is not None:
+ x = torch.cat([x, x_skip], dim=1)
+ x = self.norm(self.depthwise(x)).permute(0, 2, 3, 1)
+ x = self.channelwise(x).permute(0, 3, 1, 2)
+ return x + x_res
+
+
+# from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
+class GlobalResponseNorm(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
+
+ def forward(self, x):
+ agg_norm = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
+ stand_div_norm = agg_norm / (agg_norm.mean(dim=-1, keepdim=True) + 1e-6)
+ return self.gamma * (x * stand_div_norm) + self.beta + x
+
+
+class AttnBlock(nn.Module):
+ def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
+ super().__init__()
+ self.self_attn = self_attn
+ self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
+ self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
+ self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c))
+
+ def forward(self, x, kv):
+ kv = self.kv_mapper(kv)
+ norm_x = self.norm(x)
+ if self.self_attn:
+ batch_size, channel, _, _ = x.shape
+ kv = torch.cat([norm_x.view(batch_size, channel, -1).transpose(1, 2), kv], dim=1)
+ x = x + self.attention(norm_x, encoder_hidden_states=kv)
+ return x
diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py
new file mode 100644
index 000000000000..d22eb7b7c991
--- /dev/null
+++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py
@@ -0,0 +1,254 @@
+# Copyright (c) 2023 Dominic Rampas MIT License
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...models.modeling_utils import ModelMixin
+from .modeling_wuerstchen_common import AttnBlock, GlobalResponseNorm, TimestepBlock, WuerstchenLayerNorm
+
+
+class WuerstchenDiffNeXt(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ c_in=4,
+ c_out=4,
+ c_r=64,
+ patch_size=2,
+ c_cond=1024,
+ c_hidden=[320, 640, 1280, 1280],
+ nhead=[-1, 10, 20, 20],
+ blocks=[4, 4, 14, 4],
+ level_config=["CT", "CTA", "CTA", "CTA"],
+ inject_effnet=[False, True, True, True],
+ effnet_embd=16,
+ clip_embd=1024,
+ kernel_size=3,
+ dropout=0.1,
+ ):
+ super().__init__()
+ self.c_r = c_r
+ self.c_cond = c_cond
+ if not isinstance(dropout, list):
+ dropout = [dropout] * len(c_hidden)
+
+ # CONDITIONING
+ self.clip_mapper = nn.Linear(clip_embd, c_cond)
+ self.effnet_mappers = nn.ModuleList(
+ [
+ nn.Conv2d(effnet_embd, c_cond, kernel_size=1) if inject else None
+ for inject in inject_effnet + list(reversed(inject_effnet))
+ ]
+ )
+ self.seq_norm = nn.LayerNorm(c_cond, elementwise_affine=False, eps=1e-6)
+
+ self.embedding = nn.Sequential(
+ nn.PixelUnshuffle(patch_size),
+ nn.Conv2d(c_in * (patch_size**2), c_hidden[0], kernel_size=1),
+ WuerstchenLayerNorm(c_hidden[0], elementwise_affine=False, eps=1e-6),
+ )
+
+ def get_block(block_type, c_hidden, nhead, c_skip=0, dropout=0):
+ if block_type == "C":
+ return ResBlockStageB(c_hidden, c_skip, kernel_size=kernel_size, dropout=dropout)
+ elif block_type == "A":
+ return AttnBlock(c_hidden, c_cond, nhead, self_attn=True, dropout=dropout)
+ elif block_type == "T":
+ return TimestepBlock(c_hidden, c_r)
+ else:
+ raise ValueError(f"Block type {block_type} not supported")
+
+ # BLOCKS
+ # -- down blocks
+ self.down_blocks = nn.ModuleList()
+ for i in range(len(c_hidden)):
+ down_block = nn.ModuleList()
+ if i > 0:
+ down_block.append(
+ nn.Sequential(
+ WuerstchenLayerNorm(c_hidden[i - 1], elementwise_affine=False, eps=1e-6),
+ nn.Conv2d(c_hidden[i - 1], c_hidden[i], kernel_size=2, stride=2),
+ )
+ )
+ for _ in range(blocks[i]):
+ for block_type in level_config[i]:
+ c_skip = c_cond if inject_effnet[i] else 0
+ down_block.append(get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i]))
+ self.down_blocks.append(down_block)
+
+ # -- up blocks
+ self.up_blocks = nn.ModuleList()
+ for i in reversed(range(len(c_hidden))):
+ up_block = nn.ModuleList()
+ for j in range(blocks[i]):
+ for k, block_type in enumerate(level_config[i]):
+ c_skip = c_hidden[i] if i < len(c_hidden) - 1 and j == k == 0 else 0
+ c_skip += c_cond if inject_effnet[i] else 0
+ up_block.append(get_block(block_type, c_hidden[i], nhead[i], c_skip=c_skip, dropout=dropout[i]))
+ if i > 0:
+ up_block.append(
+ nn.Sequential(
+ WuerstchenLayerNorm(c_hidden[i], elementwise_affine=False, eps=1e-6),
+ nn.ConvTranspose2d(c_hidden[i], c_hidden[i - 1], kernel_size=2, stride=2),
+ )
+ )
+ self.up_blocks.append(up_block)
+
+ # OUTPUT
+ self.clf = nn.Sequential(
+ WuerstchenLayerNorm(c_hidden[0], elementwise_affine=False, eps=1e-6),
+ nn.Conv2d(c_hidden[0], 2 * c_out * (patch_size**2), kernel_size=1),
+ nn.PixelShuffle(patch_size),
+ )
+
+ # --- WEIGHT INIT ---
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ # General init
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
+ nn.init.xavier_uniform_(m.weight)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ for mapper in self.effnet_mappers:
+ if mapper is not None:
+ nn.init.normal_(mapper.weight, std=0.02) # conditionings
+ nn.init.normal_(self.clip_mapper.weight, std=0.02) # conditionings
+ nn.init.xavier_uniform_(self.embedding[1].weight, 0.02) # inputs
+ nn.init.constant_(self.clf[1].weight, 0) # outputs
+
+ # blocks
+ for level_block in self.down_blocks + self.up_blocks:
+ for block in level_block:
+ if isinstance(block, ResBlockStageB):
+ block.channelwise[-1].weight.data *= np.sqrt(1 / sum(self.config.blocks))
+ elif isinstance(block, TimestepBlock):
+ nn.init.constant_(block.mapper.weight, 0)
+
+ def gen_r_embedding(self, r, max_positions=10000):
+ r = r * max_positions
+ half_dim = self.c_r // 2
+ emb = math.log(max_positions) / (half_dim - 1)
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
+ emb = r[:, None] * emb[None, :]
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
+ if self.c_r % 2 == 1: # zero pad
+ emb = nn.functional.pad(emb, (0, 1), mode="constant")
+ return emb.to(dtype=r.dtype)
+
+ def gen_c_embeddings(self, clip):
+ clip = self.clip_mapper(clip)
+ clip = self.seq_norm(clip)
+ return clip
+
+ def _down_encode(self, x, r_embed, effnet, clip=None):
+ level_outputs = []
+ for i, down_block in enumerate(self.down_blocks):
+ effnet_c = None
+ for block in down_block:
+ if isinstance(block, ResBlockStageB):
+ if effnet_c is None and self.effnet_mappers[i] is not None:
+ dtype = effnet.dtype
+ effnet_c = self.effnet_mappers[i](
+ nn.functional.interpolate(
+ effnet.float(), size=x.shape[-2:], mode="bicubic", antialias=True, align_corners=True
+ ).to(dtype)
+ )
+ skip = effnet_c if self.effnet_mappers[i] is not None else None
+ x = block(x, skip)
+ elif isinstance(block, AttnBlock):
+ x = block(x, clip)
+ elif isinstance(block, TimestepBlock):
+ x = block(x, r_embed)
+ else:
+ x = block(x)
+ level_outputs.insert(0, x)
+ return level_outputs
+
+ def _up_decode(self, level_outputs, r_embed, effnet, clip=None):
+ x = level_outputs[0]
+ for i, up_block in enumerate(self.up_blocks):
+ effnet_c = None
+ for j, block in enumerate(up_block):
+ if isinstance(block, ResBlockStageB):
+ if effnet_c is None and self.effnet_mappers[len(self.down_blocks) + i] is not None:
+ dtype = effnet.dtype
+ effnet_c = self.effnet_mappers[len(self.down_blocks) + i](
+ nn.functional.interpolate(
+ effnet.float(), size=x.shape[-2:], mode="bicubic", antialias=True, align_corners=True
+ ).to(dtype)
+ )
+ skip = level_outputs[i] if j == 0 and i > 0 else None
+ if effnet_c is not None:
+ if skip is not None:
+ skip = torch.cat([skip, effnet_c], dim=1)
+ else:
+ skip = effnet_c
+ x = block(x, skip)
+ elif isinstance(block, AttnBlock):
+ x = block(x, clip)
+ elif isinstance(block, TimestepBlock):
+ x = block(x, r_embed)
+ else:
+ x = block(x)
+ return x
+
+ def forward(self, x, r, effnet, clip=None, x_cat=None, eps=1e-3, return_noise=True):
+ if x_cat is not None:
+ x = torch.cat([x, x_cat], dim=1)
+ # Process the conditioning embeddings
+ r_embed = self.gen_r_embedding(r)
+ if clip is not None:
+ clip = self.gen_c_embeddings(clip)
+
+ # Model Blocks
+ x_in = x
+ x = self.embedding(x)
+ level_outputs = self._down_encode(x, r_embed, effnet, clip)
+ x = self._up_decode(level_outputs, r_embed, effnet, clip)
+ a, b = self.clf(x).chunk(2, dim=1)
+ b = b.sigmoid() * (1 - eps * 2) + eps
+ if return_noise:
+ return (x_in - a) / b
+ else:
+ return a, b
+
+
+class ResBlockStageB(nn.Module):
+ def __init__(self, c, c_skip=None, kernel_size=3, dropout=0.0):
+ super().__init__()
+ self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
+ self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
+ self.channelwise = nn.Sequential(
+ nn.Linear(c + c_skip, c * 4),
+ nn.GELU(),
+ GlobalResponseNorm(c * 4),
+ nn.Dropout(dropout),
+ nn.Linear(c * 4, c),
+ )
+
+ def forward(self, x, x_skip=None):
+ x_res = x
+ x = self.norm(self.depthwise(x))
+ if x_skip is not None:
+ x = torch.cat([x, x_skip], dim=1)
+ x = self.channelwise(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+ return x + x_res
diff --git a/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py
new file mode 100644
index 000000000000..9bd29b59b3af
--- /dev/null
+++ b/src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py
@@ -0,0 +1,72 @@
+# Copyright (c) 2023 Dominic Rampas MIT License
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+
+import torch
+import torch.nn as nn
+
+from ...configuration_utils import ConfigMixin, register_to_config
+from ...models.modeling_utils import ModelMixin
+from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm
+
+
+class WuerstchenPrior(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1):
+ super().__init__()
+ self.c_r = c_r
+ self.projection = nn.Conv2d(c_in, c, kernel_size=1)
+ self.cond_mapper = nn.Sequential(
+ nn.Linear(c_cond, c),
+ nn.LeakyReLU(0.2),
+ nn.Linear(c, c),
+ )
+
+ self.blocks = nn.ModuleList()
+ for _ in range(depth):
+ self.blocks.append(ResBlock(c, dropout=dropout))
+ self.blocks.append(TimestepBlock(c, c_r))
+ self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout))
+ self.out = nn.Sequential(
+ WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6),
+ nn.Conv2d(c, c_in * 2, kernel_size=1),
+ )
+
+ def gen_r_embedding(self, r, max_positions=10000):
+ r = r * max_positions
+ half_dim = self.c_r // 2
+ emb = math.log(max_positions) / (half_dim - 1)
+ emb = torch.arange(half_dim, device=r.device).float().mul(-emb).exp()
+ emb = r[:, None] * emb[None, :]
+ emb = torch.cat([emb.sin(), emb.cos()], dim=1)
+ if self.c_r % 2 == 1: # zero pad
+ emb = nn.functional.pad(emb, (0, 1), mode="constant")
+ return emb.to(dtype=r.dtype)
+
+ def forward(self, x, r, c):
+ x_in = x
+ x = self.projection(x)
+ c_embed = self.cond_mapper(c)
+ r_embed = self.gen_r_embedding(r)
+ for block in self.blocks:
+ if isinstance(block, AttnBlock):
+ x = block(x, c_embed)
+ elif isinstance(block, TimestepBlock):
+ x = block(x, r_embed)
+ else:
+ x = block(x)
+ a, b = self.out(x).chunk(2, dim=1)
+ return (x_in - a) / ((1 - b).abs() + 1e-5)
diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
new file mode 100644
index 000000000000..78aeebed7943
--- /dev/null
+++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py
@@ -0,0 +1,399 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from ...schedulers import DDPMWuerstchenScheduler
+from ...utils import is_accelerate_available, is_accelerate_version, logging, randn_tensor, replace_example_docstring
+from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
+from .modeling_paella_vq_model import PaellaVQModel
+from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import WuerstchenPriorPipeline, WuerstchenDecoderPipeline
+
+ >>> prior_pipe = WuerstchenPriorPipeline.from_pretrained(
+ ... "warp-diffusion/wuerstchen-prior", torch_dtype=torch.float16
+ ... ).to("cuda")
+ >>> gen_pipe = WuerstchenDecoderPipeline.from_pretrain(
+ ... "warp-diffusion/wuerstchen", torch_dtype=torch.float16
+ ... ).to("cuda")
+
+ >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
+ >>> prior_output = pipe(prompt)
+ >>> images = gen_pipe(prior_output.image_embeddings, prompt=prompt)
+ ```
+"""
+
+
+class WuerstchenDecoderPipeline(DiffusionPipeline):
+ """
+ Pipeline for generating images from the Wuerstchen model.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ tokenizer (`CLIPTokenizer`):
+ The CLIP tokenizer.
+ text_encoder (`CLIPTextModel`):
+ The CLIP text encoder.
+ decoder ([`WuerstchenDiffNeXt`]):
+ The WuerstchenDiffNeXt unet decoder.
+ vqgan ([`PaellaVQModel`]):
+ The VQGAN model.
+ scheduler ([`DDPMWuerstchenScheduler`]):
+ A scheduler to be used in combination with `prior` to generate image embedding.
+ latent_dim_scale (float, `optional`, defaults to 10.67):
+ Multiplier to determine the VQ latent space size from the image embeddings. If the image embeddings are
+ height=24 and width=24, the VQ latent shape needs to be height=int(24*10.67)=256 and
+ width=int(24*10.67)=256 in order to match the training conditions.
+ """
+
+ def __init__(
+ self,
+ tokenizer: CLIPTokenizer,
+ text_encoder: CLIPTextModel,
+ decoder: WuerstchenDiffNeXt,
+ scheduler: DDPMWuerstchenScheduler,
+ vqgan: PaellaVQModel,
+ latent_dim_scale: float = 10.67,
+ ) -> None:
+ super().__init__()
+ self.register_modules(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ decoder=decoder,
+ scheduler=scheduler,
+ vqgan=vqgan,
+ )
+ self.register_to_config(latent_dim_scale=latent_dim_scale)
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ latents = latents * scheduler.init_noise_sigma
+ return latents
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.decoder]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.prior_hook = hook
+
+ _, hook = cpu_offload_with_hook(self.vqgan, device, prev_module_hook=self.prior_hook)
+
+ self.final_offload_hook = hook
+
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ ):
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ attention_mask = text_inputs.attention_mask
+
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+ attention_mask = attention_mask[:, : self.tokenizer.model_max_length]
+
+ text_encoder_output = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask.to(device))
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
+ text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+
+ uncond_text_encoder_hidden_states = None
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ negative_prompt_embeds_text_encoder_output = self.text_encoder(
+ uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device)
+ )
+
+ uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_text_encoder_hidden_states.shape[1]
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+ # done duplicates
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ return text_encoder_hidden_states, uncond_text_encoder_hidden_states
+
+ def check_inputs(
+ self,
+ image_embeddings,
+ prompt,
+ negative_prompt,
+ num_inference_steps,
+ do_classifier_free_guidance,
+ device,
+ dtype,
+ ):
+ if not isinstance(prompt, list):
+ if isinstance(prompt, str):
+ prompt = [prompt]
+ else:
+ raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.")
+
+ if do_classifier_free_guidance:
+ if negative_prompt is not None and not isinstance(negative_prompt, list):
+ if isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt]
+ else:
+ raise TypeError(
+ f"'negative_prompt' must be of type 'list' or 'str', but got {type(negative_prompt)}."
+ )
+
+ if isinstance(image_embeddings, list):
+ image_embeddings = torch.cat(image_embeddings, dim=0)
+ if isinstance(image_embeddings, np.ndarray):
+ image_embeddings = torch.Tensor(image_embeddings, device=device).to(dtype=dtype)
+ if not isinstance(image_embeddings, torch.Tensor):
+ raise TypeError(
+ f"'image_embeddings' must be of type 'torch.Tensor' or 'np.array', but got {type(image_embeddings)}."
+ )
+
+ if not isinstance(num_inference_steps, int):
+ raise TypeError(
+ f"'num_inference_steps' must be of type 'int', but got {type(num_inference_steps)}\
+ In Case you want to provide explicit timesteps, please use the 'timesteps' argument."
+ )
+
+ return image_embeddings, prompt, negative_prompt, num_inference_steps
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ image_embeddings: Union[torch.FloatTensor, List[torch.FloatTensor]],
+ prompt: Union[str, List[str]] = None,
+ num_inference_steps: int = 12,
+ timesteps: Optional[List[float]] = None,
+ guidance_scale: float = 0.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: int = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ image_embedding (`torch.FloatTensor` or `List[torch.FloatTensor]`):
+ Image Embeddings either extracted from an image or generated by a Prior Model.
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ num_inference_steps (`int`, *optional*, defaults to 30):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
+ timesteps are used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
+ `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely
+ linked to the text `prompt`, usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `decoder_guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
+ (`np.array`) or `"pt"` (`torch.Tensor`).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True,
+ otherwise a `tuple`. When returning a tuple, the first element is a list with the generated image
+ embeddings.
+ """
+
+ # 0. Define commonly used variables
+ device = self._execution_device
+ dtype = self.decoder.dtype
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 1. Check inputs. Raise error if not correct
+ image_embeddings, prompt, negative_prompt, num_inference_steps = self.check_inputs(
+ image_embeddings, prompt, negative_prompt, num_inference_steps, do_classifier_free_guidance, device, dtype
+ )
+
+ # 2. Encode caption
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+ text_encoder_hidden_states = (
+ torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds
+ )
+
+ # 3. Determine latent shape of latents
+ latent_height = int(image_embeddings.size(2) * self.config.latent_dim_scale)
+ latent_width = int(image_embeddings.size(3) * self.config.latent_dim_scale)
+ latent_features_shape = (image_embeddings.size(0) * num_images_per_prompt, 4, latent_height, latent_width)
+
+ # 4. Prepare and set timesteps
+ if timesteps is not None:
+ self.scheduler.set_timesteps(timesteps=timesteps, device=device)
+ timesteps = self.scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latents
+ latents = self.prepare_latents(latent_features_shape, dtype, device, generator, latents, self.scheduler)
+
+ # 6. Run denoising loop
+ for t in self.progress_bar(timesteps[:-1]):
+ ratio = t.expand(latents.size(0)).to(dtype)
+ effnet = (
+ torch.cat([image_embeddings, torch.zeros_like(image_embeddings)])
+ if do_classifier_free_guidance
+ else image_embeddings
+ )
+ # 7. Denoise latents
+ predicted_latents = self.decoder(
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents,
+ r=torch.cat([ratio] * 2) if do_classifier_free_guidance else ratio,
+ effnet=effnet,
+ clip=text_encoder_hidden_states,
+ )
+
+ # 8. Check for classifier free guidance and apply it
+ if do_classifier_free_guidance:
+ predicted_latents_text, predicted_latents_uncond = predicted_latents.chunk(2)
+ predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, guidance_scale)
+
+ # 9. Renoise latents to next timestep
+ latents = self.scheduler.step(
+ model_output=predicted_latents,
+ timestep=ratio,
+ sample=latents,
+ generator=generator,
+ ).prev_sample
+
+ # 10. Scale and decode the image latents with vq-vae
+ latents = self.vqgan.config.scale_factor * latents
+ images = self.vqgan.decode(latents).sample.clamp(0, 1)
+
+ if output_type not in ["pt", "np", "pil"]:
+ raise ValueError(f"Only the output types `pt`, `np` and `pil` are supported not output_type={output_type}")
+
+ if output_type == "np":
+ images = images.permute(0, 2, 3, 1).cpu().numpy()
+ elif output_type == "pil":
+ images = images.permute(0, 2, 3, 1).cpu().numpy()
+ images = self.numpy_to_pil(images)
+
+ if not return_dict:
+ return images
+ return ImagePipelineOutput(images)
diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py
new file mode 100644
index 000000000000..ff4c31686bf5
--- /dev/null
+++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py
@@ -0,0 +1,248 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Callable, List, Optional, Union
+
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from ...schedulers import DDPMWuerstchenScheduler
+from ...utils import replace_example_docstring
+from ..pipeline_utils import DiffusionPipeline
+from .modeling_paella_vq_model import PaellaVQModel
+from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
+from .modeling_wuerstchen_prior import WuerstchenPrior
+from .pipeline_wuerstchen import WuerstchenDecoderPipeline
+from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline
+
+
+TEXT2IMAGE_EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> from diffusions import WuerstchenCombinedPipeline
+
+ >>> pipe = WuerstchenCombinedPipeline.from_pretrained(
+ ... "warp-diffusion/Wuerstchen", torch_dtype=torch.float16
+ ... ).to("cuda")
+ >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
+ >>> images = pipe(prompt=prompt)
+ ```
+"""
+
+
+class WuerstchenCombinedPipeline(DiffusionPipeline):
+ """
+ Combined Pipeline for text-to-image generation using Wuerstchen
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ tokenizer (`CLIPTokenizer`):
+ The decoder tokenizer to be used for text inputs.
+ text_encoder (`CLIPTextModel`):
+ The decoder text encoder to be used for text inputs.
+ decoder (`WuerstchenDiffNeXt`):
+ The decoder model to be used for decoder image generation pipeline.
+ scheduler (`DDPMWuerstchenScheduler`):
+ The scheduler to be used for decoder image generation pipeline.
+ vqgan (`PaellaVQModel`):
+ The VQGAN model to be used for decoder image generation pipeline.
+ prior_tokenizer (`CLIPTokenizer`):
+ The prior tokenizer to be used for text inputs.
+ prior_text_encoder (`CLIPTextModel`):
+ The prior text encoder to be used for text inputs.
+ prior (`WuerstchenPrior`):
+ The prior model to be used for prior pipeline.
+ prior_scheduler (`DDPMWuerstchenScheduler`):
+ The scheduler to be used for prior pipeline.
+ """
+
+ _load_connected_pipes = True
+
+ def __init__(
+ self,
+ tokenizer: CLIPTokenizer,
+ text_encoder: CLIPTextModel,
+ decoder: WuerstchenDiffNeXt,
+ scheduler: DDPMWuerstchenScheduler,
+ vqgan: PaellaVQModel,
+ prior_tokenizer: CLIPTokenizer,
+ prior_text_encoder: CLIPTextModel,
+ prior_prior: WuerstchenPrior,
+ prior_scheduler: DDPMWuerstchenScheduler,
+ ):
+ super().__init__()
+
+ self.register_modules(
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ decoder=decoder,
+ scheduler=scheduler,
+ vqgan=vqgan,
+ prior_prior=prior_prior,
+ prior_text_encoder=prior_text_encoder,
+ prior_tokenizer=prior_tokenizer,
+ prior_scheduler=prior_scheduler,
+ )
+ self.prior_pipe = WuerstchenPriorPipeline(
+ prior=prior_prior,
+ text_encoder=prior_text_encoder,
+ tokenizer=prior_tokenizer,
+ scheduler=prior_scheduler,
+ )
+ self.decoder_pipe = WuerstchenDecoderPipeline(
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ decoder=decoder,
+ scheduler=scheduler,
+ vqgan=vqgan,
+ )
+
+ def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
+ self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ self.prior_pipe.enable_model_cpu_offload()
+ self.decoder_pipe.enable_model_cpu_offload()
+
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
+ Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
+ GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis.
+ Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower.
+ """
+ self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
+ self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
+
+ def progress_bar(self, iterable=None, total=None):
+ self.prior_pipe.progress_bar(iterable=iterable, total=total)
+ self.decoder_pipe.progress_bar(iterable=iterable, total=total)
+
+ def set_progress_bar_config(self, **kwargs):
+ self.prior_pipe.set_progress_bar_config(**kwargs)
+ self.decoder_pipe.set_progress_bar_config(**kwargs)
+
+ @torch.no_grad()
+ @replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]],
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ guidance_scale: float = 4.0,
+ num_images_per_prompt: int = 1,
+ height: int = 512,
+ width: int = 512,
+ prior_guidance_scale: float = 4.0,
+ prior_num_inference_steps: int = 60,
+ num_inference_steps: int = 12,
+ prior_timesteps: Optional[List[float]] = None,
+ timesteps: Optional[List[float]] = None,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ prior_guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `prior_guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
+ `prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked
+ to the text `prompt`, usually at the expense of lower image quality.
+ prior_num_inference_steps (`Union[int, Dict[float, int]]`, *optional*, defaults to 30):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. For more specific timestep spacing, you can pass customized
+ `prior_timesteps`
+ num_inference_steps (`int`, *optional*, defaults to 12):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. For more specific timestep spacing, you can pass customized `timesteps`
+ prior_timesteps (`List[float]`, *optional*):
+ Custom timesteps to use for the denoising process for the prior. If not defined, equal spaced
+ `prior_num_inference_steps` timesteps are used. Must be in descending order.
+ timesteps (`List[float]`, *optional*):
+ Custom timesteps to use for the denoising process for the decoder. If not defined, equal spaced
+ `decoder_num_inference_steps` timesteps are used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
+ (`np.array`) or `"pt"` (`torch.Tensor`).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True,
+ otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
+ """
+ prior_outputs = self.prior_pipe(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ width=width,
+ height=height,
+ num_images_per_prompt=num_images_per_prompt,
+ num_inference_steps=prior_num_inference_steps,
+ timesteps=prior_timesteps,
+ generator=generator,
+ latents=latents,
+ guidance_scale=prior_guidance_scale,
+ output_type="pt",
+ return_dict=False,
+ )
+ image_embeddings = prior_outputs[0]
+
+ outputs = self.decoder_pipe(
+ prompt=prompt,
+ image_embeddings=image_embeddings,
+ num_inference_steps=num_inference_steps,
+ timesteps=timesteps,
+ generator=generator,
+ guidance_scale=guidance_scale,
+ output_type=output_type,
+ return_dict=return_dict,
+ )
+ return outputs
diff --git a/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py
new file mode 100644
index 000000000000..8b13d8fdf2b7
--- /dev/null
+++ b/src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py
@@ -0,0 +1,402 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from dataclasses import dataclass
+from math import ceil
+from typing import List, Optional, Union
+
+import numpy as np
+import torch
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from ...schedulers import DDPMWuerstchenScheduler
+from ...utils import (
+ BaseOutput,
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ randn_tensor,
+ replace_example_docstring,
+)
+from ..pipeline_utils import DiffusionPipeline
+from .modeling_wuerstchen_prior import WuerstchenPrior
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import WuerstchenPriorPipeline
+
+ >>> prior_pipe = WuerstchenPriorPipeline.from_pretrained(
+ ... "warp-diffusion/wuerstchen-prior", torch_dtype=torch.float16
+ ... ).to("cuda")
+
+ >>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
+ >>> prior_output = pipe(prompt)
+ ```
+"""
+
+
+@dataclass
+class WuerstchenPriorPipelineOutput(BaseOutput):
+ """
+ Output class for WuerstchenPriorPipeline.
+
+ Args:
+ image_embeddings (`torch.FloatTensor` or `np.ndarray`)
+ Prior image embeddings for text prompt
+
+ """
+
+ image_embeddings: Union[torch.FloatTensor, np.ndarray]
+
+
+class WuerstchenPriorPipeline(DiffusionPipeline):
+ """
+ Pipeline for generating image prior for Wuerstchen.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ prior ([`Prior`]):
+ The canonical unCLIP prior to approximate the image embedding from the text embedding.
+ text_encoder ([`CLIPTextModelWithProjection`]):
+ Frozen text-encoder.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ scheduler ([`DDPMWuerstchenScheduler`]):
+ A scheduler to be used in combination with `prior` to generate image embedding.
+ """
+
+ def __init__(
+ self,
+ tokenizer: CLIPTokenizer,
+ text_encoder: CLIPTextModel,
+ prior: WuerstchenPrior,
+ scheduler: DDPMWuerstchenScheduler,
+ latent_mean: float = 42.0,
+ latent_std: float = 1.0,
+ resolution_multiple: float = 42.67,
+ ) -> None:
+ super().__init__()
+ self.register_modules(
+ tokenizer=tokenizer,
+ text_encoder=text_encoder,
+ prior=prior,
+ scheduler=scheduler,
+ )
+ self.register_to_config(
+ latent_mean=latent_mean, latent_std=latent_std, resolution_multiple=resolution_multiple
+ )
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.prior_hook = hook
+
+ _, hook = cpu_offload_with_hook(self.prior, device, prev_module_hook=self.prior_hook)
+
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.unclip.pipeline_unclip.UnCLIPPipeline.prepare_latents
+ def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ latents = latents * scheduler.init_noise_sigma
+ return latents
+
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ ):
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+ # get prompt text embeddings
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ attention_mask = text_inputs.attention_mask
+
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+ text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
+ attention_mask = attention_mask[:, : self.tokenizer.model_max_length]
+
+ text_encoder_output = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask.to(device))
+ text_encoder_hidden_states = text_encoder_output.last_hidden_state
+ text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+
+ uncond_text_encoder_hidden_states = None
+ if do_classifier_free_guidance:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ negative_prompt_embeds_text_encoder_output = self.text_encoder(
+ uncond_input.input_ids.to(device), attention_mask=uncond_input.attention_mask.to(device)
+ )
+
+ uncond_text_encoder_hidden_states = negative_prompt_embeds_text_encoder_output.last_hidden_state
+
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = uncond_text_encoder_hidden_states.shape[1]
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
+ uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+ # done duplicates
+
+ return text_encoder_hidden_states, uncond_text_encoder_hidden_states
+
+ def check_inputs(
+ self,
+ prompt,
+ negative_prompt,
+ num_inference_steps,
+ do_classifier_free_guidance,
+ batch_size,
+ ):
+ if not isinstance(prompt, list):
+ if isinstance(prompt, str):
+ prompt = [prompt]
+ else:
+ raise TypeError(f"'prompt' must be of type 'list' or 'str', but got {type(prompt)}.")
+
+ if do_classifier_free_guidance:
+ if negative_prompt is not None and not isinstance(negative_prompt, list):
+ if isinstance(negative_prompt, str):
+ negative_prompt = [negative_prompt]
+ else:
+ raise TypeError(
+ f"'negative_prompt' must be of type 'list' or 'str', but got {type(negative_prompt)}."
+ )
+
+ if not isinstance(num_inference_steps, int):
+ raise TypeError(
+ f"'num_inference_steps' must be of type 'int', but got {type(num_inference_steps)}\
+ In Case you want to provide explicit timesteps, please use the 'timesteps' argument."
+ )
+
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ return prompt, negative_prompt, num_inference_steps, batch_size
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: int = 1024,
+ width: int = 1024,
+ num_inference_steps: int = 30,
+ timesteps: List[float] = None,
+ guidance_scale: float = 8.0,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pt",
+ return_dict: bool = True,
+ ):
+ """
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`):
+ The prompt or prompts to guide the image generation.
+ height (`int`, *optional*, defaults to 512):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to 512):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 30):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps`
+ timesteps are used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 4.0):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `decoder_guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
+ `decoder_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely
+ linked to the text `prompt`, usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
+ if `decoder_guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
+ (`np.array`) or `"pt"` (`torch.Tensor`).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.WuerstchenPriorPipelineOutput`] or `tuple` [`~pipelines.WuerstchenPriorPipelineOutput`] if
+ `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
+ generated image embeddings.
+ """
+
+ # 0. Define commonly used variables
+ device = self._execution_device
+ do_classifier_free_guidance = guidance_scale > 1.0
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
+
+ # 1. Check inputs. Raise error if not correct
+ prompt, negative_prompt, num_inference_steps, batch_size = self.check_inputs(
+ prompt, negative_prompt, num_inference_steps, do_classifier_free_guidance, batch_size
+ )
+
+ # 2. Encode caption
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
+ )
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ text_encoder_hidden_states = (
+ torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds
+ )
+
+ # 3. Determine latent shape of image embeddings
+ dtype = text_encoder_hidden_states.dtype
+ latent_height = ceil(height / self.config.resolution_multiple)
+ latent_width = ceil(width / self.config.resolution_multiple)
+ num_channels = self.prior.config.c_in
+ effnet_features_shape = (num_images_per_prompt * batch_size, num_channels, latent_height, latent_width)
+
+ # 4. Prepare and set timesteps
+ if timesteps is not None:
+ self.scheduler.set_timesteps(timesteps=timesteps, device=device)
+ timesteps = self.scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latents
+ latents = self.prepare_latents(effnet_features_shape, dtype, device, generator, latents, self.scheduler)
+
+ # 6. Run denoising loop
+ for t in self.progress_bar(timesteps[:-1]):
+ ratio = t.expand(latents.size(0)).to(dtype)
+
+ # 7. Denoise image embeddings
+ predicted_image_embedding = self.prior(
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents,
+ r=torch.cat([ratio] * 2) if do_classifier_free_guidance else ratio,
+ c=text_encoder_hidden_states,
+ )
+
+ # 8. Check for classifier free guidance and apply it
+ if do_classifier_free_guidance:
+ predicted_image_embedding_text, predicted_image_embedding_uncond = predicted_image_embedding.chunk(2)
+ predicted_image_embedding = torch.lerp(
+ predicted_image_embedding_uncond, predicted_image_embedding_text, guidance_scale
+ )
+
+ # 9. Renoise latents to next timestep
+ latents = self.scheduler.step(
+ model_output=predicted_image_embedding,
+ timestep=ratio,
+ sample=latents,
+ generator=generator,
+ ).prev_sample
+
+ # 10. Denormalize the latents
+ latents = latents * self.config.latent_mean - self.config.latent_std
+
+ if output_type == "np":
+ latents = latents.cpu().numpy()
+
+ if not return_dict:
+ return (latents,)
+
+ return WuerstchenPriorPipelineOutput(latents)
diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py
index 0a07ce4baed2..84df4ffb84db 100644
--- a/src/diffusers/schedulers/__init__.py
+++ b/src/diffusers/schedulers/__init__.py
@@ -34,6 +34,7 @@
from .scheduling_ddim_parallel import DDIMParallelScheduler
from .scheduling_ddpm import DDPMScheduler
from .scheduling_ddpm_parallel import DDPMParallelScheduler
+ from .scheduling_ddpm_wuerstchen import DDPMWuerstchenScheduler
from .scheduling_deis_multistep import DEISMultistepScheduler
from .scheduling_dpmsolver_multistep import DPMSolverMultistepScheduler
from .scheduling_dpmsolver_multistep_inverse import DPMSolverMultistepInverseScheduler
diff --git a/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py
new file mode 100644
index 000000000000..28311fc03301
--- /dev/null
+++ b/src/diffusers/schedulers/scheduling_ddpm_wuerstchen.py
@@ -0,0 +1,238 @@
+# Copyright (c) 2022 Pablo Pernías MIT License
+# Copyright 2023 UC Berkeley Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import torch
+
+from ..configuration_utils import ConfigMixin, register_to_config
+from ..utils import BaseOutput, randn_tensor
+from .scheduling_utils import SchedulerMixin
+
+
+@dataclass
+class DDPMWuerstchenSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ """
+
+ prev_sample: torch.FloatTensor
+
+
+def betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ max_beta=0.999,
+ alpha_transform_type="cosine",
+):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
+ Choose from `cosine` or `exp`
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+ if alpha_transform_type == "cosine":
+
+ def alpha_bar_fn(t):
+ return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ elif alpha_transform_type == "exp":
+
+ def alpha_bar_fn(t):
+ return math.exp(t * -12.0)
+
+ else:
+ raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Denoising diffusion probabilistic models (DDPMs) explores the connections between denoising score matching and
+ Langevin dynamics sampling.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details, see the original paper: https://arxiv.org/abs/2006.11239
+
+ Args:
+ scaler (`float`): ....
+ s (`float`): ....
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ scaler: float = 1.0,
+ s: float = 0.008,
+ ):
+ self.scaler = scaler
+ self.s = torch.tensor([s])
+ self._init_alpha_cumprod = torch.cos(self.s / (1 + self.s) * torch.pi * 0.5) ** 2
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ def _alpha_cumprod(self, t, device):
+ if self.scaler > 1:
+ t = 1 - (1 - t) ** self.scaler
+ elif self.scaler < 1:
+ t = t**self.scaler
+ alpha_cumprod = torch.cos(
+ (t + self.s.to(device)) / (1 + self.s.to(device)) * torch.pi * 0.5
+ ) ** 2 / self._init_alpha_cumprod.to(device)
+ return alpha_cumprod.clamp(0.0001, 0.9999)
+
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`int`, optional): current timestep
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ def set_timesteps(
+ self,
+ num_inference_steps: int = None,
+ timesteps: Optional[List[int]] = None,
+ device: Union[str, torch.device] = None,
+ ):
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`Dict[float, int]`):
+ the number of diffusion steps used when generating samples with a pre-trained model. If passed, then
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, optional):
+ the device to which the timesteps are moved to. {2 / 3: 20, 0.0: 10}
+ """
+ if timesteps is None:
+ timesteps = torch.linspace(1.0, 0.0, num_inference_steps + 1, device=device)
+ if not isinstance(timesteps, torch.Tensor):
+ timesteps = torch.Tensor(timesteps).to(device)
+ self.timesteps = timesteps
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ generator=None,
+ return_dict: bool = True,
+ ) -> Union[DDPMWuerstchenSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ generator: random number generator.
+ return_dict (`bool`): option for returning tuple rather than DDPMWuerstchenSchedulerOutput class
+
+ Returns:
+ [`DDPMWuerstchenSchedulerOutput`] or `tuple`: [`DDPMWuerstchenSchedulerOutput`] if `return_dict` is True,
+ otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
+
+ """
+ dtype = model_output.dtype
+ device = model_output.device
+ t = timestep
+
+ prev_t = self.previous_timestep(t)
+
+ alpha_cumprod = self._alpha_cumprod(t, device).view(t.size(0), *[1 for _ in sample.shape[1:]])
+ alpha_cumprod_prev = self._alpha_cumprod(prev_t, device).view(prev_t.size(0), *[1 for _ in sample.shape[1:]])
+ alpha = alpha_cumprod / alpha_cumprod_prev
+
+ mu = (1.0 / alpha).sqrt() * (sample - (1 - alpha) * model_output / (1 - alpha_cumprod).sqrt())
+
+ std_noise = randn_tensor(mu.shape, generator=generator, device=model_output.device, dtype=model_output.dtype)
+ std = ((1 - alpha) * (1.0 - alpha_cumprod_prev) / (1.0 - alpha_cumprod)).sqrt() * std_noise
+ pred = mu + std * (prev_t != 0).float().view(prev_t.size(0), *[1 for _ in sample.shape[1:]])
+
+ if not return_dict:
+ return (pred.to(dtype),)
+
+ return DDPMWuerstchenSchedulerOutput(prev_sample=pred.to(dtype))
+
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ def __len__(self):
+ return self.config.num_train_timesteps
+
+ def previous_timestep(self, timestep):
+ index = (self.timesteps - timestep[0]).abs().argmin().item()
+ prev_t = self.timesteps[index + 1][None].expand(timestep.shape[0])
+ return prev_t
diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py
index 4debac8b031a..0c7b3117fa47 100644
--- a/src/diffusers/utils/dummy_pt_objects.py
+++ b/src/diffusers/utils/dummy_pt_objects.py
@@ -615,6 +615,21 @@ def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
+class DDPMWuerstchenScheduler(metaclass=DummyObject):
+ _backends = ["torch"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch"])
+
+
class DEISMultistepScheduler(metaclass=DummyObject):
_backends = ["torch"]
diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
index 42a860a16483..5a123c1cd1ee 100644
--- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py
+++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py
@@ -1260,3 +1260,48 @@ def from_config(cls, *args, **kwargs):
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
+
+
+class WuerstchenCombinedPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class WuerstchenDecoderPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+
+class WuerstchenPriorPipeline(metaclass=DummyObject):
+ _backends = ["torch", "transformers"]
+
+ def __init__(self, *args, **kwargs):
+ requires_backends(self, ["torch", "transformers"])
+
+ @classmethod
+ def from_config(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
+
+ @classmethod
+ def from_pretrained(cls, *args, **kwargs):
+ requires_backends(cls, ["torch", "transformers"])
diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py
index 319dcb5aab32..a6f828443cb0 100644
--- a/tests/pipelines/test_pipelines_common.py
+++ b/tests/pipelines/test_pipelines_common.py
@@ -535,6 +535,8 @@ def test_dict_tuple_outputs_equivalent(self, expected_max_difference=1e-4):
def test_components_function(self):
init_components = self.get_dummy_components()
+ init_components = {k: v for k, v in init_components.items() if not isinstance(v, (str, int, float))}
+
pipe = self.pipeline_class(**init_components)
self.assertTrue(hasattr(pipe, "components"))
diff --git a/tests/pipelines/wuerstchen/__init__.py b/tests/pipelines/wuerstchen/__init__.py
new file mode 100644
index 000000000000..e69de29bb2d1
diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_combined.py b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py
new file mode 100644
index 000000000000..7d2e98030b30
--- /dev/null
+++ b/tests/pipelines/wuerstchen/test_wuerstchen_combined.py
@@ -0,0 +1,234 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
+
+from diffusers import DDPMWuerstchenScheduler, WuerstchenCombinedPipeline
+from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt, WuerstchenPrior
+from diffusers.utils import torch_device
+from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu
+
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class WuerstchenCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = WuerstchenCombinedPipeline
+ params = ["prompt"]
+ batch_params = ["prompt", "negative_prompt"]
+ required_optional_params = [
+ "generator",
+ "height",
+ "width",
+ "latents",
+ "guidance_scale",
+ "negative_prompt",
+ "num_inference_steps",
+ "return_dict",
+ "prior_num_inference_steps",
+ "output_type",
+ "return_dict",
+ ]
+ test_xformers_attention = True
+
+ @property
+ def text_embedder_hidden_size(self):
+ return 32
+
+ @property
+ def dummy_prior(self):
+ torch.manual_seed(0)
+
+ model_kwargs = {"c_in": 2, "c": 8, "depth": 2, "c_cond": 32, "c_r": 8, "nhead": 2}
+ model = WuerstchenPrior(**model_kwargs)
+ return model.eval()
+
+ @property
+ def dummy_tokenizer(self):
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ return tokenizer
+
+ @property
+ def dummy_prior_text_encoder(self):
+ torch.manual_seed(0)
+ config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=self.text_embedder_hidden_size,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ )
+ return CLIPTextModel(config).eval()
+
+ @property
+ def dummy_text_encoder(self):
+ torch.manual_seed(0)
+ config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ projection_dim=self.text_embedder_hidden_size,
+ hidden_size=self.text_embedder_hidden_size,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ )
+ return CLIPTextModel(config).eval()
+
+ @property
+ def dummy_vqgan(self):
+ torch.manual_seed(0)
+
+ model_kwargs = {
+ "bottleneck_blocks": 1,
+ "num_vq_embeddings": 2,
+ }
+ model = PaellaVQModel(**model_kwargs)
+ return model.eval()
+
+ @property
+ def dummy_decoder(self):
+ torch.manual_seed(0)
+
+ model_kwargs = {
+ "c_cond": self.text_embedder_hidden_size,
+ "c_hidden": [320],
+ "nhead": [-1],
+ "blocks": [4],
+ "level_config": ["CT"],
+ "clip_embd": self.text_embedder_hidden_size,
+ "inject_effnet": [False],
+ }
+
+ model = WuerstchenDiffNeXt(**model_kwargs)
+ return model.eval()
+
+ def get_dummy_components(self):
+ prior = self.dummy_prior
+ prior_text_encoder = self.dummy_prior_text_encoder
+
+ scheduler = DDPMWuerstchenScheduler()
+ tokenizer = self.dummy_tokenizer
+
+ text_encoder = self.dummy_text_encoder
+ decoder = self.dummy_decoder
+ vqgan = self.dummy_vqgan
+
+ components = {
+ "tokenizer": tokenizer,
+ "text_encoder": text_encoder,
+ "decoder": decoder,
+ "vqgan": vqgan,
+ "scheduler": scheduler,
+ "prior_prior": prior,
+ "prior_text_encoder": prior_text_encoder,
+ "prior_tokenizer": tokenizer,
+ "prior_scheduler": scheduler,
+ }
+
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "horse",
+ "generator": generator,
+ "prior_guidance_scale": 4.0,
+ "guidance_scale": 4.0,
+ "num_inference_steps": 2,
+ "prior_num_inference_steps": 2,
+ "output_type": "np",
+ "height": 128,
+ "width": 128,
+ }
+ return inputs
+
+ def test_wuerstchen(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+
+ pipe.set_progress_bar_config(disable=None)
+
+ output = pipe(**self.get_dummy_inputs(device))
+ image = output.images
+
+ image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)[0]
+
+ image_slice = image[0, -3:, -3:, -1]
+ image_from_tuple_slice = image_from_tuple[-3:, -3:, -1]
+
+ assert image.shape == (1, 128, 128, 3)
+
+ expected_slice = np.array([0.7616304, 0.0, 1.0, 0.0, 1.0, 0.0, 0.05925313, 0.0, 0.951898])
+
+ assert (
+ np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
+ assert (
+ np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+ ), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
+
+ @require_torch_gpu
+ def test_offloads(self):
+ pipes = []
+ components = self.get_dummy_components()
+ sd_pipe = self.pipeline_class(**components).to(torch_device)
+ pipes.append(sd_pipe)
+
+ components = self.get_dummy_components()
+ sd_pipe = self.pipeline_class(**components)
+ sd_pipe.enable_sequential_cpu_offload()
+ pipes.append(sd_pipe)
+
+ components = self.get_dummy_components()
+ sd_pipe = self.pipeline_class(**components)
+ sd_pipe.enable_model_cpu_offload()
+ pipes.append(sd_pipe)
+
+ image_slices = []
+ for pipe in pipes:
+ inputs = self.get_dummy_inputs(torch_device)
+ image = pipe(**inputs).images
+
+ image_slices.append(image[0, -3:, -3:, -1].flatten())
+
+ assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
+ assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
+
+ def test_inference_batch_single_identical(self):
+ super().test_inference_batch_single_identical(expected_max_diff=1e-2)
+
+ @unittest.skip(reason="flakey and float16 requires CUDA")
+ def test_float16_inference(self):
+ super().test_float16_inference()
diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py
new file mode 100644
index 000000000000..709e2c1a3436
--- /dev/null
+++ b/tests/pipelines/wuerstchen/test_wuerstchen_decoder.py
@@ -0,0 +1,196 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
+
+from diffusers import DDPMWuerstchenScheduler, WuerstchenDecoderPipeline
+from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt
+from diffusers.utils import torch_device
+from diffusers.utils.testing_utils import enable_full_determinism, skip_mps
+
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class WuerstchenDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = WuerstchenDecoderPipeline
+ params = ["prompt"]
+ batch_params = ["image_embeddings", "prompt", "negative_prompt"]
+ required_optional_params = [
+ "num_images_per_prompt",
+ "num_inference_steps",
+ "latents",
+ "negative_prompt",
+ "guidance_scale",
+ "output_type",
+ "return_dict",
+ ]
+ test_xformers_attention = False
+
+ @property
+ def text_embedder_hidden_size(self):
+ return 32
+
+ @property
+ def time_input_dim(self):
+ return 32
+
+ @property
+ def block_out_channels_0(self):
+ return self.time_input_dim
+
+ @property
+ def time_embed_dim(self):
+ return self.time_input_dim * 4
+
+ @property
+ def dummy_tokenizer(self):
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ return tokenizer
+
+ @property
+ def dummy_text_encoder(self):
+ torch.manual_seed(0)
+ config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ projection_dim=self.text_embedder_hidden_size,
+ hidden_size=self.text_embedder_hidden_size,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ )
+ return CLIPTextModel(config).eval()
+
+ @property
+ def dummy_vqgan(self):
+ torch.manual_seed(0)
+
+ model_kwargs = {
+ "bottleneck_blocks": 1,
+ "num_vq_embeddings": 2,
+ }
+ model = PaellaVQModel(**model_kwargs)
+ return model.eval()
+
+ @property
+ def dummy_decoder(self):
+ torch.manual_seed(0)
+
+ model_kwargs = {
+ "c_cond": self.text_embedder_hidden_size,
+ "c_hidden": [320],
+ "nhead": [-1],
+ "blocks": [4],
+ "level_config": ["CT"],
+ "clip_embd": self.text_embedder_hidden_size,
+ "inject_effnet": [False],
+ }
+
+ model = WuerstchenDiffNeXt(**model_kwargs)
+ return model.eval()
+
+ def get_dummy_components(self):
+ decoder = self.dummy_decoder
+ text_encoder = self.dummy_text_encoder
+ tokenizer = self.dummy_tokenizer
+ vqgan = self.dummy_vqgan
+
+ scheduler = DDPMWuerstchenScheduler()
+
+ components = {
+ "decoder": decoder,
+ "vqgan": vqgan,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "scheduler": scheduler,
+ "latent_dim_scale": 4.0,
+ }
+
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "image_embeddings": torch.ones((1, 4, 4, 4), device=device),
+ "prompt": "horse",
+ "generator": generator,
+ "guidance_scale": 1.0,
+ "num_inference_steps": 2,
+ "output_type": "np",
+ }
+ return inputs
+
+ def test_wuerstchen_decoder(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+
+ pipe.set_progress_bar_config(disable=None)
+
+ output = pipe(**self.get_dummy_inputs(device))
+ image = output.images
+
+ image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)
+
+ image_slice = image[0, -3:, -3:, -1]
+ image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
+
+ assert image.shape == (1, 64, 64, 3)
+
+ expected_slice = np.array([0.0000, 0.0000, 0.0089, 1.0000, 1.0000, 0.3927, 1.0000, 1.0000, 1.0000])
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+
+ @skip_mps
+ def test_inference_batch_single_identical(self):
+ test_max_difference = torch_device == "cpu"
+ relax_max_difference = True
+ test_mean_pixel_difference = False
+
+ self._test_inference_batch_single_identical(
+ test_max_difference=test_max_difference,
+ relax_max_difference=relax_max_difference,
+ test_mean_pixel_difference=test_mean_pixel_difference,
+ )
+
+ @skip_mps
+ def test_attention_slicing_forward_pass(self):
+ test_max_difference = torch_device == "cpu"
+ test_mean_pixel_difference = False
+
+ self._test_attention_slicing_forward_pass(
+ test_max_difference=test_max_difference,
+ test_mean_pixel_difference=test_mean_pixel_difference,
+ )
+
+ @unittest.skip(reason="bf16 not supported and requires CUDA")
+ def test_float16_inference(self):
+ super().test_float16_inference()
diff --git a/tests/pipelines/wuerstchen/test_wuerstchen_prior.py b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py
new file mode 100644
index 000000000000..a255a665c48e
--- /dev/null
+++ b/tests/pipelines/wuerstchen/test_wuerstchen_prior.py
@@ -0,0 +1,194 @@
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+
+import numpy as np
+import torch
+from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
+
+from diffusers import DDPMWuerstchenScheduler, WuerstchenPriorPipeline
+from diffusers.pipelines.wuerstchen import WuerstchenPrior
+from diffusers.utils import torch_device
+from diffusers.utils.testing_utils import enable_full_determinism, skip_mps
+
+from ..test_pipelines_common import PipelineTesterMixin
+
+
+enable_full_determinism()
+
+
+class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
+ pipeline_class = WuerstchenPriorPipeline
+ params = ["prompt"]
+ batch_params = ["prompt", "negative_prompt"]
+ required_optional_params = [
+ "num_images_per_prompt",
+ "generator",
+ "num_inference_steps",
+ "latents",
+ "negative_prompt",
+ "guidance_scale",
+ "output_type",
+ "return_dict",
+ ]
+ test_xformers_attention = False
+
+ @property
+ def text_embedder_hidden_size(self):
+ return 32
+
+ @property
+ def time_input_dim(self):
+ return 32
+
+ @property
+ def block_out_channels_0(self):
+ return self.time_input_dim
+
+ @property
+ def time_embed_dim(self):
+ return self.time_input_dim * 4
+
+ @property
+ def dummy_tokenizer(self):
+ tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
+ return tokenizer
+
+ @property
+ def dummy_text_encoder(self):
+ torch.manual_seed(0)
+ config = CLIPTextConfig(
+ bos_token_id=0,
+ eos_token_id=2,
+ hidden_size=self.text_embedder_hidden_size,
+ intermediate_size=37,
+ layer_norm_eps=1e-05,
+ num_attention_heads=4,
+ num_hidden_layers=5,
+ pad_token_id=1,
+ vocab_size=1000,
+ )
+ return CLIPTextModel(config).eval()
+
+ @property
+ def dummy_prior(self):
+ torch.manual_seed(0)
+
+ model_kwargs = {
+ "c_in": 2,
+ "c": 8,
+ "depth": 2,
+ "c_cond": 32,
+ "c_r": 8,
+ "nhead": 2,
+ }
+
+ model = WuerstchenPrior(**model_kwargs)
+ return model.eval()
+
+ def get_dummy_components(self):
+ prior = self.dummy_prior
+ text_encoder = self.dummy_text_encoder
+ tokenizer = self.dummy_tokenizer
+
+ scheduler = DDPMWuerstchenScheduler()
+
+ components = {
+ "prior": prior,
+ "text_encoder": text_encoder,
+ "tokenizer": tokenizer,
+ "scheduler": scheduler,
+ }
+
+ return components
+
+ def get_dummy_inputs(self, device, seed=0):
+ if str(device).startswith("mps"):
+ generator = torch.manual_seed(seed)
+ else:
+ generator = torch.Generator(device=device).manual_seed(seed)
+ inputs = {
+ "prompt": "horse",
+ "generator": generator,
+ "guidance_scale": 4.0,
+ "num_inference_steps": 2,
+ "output_type": "np",
+ }
+ return inputs
+
+ def test_wuerstchen_prior(self):
+ device = "cpu"
+
+ components = self.get_dummy_components()
+
+ pipe = self.pipeline_class(**components)
+ pipe = pipe.to(device)
+
+ pipe.set_progress_bar_config(disable=None)
+
+ output = pipe(**self.get_dummy_inputs(device))
+ image = output.image_embeddings
+
+ image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)[0]
+
+ image_slice = image[0, 0, 0, -10:]
+ image_from_tuple_slice = image_from_tuple[0, 0, 0, -10:]
+
+ assert image.shape == (1, 2, 24, 24)
+
+ expected_slice = np.array(
+ [
+ -7172.837,
+ -3438.855,
+ -1093.312,
+ 388.8835,
+ -7471.467,
+ -7998.1206,
+ -5328.259,
+ 218.00089,
+ -2731.5745,
+ -8056.734,
+ ],
+ )
+ assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
+ assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
+
+ @skip_mps
+ def test_inference_batch_single_identical(self):
+ test_max_difference = torch_device == "cpu"
+ relax_max_difference = True
+ test_mean_pixel_difference = False
+
+ self._test_inference_batch_single_identical(
+ test_max_difference=test_max_difference,
+ relax_max_difference=relax_max_difference,
+ test_mean_pixel_difference=test_mean_pixel_difference,
+ expected_max_diff=1e-1,
+ )
+
+ @skip_mps
+ def test_attention_slicing_forward_pass(self):
+ test_max_difference = torch_device == "cpu"
+ test_mean_pixel_difference = False
+
+ self._test_attention_slicing_forward_pass(
+ test_max_difference=test_max_difference,
+ test_mean_pixel_difference=test_mean_pixel_difference,
+ )
+
+ @unittest.skip(reason="flaky for now")
+ def test_float16_inference(self):
+ super().test_float16_inference()