Skip to content

Commit

Permalink
Würstchen model (huggingface#3849)
Browse files Browse the repository at this point in the history
* initial

* initial

* added initial convert script for paella vqmodel

* initial wuerstchen pipeline

* add LayerNorm2d

* added modules

* fix typo

* use model_v2

* embed clip caption amd negative_caption

* fixed name of var

* initial modules in one place

* WuerstchenPriorPipeline

* inital shape

* initial denoising prior loop

* fix output

* add WuerstchenPriorPipeline to __init__.py

* use the noise ratio in the Prior

* try to save pipeline

* save_pretrained working

* Few additions

* add _execution_device

* shape is int

* fix batch size

* fix shape of ratio

* fix shape of ratio

* fix output dataclass

* tests folder

* fix formatting

* fix float16 + started with generator

* Update pipeline_wuerstchen.py

* removed vqgan code

* add WuerstchenGeneratorPipeline

* fix WuerstchenGeneratorPipeline

* fix docstrings

* fix imports

* convert generator pipeline

* fix convert

* Work on Generator Pipeline. WIP

* Pipeline works with our diffuzz code

* apply scale factor

* removed vqgan.py

* use cosine schedule

* redo the denoising loop

* Update src/diffusers/models/resnet.py

Co-authored-by: Patrick von Platen <[email protected]>

* use torch.lerp

* use warp-diffusion org

* clip_sample=False,

* some refactoring

* use model_v3_stage_c

* c_cond size

* use clip-bigG

* allow stage b clip to be None

* add dummy

* würstchen scheduler

* minor changes

* set clip=None in the pipeline

* fix attention mask

* add attention_masks to text_encoder

* make fix-copies

* add back clip

* add text_encoder

* gen_text_encoder and tokenizer

* fix import

* updated pipeline test

* undo changes to pipeline test

* nip

* fix typo

* fix output name

* set guidance_scale=0 and remove diffuze

* fix doc strings

* make style

* nip

* removed unused

* initial docs

* rename

* toc

* cleanup

* remvoe test script

* fix-copies

* fix multi images

* remove dup

* remove unused modules

* undo changes for debugging

* no  new line

* remove dup conversion script

* fix doc string

* cleanup

* pass default args

* dup permute

* fix some tests

* fix prepare_latents

* move Prior class to modules

* offload only the text encoder and vqgan

* fix resolution calculation for prior

* nip

* removed testing script

* fix shape

* fix argument to set_timesteps

* do not change .gitignore

* fix resolution calculations + readme

* resolution calculation fix + readme

* small fixes

* Add combined pipeline

* rename generator -> decoder

* Update .gitignore

Co-authored-by: Patrick von Platen <[email protected]>

* removed efficient_net

* create combined WuerstchenPipeline

* make arguments consistent with VQ model

* fix var names

* no need to return text_encoder_hidden_states

* add latent_dim_scale to config

* split model into its own file

* add WuerschenPipeline to docs

* remove unused latent_size

* register latent_dim_scale

* update script

* update docstring

* use Attention preprocessor

* concat with normed input

* fix-copies

* add docs

* fix test

* fix style

* add to cpu_offloaded_model

* updated type

* remove 1-line func

* updated type

* initial decoder test

* formatting

* formatting

* fix autodoc link

* num_inference_steps is int

* remove comments

* fix example in docs

* Update src/diffusers/pipelines/wuerstchen/diffnext.py

Co-authored-by: Patrick von Platen <[email protected]>

* rename layernorm to WuerstchenLayerNorm

* rename DiffNext to WuerstchenDiffNeXt

* added comment about MixingResidualBlock

* move paella vq-vae to pipelines' folder

* initial decoder test

* increased test_float16_inference expected diff

* self_attn is always true

* more passing decoder tests

* batch image_embeds

* fix failing tests

* set the correct dtype

* relax inference test

* update prior

* added combined pipeline test

* faster test

* faster test

* Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py

Co-authored-by: Patrick von Platen <[email protected]>

* fix issues from review

* update wuerstchen.md + change generator name

* resolve issues

* fix copied from usage and add back batch_size

* fix API

* fix arguments

* fix combined test

* Added timesteps argument + fixes

* Update tests/pipelines/test_pipelines_common.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update tests/pipelines/wuerstchen/test_wuerstchen_prior.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py

Co-authored-by: Patrick von Platen <[email protected]>

* Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py

* up

* Fix more

* failing tests

* up

* up

* correct naming

* correct docs

* correct docs

* fix test params

* correct docs

* fix classifier free guidance

* fix classifier free guidance

* fix more

* fix all

* make tests faster

---------

Co-authored-by: Dominic Rampas <[email protected]>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Dominic Rampas <[email protected]>
  • Loading branch information
4 people authored Sep 6, 2023
1 parent bea99fc commit 12d09bc
Show file tree
Hide file tree
Showing 18 changed files with 1,955 additions and 15 deletions.
4 changes: 4 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
DDIMScheduler,
DDPMParallelScheduler,
DDPMScheduler,
DDPMWuerstchenScheduler,
DEISMultistepScheduler,
DPMSolverMultistepInverseScheduler,
DPMSolverMultistepScheduler,
Expand Down Expand Up @@ -216,6 +217,9 @@
VersatileDiffusionTextToImagePipeline,
VideoToVideoSDPipeline,
VQDiffusionPipeline,
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
)

try:
Expand Down
2 changes: 1 addition & 1 deletion models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def set_use_memory_efficient_attention_xformers(
if use_memory_efficient_attention_xformers:
if is_added_kv_processor and (is_lora or is_custom_diffusion):
raise NotImplementedError(
f"Memory efficient attention is currently not supported for LoRA or custom diffuson for attention processor type {self.processor}"
f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}"
)
if not is_xformers_available():
raise ModuleNotFoundError(
Expand Down
2 changes: 1 addition & 1 deletion models/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
super().__init__()
self.layers_per_block = layers_per_block

self.conv_in = torch.nn.Conv2d(
self.conv_in = nn.Conv2d(
in_channels,
block_out_channels[0],
kernel_size=3,
Expand Down
2 changes: 1 addition & 1 deletion models/vq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def decode(
) -> Union[DecoderOutput, torch.FloatTensor]:
# also go through quantization layer
if not force_not_quantize:
quant, emb_loss, info = self.quantize(h)
quant, _, _ = self.quantize(h)
else:
quant = h
quant2 = self.post_quant_conv(quant)
Expand Down
1 change: 1 addition & 0 deletions pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@
VersatileDiffusionTextToImagePipeline,
)
from .vq_diffusion import VQDiffusionPipeline
from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline, WuerstchenPriorPipeline


try:
Expand Down
15 changes: 3 additions & 12 deletions pipelines/auto_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline,
)
from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline


AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
Expand All @@ -63,6 +64,7 @@
("kandinsky22", KandinskyV22CombinedPipeline),
("stable-diffusion-controlnet", StableDiffusionControlNetPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline),
("wuerstchen", WuerstchenCombinedPipeline),
]
)

Expand Down Expand Up @@ -93,6 +95,7 @@
[
("kandinsky", KandinskyPipeline),
("kandinsky22", KandinskyV22Pipeline),
("wuerstchen", WuerstchenDecoderPipeline),
]
)
_AUTO_IMAGE2IMAGE_DECODER_PIPELINES_MAPPING = OrderedDict(
Expand Down Expand Up @@ -305,8 +308,6 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = kwargs.pop("user_agent", {})

load_config_kwargs = {
"cache_dir": cache_dir,
Expand All @@ -316,8 +317,6 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
"use_auth_token": use_auth_token,
"local_files_only": local_files_only,
"revision": revision,
"subfolder": subfolder,
"user_agent": user_agent,
}

config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
Expand Down Expand Up @@ -580,8 +579,6 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = kwargs.pop("user_agent", {})

load_config_kwargs = {
"cache_dir": cache_dir,
Expand All @@ -591,8 +588,6 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
"use_auth_token": use_auth_token,
"local_files_only": local_files_only,
"revision": revision,
"subfolder": subfolder,
"user_agent": user_agent,
}

config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
Expand Down Expand Up @@ -856,8 +851,6 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
use_auth_token = kwargs.pop("use_auth_token", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = kwargs.pop("user_agent", {})

load_config_kwargs = {
"cache_dir": cache_dir,
Expand All @@ -867,8 +860,6 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
"use_auth_token": use_auth_token,
"local_files_only": local_files_only,
"revision": revision,
"subfolder": subfolder,
"user_agent": user_agent,
}

config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
Expand Down
10 changes: 10 additions & 0 deletions pipelines/wuerstchen/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from ...utils import is_torch_available, is_transformers_available


if is_transformers_available() and is_torch_available():
from .modeling_paella_vq_model import PaellaVQModel
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
from .modeling_wuerstchen_prior import WuerstchenPrior
from .pipeline_wuerstchen import WuerstchenDecoderPipeline
from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline
from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline
172 changes: 172 additions & 0 deletions pipelines/wuerstchen/modeling_paella_vq_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
# Copyright (c) 2022 Dominic Rampas MIT License
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union

import torch
import torch.nn as nn

from ...configuration_utils import ConfigMixin, register_to_config
from ...models.modeling_utils import ModelMixin
from ...models.vae import DecoderOutput, VectorQuantizer
from ...models.vq_model import VQEncoderOutput
from ...utils import apply_forward_hook


class MixingResidualBlock(nn.Module):
"""
Residual block with mixing used by Paella's VQ-VAE.
"""

def __init__(self, inp_channels, embed_dim):
super().__init__()
# depthwise
self.norm1 = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6)
self.depthwise = nn.Sequential(
nn.ReplicationPad2d(1), nn.Conv2d(inp_channels, inp_channels, kernel_size=3, groups=inp_channels)
)

# channelwise
self.norm2 = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6)
self.channelwise = nn.Sequential(
nn.Linear(inp_channels, embed_dim), nn.GELU(), nn.Linear(embed_dim, inp_channels)
)

self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)

def forward(self, x):
mods = self.gammas
x_temp = self.norm1(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * (1 + mods[0]) + mods[1]
x = x + self.depthwise(x_temp) * mods[2]
x_temp = self.norm2(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * (1 + mods[3]) + mods[4]
x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5]
return x


class PaellaVQModel(ModelMixin, ConfigMixin):
r"""VQ-VAE model from Paella model.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the model (such as downloading or saving, etc.)
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
up_down_scale_factor (int, *optional*, defaults to 2): Up and Downscale factor of the input image.
levels (int, *optional*, defaults to 2): Number of levels in the model.
bottleneck_blocks (int, *optional*, defaults to 12): Number of bottleneck blocks in the model.
embed_dim (int, *optional*, defaults to 384): Number of hidden channels in the model.
latent_channels (int, *optional*, defaults to 4): Number of latent channels in the VQ-VAE model.
num_vq_embeddings (int, *optional*, defaults to 8192): Number of codebook vectors in the VQ-VAE.
scale_factor (float, *optional*, defaults to 0.3764): Scaling factor of the latent space.
"""

@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
up_down_scale_factor: int = 2,
levels: int = 2,
bottleneck_blocks: int = 12,
embed_dim: int = 384,
latent_channels: int = 4,
num_vq_embeddings: int = 8192,
scale_factor: float = 0.3764,
):
super().__init__()

c_levels = [embed_dim // (2**i) for i in reversed(range(levels))]
# Encoder blocks
self.in_block = nn.Sequential(
nn.PixelUnshuffle(up_down_scale_factor),
nn.Conv2d(in_channels * up_down_scale_factor**2, c_levels[0], kernel_size=1),
)
down_blocks = []
for i in range(levels):
if i > 0:
down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
block = MixingResidualBlock(c_levels[i], c_levels[i] * 4)
down_blocks.append(block)
down_blocks.append(
nn.Sequential(
nn.Conv2d(c_levels[-1], latent_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(latent_channels), # then normalize them to have mean 0 and std 1
)
)
self.down_blocks = nn.Sequential(*down_blocks)

# Vector Quantizer
self.vquantizer = VectorQuantizer(num_vq_embeddings, vq_embed_dim=latent_channels, legacy=False, beta=0.25)

# Decoder blocks
up_blocks = [nn.Sequential(nn.Conv2d(latent_channels, c_levels[-1], kernel_size=1))]
for i in range(levels):
for j in range(bottleneck_blocks if i == 0 else 1):
block = MixingResidualBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4)
up_blocks.append(block)
if i < levels - 1:
up_blocks.append(
nn.ConvTranspose2d(
c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, padding=1
)
)
self.up_blocks = nn.Sequential(*up_blocks)
self.out_block = nn.Sequential(
nn.Conv2d(c_levels[0], out_channels * up_down_scale_factor**2, kernel_size=1),
nn.PixelShuffle(up_down_scale_factor),
)

@apply_forward_hook
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
h = self.in_block(x)
h = self.down_blocks(h)

if not return_dict:
return (h,)

return VQEncoderOutput(latents=h)

@apply_forward_hook
def decode(
self, h: torch.FloatTensor, force_not_quantize: bool = True, return_dict: bool = True
) -> Union[DecoderOutput, torch.FloatTensor]:
if not force_not_quantize:
quant, _, _ = self.vquantizer(h)
else:
quant = h

x = self.up_blocks(quant)
dec = self.out_block(x)
if not return_dict:
return (dec,)

return DecoderOutput(sample=dec)

def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
h = self.encode(x).latents
dec = self.decode(h).sample

if not return_dict:
return (dec,)

return DecoderOutput(sample=dec)
Loading

0 comments on commit 12d09bc

Please sign in to comment.