From 4039815276215ea0ff8ce7ac6670dc9dbe08f817 Mon Sep 17 00:00:00 2001 From: Will Berman Date: Thu, 21 Dec 2023 11:40:55 -0800 Subject: [PATCH] open muse (#5437) amused rename Update docs/source/en/api/pipelines/amused.md Co-authored-by: Patrick von Platen AdaLayerNormContinuous default values custom micro conditioning micro conditioning docs put lookup from codebook in constructor fix conversion script remove manual fused flash attn kernel add training script temp remove training script add dummy gradient checkpointing func clarify temperatures is an instance variable by setting it remove additional SkipFF block args hardcode norm args rename tests folder fix paths and samples fix tests add training script training readme lora saving and loading non-lora saving/loading some readme fixes guards Update docs/source/en/api/pipelines/amused.md Co-authored-by: Suraj Patil Update examples/amused/README.md Co-authored-by: Suraj Patil Update examples/amused/train_amused.py Co-authored-by: Suraj Patil vae upcasting add fp16 integration tests use tuple for micro cond copyrights remove casts delegate to torch.nn.LayerNorm move temperature to pipeline call upsampling/downsampling changes --- docs/source/en/_toctree.yml | 2 + docs/source/en/api/pipelines/amused.md | 30 + examples/amused/README.md | 326 ++++++ examples/amused/train_amused.py | 972 ++++++++++++++++++ scripts/convert_amused.py | 523 ++++++++++ src/diffusers/__init__.py | 10 + src/diffusers/loaders/lora.py | 95 +- src/diffusers/models/__init__.py | 2 + src/diffusers/models/attention.py | 139 ++- src/diffusers/models/autoencoders/vae.py | 4 + src/diffusers/models/downsampling.py | 22 +- src/diffusers/models/embeddings.py | 5 +- src/diffusers/models/normalization.py | 106 ++ src/diffusers/models/upsampling.py | 40 +- src/diffusers/models/uvit_2d.py | 471 +++++++++ src/diffusers/models/vq_model.py | 9 +- src/diffusers/pipelines/__init__.py | 2 + src/diffusers/pipelines/amused/__init__.py | 62 ++ .../pipelines/amused/pipeline_amused.py | 328 ++++++ .../amused/pipeline_amused_img2img.py | 347 +++++++ .../amused/pipeline_amused_inpaint.py | 378 +++++++ src/diffusers/schedulers/__init__.py | 2 + src/diffusers/schedulers/scheduling_amused.py | 162 +++ src/diffusers/utils/dummy_pt_objects.py | 30 + .../dummy_torch_and_transformers_objects.py | 45 + tests/pipelines/amused/__init__.py | 0 tests/pipelines/amused/test_amused.py | 181 ++++ tests/pipelines/amused/test_amused_img2img.py | 239 +++++ tests/pipelines/amused/test_amused_inpaint.py | 277 +++++ tests/pipelines/test_pipelines_common.py | 4 +- 30 files changed, 4789 insertions(+), 24 deletions(-) create mode 100644 docs/source/en/api/pipelines/amused.md create mode 100644 examples/amused/README.md create mode 100644 examples/amused/train_amused.py create mode 100644 scripts/convert_amused.py create mode 100644 src/diffusers/models/uvit_2d.py create mode 100644 src/diffusers/pipelines/amused/__init__.py create mode 100644 src/diffusers/pipelines/amused/pipeline_amused.py create mode 100644 src/diffusers/pipelines/amused/pipeline_amused_img2img.py create mode 100644 src/diffusers/pipelines/amused/pipeline_amused_inpaint.py create mode 100644 src/diffusers/schedulers/scheduling_amused.py create mode 100644 tests/pipelines/amused/__init__.py create mode 100644 tests/pipelines/amused/test_amused.py create mode 100644 tests/pipelines/amused/test_amused_img2img.py create mode 100644 tests/pipelines/amused/test_amused_inpaint.py diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 62588bf4abb8..3e9e83e6512e 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -244,6 +244,8 @@ - sections: - local: api/pipelines/overview title: Overview + - local: api/pipelines/amused + title: aMUSEd - local: api/pipelines/animatediff title: AnimateDiff - local: api/pipelines/attend_and_excite diff --git a/docs/source/en/api/pipelines/amused.md b/docs/source/en/api/pipelines/amused.md new file mode 100644 index 000000000000..cb8693802173 --- /dev/null +++ b/docs/source/en/api/pipelines/amused.md @@ -0,0 +1,30 @@ + + +# aMUSEd + +Amused is a lightweight text to image model based off of the [muse](https://arxiv.org/pdf/2301.00704.pdf) architecture. Amused is particularly useful in applications that require a lightweight and fast model such as generating many images quickly at once. + +Amused is a vqvae token based transformer that can generate an image in fewer forward passes than many diffusion models. In contrast with muse, it uses the smaller text encoder CLIP-L/14 instead of t5-xxl. Due to its small parameter count and few forward pass generation process, amused can generate many images quickly. This benefit is seen particularly at larger batch sizes. + +| Model | Params | +|-------|--------| +| [amused-256](https://huggingface.co/huggingface/amused-256) | 603M | +| [amused-512](https://huggingface.co/huggingface/amused-512) | 608M | + +## AmusedPipeline + +[[autodoc]] AmusedPipeline + - __call__ + - all + - enable_xformers_memory_efficient_attention + - disable_xformers_memory_efficient_attention \ No newline at end of file diff --git a/examples/amused/README.md b/examples/amused/README.md new file mode 100644 index 000000000000..517c2d382f8e --- /dev/null +++ b/examples/amused/README.md @@ -0,0 +1,326 @@ +## Amused training + +Amused can be finetuned on simple datasets relatively cheaply and quickly. Using 8bit optimizers, lora, and gradient accumulation, amused can be finetuned with as little as 5.5 GB. Here are a set of examples for finetuning amused on some relatively simple datasets. These training recipies are aggressively oriented towards minimal resources and fast verification -- i.e. the batch sizes are quite low and the learning rates are quite high. For optimal quality, you will probably want to increase the batch sizes and decrease learning rates. + +All training examples use fp16 mixed precision and gradient checkpointing. We don't show 8 bit adam + lora as its about the same memory use as just using lora (bitsandbytes uses full precision optimizer states for weights below a minimum size). + +### Finetuning the 256 checkpoint + +These examples finetune on this [nouns](https://huggingface.co/datasets/m1guelpf/nouns) dataset. + +Example results: + +![noun1](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/noun1.png) ![noun2](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/noun2.png) ![noun3](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/noun3.png) + + +#### Full finetuning + +Batch size: 8, Learning rate: 1e-4, Gives decent results in 750-1000 steps + +| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used | +|------------|-----------------------------|------------------|-------------| +| 8 | 1 | 8 | 19.7 GB | +| 4 | 2 | 8 | 18.3 GB | +| 1 | 8 | 8 | 17.9 GB | + +```sh +accelerate launch train_amused.py \ + --output_dir \ + --train_batch_size \ + --gradient_accumulation_steps \ + --learning_rate 1e-4 \ + --pretrained_model_name_or_path huggingface/amused-256 \ + --instance_data_dataset 'm1guelpf/nouns' \ + --image_key image \ + --prompt_key text \ + --resolution 256 \ + --mixed_precision fp16 \ + --lr_scheduler constant \ + --validation_prompts \ + 'a pixel art character with square red glasses, a baseball-shaped head and a orange-colored body on a dark background' \ + 'a pixel art character with square orange glasses, a lips-shaped head and a red-colored body on a light background' \ + 'a pixel art character with square blue glasses, a microwave-shaped head and a purple-colored body on a sunny background' \ + 'a pixel art character with square red glasses, a baseball-shaped head and a blue-colored body on an orange background' \ + 'a pixel art character with square red glasses' \ + 'a pixel art character' \ + 'square red glasses on a pixel art character' \ + 'square red glasses on a pixel art character with a baseball-shaped head' \ + --max_train_steps 10000 \ + --checkpointing_steps 500 \ + --validation_steps 250 \ + --gradient_checkpointing +``` + +#### Full finetuning + 8 bit adam + +Note that this training config keeps the batch size low and the learning rate high to get results fast with low resources. However, due to 8 bit adam, it will diverge eventually. If you want to train for longer, you will have to up the batch size and lower the learning rate. + +Batch size: 16, Learning rate: 2e-5, Gives decent results in ~750 steps + +| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used | +|------------|-----------------------------|------------------|-------------| +| 16 | 1 | 16 | 20.1 GB | +| 8 | 2 | 16 | 15.6 GB | +| 1 | 16 | 16 | 10.7 GB | + +```sh +accelerate launch train_amused.py \ + --output_dir \ + --train_batch_size \ + --gradient_accumulation_steps \ + --learning_rate 2e-5 \ + --use_8bit_adam \ + --pretrained_model_name_or_path huggingface/amused-256 \ + --instance_data_dataset 'm1guelpf/nouns' \ + --image_key image \ + --prompt_key text \ + --resolution 256 \ + --mixed_precision fp16 \ + --lr_scheduler constant \ + --validation_prompts \ + 'a pixel art character with square red glasses, a baseball-shaped head and a orange-colored body on a dark background' \ + 'a pixel art character with square orange glasses, a lips-shaped head and a red-colored body on a light background' \ + 'a pixel art character with square blue glasses, a microwave-shaped head and a purple-colored body on a sunny background' \ + 'a pixel art character with square red glasses, a baseball-shaped head and a blue-colored body on an orange background' \ + 'a pixel art character with square red glasses' \ + 'a pixel art character' \ + 'square red glasses on a pixel art character' \ + 'square red glasses on a pixel art character with a baseball-shaped head' \ + --max_train_steps 10000 \ + --checkpointing_steps 500 \ + --validation_steps 250 \ + --gradient_checkpointing +``` + +#### Full finetuning + lora + +Batch size: 16, Learning rate: 8e-4, Gives decent results in 1000-1250 steps + +| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used | +|------------|-----------------------------|------------------|-------------| +| 16 | 1 | 16 | 14.1 GB | +| 8 | 2 | 16 | 10.1 GB | +| 1 | 16 | 16 | 6.5 GB | + +```sh +accelerate launch train_amused.py \ + --output_dir \ + --train_batch_size \ + --gradient_accumulation_steps \ + --learning_rate 8e-4 \ + --use_lora \ + --pretrained_model_name_or_path huggingface/amused-256 \ + --instance_data_dataset 'm1guelpf/nouns' \ + --image_key image \ + --prompt_key text \ + --resolution 256 \ + --mixed_precision fp16 \ + --lr_scheduler constant \ + --validation_prompts \ + 'a pixel art character with square red glasses, a baseball-shaped head and a orange-colored body on a dark background' \ + 'a pixel art character with square orange glasses, a lips-shaped head and a red-colored body on a light background' \ + 'a pixel art character with square blue glasses, a microwave-shaped head and a purple-colored body on a sunny background' \ + 'a pixel art character with square red glasses, a baseball-shaped head and a blue-colored body on an orange background' \ + 'a pixel art character with square red glasses' \ + 'a pixel art character' \ + 'square red glasses on a pixel art character' \ + 'square red glasses on a pixel art character with a baseball-shaped head' \ + --max_train_steps 10000 \ + --checkpointing_steps 500 \ + --validation_steps 250 \ + --gradient_checkpointing +``` + +### Finetuning the 512 checkpoint + +These examples finetune on this [minecraft](https://huggingface.co/monadical-labs/minecraft-preview) dataset. + +Example results: + +![minecraft1](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/minecraft1.png) ![minecraft2](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/minecraft2.png) ![minecraft3](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/minecraft3.png) + +#### Full finetuning + +Batch size: 8, Learning rate: 8e-5, Gives decent results in 500-1000 steps + +| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used | +|------------|-----------------------------|------------------|-------------| +| 8 | 1 | 8 | 24.2 GB | +| 4 | 2 | 8 | 19.7 GB | +| 1 | 8 | 8 | 16.99 GB | + +```sh +accelerate launch train_amused.py \ + --output_dir \ + --train_batch_size \ + --gradient_accumulation_steps \ + --learning_rate 8e-5 \ + --pretrained_model_name_or_path huggingface/amused-512 \ + --instance_data_dataset 'monadical-labs/minecraft-preview' \ + --prompt_prefix 'minecraft ' \ + --image_key image \ + --prompt_key text \ + --resolution 512 \ + --mixed_precision fp16 \ + --lr_scheduler constant \ + --validation_prompts \ + 'minecraft Avatar' \ + 'minecraft character' \ + 'minecraft' \ + 'minecraft president' \ + 'minecraft pig' \ + --max_train_steps 10000 \ + --checkpointing_steps 500 \ + --validation_steps 250 \ + --gradient_checkpointing +``` + +#### Full finetuning + 8 bit adam + +Batch size: 8, Learning rate: 5e-6, Gives decent results in 500-1000 steps + +| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used | +|------------|-----------------------------|------------------|-------------| +| 8 | 1 | 8 | 21.2 GB | +| 4 | 2 | 8 | 13.3 GB | +| 1 | 8 | 8 | 9.9 GB | + +```sh +accelerate launch train_amused.py \ + --output_dir \ + --train_batch_size \ + --gradient_accumulation_steps \ + --learning_rate 5e-6 \ + --pretrained_model_name_or_path huggingface/amused-512 \ + --instance_data_dataset 'monadical-labs/minecraft-preview' \ + --prompt_prefix 'minecraft ' \ + --image_key image \ + --prompt_key text \ + --resolution 512 \ + --mixed_precision fp16 \ + --lr_scheduler constant \ + --validation_prompts \ + 'minecraft Avatar' \ + 'minecraft character' \ + 'minecraft' \ + 'minecraft president' \ + 'minecraft pig' \ + --max_train_steps 10000 \ + --checkpointing_steps 500 \ + --validation_steps 250 \ + --gradient_checkpointing +``` + +#### Full finetuning + lora + +Batch size: 8, Learning rate: 1e-4, Gives decent results in 500-1000 steps + +| Batch Size | Gradient Accumulation Steps | Effective Total Batch Size | Memory Used | +|------------|-----------------------------|------------------|-------------| +| 8 | 1 | 8 | 12.7 GB | +| 4 | 2 | 8 | 9.0 GB | +| 1 | 8 | 8 | 5.6 GB | + +```sh +accelerate launch train_amused.py \ + --output_dir \ + --train_batch_size \ + --gradient_accumulation_steps \ + --learning_rate 1e-4 \ + --use_lora \ + --pretrained_model_name_or_path huggingface/amused-512 \ + --instance_data_dataset 'monadical-labs/minecraft-preview' \ + --prompt_prefix 'minecraft ' \ + --image_key image \ + --prompt_key text \ + --resolution 512 \ + --mixed_precision fp16 \ + --lr_scheduler constant \ + --validation_prompts \ + 'minecraft Avatar' \ + 'minecraft character' \ + 'minecraft' \ + 'minecraft president' \ + 'minecraft pig' \ + --max_train_steps 10000 \ + --checkpointing_steps 500 \ + --validation_steps 250 \ + --gradient_checkpointing +``` + +### Styledrop + +[Styledrop](https://arxiv.org/abs/2306.00983) is an efficient finetuning method for learning a new style from just one or very few images. It has an optional first stage to generate human picked additional training samples. The additional training samples can be used to augment the initial images. Our examples exclude the optional additional image selection stage and instead we just finetune on a single image. + +This is our example style image: +![example](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/A%20mushroom%20in%20%5BV%5D%20style.png) + +Download it to your local directory with +```sh +wget https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/A%20mushroom%20in%20%5BV%5D%20style.png +``` + +#### 256 + +Example results: + +![glowing_256_1](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_256_1.png) ![glowing_256_2](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_256_2.png) ![glowing_256_3](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_256_3.png) + +Learning rate: 4e-4, Gives decent results in 1500-2000 steps + +Memory used: 6.5 GB + +```sh +accelerate launch train_amused.py \ + --output_dir \ + --mixed_precision fp16 \ + --report_to wandb \ + --use_lora \ + --pretrained_model_name_or_path huggingface/amused-256 \ + --train_batch_size 1 \ + --lr_scheduler constant \ + --learning_rate 4e-4 \ + --validation_prompts \ + 'A chihuahua walking on the street in [V] style' \ + 'A banana on the table in [V] style' \ + 'A church on the street in [V] style' \ + 'A tabby cat walking in the forest in [V] style' \ + --instance_data_image 'A mushroom in [V] style.png' \ + --max_train_steps 10000 \ + --checkpointing_steps 500 \ + --validation_steps 100 \ + --resolution 256 +``` + +#### 512 + +Example results: + +![glowing_512_1](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_512_1.png) ![glowing_512_2](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_512_2.png) ![glowing_512_3](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/amused/glowing_512_3.png) + +Learning rate: 1e-3, Lora alpha 1, Gives decent results in 1500-2000 steps + +Memory used: 5.6 GB + +``` +accelerate launch train_amused.py \ + --output_dir \ + --mixed_precision fp16 \ + --report_to wandb \ + --use_lora \ + --pretrained_model_name_or_path huggingface/amused-512 \ + --train_batch_size 1 \ + --lr_scheduler constant \ + --learning_rate 1e-3 \ + --validation_prompts \ + 'A chihuahua walking on the street in [V] style' \ + 'A banana on the table in [V] style' \ + 'A church on the street in [V] style' \ + 'A tabby cat walking in the forest in [V] style' \ + --instance_data_image 'A mushroom in [V] style.png' \ + --max_train_steps 100000 \ + --checkpointing_steps 500 \ + --validation_steps 100 \ + --resolution 512 \ + --lora_alpha 1 +``` \ No newline at end of file diff --git a/examples/amused/train_amused.py b/examples/amused/train_amused.py new file mode 100644 index 000000000000..7ae7088d66d8 --- /dev/null +++ b/examples/amused/train_amused.py @@ -0,0 +1,972 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import copy +import logging +import math +import os +import shutil +from contextlib import nullcontext +from pathlib import Path + +import torch +import torch.nn.functional as F +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from peft import LoraConfig +from peft.utils import get_peft_model_state_dict +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import DataLoader, Dataset, default_collate +from torchvision import transforms +from transformers import ( + CLIPTextModelWithProjection, + CLIPTokenizer, +) + +import diffusers.optimization +from diffusers import AmusedPipeline, AmusedScheduler, EMAModel, UVit2DModel, VQModel +from diffusers.loaders import LoraLoaderMixin +from diffusers.utils import is_wandb_available + + +if is_wandb_available(): + import wandb + +logger = get_logger(__name__, log_level="INFO") + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--instance_data_dataset", + type=str, + default=None, + required=False, + help="A Hugging Face dataset containing the training images", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--instance_data_image", type=str, default=None, required=False, help="A single training image" + ) + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") + parser.add_argument("--ema_decay", type=float, default=0.9999) + parser.add_argument("--ema_update_after_step", type=int, default=0) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument( + "--output_dir", + type=str, + default="muse_training", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " + "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." + "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." + "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" + "instructions." + ), + ) + parser.add_argument( + "--logging_steps", + type=int, + default=50, + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=( + "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." + " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" + " for more details" + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=0.0003, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="wandb", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument("--validation_prompts", type=str, nargs="*") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument("--split_vae_encode", type=int, required=False, default=None) + parser.add_argument("--min_masking_rate", type=float, default=0.0) + parser.add_argument("--cond_dropout_prob", type=float, default=0.0) + parser.add_argument("--max_grad_norm", default=None, type=float, help="Max gradient norm.", required=False) + parser.add_argument("--use_lora", action="store_true", help="Fine tune the model using LoRa") + parser.add_argument("--text_encoder_use_lora", action="store_true", help="Fine tune the model using LoRa") + parser.add_argument("--lora_r", default=16, type=int) + parser.add_argument("--lora_alpha", default=32, type=int) + parser.add_argument("--lora_target_modules", default=["to_q", "to_k", "to_v"], type=str, nargs="+") + parser.add_argument("--text_encoder_lora_r", default=16, type=int) + parser.add_argument("--text_encoder_lora_alpha", default=32, type=int) + parser.add_argument("--text_encoder_lora_target_modules", default=["to_q", "to_k", "to_v"], type=str, nargs="+") + parser.add_argument("--train_text_encoder", action="store_true") + parser.add_argument("--image_key", type=str, required=False) + parser.add_argument("--prompt_key", type=str, required=False) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument("--prompt_prefix", type=str, required=False, default=None) + + args = parser.parse_args() + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + num_datasources = sum( + [x is not None for x in [args.instance_data_dir, args.instance_data_image, args.instance_data_dataset]] + ) + + if num_datasources != 1: + raise ValueError( + "provide one and only one of `--instance_data_dir`, `--instance_data_image`, or `--instance_data_dataset`" + ) + + if args.instance_data_dir is not None: + if not os.path.exists(args.instance_data_dir): + raise ValueError(f"Does not exist: `--args.instance_data_dir` {args.instance_data_dir}") + + if args.instance_data_image is not None: + if not os.path.exists(args.instance_data_image): + raise ValueError(f"Does not exist: `--args.instance_data_image` {args.instance_data_image}") + + if args.instance_data_dataset is not None and (args.image_key is None or args.prompt_key is None): + raise ValueError("`--instance_data_dataset` requires setting `--image_key` and `--prompt_key`") + + return args + + +class InstanceDataRootDataset(Dataset): + def __init__( + self, + instance_data_root, + tokenizer, + size=512, + ): + self.size = size + self.tokenizer = tokenizer + self.instance_images_path = list(Path(instance_data_root).iterdir()) + + def __len__(self): + return len(self.instance_images_path) + + def __getitem__(self, index): + image_path = self.instance_images_path[index % len(self.instance_images_path)] + instance_image = Image.open(image_path) + rv = process_image(instance_image, self.size) + + prompt = os.path.splitext(os.path.basename(image_path))[0] + rv["prompt_input_ids"] = tokenize_prompt(self.tokenizer, prompt)[0] + return rv + + +class InstanceDataImageDataset(Dataset): + def __init__( + self, + instance_data_image, + train_batch_size, + size=512, + ): + self.value = process_image(Image.open(instance_data_image), size) + self.train_batch_size = train_batch_size + + def __len__(self): + # Needed so a full batch of the data can be returned. Otherwise will return + # batches of size 1 + return self.train_batch_size + + def __getitem__(self, index): + return self.value + + +class HuggingFaceDataset(Dataset): + def __init__( + self, + hf_dataset, + tokenizer, + image_key, + prompt_key, + prompt_prefix=None, + size=512, + ): + self.size = size + self.image_key = image_key + self.prompt_key = prompt_key + self.tokenizer = tokenizer + self.hf_dataset = hf_dataset + self.prompt_prefix = prompt_prefix + + def __len__(self): + return len(self.hf_dataset) + + def __getitem__(self, index): + item = self.hf_dataset[index] + + rv = process_image(item[self.image_key], self.size) + + prompt = item[self.prompt_key] + + if self.prompt_prefix is not None: + prompt = self.prompt_prefix + prompt + + rv["prompt_input_ids"] = tokenize_prompt(self.tokenizer, prompt)[0] + + return rv + + +def process_image(image, size): + image = exif_transpose(image) + + if not image.mode == "RGB": + image = image.convert("RGB") + + orig_height = image.height + orig_width = image.width + + image = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)(image) + + c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(size, size)) + image = transforms.functional.crop(image, c_top, c_left, size, size) + + image = transforms.ToTensor()(image) + + micro_conds = torch.tensor( + [orig_width, orig_height, c_top, c_left, 6.0], + ) + + return {"image": image, "micro_conds": micro_conds} + + +def tokenize_prompt(tokenizer, prompt): + return tokenizer( + prompt, + truncation=True, + padding="max_length", + max_length=77, + return_tensors="pt", + ).input_ids + + +def encode_prompt(text_encoder, input_ids): + outputs = text_encoder(input_ids, return_dict=True, output_hidden_states=True) + encoder_hidden_states = outputs.hidden_states[-2] + cond_embeds = outputs[0] + return encoder_hidden_states, cond_embeds + + +def main(args): + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + if accelerator.is_main_process: + os.makedirs(args.output_dir, exist_ok=True) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + + if accelerator.is_main_process: + accelerator.init_trackers("amused", config=vars(copy.deepcopy(args))) + + if args.seed is not None: + set_seed(args.seed) + + # TODO - will have to fix loading if training text encoder + text_encoder = CLIPTextModelWithProjection.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, variant=args.variant + ) + vq_model = VQModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vqvae", revision=args.revision, variant=args.variant + ) + + if args.train_text_encoder: + if args.text_encoder_use_lora: + lora_config = LoraConfig( + r=args.text_encoder_lora_r, + lora_alpha=args.text_encoder_lora_alpha, + target_modules=args.text_encoder_lora_target_modules, + ) + text_encoder.add_adapter(lora_config) + text_encoder.train() + text_encoder.requires_grad_(True) + else: + text_encoder.eval() + text_encoder.requires_grad_(False) + + vq_model.requires_grad_(False) + + model = UVit2DModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="transformer", + revision=args.revision, + variant=args.variant, + ) + + if args.use_lora: + lora_config = LoraConfig( + r=args.lora_r, + lora_alpha=args.lora_alpha, + target_modules=args.lora_target_modules, + ) + model.add_adapter(lora_config) + + model.train() + + if args.gradient_checkpointing: + model.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder.gradient_checkpointing_enable() + + if args.use_ema: + ema = EMAModel( + model.parameters(), + decay=args.ema_decay, + update_after_step=args.ema_update_after_step, + model_cls=UVit2DModel, + model_config=model.config, + ) + + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + transformer_lora_layers_to_save = None + text_encoder_lora_layers_to_save = None + + for model_ in models: + if isinstance(model_, type(accelerator.unwrap_model(model))): + if args.use_lora: + transformer_lora_layers_to_save = get_peft_model_state_dict(model_) + else: + model_.save_pretrained(os.path.join(output_dir, "transformer")) + elif isinstance(model_, type(accelerator.unwrap_model(text_encoder))): + if args.text_encoder_use_lora: + text_encoder_lora_layers_to_save = get_peft_model_state_dict(model_) + else: + model_.save_pretrained(os.path.join(output_dir, "text_encoder")) + else: + raise ValueError(f"unexpected save model: {model_.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + if transformer_lora_layers_to_save is not None or text_encoder_lora_layers_to_save is not None: + LoraLoaderMixin.save_lora_weights( + output_dir, + transformer_lora_layers=transformer_lora_layers_to_save, + text_encoder_lora_layers=text_encoder_lora_layers_to_save, + ) + + if args.use_ema: + ema.save_pretrained(os.path.join(output_dir, "ema_model")) + + def load_model_hook(models, input_dir): + transformer = None + text_encoder_ = None + + while len(models) > 0: + model_ = models.pop() + + if isinstance(model_, type(accelerator.unwrap_model(model))): + if args.use_lora: + transformer = model_ + else: + load_model = UVit2DModel.from_pretrained(os.path.join(input_dir, "transformer")) + model_.load_state_dict(load_model.state_dict()) + del load_model + elif isinstance(model, type(accelerator.unwrap_model(text_encoder))): + if args.text_encoder_use_lora: + text_encoder_ = model_ + else: + load_model = CLIPTextModelWithProjection.from_pretrained(os.path.join(input_dir, "text_encoder")) + model_.load_state_dict(load_model.state_dict()) + del load_model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + if transformer is not None or text_encoder_ is not None: + lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) + LoraLoaderMixin.load_lora_into_text_encoder( + lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_ + ) + LoraLoaderMixin.load_lora_into_transformer( + lora_state_dict, network_alphas=network_alphas, transformer=transformer + ) + + if args.use_ema: + load_from = EMAModel.from_pretrained(os.path.join(input_dir, "ema_model"), model_cls=UVit2DModel) + ema.load_state_dict(load_from.state_dict()) + del load_from + + accelerator.register_load_state_pre_hook(load_model_hook) + accelerator.register_save_state_pre_hook(save_model_hook) + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + ) + + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + # no decay on bias and layernorm and embedding + no_decay = ["bias", "layer_norm.weight", "mlm_ln.weight", "embeddings.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": args.adam_weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + + if args.train_text_encoder: + optimizer_grouped_parameters.append( + {"params": text_encoder.parameters(), "weight_decay": args.adam_weight_decay} + ) + + optimizer = optimizer_cls( + optimizer_grouped_parameters, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + logger.info("Creating dataloaders and lr_scheduler") + + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + if args.instance_data_dir is not None: + dataset = InstanceDataRootDataset( + instance_data_root=args.instance_data_dir, + tokenizer=tokenizer, + size=args.resolution, + ) + elif args.instance_data_image is not None: + dataset = InstanceDataImageDataset( + instance_data_image=args.instance_data_image, + train_batch_size=args.train_batch_size, + size=args.resolution, + ) + elif args.instance_data_dataset is not None: + dataset = HuggingFaceDataset( + hf_dataset=load_dataset(args.instance_data_dataset, split="train"), + tokenizer=tokenizer, + image_key=args.image_key, + prompt_key=args.prompt_key, + prompt_prefix=args.prompt_prefix, + size=args.resolution, + ) + else: + assert False + + train_dataloader = DataLoader( + dataset, + batch_size=args.train_batch_size, + shuffle=True, + num_workers=args.dataloader_num_workers, + collate_fn=default_collate, + ) + train_dataloader.num_batches = len(train_dataloader) + + lr_scheduler = diffusers.optimization.get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + ) + + logger.info("Preparing model, optimizer and dataloaders") + + if args.train_text_encoder: + model, optimizer, lr_scheduler, train_dataloader, text_encoder = accelerator.prepare( + model, optimizer, lr_scheduler, train_dataloader, text_encoder + ) + else: + model, optimizer, lr_scheduler, train_dataloader = accelerator.prepare( + model, optimizer, lr_scheduler, train_dataloader + ) + + train_dataloader.num_batches = len(train_dataloader) + + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + if not args.train_text_encoder: + text_encoder.to(device=accelerator.device, dtype=weight_dtype) + + vq_model.to(device=accelerator.device) + + if args.use_ema: + ema.to(accelerator.device) + + with nullcontext() if args.train_text_encoder else torch.no_grad(): + empty_embeds, empty_clip_embeds = encode_prompt( + text_encoder, tokenize_prompt(tokenizer, "").to(text_encoder.device, non_blocking=True) + ) + + # There is a single image, we can just pre-encode the single prompt + if args.instance_data_image is not None: + prompt = os.path.splitext(os.path.basename(args.instance_data_image))[0] + encoder_hidden_states, cond_embeds = encode_prompt( + text_encoder, tokenize_prompt(tokenizer, prompt).to(text_encoder.device, non_blocking=True) + ) + encoder_hidden_states = encoder_hidden_states.repeat(args.train_batch_size, 1, 1) + cond_embeds = cond_embeds.repeat(args.train_batch_size, 1) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps) + # Afterwards we recalculate our number of training epochs. + # Note: We are not doing epoch based training here, but just using this for book keeping and being able to + # reuse the same training loop with other datasets/loaders. + num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # Train! + logger.info("***** Running training *****") + logger.info(f" Num training steps = {args.max_train_steps}") + logger.info(f" Instantaneous batch size per device = { args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + + resume_from_checkpoint = args.resume_from_checkpoint + if resume_from_checkpoint: + if resume_from_checkpoint == "latest": + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + if len(dirs) > 0: + resume_from_checkpoint = os.path.join(args.output_dir, dirs[-1]) + else: + resume_from_checkpoint = None + + if resume_from_checkpoint is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + else: + accelerator.print(f"Resuming from checkpoint {resume_from_checkpoint}") + + if resume_from_checkpoint is None: + global_step = 0 + first_epoch = 0 + else: + accelerator.load_state(resume_from_checkpoint) + global_step = int(os.path.basename(resume_from_checkpoint).split("-")[1]) + first_epoch = global_step // num_update_steps_per_epoch + + # As stated above, we are not doing epoch based training here, but just using this for book keeping and being able to + # reuse the same training loop with other datasets/loaders. + for epoch in range(first_epoch, num_train_epochs): + for batch in train_dataloader: + with torch.no_grad(): + micro_conds = batch["micro_conds"].to(accelerator.device, non_blocking=True) + pixel_values = batch["image"].to(accelerator.device, non_blocking=True) + + batch_size = pixel_values.shape[0] + + split_batch_size = args.split_vae_encode if args.split_vae_encode is not None else batch_size + num_splits = math.ceil(batch_size / split_batch_size) + image_tokens = [] + for i in range(num_splits): + start_idx = i * split_batch_size + end_idx = min((i + 1) * split_batch_size, batch_size) + bs = pixel_values.shape[0] + image_tokens.append( + vq_model.quantize(vq_model.encode(pixel_values[start_idx:end_idx]).latents)[2][2].reshape( + bs, -1 + ) + ) + image_tokens = torch.cat(image_tokens, dim=0) + + batch_size, seq_len = image_tokens.shape + + timesteps = torch.rand(batch_size, device=image_tokens.device) + mask_prob = torch.cos(timesteps * math.pi * 0.5) + mask_prob = mask_prob.clip(args.min_masking_rate) + + num_token_masked = (seq_len * mask_prob).round().clamp(min=1) + batch_randperm = torch.rand(batch_size, seq_len, device=image_tokens.device).argsort(dim=-1) + mask = batch_randperm < num_token_masked.unsqueeze(-1) + + mask_id = accelerator.unwrap_model(model).config.vocab_size - 1 + input_ids = torch.where(mask, mask_id, image_tokens) + labels = torch.where(mask, image_tokens, -100) + + if args.cond_dropout_prob > 0.0: + assert encoder_hidden_states is not None + + batch_size = encoder_hidden_states.shape[0] + + mask = ( + torch.zeros((batch_size, 1, 1), device=encoder_hidden_states.device).float().uniform_(0, 1) + < args.cond_dropout_prob + ) + + empty_embeds_ = empty_embeds.expand(batch_size, -1, -1) + encoder_hidden_states = torch.where( + (encoder_hidden_states * mask).bool(), encoder_hidden_states, empty_embeds_ + ) + + empty_clip_embeds_ = empty_clip_embeds.expand(batch_size, -1) + cond_embeds = torch.where((cond_embeds * mask.squeeze(-1)).bool(), cond_embeds, empty_clip_embeds_) + + bs = input_ids.shape[0] + vae_scale_factor = 2 ** (len(vq_model.config.block_out_channels) - 1) + resolution = args.resolution // vae_scale_factor + input_ids = input_ids.reshape(bs, resolution, resolution) + + if "prompt_input_ids" in batch: + with nullcontext() if args.train_text_encoder else torch.no_grad(): + encoder_hidden_states, cond_embeds = encode_prompt( + text_encoder, batch["prompt_input_ids"].to(accelerator.device, non_blocking=True) + ) + + # Train Step + with accelerator.accumulate(model): + codebook_size = accelerator.unwrap_model(model).config.codebook_size + + logits = ( + model( + input_ids=input_ids, + encoder_hidden_states=encoder_hidden_states, + micro_conds=micro_conds, + pooled_text_emb=cond_embeds, + ) + .reshape(bs, codebook_size, -1) + .permute(0, 2, 1) + .reshape(-1, codebook_size) + ) + + loss = F.cross_entropy( + logits, + labels.view(-1), + ignore_index=-100, + reduction="mean", + ) + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() + avg_masking_rate = accelerator.gather(mask_prob.repeat(args.train_batch_size)).mean() + + accelerator.backward(loss) + + if args.max_grad_norm is not None and accelerator.sync_gradients: + accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + + optimizer.zero_grad(set_to_none=True) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + if args.use_ema: + ema.step(model.parameters()) + + if (global_step + 1) % args.logging_steps == 0: + logs = { + "step_loss": avg_loss.item(), + "lr": lr_scheduler.get_last_lr()[0], + "avg_masking_rate": avg_masking_rate.item(), + } + accelerator.log(logs, step=global_step + 1) + + logger.info( + f"Step: {global_step + 1} " + f"Loss: {avg_loss.item():0.4f} " + f"LR: {lr_scheduler.get_last_lr()[0]:0.6f}" + ) + + if (global_step + 1) % args.checkpointing_steps == 0: + save_checkpoint(args, accelerator, global_step + 1) + + if (global_step + 1) % args.validation_steps == 0 and accelerator.is_main_process: + if args.use_ema: + ema.store(model.parameters()) + ema.copy_to(model.parameters()) + + with torch.no_grad(): + logger.info("Generating images...") + + model.eval() + + if args.train_text_encoder: + text_encoder.eval() + + scheduler = AmusedScheduler.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="scheduler", + revision=args.revision, + variant=args.variant, + ) + + pipe = AmusedPipeline( + transformer=accelerator.unwrap_model(model), + tokenizer=tokenizer, + text_encoder=text_encoder, + vqvae=vq_model, + scheduler=scheduler, + ) + + pil_images = pipe(prompt=args.validation_prompts).images + wandb_images = [ + wandb.Image(image, caption=args.validation_prompts[i]) + for i, image in enumerate(pil_images) + ] + + wandb.log({"generated_images": wandb_images}, step=global_step + 1) + + model.train() + + if args.train_text_encoder: + text_encoder.train() + + if args.use_ema: + ema.restore(model.parameters()) + + global_step += 1 + + # Stop training if max steps is reached + if global_step >= args.max_train_steps: + break + # End for + + accelerator.wait_for_everyone() + + # Evaluate and save checkpoint at the end of training + save_checkpoint(args, accelerator, global_step) + + # Save the final trained checkpoint + if accelerator.is_main_process: + model = accelerator.unwrap_model(model) + if args.use_ema: + ema.copy_to(model.parameters()) + model.save_pretrained(args.output_dir) + + accelerator.end_training() + + +def save_checkpoint(args, accelerator, global_step): + output_dir = args.output_dir + + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if accelerator.is_main_process and args.checkpoints_total_limit is not None: + checkpoints = os.listdir(output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = Path(output_dir) / f"checkpoint-{global_step}" + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + +if __name__ == "__main__": + main(parse_args()) diff --git a/scripts/convert_amused.py b/scripts/convert_amused.py new file mode 100644 index 000000000000..fdddbef7cd65 --- /dev/null +++ b/scripts/convert_amused.py @@ -0,0 +1,523 @@ +import inspect +import os +from argparse import ArgumentParser + +import numpy as np +import torch +from muse import MaskGiTUViT, VQGANModel +from muse import PipelineMuse as OldPipelineMuse +from transformers import CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import VQModel +from diffusers.models.attention_processor import AttnProcessor +from diffusers.models.uvit_2d import UVit2DModel +from diffusers.pipelines.amused.pipeline_amused import AmusedPipeline +from diffusers.schedulers import AmusedScheduler + + +torch.backends.cuda.enable_flash_sdp(False) +torch.backends.cuda.enable_mem_efficient_sdp(False) +torch.backends.cuda.enable_math_sdp(True) + +os.environ["CUDA_LAUNCH_BLOCKING"] = "1" +os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" +torch.use_deterministic_algorithms(True) + +# Enable CUDNN deterministic mode +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False +torch.backends.cuda.matmul.allow_tf32 = False + +device = "cuda" + + +def main(): + args = ArgumentParser() + args.add_argument("--model_256", action="store_true") + args.add_argument("--write_to", type=str, required=False, default=None) + args.add_argument("--transformer_path", type=str, required=False, default=None) + args = args.parse_args() + + transformer_path = args.transformer_path + subfolder = "transformer" + + if transformer_path is None: + if args.model_256: + transformer_path = "openMUSE/muse-256" + else: + transformer_path = ( + "../research-run-512-checkpoints/research-run-512-with-downsample-checkpoint-554000/unwrapped_model/" + ) + subfolder = None + + old_transformer = MaskGiTUViT.from_pretrained(transformer_path, subfolder=subfolder) + + old_transformer.to(device) + + old_vae = VQGANModel.from_pretrained("openMUSE/muse-512", subfolder="vae") + old_vae.to(device) + + vqvae = make_vqvae(old_vae) + + tokenizer = CLIPTokenizer.from_pretrained("openMUSE/muse-512", subfolder="text_encoder") + + text_encoder = CLIPTextModelWithProjection.from_pretrained("openMUSE/muse-512", subfolder="text_encoder") + text_encoder.to(device) + + transformer = make_transformer(old_transformer, args.model_256) + + scheduler = AmusedScheduler(mask_token_id=old_transformer.config.mask_token_id) + + new_pipe = AmusedPipeline( + vqvae=vqvae, tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, scheduler=scheduler + ) + + old_pipe = OldPipelineMuse( + vae=old_vae, transformer=old_transformer, text_encoder=text_encoder, tokenizer=tokenizer + ) + old_pipe.to(device) + + if args.model_256: + transformer_seq_len = 256 + orig_size = (256, 256) + else: + transformer_seq_len = 1024 + orig_size = (512, 512) + + old_out = old_pipe( + "dog", + generator=torch.Generator(device).manual_seed(0), + transformer_seq_len=transformer_seq_len, + orig_size=orig_size, + timesteps=12, + )[0] + + new_out = new_pipe("dog", generator=torch.Generator(device).manual_seed(0)).images[0] + + old_out = np.array(old_out) + new_out = np.array(new_out) + + diff = np.abs(old_out.astype(np.float64) - new_out.astype(np.float64)) + + # assert diff diff.sum() == 0 + print("skipping pipeline full equivalence check") + + print(f"max diff: {diff.max()}, diff.sum() / diff.size {diff.sum() / diff.size}") + + if args.model_256: + assert diff.max() <= 3 + assert diff.sum() / diff.size < 0.7 + else: + assert diff.max() <= 1 + assert diff.sum() / diff.size < 0.4 + + if args.write_to is not None: + new_pipe.save_pretrained(args.write_to) + + +def make_transformer(old_transformer, model_256): + args = dict(old_transformer.config) + force_down_up_sample = args["force_down_up_sample"] + + signature = inspect.signature(UVit2DModel.__init__) + + args_ = { + "downsample": force_down_up_sample, + "upsample": force_down_up_sample, + "block_out_channels": args["block_out_channels"][0], + "sample_size": 16 if model_256 else 32, + } + + for s in list(signature.parameters.keys()): + if s in ["self", "downsample", "upsample", "sample_size", "block_out_channels"]: + continue + + args_[s] = args[s] + + new_transformer = UVit2DModel(**args_) + new_transformer.to(device) + + new_transformer.set_attn_processor(AttnProcessor()) + + state_dict = old_transformer.state_dict() + + state_dict["cond_embed.linear_1.weight"] = state_dict.pop("cond_embed.0.weight") + state_dict["cond_embed.linear_2.weight"] = state_dict.pop("cond_embed.2.weight") + + for i in range(22): + state_dict[f"transformer_layers.{i}.norm1.norm.weight"] = state_dict.pop( + f"transformer_layers.{i}.attn_layer_norm.weight" + ) + state_dict[f"transformer_layers.{i}.norm1.linear.weight"] = state_dict.pop( + f"transformer_layers.{i}.self_attn_adaLN_modulation.mapper.weight" + ) + + state_dict[f"transformer_layers.{i}.attn1.to_q.weight"] = state_dict.pop( + f"transformer_layers.{i}.attention.query.weight" + ) + state_dict[f"transformer_layers.{i}.attn1.to_k.weight"] = state_dict.pop( + f"transformer_layers.{i}.attention.key.weight" + ) + state_dict[f"transformer_layers.{i}.attn1.to_v.weight"] = state_dict.pop( + f"transformer_layers.{i}.attention.value.weight" + ) + state_dict[f"transformer_layers.{i}.attn1.to_out.0.weight"] = state_dict.pop( + f"transformer_layers.{i}.attention.out.weight" + ) + + state_dict[f"transformer_layers.{i}.norm2.norm.weight"] = state_dict.pop( + f"transformer_layers.{i}.crossattn_layer_norm.weight" + ) + state_dict[f"transformer_layers.{i}.norm2.linear.weight"] = state_dict.pop( + f"transformer_layers.{i}.cross_attn_adaLN_modulation.mapper.weight" + ) + + state_dict[f"transformer_layers.{i}.attn2.to_q.weight"] = state_dict.pop( + f"transformer_layers.{i}.crossattention.query.weight" + ) + state_dict[f"transformer_layers.{i}.attn2.to_k.weight"] = state_dict.pop( + f"transformer_layers.{i}.crossattention.key.weight" + ) + state_dict[f"transformer_layers.{i}.attn2.to_v.weight"] = state_dict.pop( + f"transformer_layers.{i}.crossattention.value.weight" + ) + state_dict[f"transformer_layers.{i}.attn2.to_out.0.weight"] = state_dict.pop( + f"transformer_layers.{i}.crossattention.out.weight" + ) + + state_dict[f"transformer_layers.{i}.norm3.norm.weight"] = state_dict.pop( + f"transformer_layers.{i}.ffn.pre_mlp_layer_norm.weight" + ) + state_dict[f"transformer_layers.{i}.norm3.linear.weight"] = state_dict.pop( + f"transformer_layers.{i}.ffn.adaLN_modulation.mapper.weight" + ) + + wi_0_weight = state_dict.pop(f"transformer_layers.{i}.ffn.wi_0.weight") + wi_1_weight = state_dict.pop(f"transformer_layers.{i}.ffn.wi_1.weight") + proj_weight = torch.concat([wi_1_weight, wi_0_weight], dim=0) + state_dict[f"transformer_layers.{i}.ff.net.0.proj.weight"] = proj_weight + + state_dict[f"transformer_layers.{i}.ff.net.2.weight"] = state_dict.pop(f"transformer_layers.{i}.ffn.wo.weight") + + if force_down_up_sample: + state_dict["down_block.downsample.norm.weight"] = state_dict.pop("down_blocks.0.downsample.0.norm.weight") + state_dict["down_block.downsample.conv.weight"] = state_dict.pop("down_blocks.0.downsample.1.weight") + + state_dict["up_block.upsample.norm.weight"] = state_dict.pop("up_blocks.0.upsample.0.norm.weight") + state_dict["up_block.upsample.conv.weight"] = state_dict.pop("up_blocks.0.upsample.1.weight") + + state_dict["mlm_layer.layer_norm.weight"] = state_dict.pop("mlm_layer.layer_norm.norm.weight") + + for i in range(3): + state_dict[f"down_block.res_blocks.{i}.norm.weight"] = state_dict.pop( + f"down_blocks.0.res_blocks.{i}.norm.norm.weight" + ) + state_dict[f"down_block.res_blocks.{i}.channelwise_linear_1.weight"] = state_dict.pop( + f"down_blocks.0.res_blocks.{i}.channelwise.0.weight" + ) + state_dict[f"down_block.res_blocks.{i}.channelwise_norm.gamma"] = state_dict.pop( + f"down_blocks.0.res_blocks.{i}.channelwise.2.gamma" + ) + state_dict[f"down_block.res_blocks.{i}.channelwise_norm.beta"] = state_dict.pop( + f"down_blocks.0.res_blocks.{i}.channelwise.2.beta" + ) + state_dict[f"down_block.res_blocks.{i}.channelwise_linear_2.weight"] = state_dict.pop( + f"down_blocks.0.res_blocks.{i}.channelwise.4.weight" + ) + state_dict[f"down_block.res_blocks.{i}.cond_embeds_mapper.weight"] = state_dict.pop( + f"down_blocks.0.res_blocks.{i}.adaLN_modulation.mapper.weight" + ) + + state_dict[f"down_block.attention_blocks.{i}.norm1.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.attn_layer_norm.weight" + ) + state_dict[f"down_block.attention_blocks.{i}.attn1.to_q.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.attention.query.weight" + ) + state_dict[f"down_block.attention_blocks.{i}.attn1.to_k.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.attention.key.weight" + ) + state_dict[f"down_block.attention_blocks.{i}.attn1.to_v.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.attention.value.weight" + ) + state_dict[f"down_block.attention_blocks.{i}.attn1.to_out.0.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.attention.out.weight" + ) + + state_dict[f"down_block.attention_blocks.{i}.norm2.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.crossattn_layer_norm.weight" + ) + state_dict[f"down_block.attention_blocks.{i}.attn2.to_q.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.crossattention.query.weight" + ) + state_dict[f"down_block.attention_blocks.{i}.attn2.to_k.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.crossattention.key.weight" + ) + state_dict[f"down_block.attention_blocks.{i}.attn2.to_v.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.crossattention.value.weight" + ) + state_dict[f"down_block.attention_blocks.{i}.attn2.to_out.0.weight"] = state_dict.pop( + f"down_blocks.0.attention_blocks.{i}.crossattention.out.weight" + ) + + state_dict[f"up_block.res_blocks.{i}.norm.weight"] = state_dict.pop( + f"up_blocks.0.res_blocks.{i}.norm.norm.weight" + ) + state_dict[f"up_block.res_blocks.{i}.channelwise_linear_1.weight"] = state_dict.pop( + f"up_blocks.0.res_blocks.{i}.channelwise.0.weight" + ) + state_dict[f"up_block.res_blocks.{i}.channelwise_norm.gamma"] = state_dict.pop( + f"up_blocks.0.res_blocks.{i}.channelwise.2.gamma" + ) + state_dict[f"up_block.res_blocks.{i}.channelwise_norm.beta"] = state_dict.pop( + f"up_blocks.0.res_blocks.{i}.channelwise.2.beta" + ) + state_dict[f"up_block.res_blocks.{i}.channelwise_linear_2.weight"] = state_dict.pop( + f"up_blocks.0.res_blocks.{i}.channelwise.4.weight" + ) + state_dict[f"up_block.res_blocks.{i}.cond_embeds_mapper.weight"] = state_dict.pop( + f"up_blocks.0.res_blocks.{i}.adaLN_modulation.mapper.weight" + ) + + state_dict[f"up_block.attention_blocks.{i}.norm1.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.attn_layer_norm.weight" + ) + state_dict[f"up_block.attention_blocks.{i}.attn1.to_q.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.attention.query.weight" + ) + state_dict[f"up_block.attention_blocks.{i}.attn1.to_k.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.attention.key.weight" + ) + state_dict[f"up_block.attention_blocks.{i}.attn1.to_v.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.attention.value.weight" + ) + state_dict[f"up_block.attention_blocks.{i}.attn1.to_out.0.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.attention.out.weight" + ) + + state_dict[f"up_block.attention_blocks.{i}.norm2.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.crossattn_layer_norm.weight" + ) + state_dict[f"up_block.attention_blocks.{i}.attn2.to_q.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.crossattention.query.weight" + ) + state_dict[f"up_block.attention_blocks.{i}.attn2.to_k.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.crossattention.key.weight" + ) + state_dict[f"up_block.attention_blocks.{i}.attn2.to_v.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.crossattention.value.weight" + ) + state_dict[f"up_block.attention_blocks.{i}.attn2.to_out.0.weight"] = state_dict.pop( + f"up_blocks.0.attention_blocks.{i}.crossattention.out.weight" + ) + + for key in list(state_dict.keys()): + if key.startswith("up_blocks.0"): + key_ = "up_block." + ".".join(key.split(".")[2:]) + state_dict[key_] = state_dict.pop(key) + + if key.startswith("down_blocks.0"): + key_ = "down_block." + ".".join(key.split(".")[2:]) + state_dict[key_] = state_dict.pop(key) + + new_transformer.load_state_dict(state_dict) + + input_ids = torch.randint(0, 10, (1, 32, 32), device=old_transformer.device) + encoder_hidden_states = torch.randn((1, 77, 768), device=old_transformer.device) + cond_embeds = torch.randn((1, 768), device=old_transformer.device) + micro_conds = torch.tensor([[512, 512, 0, 0, 6]], dtype=torch.float32, device=old_transformer.device) + + old_out = old_transformer(input_ids.reshape(1, -1), encoder_hidden_states, cond_embeds, micro_conds) + old_out = old_out.reshape(1, 32, 32, 8192).permute(0, 3, 1, 2) + + new_out = new_transformer(input_ids, encoder_hidden_states, cond_embeds, micro_conds) + + # NOTE: these differences are solely due to using the geglu block that has a single linear layer of + # double output dimension instead of two different linear layers + max_diff = (old_out - new_out).abs().max() + total_diff = (old_out - new_out).abs().sum() + print(f"Transformer max_diff: {max_diff} total_diff: {total_diff}") + assert max_diff < 0.01 + assert total_diff < 1500 + + return new_transformer + + +def make_vqvae(old_vae): + new_vae = VQModel( + act_fn="silu", + block_out_channels=[128, 256, 256, 512, 768], + down_block_types=[ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D", + ], + in_channels=3, + latent_channels=64, + layers_per_block=2, + norm_num_groups=32, + num_vq_embeddings=8192, + out_channels=3, + sample_size=32, + up_block_types=[ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D", + ], + mid_block_add_attention=False, + lookup_from_codebook=True, + ) + new_vae.to(device) + + # fmt: off + + new_state_dict = {} + + old_state_dict = old_vae.state_dict() + + new_state_dict["encoder.conv_in.weight"] = old_state_dict.pop("encoder.conv_in.weight") + new_state_dict["encoder.conv_in.bias"] = old_state_dict.pop("encoder.conv_in.bias") + + convert_vae_block_state_dict(old_state_dict, "encoder.down.0", new_state_dict, "encoder.down_blocks.0") + convert_vae_block_state_dict(old_state_dict, "encoder.down.1", new_state_dict, "encoder.down_blocks.1") + convert_vae_block_state_dict(old_state_dict, "encoder.down.2", new_state_dict, "encoder.down_blocks.2") + convert_vae_block_state_dict(old_state_dict, "encoder.down.3", new_state_dict, "encoder.down_blocks.3") + convert_vae_block_state_dict(old_state_dict, "encoder.down.4", new_state_dict, "encoder.down_blocks.4") + + new_state_dict["encoder.mid_block.resnets.0.norm1.weight"] = old_state_dict.pop("encoder.mid.block_1.norm1.weight") + new_state_dict["encoder.mid_block.resnets.0.norm1.bias"] = old_state_dict.pop("encoder.mid.block_1.norm1.bias") + new_state_dict["encoder.mid_block.resnets.0.conv1.weight"] = old_state_dict.pop("encoder.mid.block_1.conv1.weight") + new_state_dict["encoder.mid_block.resnets.0.conv1.bias"] = old_state_dict.pop("encoder.mid.block_1.conv1.bias") + new_state_dict["encoder.mid_block.resnets.0.norm2.weight"] = old_state_dict.pop("encoder.mid.block_1.norm2.weight") + new_state_dict["encoder.mid_block.resnets.0.norm2.bias"] = old_state_dict.pop("encoder.mid.block_1.norm2.bias") + new_state_dict["encoder.mid_block.resnets.0.conv2.weight"] = old_state_dict.pop("encoder.mid.block_1.conv2.weight") + new_state_dict["encoder.mid_block.resnets.0.conv2.bias"] = old_state_dict.pop("encoder.mid.block_1.conv2.bias") + new_state_dict["encoder.mid_block.resnets.1.norm1.weight"] = old_state_dict.pop("encoder.mid.block_2.norm1.weight") + new_state_dict["encoder.mid_block.resnets.1.norm1.bias"] = old_state_dict.pop("encoder.mid.block_2.norm1.bias") + new_state_dict["encoder.mid_block.resnets.1.conv1.weight"] = old_state_dict.pop("encoder.mid.block_2.conv1.weight") + new_state_dict["encoder.mid_block.resnets.1.conv1.bias"] = old_state_dict.pop("encoder.mid.block_2.conv1.bias") + new_state_dict["encoder.mid_block.resnets.1.norm2.weight"] = old_state_dict.pop("encoder.mid.block_2.norm2.weight") + new_state_dict["encoder.mid_block.resnets.1.norm2.bias"] = old_state_dict.pop("encoder.mid.block_2.norm2.bias") + new_state_dict["encoder.mid_block.resnets.1.conv2.weight"] = old_state_dict.pop("encoder.mid.block_2.conv2.weight") + new_state_dict["encoder.mid_block.resnets.1.conv2.bias"] = old_state_dict.pop("encoder.mid.block_2.conv2.bias") + new_state_dict["encoder.conv_norm_out.weight"] = old_state_dict.pop("encoder.norm_out.weight") + new_state_dict["encoder.conv_norm_out.bias"] = old_state_dict.pop("encoder.norm_out.bias") + new_state_dict["encoder.conv_out.weight"] = old_state_dict.pop("encoder.conv_out.weight") + new_state_dict["encoder.conv_out.bias"] = old_state_dict.pop("encoder.conv_out.bias") + new_state_dict["quant_conv.weight"] = old_state_dict.pop("quant_conv.weight") + new_state_dict["quant_conv.bias"] = old_state_dict.pop("quant_conv.bias") + new_state_dict["quantize.embedding.weight"] = old_state_dict.pop("quantize.embedding.weight") + new_state_dict["post_quant_conv.weight"] = old_state_dict.pop("post_quant_conv.weight") + new_state_dict["post_quant_conv.bias"] = old_state_dict.pop("post_quant_conv.bias") + new_state_dict["decoder.conv_in.weight"] = old_state_dict.pop("decoder.conv_in.weight") + new_state_dict["decoder.conv_in.bias"] = old_state_dict.pop("decoder.conv_in.bias") + new_state_dict["decoder.mid_block.resnets.0.norm1.weight"] = old_state_dict.pop("decoder.mid.block_1.norm1.weight") + new_state_dict["decoder.mid_block.resnets.0.norm1.bias"] = old_state_dict.pop("decoder.mid.block_1.norm1.bias") + new_state_dict["decoder.mid_block.resnets.0.conv1.weight"] = old_state_dict.pop("decoder.mid.block_1.conv1.weight") + new_state_dict["decoder.mid_block.resnets.0.conv1.bias"] = old_state_dict.pop("decoder.mid.block_1.conv1.bias") + new_state_dict["decoder.mid_block.resnets.0.norm2.weight"] = old_state_dict.pop("decoder.mid.block_1.norm2.weight") + new_state_dict["decoder.mid_block.resnets.0.norm2.bias"] = old_state_dict.pop("decoder.mid.block_1.norm2.bias") + new_state_dict["decoder.mid_block.resnets.0.conv2.weight"] = old_state_dict.pop("decoder.mid.block_1.conv2.weight") + new_state_dict["decoder.mid_block.resnets.0.conv2.bias"] = old_state_dict.pop("decoder.mid.block_1.conv2.bias") + new_state_dict["decoder.mid_block.resnets.1.norm1.weight"] = old_state_dict.pop("decoder.mid.block_2.norm1.weight") + new_state_dict["decoder.mid_block.resnets.1.norm1.bias"] = old_state_dict.pop("decoder.mid.block_2.norm1.bias") + new_state_dict["decoder.mid_block.resnets.1.conv1.weight"] = old_state_dict.pop("decoder.mid.block_2.conv1.weight") + new_state_dict["decoder.mid_block.resnets.1.conv1.bias"] = old_state_dict.pop("decoder.mid.block_2.conv1.bias") + new_state_dict["decoder.mid_block.resnets.1.norm2.weight"] = old_state_dict.pop("decoder.mid.block_2.norm2.weight") + new_state_dict["decoder.mid_block.resnets.1.norm2.bias"] = old_state_dict.pop("decoder.mid.block_2.norm2.bias") + new_state_dict["decoder.mid_block.resnets.1.conv2.weight"] = old_state_dict.pop("decoder.mid.block_2.conv2.weight") + new_state_dict["decoder.mid_block.resnets.1.conv2.bias"] = old_state_dict.pop("decoder.mid.block_2.conv2.bias") + + convert_vae_block_state_dict(old_state_dict, "decoder.up.0", new_state_dict, "decoder.up_blocks.4") + convert_vae_block_state_dict(old_state_dict, "decoder.up.1", new_state_dict, "decoder.up_blocks.3") + convert_vae_block_state_dict(old_state_dict, "decoder.up.2", new_state_dict, "decoder.up_blocks.2") + convert_vae_block_state_dict(old_state_dict, "decoder.up.3", new_state_dict, "decoder.up_blocks.1") + convert_vae_block_state_dict(old_state_dict, "decoder.up.4", new_state_dict, "decoder.up_blocks.0") + + new_state_dict["decoder.conv_norm_out.weight"] = old_state_dict.pop("decoder.norm_out.weight") + new_state_dict["decoder.conv_norm_out.bias"] = old_state_dict.pop("decoder.norm_out.bias") + new_state_dict["decoder.conv_out.weight"] = old_state_dict.pop("decoder.conv_out.weight") + new_state_dict["decoder.conv_out.bias"] = old_state_dict.pop("decoder.conv_out.bias") + + # fmt: on + + assert len(old_state_dict.keys()) == 0 + + new_vae.load_state_dict(new_state_dict) + + input = torch.randn((1, 3, 512, 512), device=device) + input = input.clamp(-1, 1) + + old_encoder_output = old_vae.quant_conv(old_vae.encoder(input)) + new_encoder_output = new_vae.quant_conv(new_vae.encoder(input)) + assert (old_encoder_output == new_encoder_output).all() + + old_decoder_output = old_vae.decoder(old_vae.post_quant_conv(old_encoder_output)) + new_decoder_output = new_vae.decoder(new_vae.post_quant_conv(new_encoder_output)) + + # assert (old_decoder_output == new_decoder_output).all() + print("kipping vae decoder equivalence check") + print(f"vae decoder diff {(old_decoder_output - new_decoder_output).float().abs().sum()}") + + old_output = old_vae(input)[0] + new_output = new_vae(input)[0] + + # assert (old_output == new_output).all() + print("skipping full vae equivalence check") + print(f"vae full diff { (old_output - new_output).float().abs().sum()}") + + return new_vae + + +def convert_vae_block_state_dict(old_state_dict, prefix_from, new_state_dict, prefix_to): + # fmt: off + + new_state_dict[f"{prefix_to}.resnets.0.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.norm1.weight") + new_state_dict[f"{prefix_to}.resnets.0.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.norm1.bias") + new_state_dict[f"{prefix_to}.resnets.0.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.conv1.weight") + new_state_dict[f"{prefix_to}.resnets.0.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.conv1.bias") + new_state_dict[f"{prefix_to}.resnets.0.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.norm2.weight") + new_state_dict[f"{prefix_to}.resnets.0.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.norm2.bias") + new_state_dict[f"{prefix_to}.resnets.0.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.conv2.weight") + new_state_dict[f"{prefix_to}.resnets.0.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.conv2.bias") + + if f"{prefix_from}.block.0.nin_shortcut.weight" in old_state_dict: + new_state_dict[f"{prefix_to}.resnets.0.conv_shortcut.weight"] = old_state_dict.pop(f"{prefix_from}.block.0.nin_shortcut.weight") + new_state_dict[f"{prefix_to}.resnets.0.conv_shortcut.bias"] = old_state_dict.pop(f"{prefix_from}.block.0.nin_shortcut.bias") + + new_state_dict[f"{prefix_to}.resnets.1.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.norm1.weight") + new_state_dict[f"{prefix_to}.resnets.1.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.norm1.bias") + new_state_dict[f"{prefix_to}.resnets.1.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.conv1.weight") + new_state_dict[f"{prefix_to}.resnets.1.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.conv1.bias") + new_state_dict[f"{prefix_to}.resnets.1.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.norm2.weight") + new_state_dict[f"{prefix_to}.resnets.1.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.norm2.bias") + new_state_dict[f"{prefix_to}.resnets.1.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.1.conv2.weight") + new_state_dict[f"{prefix_to}.resnets.1.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.1.conv2.bias") + + if f"{prefix_from}.downsample.conv.weight" in old_state_dict: + new_state_dict[f"{prefix_to}.downsamplers.0.conv.weight"] = old_state_dict.pop(f"{prefix_from}.downsample.conv.weight") + new_state_dict[f"{prefix_to}.downsamplers.0.conv.bias"] = old_state_dict.pop(f"{prefix_from}.downsample.conv.bias") + + if f"{prefix_from}.upsample.conv.weight" in old_state_dict: + new_state_dict[f"{prefix_to}.upsamplers.0.conv.weight"] = old_state_dict.pop(f"{prefix_from}.upsample.conv.weight") + new_state_dict[f"{prefix_to}.upsamplers.0.conv.bias"] = old_state_dict.pop(f"{prefix_from}.upsample.conv.bias") + + if f"{prefix_from}.block.2.norm1.weight" in old_state_dict: + new_state_dict[f"{prefix_to}.resnets.2.norm1.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.norm1.weight") + new_state_dict[f"{prefix_to}.resnets.2.norm1.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.norm1.bias") + new_state_dict[f"{prefix_to}.resnets.2.conv1.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.conv1.weight") + new_state_dict[f"{prefix_to}.resnets.2.conv1.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.conv1.bias") + new_state_dict[f"{prefix_to}.resnets.2.norm2.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.norm2.weight") + new_state_dict[f"{prefix_to}.resnets.2.norm2.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.norm2.bias") + new_state_dict[f"{prefix_to}.resnets.2.conv2.weight"] = old_state_dict.pop(f"{prefix_from}.block.2.conv2.weight") + new_state_dict[f"{prefix_to}.resnets.2.conv2.bias"] = old_state_dict.pop(f"{prefix_from}.block.2.conv2.bias") + + # fmt: on + + +if __name__ == "__main__": + main() diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index c43000e27b82..10c5b0f46565 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -95,6 +95,7 @@ "UNet3DConditionModel", "UNetMotionModel", "UNetSpatioTemporalConditionModel", + "UVit2DModel", "VQModel", ] ) @@ -131,6 +132,7 @@ ) _import_structure["schedulers"].extend( [ + "AmusedScheduler", "CMStochasticIterativeScheduler", "DDIMInverseScheduler", "DDIMParallelScheduler", @@ -202,6 +204,9 @@ [ "AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline", + "AmusedImg2ImgPipeline", + "AmusedInpaintPipeline", + "AmusedPipeline", "AnimateDiffPipeline", "AudioLDM2Pipeline", "AudioLDM2ProjectionModel", @@ -472,6 +477,7 @@ UNet3DConditionModel, UNetMotionModel, UNetSpatioTemporalConditionModel, + UVit2DModel, VQModel, ) from .optimization import ( @@ -506,6 +512,7 @@ ScoreSdeVePipeline, ) from .schedulers import ( + AmusedScheduler, CMStochasticIterativeScheduler, DDIMInverseScheduler, DDIMParallelScheduler, @@ -560,6 +567,9 @@ from .pipelines import ( AltDiffusionImg2ImgPipeline, AltDiffusionPipeline, + AmusedImg2ImgPipeline, + AmusedInpaintPipeline, + AmusedPipeline, AnimateDiffPipeline, AudioLDM2Pipeline, AudioLDM2ProjectionModel, diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index c1c3a260ec11..fc50c52e412b 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -59,6 +59,7 @@ TEXT_ENCODER_NAME = "text_encoder" UNET_NAME = "unet" +TRANSFORMER_NAME = "transformer" LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" @@ -74,6 +75,7 @@ class LoraLoaderMixin: text_encoder_name = TEXT_ENCODER_NAME unet_name = UNET_NAME + transformer_name = TRANSFORMER_NAME num_fused_loras = 0 def load_lora_weights( @@ -661,6 +663,89 @@ def load_lora_into_text_encoder( _pipeline.enable_sequential_cpu_offload() # Unsafe code /> + @classmethod + def load_lora_into_transformer( + cls, state_dict, network_alphas, transformer, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None + ): + """ + This will load the LoRA layers specified in `state_dict` into `transformer`. + + Parameters: + state_dict (`dict`): + A standard state dict containing the lora layer parameters. The keys can either be indexed directly + into the unet or prefixed with an additional `unet` which can be used to distinguish between text + encoder lora layers. + network_alphas (`Dict[str, float]`): + See `LoRALinearLayer` for more details. + unet (`UNet2DConditionModel`): + The UNet model to load the LoRA layers into. + low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): + Speed up model loading only loading the pretrained weights and not initializing the weights. This also + tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. + Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this + argument to `True` will raise an error. + adapter_name (`str`, *optional*): + Adapter name to be used for referencing the loaded adapter model. If not specified, it will use + `default_{i}` where i is the total number of adapters being loaded. + """ + low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT + + keys = list(state_dict.keys()) + + transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] + state_dict = { + k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys + } + + if network_alphas is not None: + alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.transformer_name)] + network_alphas = { + k.replace(f"{cls.transformer_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys + } + + if len(state_dict.keys()) > 0: + from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict + + if adapter_name in getattr(transformer, "peft_config", {}): + raise ValueError( + f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name." + ) + + rank = {} + for key, val in state_dict.items(): + if "lora_B" in key: + rank[key] = val.shape[1] + + lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict) + lora_config = LoraConfig(**lora_config_kwargs) + + # adapter_name + if adapter_name is None: + adapter_name = get_adapter_name(transformer) + + # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks + # otherwise loading LoRA weights will lead to an error + is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline) + + inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name) + incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name) + + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + # Offload back. + if is_model_cpu_offload: + _pipeline.enable_model_cpu_offload() + elif is_sequential_cpu_offload: + _pipeline.enable_sequential_cpu_offload() + # Unsafe code /> + @property def lora_scale(self) -> float: # property function that returns the lora scale which can be set at run time by the pipeline. @@ -786,6 +871,7 @@ def save_lora_weights( save_directory: Union[str, os.PathLike], unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, + transformer_lora_layers: Dict[str, torch.nn.Module] = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, @@ -820,8 +906,10 @@ def pack_weights(layers, prefix): layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} return layers_state_dict - if not (unet_lora_layers or text_encoder_lora_layers): - raise ValueError("You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`.") + if not (unet_lora_layers or text_encoder_lora_layers or transformer_lora_layers): + raise ValueError( + "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, or `transformer_lora_layers`." + ) if unet_lora_layers: state_dict.update(pack_weights(unet_lora_layers, "unet")) @@ -829,6 +917,9 @@ def pack_weights(layers, prefix): if text_encoder_lora_layers: state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) + if transformer_lora_layers: + state_dict.update(pack_weights(transformer_lora_layers, "transformer")) + # Save the model cls.write_lora_layers( state_dict=state_dict, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 7487bbf2f98e..6e7fe72bc949 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -47,6 +47,7 @@ _import_structure["unet_kandinsky3"] = ["Kandinsky3UNet"] _import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"] _import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"] + _import_structure["uvit_2d"] = ["UVit2DModel"] _import_structure["vq_model"] = ["VQModel"] if is_flax_available(): @@ -81,6 +82,7 @@ from .unet_kandinsky3 import Kandinsky3UNet from .unet_motion_model import MotionAdapter, UNetMotionModel from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel + from .uvit_2d import UVit2DModel from .vq_model import VQModel if is_flax_available(): diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 08faaaf3e5bf..a34d7421b4f9 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -14,6 +14,7 @@ from typing import Any, Dict, Optional import torch +import torch.nn.functional as F from torch import nn from ..utils import USE_PEFT_BACKEND @@ -22,7 +23,7 @@ from .attention_processor import Attention from .embeddings import SinusoidalPositionalEmbedding from .lora import LoRACompatibleLinear -from .normalization import AdaLayerNorm, AdaLayerNormZero +from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm def _chunked_feed_forward( @@ -148,6 +149,11 @@ def __init__( attention_type: str = "default", positional_embeddings: Optional[str] = None, num_positional_embeddings: Optional[int] = None, + ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, + ada_norm_bias: Optional[int] = None, + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, ): super().__init__() self.only_cross_attention = only_cross_attention @@ -156,6 +162,7 @@ def __init__( self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" self.use_ada_layer_norm_single = norm_type == "ada_norm_single" self.use_layer_norm = norm_type == "layer_norm" + self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: raise ValueError( @@ -179,6 +186,15 @@ def __init__( self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) elif self.use_ada_layer_norm_zero: self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_continuous: + self.norm1 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) else: self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) @@ -190,6 +206,7 @@ def __init__( bias=attention_bias, cross_attention_dim=cross_attention_dim if only_cross_attention else None, upcast_attention=upcast_attention, + out_bias=attention_out_bias, ) # 2. Cross-Attn @@ -197,11 +214,20 @@ def __init__( # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during # the second cross attention block. - self.norm2 = ( - AdaLayerNorm(dim, num_embeds_ada_norm) - if self.use_ada_layer_norm - else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) - ) + if self.use_ada_layer_norm: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) + elif self.use_ada_layer_norm_continuous: + self.norm2 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "rms_norm", + ) + else: + self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) + self.attn2 = Attention( query_dim=dim, cross_attention_dim=cross_attention_dim if not double_self_attention else None, @@ -210,20 +236,32 @@ def __init__( dropout=dropout, bias=attention_bias, upcast_attention=upcast_attention, + out_bias=attention_out_bias, ) # is self-attn if encoder_hidden_states is none else: self.norm2 = None self.attn2 = None # 3. Feed-forward - if not self.use_ada_layer_norm_single: - self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) + if self.use_ada_layer_norm_continuous: + self.norm3 = AdaLayerNormContinuous( + dim, + ada_norm_continous_conditioning_embedding_dim, + norm_elementwise_affine, + norm_eps, + ada_norm_bias, + "layer_norm", + ) + elif not self.use_ada_layer_norm_single: + self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) self.ff = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, ) # 4. Fuser @@ -252,6 +290,7 @@ def forward( timestep: Optional[torch.LongTensor] = None, cross_attention_kwargs: Dict[str, Any] = None, class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, ) -> torch.FloatTensor: # Notice that normalization is always applied before the real computation in the following blocks. # 0. Self-Attention @@ -265,6 +304,8 @@ def forward( ) elif self.use_layer_norm: norm_hidden_states = self.norm1(hidden_states) + elif self.use_ada_layer_norm_continuous: + norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) elif self.use_ada_layer_norm_single: shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) @@ -314,6 +355,8 @@ def forward( # For PixArt norm2 isn't applied here: # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 norm_hidden_states = hidden_states + elif self.use_ada_layer_norm_continuous: + norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) else: raise ValueError("Incorrect norm") @@ -329,7 +372,9 @@ def forward( hidden_states = attn_output + hidden_states # 4. Feed-forward - if not self.use_ada_layer_norm_single: + if self.use_ada_layer_norm_continuous: + norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) + elif not self.use_ada_layer_norm_single: norm_hidden_states = self.norm3(hidden_states) if self.use_ada_layer_norm_zero: @@ -490,6 +535,78 @@ def forward( return hidden_states +class SkipFFTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + kv_input_dim: int, + kv_input_dim_proj_use_bias: bool, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + attention_out_bias: bool = True, + ): + super().__init__() + if kv_input_dim != dim: + self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias) + else: + self.kv_mapper = None + + self.norm1 = RMSNorm(dim, 1e-06) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim, + out_bias=attention_out_bias, + ) + + self.norm2 = RMSNorm(dim, 1e-06) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + out_bias=attention_out_bias, + ) + + def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs): + cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + + if self.kv_mapper is not None: + encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states)) + + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + norm_hidden_states = self.norm2(hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + **cross_attention_kwargs, + ) + + hidden_states = attn_output + hidden_states + + return hidden_states + + class FeedForward(nn.Module): r""" A feed-forward layer. @@ -512,10 +629,12 @@ def __init__( dropout: float = 0.0, activation_fn: str = "geglu", final_dropout: bool = False, + inner_dim=None, bias: bool = True, ): super().__init__() - inner_dim = int(dim * mult) + if inner_dim is None: + inner_dim = int(dim * mult) dim_out = dim_out if dim_out is not None else dim linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear diff --git a/src/diffusers/models/autoencoders/vae.py b/src/diffusers/models/autoencoders/vae.py index 9ed0232e6983..3f1643bc50ef 100644 --- a/src/diffusers/models/autoencoders/vae.py +++ b/src/diffusers/models/autoencoders/vae.py @@ -77,6 +77,7 @@ def __init__( norm_num_groups: int = 32, act_fn: str = "silu", double_z: bool = True, + mid_block_add_attention=True, ): super().__init__() self.layers_per_block = layers_per_block @@ -124,6 +125,7 @@ def __init__( attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=None, + add_attention=mid_block_add_attention, ) # out @@ -213,6 +215,7 @@ def __init__( norm_num_groups: int = 32, act_fn: str = "silu", norm_type: str = "group", # group, spatial + mid_block_add_attention=True, ): super().__init__() self.layers_per_block = layers_per_block @@ -240,6 +243,7 @@ def __init__( attention_head_dim=block_out_channels[-1], resnet_groups=norm_num_groups, temb_channels=temb_channels, + add_attention=mid_block_add_attention, ) # up diff --git a/src/diffusers/models/downsampling.py b/src/diffusers/models/downsampling.py index d39bae22e831..ecab1fffe2f0 100644 --- a/src/diffusers/models/downsampling.py +++ b/src/diffusers/models/downsampling.py @@ -20,6 +20,7 @@ from ..utils import USE_PEFT_BACKEND from .lora import LoRACompatibleConv +from .normalization import RMSNorm from .upsampling import upfirdn2d_native @@ -89,6 +90,11 @@ def __init__( out_channels: Optional[int] = None, padding: int = 1, name: str = "conv", + kernel_size=3, + norm_type=None, + eps=None, + elementwise_affine=None, + bias=True, ): super().__init__() self.channels = channels @@ -99,8 +105,19 @@ def __init__( self.name = name conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + if norm_type == "ln_norm": + self.norm = nn.LayerNorm(channels, eps, elementwise_affine) + elif norm_type == "rms_norm": + self.norm = RMSNorm(channels, eps, elementwise_affine) + elif norm_type is None: + self.norm = None + else: + raise ValueError(f"unknown norm_type: {norm_type}") + if use_conv: - conv = conv_cls(self.channels, self.out_channels, 3, stride=stride, padding=padding) + conv = conv_cls( + self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias + ) else: assert self.channels == self.out_channels conv = nn.AvgPool2d(kernel_size=stride, stride=stride) @@ -117,6 +134,9 @@ def __init__( def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: assert hidden_states.shape[1] == self.channels + if self.norm is not None: + hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + if self.use_conv and self.padding == 0: pad = (0, 1, 0, 1) hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index db68591bdb44..7e98f77baf26 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -197,11 +197,12 @@ def __init__( out_dim: int = None, post_act_fn: Optional[str] = None, cond_proj_dim=None, + sample_proj_bias=True, ): super().__init__() linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear - self.linear_1 = linear_cls(in_channels, time_embed_dim) + self.linear_1 = linear_cls(in_channels, time_embed_dim, sample_proj_bias) if cond_proj_dim is not None: self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) @@ -214,7 +215,7 @@ def __init__( time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim - self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out) + self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out, sample_proj_bias) if post_act_fn is None: self.post_act = None diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 25af4d853b86..7f6e2c145435 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -13,12 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import numbers from typing import Dict, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F +from ..utils import is_torch_version from .activations import get_activation from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings @@ -146,3 +148,107 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor: x = F.group_norm(x, self.num_groups, eps=self.eps) x = x * (1 + scale) + shift return x + + +class AdaLayerNormContinuous(nn.Module): + def __init__( + self, + embedding_dim: int, + conditioning_embedding_dim: int, + # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters + # because the output is immediately scaled and shifted by the projected conditioning embeddings. + # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters. + # However, this is how it was implemented in the original code, and it's rather likely you should + # set `elementwise_affine` to False. + elementwise_affine=True, + eps=1e-5, + bias=True, + norm_type="layer_norm", + ): + super().__init__() + self.silu = nn.SiLU() + self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias) + if norm_type == "layer_norm": + self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias) + elif norm_type == "rms_norm": + self.norm = RMSNorm(embedding_dim, eps, elementwise_affine) + else: + raise ValueError(f"unknown norm_type {norm_type}") + + def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: + emb = self.linear(self.silu(conditioning_embedding)) + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :] + return x + + +if is_torch_version(">=", "2.1.0"): + LayerNorm = nn.LayerNorm +else: + # Has optional bias parameter compared to torch layer norm + # TODO: replace with torch layernorm once min required torch version >= 2.1 + class LayerNorm(nn.Module): + def __init__(self, dim, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True): + super().__init__() + + self.eps = eps + + if isinstance(dim, numbers.Integral): + dim = (dim,) + + self.dim = torch.Size(dim) + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + self.bias = nn.Parameter(torch.zeros(dim)) if bias else None + else: + self.weight = None + self.bias = None + + def forward(self, input): + return F.layer_norm(input, self.dim, self.weight, self.bias, self.eps) + + +class RMSNorm(nn.Module): + def __init__(self, dim, eps: float, elementwise_affine: bool = True): + super().__init__() + + self.eps = eps + + if isinstance(dim, numbers.Integral): + dim = (dim,) + + self.dim = torch.Size(dim) + + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + else: + self.weight = None + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.eps) + + if self.weight is not None: + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + hidden_states = hidden_states * self.weight + else: + hidden_states = hidden_states.to(input_dtype) + + return hidden_states + + +class GlobalResponseNorm(nn.Module): + # Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * nx) + self.beta + x diff --git a/src/diffusers/models/upsampling.py b/src/diffusers/models/upsampling.py index 542a5d9d1eb0..1e4e61201059 100644 --- a/src/diffusers/models/upsampling.py +++ b/src/diffusers/models/upsampling.py @@ -20,6 +20,7 @@ from ..utils import USE_PEFT_BACKEND from .lora import LoRACompatibleConv +from .normalization import RMSNorm class Upsample1D(nn.Module): @@ -95,6 +96,13 @@ def __init__( use_conv_transpose: bool = False, out_channels: Optional[int] = None, name: str = "conv", + kernel_size: Optional[int] = None, + padding=1, + norm_type=None, + eps=None, + elementwise_affine=None, + bias=True, + interpolate=True, ): super().__init__() self.channels = channels @@ -102,13 +110,29 @@ def __init__( self.use_conv = use_conv self.use_conv_transpose = use_conv_transpose self.name = name + self.interpolate = interpolate conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv + if norm_type == "ln_norm": + self.norm = nn.LayerNorm(channels, eps, elementwise_affine) + elif norm_type == "rms_norm": + self.norm = RMSNorm(channels, eps, elementwise_affine) + elif norm_type is None: + self.norm = None + else: + raise ValueError(f"unknown norm_type: {norm_type}") + conv = None if use_conv_transpose: - conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) + if kernel_size is None: + kernel_size = 4 + conv = nn.ConvTranspose2d( + channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias + ) elif use_conv: - conv = conv_cls(self.channels, self.out_channels, 3, padding=1) + if kernel_size is None: + kernel_size = 3 + conv = conv_cls(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if name == "conv": @@ -124,6 +148,9 @@ def forward( ) -> torch.FloatTensor: assert hidden_states.shape[1] == self.channels + if self.norm is not None: + hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + if self.use_conv_transpose: return self.conv(hidden_states) @@ -140,10 +167,11 @@ def forward( # if `output_size` is passed we force the interpolation output # size and do not make use of `scale_factor=2` - if output_size is None: - hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") - else: - hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + if self.interpolate: + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") # If the input is bfloat16, we cast back to bfloat16 if dtype == torch.bfloat16: diff --git a/src/diffusers/models/uvit_2d.py b/src/diffusers/models/uvit_2d.py new file mode 100644 index 000000000000..14dd8aee8e89 --- /dev/null +++ b/src/diffusers/models/uvit_2d.py @@ -0,0 +1,471 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Union + +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.checkpoint import checkpoint + +from ..configuration_utils import ConfigMixin, register_to_config +from .attention import BasicTransformerBlock, SkipFFTransformerBlock +from .attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from .embeddings import TimestepEmbedding, get_timestep_embedding +from .modeling_utils import ModelMixin +from .normalization import GlobalResponseNorm, RMSNorm +from .resnet import Downsample2D, Upsample2D + + +class UVit2DModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + # global config + hidden_size: int = 1024, + use_bias: bool = False, + hidden_dropout: float = 0.0, + # conditioning dimensions + cond_embed_dim: int = 768, + micro_cond_encode_dim: int = 256, + micro_cond_embed_dim: int = 1280, + encoder_hidden_size: int = 768, + # num tokens + vocab_size: int = 8256, # codebook_size + 1 (for the mask token) rounded + codebook_size: int = 8192, + # `UVit2DConvEmbed` + in_channels: int = 768, + block_out_channels: int = 768, + num_res_blocks: int = 3, + downsample: bool = False, + upsample: bool = False, + block_num_heads: int = 12, + # `TransformerLayer` + num_hidden_layers: int = 22, + num_attention_heads: int = 16, + # `Attention` + attention_dropout: float = 0.0, + # `FeedForward` + intermediate_size: int = 2816, + # `Norm` + layer_norm_eps: float = 1e-6, + ln_elementwise_affine: bool = True, + sample_size: int = 64, + ): + super().__init__() + + self.encoder_proj = nn.Linear(encoder_hidden_size, hidden_size, bias=use_bias) + self.encoder_proj_layer_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine) + + self.embed = UVit2DConvEmbed( + in_channels, block_out_channels, vocab_size, ln_elementwise_affine, layer_norm_eps, use_bias + ) + + self.cond_embed = TimestepEmbedding( + micro_cond_embed_dim + cond_embed_dim, hidden_size, sample_proj_bias=use_bias + ) + + self.down_block = UVitBlock( + block_out_channels, + num_res_blocks, + hidden_size, + hidden_dropout, + ln_elementwise_affine, + layer_norm_eps, + use_bias, + block_num_heads, + attention_dropout, + downsample, + False, + ) + + self.project_to_hidden_norm = RMSNorm(block_out_channels, layer_norm_eps, ln_elementwise_affine) + self.project_to_hidden = nn.Linear(block_out_channels, hidden_size, bias=use_bias) + + self.transformer_layers = nn.ModuleList( + [ + BasicTransformerBlock( + dim=hidden_size, + num_attention_heads=num_attention_heads, + attention_head_dim=hidden_size // num_attention_heads, + dropout=hidden_dropout, + cross_attention_dim=hidden_size, + attention_bias=use_bias, + norm_type="ada_norm_continuous", + ada_norm_continous_conditioning_embedding_dim=hidden_size, + norm_elementwise_affine=ln_elementwise_affine, + norm_eps=layer_norm_eps, + ada_norm_bias=use_bias, + ff_inner_dim=intermediate_size, + ff_bias=use_bias, + attention_out_bias=use_bias, + ) + for _ in range(num_hidden_layers) + ] + ) + + self.project_from_hidden_norm = RMSNorm(hidden_size, layer_norm_eps, ln_elementwise_affine) + self.project_from_hidden = nn.Linear(hidden_size, block_out_channels, bias=use_bias) + + self.up_block = UVitBlock( + block_out_channels, + num_res_blocks, + hidden_size, + hidden_dropout, + ln_elementwise_affine, + layer_norm_eps, + use_bias, + block_num_heads, + attention_dropout, + downsample=False, + upsample=upsample, + ) + + self.mlm_layer = ConvMlmLayer( + block_out_channels, in_channels, use_bias, ln_elementwise_affine, layer_norm_eps, codebook_size + ) + + self.gradient_checkpointing = False + + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: + pass + + def forward(self, input_ids, encoder_hidden_states, pooled_text_emb, micro_conds, cross_attention_kwargs=None): + encoder_hidden_states = self.encoder_proj(encoder_hidden_states) + encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states) + + micro_cond_embeds = get_timestep_embedding( + micro_conds.flatten(), self.config.micro_cond_encode_dim, flip_sin_to_cos=True, downscale_freq_shift=0 + ) + + micro_cond_embeds = micro_cond_embeds.reshape((input_ids.shape[0], -1)) + + pooled_text_emb = torch.cat([pooled_text_emb, micro_cond_embeds], dim=1) + pooled_text_emb = pooled_text_emb.to(dtype=self.dtype) + pooled_text_emb = self.cond_embed(pooled_text_emb).to(encoder_hidden_states.dtype) + + hidden_states = self.embed(input_ids) + + hidden_states = self.down_block( + hidden_states, + pooled_text_emb=pooled_text_emb, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) + + batch_size, channels, height, width = hidden_states.shape + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels) + + hidden_states = self.project_to_hidden_norm(hidden_states) + hidden_states = self.project_to_hidden(hidden_states) + + for layer in self.transformer_layers: + if self.training and self.gradient_checkpointing: + + def layer_(*args): + return checkpoint(layer, *args) + + else: + layer_ = layer + + hidden_states = layer_( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs={"pooled_text_emb": pooled_text_emb}, + ) + + hidden_states = self.project_from_hidden_norm(hidden_states) + hidden_states = self.project_from_hidden(hidden_states) + + hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) + + hidden_states = self.up_block( + hidden_states, + pooled_text_emb=pooled_text_emb, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) + + logits = self.mlm_layer(hidden_states) + + return logits + + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor, _remove_lora=_remove_lora) + else: + module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor, _remove_lora=True) + + +class UVit2DConvEmbed(nn.Module): + def __init__(self, in_channels, block_out_channels, vocab_size, elementwise_affine, eps, bias): + super().__init__() + self.embeddings = nn.Embedding(vocab_size, in_channels) + self.layer_norm = RMSNorm(in_channels, eps, elementwise_affine) + self.conv = nn.Conv2d(in_channels, block_out_channels, kernel_size=1, bias=bias) + + def forward(self, input_ids): + embeddings = self.embeddings(input_ids) + embeddings = self.layer_norm(embeddings) + embeddings = embeddings.permute(0, 3, 1, 2) + embeddings = self.conv(embeddings) + return embeddings + + +class UVitBlock(nn.Module): + def __init__( + self, + channels, + num_res_blocks: int, + hidden_size, + hidden_dropout, + ln_elementwise_affine, + layer_norm_eps, + use_bias, + block_num_heads, + attention_dropout, + downsample: bool, + upsample: bool, + ): + super().__init__() + + if downsample: + self.downsample = Downsample2D( + channels, + use_conv=True, + padding=0, + name="Conv2d_0", + kernel_size=2, + norm_type="rms_norm", + eps=layer_norm_eps, + elementwise_affine=ln_elementwise_affine, + bias=use_bias, + ) + else: + self.downsample = None + + self.res_blocks = nn.ModuleList( + [ + ConvNextBlock( + channels, + layer_norm_eps, + ln_elementwise_affine, + use_bias, + hidden_dropout, + hidden_size, + ) + for i in range(num_res_blocks) + ] + ) + + self.attention_blocks = nn.ModuleList( + [ + SkipFFTransformerBlock( + channels, + block_num_heads, + channels // block_num_heads, + hidden_size, + use_bias, + attention_dropout, + channels, + attention_bias=use_bias, + attention_out_bias=use_bias, + ) + for _ in range(num_res_blocks) + ] + ) + + if upsample: + self.upsample = Upsample2D( + channels, + use_conv_transpose=True, + kernel_size=2, + padding=0, + name="conv", + norm_type="rms_norm", + eps=layer_norm_eps, + elementwise_affine=ln_elementwise_affine, + bias=use_bias, + interpolate=False, + ) + else: + self.upsample = None + + def forward(self, x, pooled_text_emb, encoder_hidden_states, cross_attention_kwargs): + if self.downsample is not None: + x = self.downsample(x) + + for res_block, attention_block in zip(self.res_blocks, self.attention_blocks): + x = res_block(x, pooled_text_emb) + + batch_size, channels, height, width = x.shape + x = x.view(batch_size, channels, height * width).permute(0, 2, 1) + x = attention_block( + x, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=cross_attention_kwargs + ) + x = x.permute(0, 2, 1).view(batch_size, channels, height, width) + + if self.upsample is not None: + x = self.upsample(x) + + return x + + +class ConvNextBlock(nn.Module): + def __init__( + self, channels, layer_norm_eps, ln_elementwise_affine, use_bias, hidden_dropout, hidden_size, res_ffn_factor=4 + ): + super().__init__() + self.depthwise = nn.Conv2d( + channels, + channels, + kernel_size=3, + padding=1, + groups=channels, + bias=use_bias, + ) + self.norm = RMSNorm(channels, layer_norm_eps, ln_elementwise_affine) + self.channelwise_linear_1 = nn.Linear(channels, int(channels * res_ffn_factor), bias=use_bias) + self.channelwise_act = nn.GELU() + self.channelwise_norm = GlobalResponseNorm(int(channels * res_ffn_factor)) + self.channelwise_linear_2 = nn.Linear(int(channels * res_ffn_factor), channels, bias=use_bias) + self.channelwise_dropout = nn.Dropout(hidden_dropout) + self.cond_embeds_mapper = nn.Linear(hidden_size, channels * 2, use_bias) + + def forward(self, x, cond_embeds): + x_res = x + + x = self.depthwise(x) + + x = x.permute(0, 2, 3, 1) + x = self.norm(x) + + x = self.channelwise_linear_1(x) + x = self.channelwise_act(x) + x = self.channelwise_norm(x) + x = self.channelwise_linear_2(x) + x = self.channelwise_dropout(x) + + x = x.permute(0, 3, 1, 2) + + x = x + x_res + + scale, shift = self.cond_embeds_mapper(F.silu(cond_embeds)).chunk(2, dim=1) + x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None] + + return x + + +class ConvMlmLayer(nn.Module): + def __init__( + self, + block_out_channels: int, + in_channels: int, + use_bias: bool, + ln_elementwise_affine: bool, + layer_norm_eps: float, + codebook_size: int, + ): + super().__init__() + self.conv1 = nn.Conv2d(block_out_channels, in_channels, kernel_size=1, bias=use_bias) + self.layer_norm = RMSNorm(in_channels, layer_norm_eps, ln_elementwise_affine) + self.conv2 = nn.Conv2d(in_channels, codebook_size, kernel_size=1, bias=use_bias) + + def forward(self, hidden_states): + hidden_states = self.conv1(hidden_states) + hidden_states = self.layer_norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + logits = self.conv2(hidden_states) + return logits diff --git a/src/diffusers/models/vq_model.py b/src/diffusers/models/vq_model.py index bfe62ec863b3..5695d7258f2e 100644 --- a/src/diffusers/models/vq_model.py +++ b/src/diffusers/models/vq_model.py @@ -88,6 +88,9 @@ def __init__( vq_embed_dim: Optional[int] = None, scaling_factor: float = 0.18215, norm_type: str = "group", # group, spatial + mid_block_add_attention=True, + lookup_from_codebook=False, + force_upcast=False, ): super().__init__() @@ -101,6 +104,7 @@ def __init__( act_fn=act_fn, norm_num_groups=norm_num_groups, double_z=False, + mid_block_add_attention=mid_block_add_attention, ) vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels @@ -119,6 +123,7 @@ def __init__( act_fn=act_fn, norm_num_groups=norm_num_groups, norm_type=norm_type, + mid_block_add_attention=mid_block_add_attention, ) @apply_forward_hook @@ -133,11 +138,13 @@ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOut @apply_forward_hook def decode( - self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True + self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None ) -> Union[DecoderOutput, torch.FloatTensor]: # also go through quantization layer if not force_not_quantize: quant, _, _ = self.quantize(h) + elif self.config.lookup_from_codebook: + quant = self.quantize.get_codebook_entry(h, shape) else: quant = h quant2 = self.post_quant_conv(quant) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 92839e596978..3bf67dfc1cdc 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -108,6 +108,7 @@ "VersatileDiffusionTextToImagePipeline", ] ) + _import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"] _import_structure["animatediff"] = ["AnimateDiffPipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm2"] = [ @@ -342,6 +343,7 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_transformers_objects import * else: + from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline from .animatediff import AnimateDiffPipeline from .audioldm import AudioLDMPipeline from .audioldm2 import ( diff --git a/src/diffusers/pipelines/amused/__init__.py b/src/diffusers/pipelines/amused/__init__.py new file mode 100644 index 000000000000..3c4d07a426b5 --- /dev/null +++ b/src/diffusers/pipelines/amused/__init__.py @@ -0,0 +1,62 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + is_torch_available, + is_transformers_available, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + AmusedImg2ImgPipeline, + AmusedInpaintPipeline, + AmusedPipeline, + ) + + _dummy_objects.update( + { + "AmusedPipeline": AmusedPipeline, + "AmusedImg2ImgPipeline": AmusedImg2ImgPipeline, + "AmusedInpaintPipeline": AmusedInpaintPipeline, + } + ) +else: + _import_structure["pipeline_amused"] = ["AmusedPipeline"] + _import_structure["pipeline_amused_img2img"] = ["AmusedImg2ImgPipeline"] + _import_structure["pipeline_amused_inpaint"] = ["AmusedInpaintPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available()): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import ( + AmusedPipeline, + ) + else: + from .pipeline_amused import AmusedPipeline + from .pipeline_amused_img2img import AmusedImg2ImgPipeline + from .pipeline_amused_inpaint import AmusedInpaintPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/amused/pipeline_amused.py b/src/diffusers/pipelines/amused/pipeline_amused.py new file mode 100644 index 000000000000..e93569c2302f --- /dev/null +++ b/src/diffusers/pipelines/amused/pipeline_amused.py @@ -0,0 +1,328 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import CLIPTextModelWithProjection, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...models import UVit2DModel, VQModel +from ...schedulers import AmusedScheduler +from ...utils import replace_example_docstring +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AmusedPipeline + + >>> pipe = AmusedPipeline.from_pretrained( + ... "huggingface/amused-512", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +class AmusedPipeline(DiffusionPipeline): + image_processor: VaeImageProcessor + vqvae: VQModel + tokenizer: CLIPTokenizer + text_encoder: CLIPTextModelWithProjection + transformer: UVit2DModel + scheduler: AmusedScheduler + + model_cpu_offload_seq = "text_encoder->transformer->vqvae" + + def __init__( + self, + vqvae: VQModel, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + transformer: UVit2DModel, + scheduler: AmusedScheduler, + ): + super().__init__() + + self.register_modules( + vqvae=vqvae, + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[List[str], str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 12, + guidance_scale: float = 10.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[torch.Generator] = None, + latents: Optional[torch.IntTensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_encoder_hidden_states: Optional[torch.Tensor] = None, + output_type="pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + micro_conditioning_aesthetic_score: int = 6, + micro_conditioning_crop_coord: Tuple[int, int] = (0, 0), + temperature: Union[int, Tuple[int, int], List[int]] = (2, 0), + ): + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.transformer.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 16): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 10.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.IntTensor`, *optional*): + Pre-generated tokens representing latent vectors in `self.vqvae`, to be used as inputs for image + gneration. If not provided, the starting latents will be completely masked. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. A single vector from the + pooled and projected final hidden states. + encoder_hidden_states (`torch.FloatTensor`, *optional*): + Pre-generated penultimate hidden states from the text encoder providing additional text conditioning. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_encoder_hidden_states (`torch.FloatTensor`, *optional*): + Analogous to `encoder_hidden_states` for the positive prompt. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6): + The targeted aesthetic score according to the laion aesthetic classifier. See https://laion.ai/blog/laion-aesthetics/ + and the micro-conditioning section of https://arxiv.org/abs/2307.01952. + micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)): + The targeted height, width crop coordinates. See the micro-conditioning section of https://arxiv.org/abs/2307.01952. + temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)): + Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`. + + Examples: + + Returns: + [`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a + `tuple` is returned where the first element is a list with the generated images. + """ + if (prompt_embeds is not None and encoder_hidden_states is None) or ( + prompt_embeds is None and encoder_hidden_states is not None + ): + raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither") + + if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or ( + negative_prompt_embeds is None and negative_encoder_hidden_states is not None + ): + raise ValueError( + "pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither" + ) + + if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None): + raise ValueError("pass only one of `prompt` or `prompt_embeds`") + + if isinstance(prompt, str): + prompt = [prompt] + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + batch_size = batch_size * num_images_per_prompt + + if height is None: + height = self.transformer.config.sample_size * self.vae_scale_factor + + if width is None: + width = self.transformer.config.sample_size * self.vae_scale_factor + + if prompt_embeds is None: + input_ids = self.tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + prompt_embeds = outputs.text_embeds + encoder_hidden_states = outputs.hidden_states[-2] + + prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1) + encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + if guidance_scale > 1.0: + if negative_prompt_embeds is None: + if negative_prompt is None: + negative_prompt = [""] * len(prompt) + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + input_ids = self.tokenizer( + negative_prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + negative_prompt_embeds = outputs.text_embeds + negative_encoder_hidden_states = outputs.hidden_states[-2] + + negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1) + negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds]) + encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states]) + + # Note that the micro conditionings _do_ flip the order of width, height for the original size + # and the crop coordinates. This is how it was done in the original code base + micro_conds = torch.tensor( + [ + width, + height, + micro_conditioning_crop_coord[0], + micro_conditioning_crop_coord[1], + micro_conditioning_aesthetic_score, + ], + device=self._execution_device, + dtype=encoder_hidden_states.dtype, + ) + micro_conds = micro_conds.unsqueeze(0) + micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1) + + shape = (batch_size, height // self.vae_scale_factor, width // self.vae_scale_factor) + + if latents is None: + latents = torch.full( + shape, self.scheduler.config.mask_token_id, dtype=torch.long, device=self._execution_device + ) + + self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device) + + num_warmup_steps = len(self.scheduler.timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, timestep in enumerate(self.scheduler.timesteps): + if guidance_scale > 1.0: + model_input = torch.cat([latents] * 2) + else: + model_input = latents + + model_output = self.transformer( + model_input, + micro_conds=micro_conds, + pooled_text_emb=prompt_embeds, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if guidance_scale > 1.0: + uncond_logits, cond_logits = model_output.chunk(2) + model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits) + + latents = self.scheduler.step( + model_output=model_output, + timestep=timestep, + sample=latents, + generator=generator, + ).prev_sample + + if i == len(self.scheduler.timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, timestep, latents) + + if output_type == "latent": + output = latents + else: + needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast + + if needs_upcasting: + self.vqvae.float() + + output = self.vqvae.decode( + latents, + force_not_quantize=True, + shape=( + batch_size, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + self.vqvae.config.latent_channels, + ), + ).sample.clip(0, 1) + output = self.image_processor.postprocess(output, output_type) + + if needs_upcasting: + self.vqvae.half() + + self.maybe_free_model_hooks() + + if not return_dict: + return (output,) + + return ImagePipelineOutput(output) diff --git a/src/diffusers/pipelines/amused/pipeline_amused_img2img.py b/src/diffusers/pipelines/amused/pipeline_amused_img2img.py new file mode 100644 index 000000000000..694b7c2229f3 --- /dev/null +++ b/src/diffusers/pipelines/amused/pipeline_amused_img2img.py @@ -0,0 +1,347 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import CLIPTextModelWithProjection, CLIPTokenizer + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...models import UVit2DModel, VQModel +from ...schedulers import AmusedScheduler +from ...utils import replace_example_docstring +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AmusedImg2ImgPipeline + >>> from diffusers.utils import load_image + + >>> pipe = AmusedImg2ImgPipeline.from_pretrained( + ... "huggingface/amused-512", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "winter mountains" + >>> input_image = ( + ... load_image( + ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg" + ... ) + ... .resize((512, 512)) + ... .convert("RGB") + ... ) + >>> image = pipe(prompt, input_image).images[0] + ``` +""" + + +class AmusedImg2ImgPipeline(DiffusionPipeline): + image_processor: VaeImageProcessor + vqvae: VQModel + tokenizer: CLIPTokenizer + text_encoder: CLIPTextModelWithProjection + transformer: UVit2DModel + scheduler: AmusedScheduler + + model_cpu_offload_seq = "text_encoder->transformer->vqvae" + + # TODO - when calling self.vqvae.quantize, it uses self.vqvae.quantize.embedding.weight before + # the forward method of self.vqvae.quantize, so the hook doesn't get called to move the parameter + # off the meta device. There should be a way to fix this instead of just not offloading it + _exclude_from_cpu_offload = ["vqvae"] + + def __init__( + self, + vqvae: VQModel, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + transformer: UVit2DModel, + scheduler: AmusedScheduler, + ): + super().__init__() + + self.register_modules( + vqvae=vqvae, + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[List[str], str]] = None, + image: PipelineImageInput = None, + strength: float = 0.5, + num_inference_steps: int = 12, + guidance_scale: float = 10.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[torch.Generator] = None, + prompt_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_encoder_hidden_states: Optional[torch.Tensor] = None, + output_type="pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + micro_conditioning_aesthetic_score: int = 6, + micro_conditioning_crop_coord: Tuple[int, int] = (0, 0), + temperature: Union[int, Tuple[int, int], List[int]] = (2, 0), + ): + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + strength (`float`, *optional*, defaults to 0.5): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 16): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 10.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. A single vector from the + pooled and projected final hidden states. + encoder_hidden_states (`torch.FloatTensor`, *optional*): + Pre-generated penultimate hidden states from the text encoder providing additional text conditioning. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_encoder_hidden_states (`torch.FloatTensor`, *optional*): + Analogous to `encoder_hidden_states` for the positive prompt. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6): + The targeted aesthetic score according to the laion aesthetic classifier. See https://laion.ai/blog/laion-aesthetics/ + and the micro-conditioning section of https://arxiv.org/abs/2307.01952. + micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)): + The targeted height, width crop coordinates. See the micro-conditioning section of https://arxiv.org/abs/2307.01952. + temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)): + Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`. + + Examples: + + Returns: + [`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a + `tuple` is returned where the first element is a list with the generated images. + """ + + if (prompt_embeds is not None and encoder_hidden_states is None) or ( + prompt_embeds is None and encoder_hidden_states is not None + ): + raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither") + + if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or ( + negative_prompt_embeds is None and negative_encoder_hidden_states is not None + ): + raise ValueError( + "pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither" + ) + + if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None): + raise ValueError("pass only one of `prompt` or `prompt_embeds`") + + if isinstance(prompt, str): + prompt = [prompt] + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + batch_size = batch_size * num_images_per_prompt + + if prompt_embeds is None: + input_ids = self.tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + prompt_embeds = outputs.text_embeds + encoder_hidden_states = outputs.hidden_states[-2] + + prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1) + encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + if guidance_scale > 1.0: + if negative_prompt_embeds is None: + if negative_prompt is None: + negative_prompt = [""] * len(prompt) + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + input_ids = self.tokenizer( + negative_prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + negative_prompt_embeds = outputs.text_embeds + negative_encoder_hidden_states = outputs.hidden_states[-2] + + negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1) + negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds]) + encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states]) + + image = self.image_processor.preprocess(image) + + height, width = image.shape[-2:] + + # Note that the micro conditionings _do_ flip the order of width, height for the original size + # and the crop coordinates. This is how it was done in the original code base + micro_conds = torch.tensor( + [ + width, + height, + micro_conditioning_crop_coord[0], + micro_conditioning_crop_coord[1], + micro_conditioning_aesthetic_score, + ], + device=self._execution_device, + dtype=encoder_hidden_states.dtype, + ) + + micro_conds = micro_conds.unsqueeze(0) + micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1) + + self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device) + num_inference_steps = int(len(self.scheduler.timesteps) * strength) + start_timestep_idx = len(self.scheduler.timesteps) - num_inference_steps + + needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast + + if needs_upcasting: + self.vqvae.float() + + latents = self.vqvae.encode(image.to(dtype=self.vqvae.dtype, device=self._execution_device)).latents + latents_bsz, channels, latents_height, latents_width = latents.shape + latents = self.vqvae.quantize(latents)[2][2].reshape(latents_bsz, latents_height, latents_width) + latents = self.scheduler.add_noise( + latents, self.scheduler.timesteps[start_timestep_idx - 1], generator=generator + ) + latents = latents.repeat(num_images_per_prompt, 1, 1) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i in range(start_timestep_idx, len(self.scheduler.timesteps)): + timestep = self.scheduler.timesteps[i] + + if guidance_scale > 1.0: + model_input = torch.cat([latents] * 2) + else: + model_input = latents + + model_output = self.transformer( + model_input, + micro_conds=micro_conds, + pooled_text_emb=prompt_embeds, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if guidance_scale > 1.0: + uncond_logits, cond_logits = model_output.chunk(2) + model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits) + + latents = self.scheduler.step( + model_output=model_output, + timestep=timestep, + sample=latents, + generator=generator, + ).prev_sample + + if i == len(self.scheduler.timesteps) - 1 or ((i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, timestep, latents) + + if output_type == "latent": + output = latents + else: + output = self.vqvae.decode( + latents, + force_not_quantize=True, + shape=( + batch_size, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + self.vqvae.config.latent_channels, + ), + ).sample.clip(0, 1) + output = self.image_processor.postprocess(output, output_type) + + if needs_upcasting: + self.vqvae.half() + + self.maybe_free_model_hooks() + + if not return_dict: + return (output,) + + return ImagePipelineOutput(output) diff --git a/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py b/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py new file mode 100644 index 000000000000..a4c5644c961c --- /dev/null +++ b/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py @@ -0,0 +1,378 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import CLIPTextModelWithProjection, CLIPTokenizer + +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...models import UVit2DModel, VQModel +from ...schedulers import AmusedScheduler +from ...utils import replace_example_docstring +from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import AmusedInpaintPipeline + >>> from diffusers.utils import load_image + + >>> pipe = AmusedInpaintPipeline.from_pretrained( + ... "huggingface/amused-512", variant="fp16", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "fall mountains" + >>> input_image = ( + ... load_image( + ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg" + ... ) + ... .resize((512, 512)) + ... .convert("RGB") + ... ) + >>> mask = ( + ... load_image( + ... "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png" + ... ) + ... .resize((512, 512)) + ... .convert("L") + ... ) + >>> pipe(prompt, input_image, mask).images[0].save("out.png") + ``` +""" + + +class AmusedInpaintPipeline(DiffusionPipeline): + image_processor: VaeImageProcessor + vqvae: VQModel + tokenizer: CLIPTokenizer + text_encoder: CLIPTextModelWithProjection + transformer: UVit2DModel + scheduler: AmusedScheduler + + model_cpu_offload_seq = "text_encoder->transformer->vqvae" + + # TODO - when calling self.vqvae.quantize, it uses self.vqvae.quantize.embedding.weight before + # the forward method of self.vqvae.quantize, so the hook doesn't get called to move the parameter + # off the meta device. There should be a way to fix this instead of just not offloading it + _exclude_from_cpu_offload = ["vqvae"] + + def __init__( + self, + vqvae: VQModel, + tokenizer: CLIPTokenizer, + text_encoder: CLIPTextModelWithProjection, + transformer: UVit2DModel, + scheduler: AmusedScheduler, + ): + super().__init__() + + self.register_modules( + vqvae=vqvae, + tokenizer=tokenizer, + text_encoder=text_encoder, + transformer=transformer, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vqvae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, + do_normalize=False, + do_binarize=True, + do_convert_grayscale=True, + do_resize=True, + ) + self.scheduler.register_to_config(masking_schedule="linear") + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Optional[Union[List[str], str]] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + strength: float = 1.0, + num_inference_steps: int = 12, + guidance_scale: float = 10.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + generator: Optional[torch.Generator] = None, + prompt_embeds: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + negative_encoder_hidden_states: Optional[torch.Tensor] = None, + output_type="pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + micro_conditioning_aesthetic_score: int = 6, + micro_conditioning_crop_coord: Tuple[int, int] = (0, 0), + temperature: Union[int, Tuple[int, int], List[int]] = (2, 0), + ): + """ + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + mask_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 16): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 10.0): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + generator (`torch.Generator`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. A single vector from the + pooled and projected final hidden states. + encoder_hidden_states (`torch.FloatTensor`, *optional*): + Pre-generated penultimate hidden states from the text encoder providing additional text conditioning. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + negative_encoder_hidden_states (`torch.FloatTensor`, *optional*): + Analogous to `encoder_hidden_states` for the positive prompt. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + micro_conditioning_aesthetic_score (`int`, *optional*, defaults to 6): + The targeted aesthetic score according to the laion aesthetic classifier. See https://laion.ai/blog/laion-aesthetics/ + and the micro-conditioning section of https://arxiv.org/abs/2307.01952. + micro_conditioning_crop_coord (`Tuple[int]`, *optional*, defaults to (0, 0)): + The targeted height, width crop coordinates. See the micro-conditioning section of https://arxiv.org/abs/2307.01952. + temperature (`Union[int, Tuple[int, int], List[int]]`, *optional*, defaults to (2, 0)): + Configures the temperature scheduler on `self.scheduler` see `AmusedScheduler#set_timesteps`. + + Examples: + + Returns: + [`~pipelines.pipeline_utils.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.pipeline_utils.ImagePipelineOutput`] is returned, otherwise a + `tuple` is returned where the first element is a list with the generated images. + """ + + if (prompt_embeds is not None and encoder_hidden_states is None) or ( + prompt_embeds is None and encoder_hidden_states is not None + ): + raise ValueError("pass either both `prompt_embeds` and `encoder_hidden_states` or neither") + + if (negative_prompt_embeds is not None and negative_encoder_hidden_states is None) or ( + negative_prompt_embeds is None and negative_encoder_hidden_states is not None + ): + raise ValueError( + "pass either both `negatve_prompt_embeds` and `negative_encoder_hidden_states` or neither" + ) + + if (prompt is None and prompt_embeds is None) or (prompt is not None and prompt_embeds is not None): + raise ValueError("pass only one of `prompt` or `prompt_embeds`") + + if isinstance(prompt, str): + prompt = [prompt] + + if prompt is not None: + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + batch_size = batch_size * num_images_per_prompt + + if prompt_embeds is None: + input_ids = self.tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + prompt_embeds = outputs.text_embeds + encoder_hidden_states = outputs.hidden_states[-2] + + prompt_embeds = prompt_embeds.repeat(num_images_per_prompt, 1) + encoder_hidden_states = encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + if guidance_scale > 1.0: + if negative_prompt_embeds is None: + if negative_prompt is None: + negative_prompt = [""] * len(prompt) + + if isinstance(negative_prompt, str): + negative_prompt = [negative_prompt] + + input_ids = self.tokenizer( + negative_prompt, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + ).input_ids.to(self._execution_device) + + outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True) + negative_prompt_embeds = outputs.text_embeds + negative_encoder_hidden_states = outputs.hidden_states[-2] + + negative_prompt_embeds = negative_prompt_embeds.repeat(num_images_per_prompt, 1) + negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(num_images_per_prompt, 1, 1) + + prompt_embeds = torch.concat([negative_prompt_embeds, prompt_embeds]) + encoder_hidden_states = torch.concat([negative_encoder_hidden_states, encoder_hidden_states]) + + image = self.image_processor.preprocess(image) + + height, width = image.shape[-2:] + + # Note that the micro conditionings _do_ flip the order of width, height for the original size + # and the crop coordinates. This is how it was done in the original code base + micro_conds = torch.tensor( + [ + width, + height, + micro_conditioning_crop_coord[0], + micro_conditioning_crop_coord[1], + micro_conditioning_aesthetic_score, + ], + device=self._execution_device, + dtype=encoder_hidden_states.dtype, + ) + + micro_conds = micro_conds.unsqueeze(0) + micro_conds = micro_conds.expand(2 * batch_size if guidance_scale > 1.0 else batch_size, -1) + + self.scheduler.set_timesteps(num_inference_steps, temperature, self._execution_device) + num_inference_steps = int(len(self.scheduler.timesteps) * strength) + start_timestep_idx = len(self.scheduler.timesteps) - num_inference_steps + + needs_upcasting = self.vqvae.dtype == torch.float16 and self.vqvae.config.force_upcast + + if needs_upcasting: + self.vqvae.float() + + latents = self.vqvae.encode(image.to(dtype=self.vqvae.dtype, device=self._execution_device)).latents + latents_bsz, channels, latents_height, latents_width = latents.shape + latents = self.vqvae.quantize(latents)[2][2].reshape(latents_bsz, latents_height, latents_width) + + mask = self.mask_processor.preprocess( + mask_image, height // self.vae_scale_factor, width // self.vae_scale_factor + ) + mask = mask.reshape(mask.shape[0], latents_height, latents_width).bool().to(latents.device) + latents[mask] = self.scheduler.config.mask_token_id + + starting_mask_ratio = mask.sum() / latents.numel() + + latents = latents.repeat(num_images_per_prompt, 1, 1) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i in range(start_timestep_idx, len(self.scheduler.timesteps)): + timestep = self.scheduler.timesteps[i] + + if guidance_scale > 1.0: + model_input = torch.cat([latents] * 2) + else: + model_input = latents + + model_output = self.transformer( + model_input, + micro_conds=micro_conds, + pooled_text_emb=prompt_embeds, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) + + if guidance_scale > 1.0: + uncond_logits, cond_logits = model_output.chunk(2) + model_output = uncond_logits + guidance_scale * (cond_logits - uncond_logits) + + latents = self.scheduler.step( + model_output=model_output, + timestep=timestep, + sample=latents, + generator=generator, + starting_mask_ratio=starting_mask_ratio, + ).prev_sample + + if i == len(self.scheduler.timesteps) - 1 or ((i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, timestep, latents) + + if output_type == "latent": + output = latents + else: + output = self.vqvae.decode( + latents, + force_not_quantize=True, + shape=( + batch_size, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + self.vqvae.config.latent_channels, + ), + ).sample.clip(0, 1) + output = self.image_processor.postprocess(output, output_type) + + if needs_upcasting: + self.vqvae.half() + + self.maybe_free_model_hooks() + + if not return_dict: + return (output,) + + return ImagePipelineOutput(output) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index 40c435dd5637..e908ba87acdd 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -39,6 +39,7 @@ else: _import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"] + _import_structure["scheduling_amused"] = ["AmusedScheduler"] _import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"] _import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"] _import_structure["scheduling_ddim"] = ["DDIMScheduler"] @@ -129,6 +130,7 @@ from ..utils.dummy_pt_objects import * # noqa F403 else: from .deprecated import KarrasVeScheduler, ScoreSdeVpScheduler + from .scheduling_amused import AmusedScheduler from .scheduling_consistency_decoder import ConsistencyDecoderScheduler from .scheduling_consistency_models import CMStochasticIterativeScheduler from .scheduling_ddim import DDIMScheduler diff --git a/src/diffusers/schedulers/scheduling_amused.py b/src/diffusers/schedulers/scheduling_amused.py new file mode 100644 index 000000000000..51fbe6a4dc7d --- /dev/null +++ b/src/diffusers/schedulers/scheduling_amused.py @@ -0,0 +1,162 @@ +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import BaseOutput +from .scheduling_utils import SchedulerMixin + + +def gumbel_noise(t, generator=None): + device = generator.device if generator is not None else t.device + noise = torch.zeros_like(t, device=device).uniform_(0, 1, generator=generator).to(t.device) + return -torch.log((-torch.log(noise.clamp(1e-20))).clamp(1e-20)) + + +def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None): + confidence = torch.log(probs.clamp(1e-20)) + temperature * gumbel_noise(probs, generator=generator) + sorted_confidence = torch.sort(confidence, dim=-1).values + cut_off = torch.gather(sorted_confidence, 1, mask_len.long()) + masking = confidence < cut_off + return masking + + +@dataclass +class AmusedSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample `(x_{0})` based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: torch.FloatTensor = None + + +class AmusedScheduler(SchedulerMixin, ConfigMixin): + order = 1 + + temperatures: torch.Tensor + + @register_to_config + def __init__( + self, + mask_token_id: int, + masking_schedule: str = "cosine", + ): + self.temperatures = None + self.timesteps = None + + def set_timesteps( + self, + num_inference_steps: int, + temperature: Union[int, Tuple[int, int], List[int]] = (2, 0), + device: Union[str, torch.device] = None, + ): + self.timesteps = torch.arange(num_inference_steps, device=device).flip(0) + + if isinstance(temperature, (tuple, list)): + self.temperatures = torch.linspace(temperature[0], temperature[1], num_inference_steps, device=device) + else: + self.temperatures = torch.linspace(temperature, 0.01, num_inference_steps, device=device) + + def step( + self, + model_output: torch.FloatTensor, + timestep: torch.long, + sample: torch.LongTensor, + starting_mask_ratio: int = 1, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[AmusedSchedulerOutput, Tuple]: + two_dim_input = sample.ndim == 3 and model_output.ndim == 4 + + if two_dim_input: + batch_size, codebook_size, height, width = model_output.shape + sample = sample.reshape(batch_size, height * width) + model_output = model_output.reshape(batch_size, codebook_size, height * width).permute(0, 2, 1) + + unknown_map = sample == self.config.mask_token_id + + probs = model_output.softmax(dim=-1) + + device = probs.device + probs_ = probs.to(generator.device) if generator is not None else probs # handles when generator is on CPU + if probs_.device.type == "cpu" and probs_.dtype != torch.float32: + probs_ = probs_.float() # multinomial is not implemented for cpu half precision + probs_ = probs_.reshape(-1, probs.size(-1)) + pred_original_sample = torch.multinomial(probs_, 1, generator=generator).to(device=device) + pred_original_sample = pred_original_sample[:, 0].view(*probs.shape[:-1]) + pred_original_sample = torch.where(unknown_map, pred_original_sample, sample) + + if timestep == 0: + prev_sample = pred_original_sample + else: + seq_len = sample.shape[1] + step_idx = (self.timesteps == timestep).nonzero() + ratio = (step_idx + 1) / len(self.timesteps) + + if self.config.masking_schedule == "cosine": + mask_ratio = torch.cos(ratio * math.pi / 2) + elif self.config.masking_schedule == "linear": + mask_ratio = 1 - ratio + else: + raise ValueError(f"unknown masking schedule {self.config.masking_schedule}") + + mask_ratio = starting_mask_ratio * mask_ratio + + mask_len = (seq_len * mask_ratio).floor() + # do not mask more than amount previously masked + mask_len = torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) + # mask at least one + mask_len = torch.max(torch.tensor([1], device=model_output.device), mask_len) + + selected_probs = torch.gather(probs, -1, pred_original_sample[:, :, None])[:, :, 0] + # Ignores the tokens given in the input by overwriting their confidence. + selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) + + masking = mask_by_random_topk(mask_len, selected_probs, self.temperatures[step_idx], generator) + + # Masks tokens with lower confidence. + prev_sample = torch.where(masking, self.config.mask_token_id, pred_original_sample) + + if two_dim_input: + prev_sample = prev_sample.reshape(batch_size, height, width) + pred_original_sample = pred_original_sample.reshape(batch_size, height, width) + + if not return_dict: + return (prev_sample, pred_original_sample) + + return AmusedSchedulerOutput(prev_sample, pred_original_sample) + + def add_noise(self, sample, timesteps, generator=None): + step_idx = (self.timesteps == timesteps).nonzero() + ratio = (step_idx + 1) / len(self.timesteps) + + if self.config.masking_schedule == "cosine": + mask_ratio = torch.cos(ratio * math.pi / 2) + elif self.config.masking_schedule == "linear": + mask_ratio = 1 - ratio + else: + raise ValueError(f"unknown masking schedule {self.config.masking_schedule}") + + mask_indices = ( + torch.rand( + sample.shape, device=generator.device if generator is not None else sample.device, generator=generator + ).to(sample.device) + < mask_ratio + ) + + masked_sample = sample.clone() + + masked_sample[mask_indices] = self.config.mask_token_id + + return masked_sample diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 263bcaea5a8d..5bd2f493ce08 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -317,6 +317,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class UVit2DModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class VQModel(metaclass=DummyObject): _backends = ["torch"] @@ -660,6 +675,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AmusedScheduler(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class CMStochasticIterativeScheduler(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 89fa03e57287..ae6c6c916065 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -32,6 +32,51 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class AmusedImg2ImgPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AmusedInpaintPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class AmusedPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class AnimateDiffPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/pipelines/amused/__init__.py b/tests/pipelines/amused/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/amused/test_amused.py b/tests/pipelines/amused/test_amused.py new file mode 100644 index 000000000000..38159cf2ac15 --- /dev/null +++ b/tests/pipelines/amused/test_amused.py @@ -0,0 +1,181 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import AmusedPipeline, AmusedScheduler, UVit2DModel, VQModel +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device + +from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class AmusedPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = AmusedPipeline + params = TEXT_TO_IMAGE_PARAMS | {"encoder_hidden_states", "negative_encoder_hidden_states"} + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = UVit2DModel( + hidden_size=32, + use_bias=False, + hidden_dropout=0.0, + cond_embed_dim=32, + micro_cond_encode_dim=2, + micro_cond_embed_dim=10, + encoder_hidden_size=32, + vocab_size=32, + codebook_size=32, + in_channels=32, + block_out_channels=32, + num_res_blocks=1, + downsample=True, + upsample=True, + block_num_heads=1, + num_hidden_layers=1, + num_attention_heads=1, + attention_dropout=0.0, + intermediate_size=32, + layer_norm_eps=1e-06, + ln_elementwise_affine=True, + ) + scheduler = AmusedScheduler(mask_token_id=31) + torch.manual_seed(0) + vqvae = VQModel( + act_fn="silu", + block_out_channels=[32], + down_block_types=[ + "DownEncoderBlock2D", + ], + in_channels=3, + latent_channels=32, + layers_per_block=2, + norm_num_groups=32, + num_vq_embeddings=32, + out_channels=3, + sample_size=32, + up_block_types=[ + "UpDecoderBlock2D", + ], + mid_block_add_attention=False, + lookup_from_codebook=True, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=64, + layer_norm_eps=1e-05, + num_attention_heads=8, + num_hidden_layers=3, + pad_token_id=1, + vocab_size=1000, + projection_dim=32, + ) + text_encoder = CLIPTextModelWithProjection(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "transformer": transformer, + "scheduler": scheduler, + "vqvae": vqvae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "output_type": "np", + "height": 4, + "width": 4, + } + return inputs + + def test_inference_batch_consistent(self, batch_sizes=[2]): + self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False) + + @unittest.skip("aMUSEd does not support lists of generators") + def test_inference_batch_single_identical(self): + ... + + +@slow +@require_torch_gpu +class AmusedPipelineSlowTests(unittest.TestCase): + def test_amused_256(self): + pipe = AmusedPipeline.from_pretrained("huggingface/amused-256") + pipe.to(torch_device) + + image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images + + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.4011, 0.3992, 0.3790, 0.3856, 0.3772, 0.3711, 0.3919, 0.3850, 0.3625]) + assert np.abs(image_slice - expected_slice).max() < 3e-3 + + def test_amused_256_fp16(self): + pipe = AmusedPipeline.from_pretrained("huggingface/amused-256", variant="fp16", torch_dtype=torch.float16) + pipe.to(torch_device) + + image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images + + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.0554, 0.05129, 0.0344, 0.0452, 0.0476, 0.0271, 0.0495, 0.0527, 0.0158]) + assert np.abs(image_slice - expected_slice).max() < 7e-3 + + def test_amused_512(self): + pipe = AmusedPipeline.from_pretrained("huggingface/amused-512") + pipe.to(torch_device) + + image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images + + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.9960, 0.9960, 0.9946, 0.9980, 0.9947, 0.9932, 0.9960, 0.9961, 0.9947]) + assert np.abs(image_slice - expected_slice).max() < 3e-3 + + def test_amused_512_fp16(self): + pipe = AmusedPipeline.from_pretrained("huggingface/amused-512", variant="fp16", torch_dtype=torch.float16) + pipe.to(torch_device) + + image = pipe("dog", generator=torch.Generator().manual_seed(0), num_inference_steps=2, output_type="np").images + + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.9983, 1.0, 1.0, 1.0, 1.0, 0.9989, 0.9994, 0.9976, 0.9977]) + assert np.abs(image_slice - expected_slice).max() < 3e-3 diff --git a/tests/pipelines/amused/test_amused_img2img.py b/tests/pipelines/amused/test_amused_img2img.py new file mode 100644 index 000000000000..dcd29ae88e5b --- /dev/null +++ b/tests/pipelines/amused/test_amused_img2img.py @@ -0,0 +1,239 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import AmusedImg2ImgPipeline, AmusedScheduler, UVit2DModel, VQModel +from diffusers.utils import load_image +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device + +from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class AmusedImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = AmusedImg2ImgPipeline + params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "latents"} + batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS + required_optional_params = PipelineTesterMixin.required_optional_params - { + "latents", + } + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = UVit2DModel( + hidden_size=32, + use_bias=False, + hidden_dropout=0.0, + cond_embed_dim=32, + micro_cond_encode_dim=2, + micro_cond_embed_dim=10, + encoder_hidden_size=32, + vocab_size=32, + codebook_size=32, + in_channels=32, + block_out_channels=32, + num_res_blocks=1, + downsample=True, + upsample=True, + block_num_heads=1, + num_hidden_layers=1, + num_attention_heads=1, + attention_dropout=0.0, + intermediate_size=32, + layer_norm_eps=1e-06, + ln_elementwise_affine=True, + ) + scheduler = AmusedScheduler(mask_token_id=31) + torch.manual_seed(0) + vqvae = VQModel( + act_fn="silu", + block_out_channels=[32], + down_block_types=[ + "DownEncoderBlock2D", + ], + in_channels=3, + latent_channels=32, + layers_per_block=2, + norm_num_groups=32, + num_vq_embeddings=32, + out_channels=3, + sample_size=32, + up_block_types=[ + "UpDecoderBlock2D", + ], + mid_block_add_attention=False, + lookup_from_codebook=True, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=64, + layer_norm_eps=1e-05, + num_attention_heads=8, + num_hidden_layers=3, + pad_token_id=1, + vocab_size=1000, + projection_dim=32, + ) + text_encoder = CLIPTextModelWithProjection(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "transformer": transformer, + "scheduler": scheduler, + "vqvae": vqvae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + image = torch.full((1, 3, 4, 4), 1.0, dtype=torch.float32, device=device) + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "output_type": "np", + "image": image, + } + return inputs + + def test_inference_batch_consistent(self, batch_sizes=[2]): + self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False) + + @unittest.skip("aMUSEd does not support lists of generators") + def test_inference_batch_single_identical(self): + ... + + +@slow +@require_torch_gpu +class AmusedImg2ImgPipelineSlowTests(unittest.TestCase): + def test_amused_256(self): + pipe = AmusedImg2ImgPipeline.from_pretrained("huggingface/amused-256") + pipe.to(torch_device) + + image = ( + load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg") + .resize((256, 256)) + .convert("RGB") + ) + + image = pipe( + "winter mountains", + image, + generator=torch.Generator().manual_seed(0), + num_inference_steps=2, + output_type="np", + ).images + + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.9993, 1.0, 0.9996, 1.0, 0.9995, 0.9925, 0.9990, 0.9954, 1.0]) + + assert np.abs(image_slice - expected_slice).max() < 1e-2 + + def test_amused_256_fp16(self): + pipe = AmusedImg2ImgPipeline.from_pretrained( + "huggingface/amused-256", torch_dtype=torch.float16, variant="fp16" + ) + pipe.to(torch_device) + + image = ( + load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg") + .resize((256, 256)) + .convert("RGB") + ) + + image = pipe( + "winter mountains", + image, + generator=torch.Generator().manual_seed(0), + num_inference_steps=2, + output_type="np", + ).images + + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.9980, 0.9980, 0.9940, 0.9944, 0.9960, 0.9908, 1.0, 1.0, 0.9986]) + + assert np.abs(image_slice - expected_slice).max() < 1e-2 + + def test_amused_512(self): + pipe = AmusedImg2ImgPipeline.from_pretrained("huggingface/amused-512") + pipe.to(torch_device) + + image = ( + load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg") + .resize((512, 512)) + .convert("RGB") + ) + + image = pipe( + "winter mountains", + image, + generator=torch.Generator().manual_seed(0), + num_inference_steps=2, + output_type="np", + ).images + + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.1344, 0.0985, 0.0, 0.1194, 0.1809, 0.0765, 0.0854, 0.1371, 0.0933]) + assert np.abs(image_slice - expected_slice).max() < 0.1 + + def test_amused_512_fp16(self): + pipe = AmusedImg2ImgPipeline.from_pretrained( + "huggingface/amused-512", variant="fp16", torch_dtype=torch.float16 + ) + pipe.to(torch_device) + + image = ( + load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains.jpg") + .resize((512, 512)) + .convert("RGB") + ) + + image = pipe( + "winter mountains", + image, + generator=torch.Generator().manual_seed(0), + num_inference_steps=2, + output_type="np", + ).images + + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.1536, 0.1767, 0.0227, 0.1079, 0.2400, 0.1427, 0.1511, 0.1564, 0.1542]) + assert np.abs(image_slice - expected_slice).max() < 0.1 diff --git a/tests/pipelines/amused/test_amused_inpaint.py b/tests/pipelines/amused/test_amused_inpaint.py new file mode 100644 index 000000000000..014485d7b9e4 --- /dev/null +++ b/tests/pipelines/amused/test_amused_inpaint.py @@ -0,0 +1,277 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np +import torch +from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import AmusedInpaintPipeline, AmusedScheduler, UVit2DModel, VQModel +from diffusers.utils import load_image +from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device + +from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class AmusedInpaintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = AmusedInpaintPipeline + params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS - {"width", "height"} + batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS + required_optional_params = PipelineTesterMixin.required_optional_params - { + "latents", + } + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = UVit2DModel( + hidden_size=32, + use_bias=False, + hidden_dropout=0.0, + cond_embed_dim=32, + micro_cond_encode_dim=2, + micro_cond_embed_dim=10, + encoder_hidden_size=32, + vocab_size=32, + codebook_size=32, + in_channels=32, + block_out_channels=32, + num_res_blocks=1, + downsample=True, + upsample=True, + block_num_heads=1, + num_hidden_layers=1, + num_attention_heads=1, + attention_dropout=0.0, + intermediate_size=32, + layer_norm_eps=1e-06, + ln_elementwise_affine=True, + ) + scheduler = AmusedScheduler(mask_token_id=31) + torch.manual_seed(0) + vqvae = VQModel( + act_fn="silu", + block_out_channels=[32], + down_block_types=[ + "DownEncoderBlock2D", + ], + in_channels=3, + latent_channels=32, + layers_per_block=2, + norm_num_groups=32, + num_vq_embeddings=32, + out_channels=3, + sample_size=32, + up_block_types=[ + "UpDecoderBlock2D", + ], + mid_block_add_attention=False, + lookup_from_codebook=True, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=64, + layer_norm_eps=1e-05, + num_attention_heads=8, + num_hidden_layers=3, + pad_token_id=1, + vocab_size=1000, + projection_dim=32, + ) + text_encoder = CLIPTextModelWithProjection(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + components = { + "transformer": transformer, + "scheduler": scheduler, + "vqvae": vqvae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + image = torch.full((1, 3, 4, 4), 1.0, dtype=torch.float32, device=device) + mask_image = torch.full((1, 1, 4, 4), 1.0, dtype=torch.float32, device=device) + mask_image[0, 0, 0, 0] = 0 + mask_image[0, 0, 0, 1] = 0 + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "output_type": "np", + "image": image, + "mask_image": mask_image, + } + return inputs + + def test_inference_batch_consistent(self, batch_sizes=[2]): + self._test_inference_batch_consistent(batch_sizes=batch_sizes, batch_generator=False) + + @unittest.skip("aMUSEd does not support lists of generators") + def test_inference_batch_single_identical(self): + ... + + +@slow +@require_torch_gpu +class AmusedInpaintPipelineSlowTests(unittest.TestCase): + def test_amused_256(self): + pipe = AmusedInpaintPipeline.from_pretrained("huggingface/amused-256") + pipe.to(torch_device) + + image = ( + load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg") + .resize((256, 256)) + .convert("RGB") + ) + + mask_image = ( + load_image( + "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png" + ) + .resize((256, 256)) + .convert("L") + ) + + image = pipe( + "winter mountains", + image, + mask_image, + generator=torch.Generator().manual_seed(0), + num_inference_steps=2, + output_type="np", + ).images + + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.0699, 0.0716, 0.0608, 0.0715, 0.0797, 0.0638, 0.0802, 0.0924, 0.0634]) + assert np.abs(image_slice - expected_slice).max() < 0.1 + + def test_amused_256_fp16(self): + pipe = AmusedInpaintPipeline.from_pretrained( + "huggingface/amused-256", variant="fp16", torch_dtype=torch.float16 + ) + pipe.to(torch_device) + + image = ( + load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg") + .resize((256, 256)) + .convert("RGB") + ) + + mask_image = ( + load_image( + "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png" + ) + .resize((256, 256)) + .convert("L") + ) + + image = pipe( + "winter mountains", + image, + mask_image, + generator=torch.Generator().manual_seed(0), + num_inference_steps=2, + output_type="np", + ).images + + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 256, 256, 3) + expected_slice = np.array([0.0735, 0.0749, 0.0650, 0.0739, 0.0805, 0.0667, 0.0802, 0.0923, 0.0622]) + assert np.abs(image_slice - expected_slice).max() < 0.1 + + def test_amused_512(self): + pipe = AmusedInpaintPipeline.from_pretrained("huggingface/amused-512") + pipe.to(torch_device) + + image = ( + load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg") + .resize((512, 512)) + .convert("RGB") + ) + + mask_image = ( + load_image( + "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png" + ) + .resize((512, 512)) + .convert("L") + ) + + image = pipe( + "winter mountains", + image, + mask_image, + generator=torch.Generator().manual_seed(0), + num_inference_steps=2, + output_type="np", + ).images + + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0005, 0.0]) + assert np.abs(image_slice - expected_slice).max() < 0.05 + + def test_amused_512_fp16(self): + pipe = AmusedInpaintPipeline.from_pretrained( + "huggingface/amused-512", variant="fp16", torch_dtype=torch.float16 + ) + pipe.to(torch_device) + + image = ( + load_image("https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1.jpg") + .resize((512, 512)) + .convert("RGB") + ) + + mask_image = ( + load_image( + "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/open_muse/mountains_1_mask.png" + ) + .resize((512, 512)) + .convert("L") + ) + + image = pipe( + "winter mountains", + image, + mask_image, + generator=torch.Generator().manual_seed(0), + num_inference_steps=2, + output_type="np", + ).images + + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 512, 512, 3) + expected_slice = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0025, 0.0]) + assert np.abs(image_slice - expected_slice).max() < 3e-3 diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index cac5ee442ae6..ed2920cb0c73 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -437,7 +437,7 @@ def test_inference_batch_consistent(self, batch_sizes=[2]): self._test_inference_batch_consistent(batch_sizes=batch_sizes) def _test_inference_batch_consistent( - self, batch_sizes=[2], additional_params_copy_to_batched_inputs=["num_inference_steps"] + self, batch_sizes=[2], additional_params_copy_to_batched_inputs=["num_inference_steps"], batch_generator=True ): components = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -472,7 +472,7 @@ def _test_inference_batch_consistent( else: batched_input[name] = batch_size * [value] - if "generator" in inputs: + if batch_generator and "generator" in inputs: batched_input["generator"] = [self.get_generator(i) for i in range(batch_size)] if "batch_size" in inputs: