forked from huggingface/diffusers
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* initial * initial * added initial convert script for paella vqmodel * initial wuerstchen pipeline * add LayerNorm2d * added modules * fix typo * use model_v2 * embed clip caption amd negative_caption * fixed name of var * initial modules in one place * WuerstchenPriorPipeline * inital shape * initial denoising prior loop * fix output * add WuerstchenPriorPipeline to __init__.py * use the noise ratio in the Prior * try to save pipeline * save_pretrained working * Few additions * add _execution_device * shape is int * fix batch size * fix shape of ratio * fix shape of ratio * fix output dataclass * tests folder * fix formatting * fix float16 + started with generator * Update pipeline_wuerstchen.py * removed vqgan code * add WuerstchenGeneratorPipeline * fix WuerstchenGeneratorPipeline * fix docstrings * fix imports * convert generator pipeline * fix convert * Work on Generator Pipeline. WIP * Pipeline works with our diffuzz code * apply scale factor * removed vqgan.py * use cosine schedule * redo the denoising loop * Update src/diffusers/models/resnet.py Co-authored-by: Patrick von Platen <[email protected]> * use torch.lerp * use warp-diffusion org * clip_sample=False, * some refactoring * use model_v3_stage_c * c_cond size * use clip-bigG * allow stage b clip to be None * add dummy * würstchen scheduler * minor changes * set clip=None in the pipeline * fix attention mask * add attention_masks to text_encoder * make fix-copies * add back clip * add text_encoder * gen_text_encoder and tokenizer * fix import * updated pipeline test * undo changes to pipeline test * nip * fix typo * fix output name * set guidance_scale=0 and remove diffuze * fix doc strings * make style * nip * removed unused * initial docs * rename * toc * cleanup * remvoe test script * fix-copies * fix multi images * remove dup * remove unused modules * undo changes for debugging * no new line * remove dup conversion script * fix doc string * cleanup * pass default args * dup permute * fix some tests * fix prepare_latents * move Prior class to modules * offload only the text encoder and vqgan * fix resolution calculation for prior * nip * removed testing script * fix shape * fix argument to set_timesteps * do not change .gitignore * fix resolution calculations + readme * resolution calculation fix + readme * small fixes * Add combined pipeline * rename generator -> decoder * Update .gitignore Co-authored-by: Patrick von Platen <[email protected]> * removed efficient_net * create combined WuerstchenPipeline * make arguments consistent with VQ model * fix var names * no need to return text_encoder_hidden_states * add latent_dim_scale to config * split model into its own file * add WuerschenPipeline to docs * remove unused latent_size * register latent_dim_scale * update script * update docstring * use Attention preprocessor * concat with normed input * fix-copies * add docs * fix test * fix style * add to cpu_offloaded_model * updated type * remove 1-line func * updated type * initial decoder test * formatting * formatting * fix autodoc link * num_inference_steps is int * remove comments * fix example in docs * Update src/diffusers/pipelines/wuerstchen/diffnext.py Co-authored-by: Patrick von Platen <[email protected]> * rename layernorm to WuerstchenLayerNorm * rename DiffNext to WuerstchenDiffNeXt * added comment about MixingResidualBlock * move paella vq-vae to pipelines' folder * initial decoder test * increased test_float16_inference expected diff * self_attn is always true * more passing decoder tests * batch image_embeds * fix failing tests * set the correct dtype * relax inference test * update prior * added combined pipeline test * faster test * faster test * Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py Co-authored-by: Patrick von Platen <[email protected]> * fix issues from review * update wuerstchen.md + change generator name * resolve issues * fix copied from usage and add back batch_size * fix API * fix arguments * fix combined test * Added timesteps argument + fixes * Update tests/pipelines/test_pipelines_common.py Co-authored-by: Patrick von Platen <[email protected]> * Update tests/pipelines/wuerstchen/test_wuerstchen_prior.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py * up * Fix more * failing tests * up * up * correct naming * correct docs * correct docs * fix test params * correct docs * fix classifier free guidance * fix classifier free guidance * fix more * fix all * make tests faster --------- Co-authored-by: Dominic Rampas <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Dominic Rampas <[email protected]>
- Loading branch information
1 parent
bea99fc
commit 12d09bc
Showing
18 changed files
with
1,955 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from ...utils import is_torch_available, is_transformers_available | ||
|
||
|
||
if is_transformers_available() and is_torch_available(): | ||
from .modeling_paella_vq_model import PaellaVQModel | ||
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt | ||
from .modeling_wuerstchen_prior import WuerstchenPrior | ||
from .pipeline_wuerstchen import WuerstchenDecoderPipeline | ||
from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline | ||
from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
# Copyright (c) 2022 Dominic Rampas MIT License | ||
# Copyright 2023 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Union | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
from ...configuration_utils import ConfigMixin, register_to_config | ||
from ...models.modeling_utils import ModelMixin | ||
from ...models.vae import DecoderOutput, VectorQuantizer | ||
from ...models.vq_model import VQEncoderOutput | ||
from ...utils import apply_forward_hook | ||
|
||
|
||
class MixingResidualBlock(nn.Module): | ||
""" | ||
Residual block with mixing used by Paella's VQ-VAE. | ||
""" | ||
|
||
def __init__(self, inp_channels, embed_dim): | ||
super().__init__() | ||
# depthwise | ||
self.norm1 = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6) | ||
self.depthwise = nn.Sequential( | ||
nn.ReplicationPad2d(1), nn.Conv2d(inp_channels, inp_channels, kernel_size=3, groups=inp_channels) | ||
) | ||
|
||
# channelwise | ||
self.norm2 = nn.LayerNorm(inp_channels, elementwise_affine=False, eps=1e-6) | ||
self.channelwise = nn.Sequential( | ||
nn.Linear(inp_channels, embed_dim), nn.GELU(), nn.Linear(embed_dim, inp_channels) | ||
) | ||
|
||
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) | ||
|
||
def forward(self, x): | ||
mods = self.gammas | ||
x_temp = self.norm1(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * (1 + mods[0]) + mods[1] | ||
x = x + self.depthwise(x_temp) * mods[2] | ||
x_temp = self.norm2(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * (1 + mods[3]) + mods[4] | ||
x = x + self.channelwise(x_temp.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) * mods[5] | ||
return x | ||
|
||
|
||
class PaellaVQModel(ModelMixin, ConfigMixin): | ||
r"""VQ-VAE model from Paella model. | ||
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library | ||
implements for all the model (such as downloading or saving, etc.) | ||
Parameters: | ||
in_channels (int, *optional*, defaults to 3): Number of channels in the input image. | ||
out_channels (int, *optional*, defaults to 3): Number of channels in the output. | ||
up_down_scale_factor (int, *optional*, defaults to 2): Up and Downscale factor of the input image. | ||
levels (int, *optional*, defaults to 2): Number of levels in the model. | ||
bottleneck_blocks (int, *optional*, defaults to 12): Number of bottleneck blocks in the model. | ||
embed_dim (int, *optional*, defaults to 384): Number of hidden channels in the model. | ||
latent_channels (int, *optional*, defaults to 4): Number of latent channels in the VQ-VAE model. | ||
num_vq_embeddings (int, *optional*, defaults to 8192): Number of codebook vectors in the VQ-VAE. | ||
scale_factor (float, *optional*, defaults to 0.3764): Scaling factor of the latent space. | ||
""" | ||
|
||
@register_to_config | ||
def __init__( | ||
self, | ||
in_channels: int = 3, | ||
out_channels: int = 3, | ||
up_down_scale_factor: int = 2, | ||
levels: int = 2, | ||
bottleneck_blocks: int = 12, | ||
embed_dim: int = 384, | ||
latent_channels: int = 4, | ||
num_vq_embeddings: int = 8192, | ||
scale_factor: float = 0.3764, | ||
): | ||
super().__init__() | ||
|
||
c_levels = [embed_dim // (2**i) for i in reversed(range(levels))] | ||
# Encoder blocks | ||
self.in_block = nn.Sequential( | ||
nn.PixelUnshuffle(up_down_scale_factor), | ||
nn.Conv2d(in_channels * up_down_scale_factor**2, c_levels[0], kernel_size=1), | ||
) | ||
down_blocks = [] | ||
for i in range(levels): | ||
if i > 0: | ||
down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1)) | ||
block = MixingResidualBlock(c_levels[i], c_levels[i] * 4) | ||
down_blocks.append(block) | ||
down_blocks.append( | ||
nn.Sequential( | ||
nn.Conv2d(c_levels[-1], latent_channels, kernel_size=1, bias=False), | ||
nn.BatchNorm2d(latent_channels), # then normalize them to have mean 0 and std 1 | ||
) | ||
) | ||
self.down_blocks = nn.Sequential(*down_blocks) | ||
|
||
# Vector Quantizer | ||
self.vquantizer = VectorQuantizer(num_vq_embeddings, vq_embed_dim=latent_channels, legacy=False, beta=0.25) | ||
|
||
# Decoder blocks | ||
up_blocks = [nn.Sequential(nn.Conv2d(latent_channels, c_levels[-1], kernel_size=1))] | ||
for i in range(levels): | ||
for j in range(bottleneck_blocks if i == 0 else 1): | ||
block = MixingResidualBlock(c_levels[levels - 1 - i], c_levels[levels - 1 - i] * 4) | ||
up_blocks.append(block) | ||
if i < levels - 1: | ||
up_blocks.append( | ||
nn.ConvTranspose2d( | ||
c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2, padding=1 | ||
) | ||
) | ||
self.up_blocks = nn.Sequential(*up_blocks) | ||
self.out_block = nn.Sequential( | ||
nn.Conv2d(c_levels[0], out_channels * up_down_scale_factor**2, kernel_size=1), | ||
nn.PixelShuffle(up_down_scale_factor), | ||
) | ||
|
||
@apply_forward_hook | ||
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput: | ||
h = self.in_block(x) | ||
h = self.down_blocks(h) | ||
|
||
if not return_dict: | ||
return (h,) | ||
|
||
return VQEncoderOutput(latents=h) | ||
|
||
@apply_forward_hook | ||
def decode( | ||
self, h: torch.FloatTensor, force_not_quantize: bool = True, return_dict: bool = True | ||
) -> Union[DecoderOutput, torch.FloatTensor]: | ||
if not force_not_quantize: | ||
quant, _, _ = self.vquantizer(h) | ||
else: | ||
quant = h | ||
|
||
x = self.up_blocks(quant) | ||
dec = self.out_block(x) | ||
if not return_dict: | ||
return (dec,) | ||
|
||
return DecoderOutput(sample=dec) | ||
|
||
def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: | ||
r""" | ||
Args: | ||
sample (`torch.FloatTensor`): Input sample. | ||
return_dict (`bool`, *optional*, defaults to `True`): | ||
Whether or not to return a [`DecoderOutput`] instead of a plain tuple. | ||
""" | ||
x = sample | ||
h = self.encode(x).latents | ||
dec = self.decode(h).sample | ||
|
||
if not return_dict: | ||
return (dec,) | ||
|
||
return DecoderOutput(sample=dec) |
Oops, something went wrong.