Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Würstchen model #3849

Merged
merged 216 commits into from
Sep 6, 2023
Merged
Show file tree
Hide file tree
Changes from 180 commits
Commits
Show all changes
216 commits
Select commit Hold shift + click to select a range
119a451
initial
kashif Jun 7, 2023
0623199
initial
kashif Jun 7, 2023
8a6a92c
added initial convert script for paella vqmodel
kashif Jun 21, 2023
80713a4
initial wuerstchen pipeline
kashif Jun 22, 2023
8bd6cb8
add LayerNorm2d
kashif Jun 22, 2023
806ed12
added modules
kashif Jun 22, 2023
ff6139d
fix typo
kashif Jun 22, 2023
560da3b
use model_v2
kashif Jun 23, 2023
3acc9fa
embed clip caption amd negative_caption
kashif Jun 23, 2023
f84ac09
fixed name of var
kashif Jun 23, 2023
25de2c6
initial modules in one place
kashif Jun 23, 2023
30e41a5
WuerstchenPriorPipeline
kashif Jun 23, 2023
d563218
inital shape
kashif Jun 23, 2023
d328459
initial denoising prior loop
kashif Jun 23, 2023
f0cc379
fix output
kashif Jun 23, 2023
4c8a791
add WuerstchenPriorPipeline to __init__.py
kashif Jun 23, 2023
ad474b1
use the noise ratio in the Prior
kashif Jun 23, 2023
a79a9ad
try to save pipeline
kashif Jun 23, 2023
4c28f9c
save_pretrained working
kashif Jun 23, 2023
6e51d7e
Few additions
dome272 Jun 24, 2023
623c1e4
Merge branch 'main' into paella
kashif Jun 25, 2023
cd5ad04
add _execution_device
kashif Jun 26, 2023
92c46df
shape is int
kashif Jun 26, 2023
4665e48
fix batch size
kashif Jun 26, 2023
58c98b1
fix shape of ratio
kashif Jun 26, 2023
66cff25
fix shape of ratio
kashif Jun 27, 2023
d06276d
fix output dataclass
kashif Jun 27, 2023
95eb11e
tests folder
kashif Jun 27, 2023
2976dd8
Merge branch 'main' into paella
kashif Jun 27, 2023
896624e
fix formatting
kashif Jun 27, 2023
1ed7a58
fix float16 + started with generator
dome272 Jun 28, 2023
e809fd7
Update pipeline_wuerstchen.py
dome272 Jun 29, 2023
0ad3f79
removed vqgan code
kashif Jun 30, 2023
2edfc48
add WuerstchenGeneratorPipeline
kashif Jun 30, 2023
624c6d9
fix WuerstchenGeneratorPipeline
kashif Jun 30, 2023
0d3c3f3
fix docstrings
kashif Jun 30, 2023
38fa6d1
fix imports
kashif Jun 30, 2023
ea2c64e
convert generator pipeline
kashif Jun 30, 2023
96cb4de
fix convert
kashif Jun 30, 2023
2c6d0dd
Work on Generator Pipeline. WIP
dome272 Jul 1, 2023
3281798
Pipeline works with our diffuzz code
dome272 Jul 3, 2023
59667d9
apply scale factor
kashif Jul 4, 2023
156901f
removed vqgan.py
kashif Jul 4, 2023
ddc9daa
use cosine schedule
kashif Jul 4, 2023
30e4888
redo the denoising loop
kashif Jul 9, 2023
fe972d6
Update src/diffusers/models/resnet.py
kashif Jul 14, 2023
db5dd65
use torch.lerp
kashif Jul 14, 2023
44d6a04
use warp-diffusion org
kashif Jul 14, 2023
180bbae
clip_sample=False,
kashif Jul 14, 2023
715e999
solve merge conflict
patrickvonplaten Jul 19, 2023
368e113
some refactoring
patrickvonplaten Jul 19, 2023
0c0bedc
use model_v3_stage_c
kashif Jul 20, 2023
1247ae9
c_cond size
kashif Jul 20, 2023
b0435b3
use clip-bigG
kashif Jul 20, 2023
9bdc662
allow stage b clip to be None
kashif Jul 21, 2023
6f9cd37
Merge branch 'huggingface:main' into paella
kashif Jul 21, 2023
c05d6c5
add dummy
kashif Jul 21, 2023
2d9d85d
würstchen scheduler
dome272 Jul 23, 2023
b9c3468
minor changes
dome272 Jul 24, 2023
a385e66
set clip=None in the pipeline
kashif Jul 25, 2023
0db4f19
fix attention mask
dome272 Aug 4, 2023
41c47cc
nip
dome272 Aug 4, 2023
3e294c6
Merge remote-tracking branch 'upstream/main' into paella
kashif Aug 4, 2023
3ae3ea4
add attention_masks to text_encoder
kashif Aug 4, 2023
b3b2b60
make fix-copies
kashif Aug 4, 2023
8663037
add back clip
kashif Aug 4, 2023
9ec3f01
add text_encoder
kashif Aug 5, 2023
83f87ef
gen_text_encoder and tokenizer
kashif Aug 5, 2023
7a3639d
fix import
kashif Aug 5, 2023
3791b94
updated pipeline test
kashif Aug 6, 2023
29610d2
Merge branch 'paella' of https://github.com/kashif/diffusers into paella
dome272 Aug 6, 2023
cee4feb
Merge branch 'paella' of https://github.com/kashif/diffusers into paella
dome272 Aug 6, 2023
f91b12e
undo changes to pipeline test
kashif Aug 6, 2023
5b518a2
nip
dome272 Aug 6, 2023
44c0f93
test
dome272 Aug 6, 2023
55ba4db
fix typo
kashif Aug 6, 2023
48034a0
fix output name
kashif Aug 6, 2023
be2529a
set guidance_scale=0 and remove diffuze
kashif Aug 6, 2023
be1aa96
fix doc strings
kashif Aug 6, 2023
a1114e3
make style
kashif Aug 6, 2023
596c7f5
nip
dome272 Aug 6, 2023
e708a76
Merge branch 'paella' of https://github.com/kashif/diffusers into paella
dome272 Aug 6, 2023
4ae05f3
removed unused
kashif Aug 6, 2023
71b7fa3
initial docs
kashif Aug 7, 2023
36e9722
rename
kashif Aug 7, 2023
f620d83
toc
kashif Aug 7, 2023
61c137c
cleanup
kashif Aug 7, 2023
e5127d5
remvoe test script
kashif Aug 7, 2023
61a5ebc
fix-copies
kashif Aug 7, 2023
c89b8a4
Merge branch 'main' into paella
kashif Aug 7, 2023
b122ddd
fix multi images
patrickvonplaten Aug 7, 2023
b50ce49
fix multi images
patrickvonplaten Aug 7, 2023
f24ee47
remove dup
kashif Aug 7, 2023
ce23ef7
remove unused modules
kashif Aug 7, 2023
163cd2b
undo changes for debugging
kashif Aug 7, 2023
59e5f15
no new line
kashif Aug 7, 2023
e6f0f75
remove dup conversion script
kashif Aug 7, 2023
35e55a7
fix doc string
kashif Aug 7, 2023
78cd405
cleanup
kashif Aug 7, 2023
3ddee34
pass default args
kashif Aug 7, 2023
fca022a
dup permute
kashif Aug 7, 2023
903ba6f
fix some tests
kashif Aug 7, 2023
67eaff6
fix prepare_latents
kashif Aug 7, 2023
3731131
move Prior class to modules
kashif Aug 7, 2023
09ca25b
offload only the text encoder and vqgan
kashif Aug 7, 2023
8f7a74a
fix resolution calculation for prior
dome272 Aug 7, 2023
9bbfb7c
Merge branch 'paella' of https://github.com/kashif/diffusers into paella
dome272 Aug 7, 2023
cc80e2b
nip
dome272 Aug 7, 2023
f74f688
removed testing script
kashif Aug 7, 2023
829a394
fix shape
kashif Aug 7, 2023
170180a
fix argument to set_timesteps
kashif Aug 7, 2023
f9f34aa
Merge branch 'main' into paella
kashif Aug 7, 2023
687de06
do not change .gitignore
kashif Aug 7, 2023
0ca12ee
fix resolution calculations + readme
dome272 Aug 7, 2023
433bded
resolution calculation fix + readme
dome272 Aug 7, 2023
945295d
Merge branch 'paella' of https://github.com/kashif/diffusers into paella
dome272 Aug 7, 2023
05b58bc
small fixes
dome272 Aug 7, 2023
b0dc35c
Add combined pipeline
patrickvonplaten Aug 7, 2023
3e08530
rename generator -> decoder
patrickvonplaten Aug 7, 2023
711246a
Update .gitignore
kashif Aug 7, 2023
5ca1fe0
removed efficient_net
kashif Aug 8, 2023
1f4bb0a
create combined WuerstchenPipeline
kashif Aug 8, 2023
474ec70
make arguments consistent with VQ model
kashif Aug 11, 2023
54d3397
fix var names
kashif Aug 14, 2023
98f4b54
no need to return text_encoder_hidden_states
kashif Aug 14, 2023
b43b463
add latent_dim_scale to config
kashif Aug 14, 2023
bb8c5b1
split model into its own file
kashif Aug 14, 2023
2ee6ed9
add WuerschenPipeline to docs
kashif Aug 14, 2023
d944bb1
remove unused latent_size
kashif Aug 15, 2023
59eb765
register latent_dim_scale
kashif Aug 15, 2023
b5ff681
update script
kashif Aug 15, 2023
7d6f2e0
update docstring
kashif Aug 15, 2023
f33dd22
use Attention preprocessor
kashif Aug 15, 2023
d230c8b
Merge branch 'main' into paella
kashif Aug 17, 2023
17d28e3
concat with normed input
kashif Aug 17, 2023
3872a96
Merge branch 'main' into paella
kashif Aug 17, 2023
3f49d52
fix-copies
kashif Aug 18, 2023
5ea91e1
add docs
kashif Aug 18, 2023
752d3f5
fix test
kashif Aug 18, 2023
eadd628
fix style
kashif Aug 18, 2023
fda1f68
add to cpu_offloaded_model
kashif Aug 18, 2023
f8ddabd
updated type
kashif Aug 18, 2023
ab22f62
Merge branch 'main' into paella
kashif Aug 18, 2023
45dcfe1
remove 1-line func
kashif Aug 18, 2023
cbf8780
updated type
kashif Aug 18, 2023
15b5f42
initial decoder test
kashif Aug 18, 2023
bd362df
formatting
kashif Aug 18, 2023
d20f8ff
formatting
kashif Aug 18, 2023
851705c
fix autodoc link
kashif Aug 18, 2023
cdf5109
num_inference_steps is int
kashif Aug 18, 2023
64cd513
Merge branch 'main' into paella
kashif Aug 19, 2023
7a24a7d
remove comments
dome272 Aug 21, 2023
2971765
Merge branch 'paella' of https://github.com/kashif/diffusers into paella
dome272 Aug 21, 2023
11cb295
fix example in docs
dome272 Aug 22, 2023
23e8740
Update src/diffusers/pipelines/wuerstchen/diffnext.py
dome272 Aug 23, 2023
cc70ca5
rename layernorm to WuerstchenLayerNorm
dome272 Aug 23, 2023
32395d3
Merge branch 'paella' of https://github.com/kashif/diffusers into paella
dome272 Aug 23, 2023
c2086cf
Merge branch 'main' into paella
kashif Aug 24, 2023
851115c
rename DiffNext to WuerstchenDiffNeXt
kashif Aug 24, 2023
ecd6ab3
added comment about MixingResidualBlock
kashif Aug 25, 2023
44ad4c3
move paella vq-vae to pipelines' folder
kashif Aug 28, 2023
fb45b37
Merge branch 'main' into paella
kashif Aug 28, 2023
d3e5919
initial decoder test
kashif Aug 28, 2023
2f325e6
increased test_float16_inference expected diff
kashif Aug 28, 2023
db8fae2
self_attn is always true
kashif Aug 28, 2023
a10f5d6
Merge branch 'main' into paella
kashif Aug 30, 2023
754b9ab
more passing decoder tests
kashif Aug 30, 2023
f162d77
batch image_embeds
kashif Aug 31, 2023
c74f9c6
fix failing tests
kashif Aug 31, 2023
741f6ef
set the correct dtype
kashif Aug 31, 2023
09781d2
relax inference test
kashif Aug 31, 2023
b8f8cff
Merge branch 'main' into paella
kashif Aug 31, 2023
06cd467
update prior
kashif Aug 31, 2023
7c15471
added combined pipeline test
kashif Aug 31, 2023
1d6615e
faster test
kashif Aug 31, 2023
ab586bd
Merge branch 'main' into paella
kashif Aug 31, 2023
8fdce9d
faster test
kashif Sep 1, 2023
56cfe6b
Merge branch 'main' into paella
kashif Sep 1, 2023
f9a9259
Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combine…
kashif Sep 1, 2023
75d4060
Merge branch 'main' into paella
kashif Sep 1, 2023
cf67355
Merge branch 'main' into paella
kashif Sep 4, 2023
d45550b
fix issues from review
kashif Sep 4, 2023
6f978ed
update wuerstchen.md + change generator name
dome272 Sep 4, 2023
50774c9
Merge branch 'paella' of https://github.com/kashif/diffusers into paella
dome272 Sep 4, 2023
81f64de
resolve issues
dome272 Sep 5, 2023
83073b1
Merge branch 'main' into paella
patrickvonplaten Sep 5, 2023
35772f1
fix copied from usage and add back batch_size
kashif Sep 5, 2023
21068df
fix API
kashif Sep 5, 2023
7f39e0c
fix arguments
kashif Sep 5, 2023
0b97829
fix combined test
kashif Sep 5, 2023
b801a56
Added timesteps argument + fixes
dome272 Sep 6, 2023
1e9336c
Update tests/pipelines/test_pipelines_common.py
kashif Sep 6, 2023
9ca78d9
Update tests/pipelines/wuerstchen/test_wuerstchen_prior.py
kashif Sep 6, 2023
9d8ea07
Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combine…
kashif Sep 6, 2023
a09b4ef
Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combine…
kashif Sep 6, 2023
1c568ce
Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combine…
kashif Sep 6, 2023
500cb6e
Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combine…
patrickvonplaten Sep 6, 2023
692a1b7
Merge branch 'main' of https://github.com/huggingface/diffusers into …
patrickvonplaten Sep 6, 2023
ad98baf
Merge branch 'paella' of https://github.com/kashif/diffusers into paella
patrickvonplaten Sep 6, 2023
2d222ed
up
patrickvonplaten Sep 6, 2023
d8b62f2
Fix more
patrickvonplaten Sep 6, 2023
d4751ab
failing tests
kashif Sep 6, 2023
3b705e8
up
patrickvonplaten Sep 6, 2023
879c82c
up
patrickvonplaten Sep 6, 2023
8490804
up
patrickvonplaten Sep 6, 2023
88032f1
correct naming
patrickvonplaten Sep 6, 2023
fb33746
correct docs
patrickvonplaten Sep 6, 2023
6347957
correct docs
patrickvonplaten Sep 6, 2023
5489081
fix test params
kashif Sep 6, 2023
30bc6b6
correct docs
patrickvonplaten Sep 6, 2023
bc8a472
correct docs
patrickvonplaten Sep 6, 2023
09787b1
fix classifier free guidance
patrickvonplaten Sep 6, 2023
ed9f96a
fix classifier free guidance
patrickvonplaten Sep 6, 2023
30a86b3
fix more
patrickvonplaten Sep 6, 2023
3f04ada
fix all
patrickvonplaten Sep 6, 2023
c35f3f7
make tests faster
patrickvonplaten Sep 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,8 @@
title: Versatile Diffusion
- local: api/pipelines/vq_diffusion
title: VQ Diffusion
- local: api/pipelines/wuerstchen
title: Wuerstchen
title: Pipelines
- sections:
- local: api/schedulers/overview
Expand Down
110 changes: 110 additions & 0 deletions docs/source/en/api/pipelines/wuerstchen.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Würstchen

[Würstchen: Efficient Pretraining of Text-to-Image Models](https://huggingface.co/papers/2306.00637) is by Pablo Pernias, Dominic Rampas, and Marc Aubreville.
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved

The abstract from the paper is:

*We introduce Würstchen, a novel technique for text-to-image synthesis that unites competitive performance with unprecedented cost-effectiveness and ease of training on constrained hardware. Building on recent advancements in machine learning, our approach, which utilizes latent diffusion strategies at strong latent image compression rates, significantly reduces the computational burden, typically associated with state-of-the-art models, while preserving, if not enhancing, the quality of generated images. Wuerstchen achieves notable speed improvements at inference time, thereby rendering real-time applications more viable. One of the key advantages of our method lies in its modest training requirements of only 9,200 GPU hours, slashing the usual costs significantly without compromising the end performance. In a comparison against the state-of-the-art, we found the approach to yield strong competitiveness. This paper opens the door to a new line of research that prioritizes both performance and computational accessibility, hence democratizing the use of sophisticated AI technologies. Through Wuerstchen, we demonstrate a compelling stride forward in the realm of text-to-image synthesis, offering an innovative path to explore in future research.*

## Würstchen v2 comes to Diffusers
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where can the users find all the checkpoints?

This is how we include this info (reference):

image


After the initial paper release, we have improved numerous things in the architecture, training and sampling, making Würstchen competetive to current state-of-the-art models in many ways. We are excited to release this new version together with Diffusers. Here is a list of the improvements.

- Higher resolution (1024x1024 up to 2048x2048)
- Faster inference
- Multi Aspect Resolution Sampling
- Better quality

## Text-to-Image Generation

For the sake of usability Würstchen can be used with a single pipeline. This pipeline is called `WuerstchenPipeline` and can be used as follows:

```python
import torch
from diffusers import WuerstchenPipeline

device = "cuda"
dtype = torch.float16
num_images_per_prompt = 2

pipeline = WuerstchenPipeline.from_pretrained(
"warp-diffusion/WuerstchenPipeline", torch_dtype=dtype
).to(device)

caption = "A captivating artwork of a mysterious stone golem"
negative_prompt = "bad anatomy, blurry, fuzzy, extra arms, extra fingers, poorly drawn hands, disfigured, tiling, deformed, mutated"

output = pipeline(
prompt=caption,
height=1024,
width=1024,
Comment on lines +51 to +52
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be a good idea to comment about the limitation here a bit I think. How does this pipeline perform when the resolution is low (as low as 256x256) and high (as high as 1024x1024).

Also, are we mentioning somewhere that we don't need more than 10 inference steps to get good-quality results to demonstrate the competitive advantage of Wuerstchen?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will do that in the documentation!

negative_prompt=negative_prompt,
guidance_scale=8.0,
num_images_per_prompt=num_images_per_prompt,
output_type="pil",
).images
```

For explanation purposes, we can also initialize the two main pipelines of Würstchen individually. Würstchen consists of 3 stages: Stage C, Stage B, Stage A. They all have different jobs and work only together. When generating text-conditional images, Stage C will first generate the latents in a very compressed latent space. This is what happens in the `prior_pipeline`. Afterwards, the generated latents will be passed to Stage B, which decompresses the latents into a bigger latent space of a VQGAN. These latents can then be decoded by Stage A, which is a VQGAN, into the pixel-space. Stage B & Stage A are both encapsulated in the `decoder_pipeline`. For more details, take a look the [paper](https://huggingface.co/papers/2306.00637).

```python
import torch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can update this code snippet with a shorter one as stated here: #3849 (comment)

from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline

device = "cuda"
dtype = torch.float16
num_images_per_prompt = 2

prior_pipeline = WuerstchenPriorPipeline.from_pretrained(
"warp-diffusion/WuerstchenPriorPipeline", torch_dtype=dtype
).to(device)
decoder_pipeline = WuerstchenDecoderPipeline.from_pretrained(
"warp-diffusion/WuerstchenDecoderPipeline", torch_dtype=dtype
).to(device)

caption = "A captivating artwork of a mysterious stone golem"
negative_prompt = "bad anatomy, blurry, fuzzy, extra arms, extra fingers, poorly drawn hands, disfigured, tiling, deformed, mutated"

prior_output = prior_pipeline(
prompt=caption,
height=1024,
width=1024,
negative_prompt=negative_prompt,
guidance_scale=8.0,
num_images_per_prompt=num_images_per_prompt,
)
decoder_output = decoder_pipeline(
predicted_image_embeddings=prior_output.image_embeds,
prompt=caption,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
guidance_scale=0.0,
output_type="pil",
).images

```

The original codebase, as well as experimental ideas, can be found at [dome272/Wuerstchen](https://github.com/dome272/Wuerstchen).

## WuerschenPipeline
sayakpaul marked this conversation as resolved.
Show resolved Hide resolved

[[autodoc]] WuerstchenPipeline
- all
- __call__

## WuerstchenPriorPipeline

[[autodoc]] WuerstchenDecoderPipeline

- all
- __call__

## WuerstchenPriorPipelineOutput

[[autodoc]] pipelines.wuerstchen.pipeline_wuerstchen_prior.WuerstchenPriorPipelineOutput

## WuerstchenDecoderPipeline
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe let's shift this above the output dataclass?


[[autodoc]] WuerstchenDecoderPipeline
- all
- __call__
115 changes: 115 additions & 0 deletions scripts/convert_wuerstchen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Run inside root directory of official source code
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe pass on a link?

import os

import torch
from transformers import AutoTokenizer, CLIPTextModel
from vqgan import VQModel

from diffusers import (
DDPMWuerstchenScheduler,
WuerstchenDecoderPipeline,
WuerstchenPipeline,
WuerstchenPriorPipeline,
)
from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt, WuerstchenPrior


model_path = "models/"
device = "cpu"

paella_vqmodel = VQModel()
state_dict = torch.load(os.path.join(model_path, "vqgan_f4_v1_500k.pt"), map_location=device)["state_dict"]
paella_vqmodel.load_state_dict(state_dict)

state_dict["vquantizer.embedding.weight"] = state_dict["vquantizer.codebook.weight"]
state_dict.pop("vquantizer.codebook.weight")
vqmodel = PaellaVQModel(num_vq_embeddings=paella_vqmodel.codebook_size, latent_channels=paella_vqmodel.c_latent)
vqmodel.load_state_dict(state_dict)

# Clip Text encoder and tokenizer
text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")

# Generator
gen_text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu")
gen_tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")

orig_state_dict = torch.load(os.path.join(model_path, "model_v2_stage_b.pt"), map_location=device)["state_dict"]
state_dict = {}
for key in orig_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
else:
state_dict[key] = orig_state_dict[key]
generator = WuerstchenDiffNeXt()
generator.load_state_dict(state_dict)

# Prior
orig_state_dict = torch.load(os.path.join(model_path, "model_v3_stage_c.pt"), map_location=device)["ema_state_dict"]
state_dict = {}
for key in orig_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
else:
state_dict[key] = orig_state_dict[key]
prior_model = WuerstchenPrior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24).to(device)
prior_model.load_state_dict(state_dict)

# scheduler
scheduler = DDPMWuerstchenScheduler()

# Prior pipeline
prior_pipeline = WuerstchenPriorPipeline(
prior=prior_model, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler
)

prior_pipeline.save_pretrained("warp-diffusion/WuerstchenPriorPipeline")

decoder_pipeline = WuerstchenDecoderPipeline(
text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, vqgan=vqmodel, generator=generator, scheduler=scheduler
)
decoder_pipeline.save_pretrained("warp-diffusion/WuerstchenDecoderPipeline")

# Wuerstchen pipeline
wuerstchen_pipeline = WuerstchenPipeline(
# Decoder
text_encoder=gen_text_encoder,
tokenizer=gen_tokenizer,
generator=generator,
scheduler=scheduler,
vqgan=vqmodel,
# Prior
prior_tokenizer=tokenizer,
prior_text_encoder=text_encoder,
prior_prior=prior_model,
prior_scheduler=scheduler,
)
wuerstchen_pipeline.save_pretrained("warp-diffusion/WuerstchenPipeline")
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
DDIMScheduler,
DDPMParallelScheduler,
DDPMScheduler,
DDPMWuerstchenScheduler,
DEISMultistepScheduler,
DPMSolverMultistepInverseScheduler,
DPMSolverMultistepScheduler,
Expand Down Expand Up @@ -215,6 +216,9 @@
VersatileDiffusionTextToImagePipeline,
VideoToVideoSDPipeline,
VQDiffusionPipeline,
WuerstchenDecoderPipeline,
WuerstchenPipeline,
WuerstchenPriorPipeline,
)

try:
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def set_use_memory_efficient_attention_xformers(
if use_memory_efficient_attention_xformers:
if is_added_kv_processor and (is_lora or is_custom_diffusion):
raise NotImplementedError(
f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}"
)
if not is_xformers_available():
raise ModuleNotFoundError(
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
super().__init__()
self.layers_per_block = layers_per_block

self.conv_in = torch.nn.Conv2d(
self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[0],
kernel_size=3,
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/vq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def decode(
) -> Union[DecoderOutput, torch.FloatTensor]:
# also go through quantization layer
if not force_not_quantize:
quant, emb_loss, info = self.quantize(h)
quant, _, _ = self.quantize(h)
else:
quant = h
quant2 = self.post_quant_conv(quant)
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@
VersatileDiffusionTextToImagePipeline,
)
from .vq_diffusion import VQDiffusionPipeline
from .wuerstchen import WuerstchenDecoderPipeline, WuerstchenPipeline, WuerstchenPriorPipeline


try:
Expand Down
3 changes: 3 additions & 0 deletions src/diffusers/pipelines/auto_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline,
)
from .wuerstchen import WuerstchenDecoderPipeline, WuerstchenPipeline


AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
Expand All @@ -63,6 +64,7 @@
("kandinsky22", KandinskyV22CombinedPipeline),
("stable-diffusion-controlnet", StableDiffusionControlNetPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline),
("wuerstchen", WuerstchenPipeline),
]
)

Expand Down Expand Up @@ -93,6 +95,7 @@
[
("kandinsky", KandinskyPipeline),
("kandinsky22", KandinskyV22Pipeline),
("wuerstchen", WuerstchenDecoderPipeline),
]
)
_AUTO_IMAGE2IMAGE_DECODER_PIPELINES_MAPPING = OrderedDict(
Expand Down
10 changes: 10 additions & 0 deletions src/diffusers/pipelines/wuerstchen/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from ...utils import is_torch_available, is_transformers_available


if is_transformers_available() and is_torch_available():
from .modeling_paella_vq_model import PaellaVQModel
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
from .modeling_wuerstchen_prior import WuerstchenPrior
from .pipeline_wuerstchen import WuerstchenDecoderPipeline
from .pipeline_wuerstchen_combined import WuerstchenPipeline
from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline
Loading