From ecfadbceb2a90ad60cc26eb7a4440d407ec0baa7 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 23 Dec 2023 15:00:37 -0800 Subject: [PATCH 01/53] Initial commit for ADD SD distillation script. --- examples/add/requirements.txt | 8 + examples/add/train_add_distill_sd_wds.py | 1674 ++++++++++++++++++++++ 2 files changed, 1682 insertions(+) create mode 100644 examples/add/requirements.txt create mode 100644 examples/add/train_add_distill_sd_wds.py diff --git a/examples/add/requirements.txt b/examples/add/requirements.txt new file mode 100644 index 000000000000..9136a0dbc774 --- /dev/null +++ b/examples/add/requirements.txt @@ -0,0 +1,8 @@ +accelerate>=0.16.0 +torchvision +transformers>=4.25.1 +ftfy +tensorboard +Jinja2 +webdataset +timm \ No newline at end of file diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py new file mode 100644 index 000000000000..0f49b271dd19 --- /dev/null +++ b/examples/add/train_add_distill_sd_wds.py @@ -0,0 +1,1674 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import functools +import gc +import itertools +import json +import logging +import math +import os +import random +import shutil +import types +from pathlib import Path +from typing import Callable, List, Union + +import accelerate +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import torchvision.transforms.functional as TF +import timm +import transformers +import webdataset as wds +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from braceexpand import braceexpand +from huggingface_hub import create_repo +from packaging import version +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torch.utils.data import default_collate +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, CLIPTextModel, PretrainedConfig +from webdataset.tariterators import ( + base_plus_ext, + tar_file_expander, + url_opener, + valid_sample, +) + +import diffusers +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from diffusers.utils import BaseOutput, check_min_version, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available + + +MAX_SEQ_LENGTH = 77 + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.25.0.dev0") + +logger = get_logger(__name__) + + +def filter_keys(key_set): + def _f(dictionary): + return {k: v for k, v in dictionary.items() if k in key_set} + + return _f + + +def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): + """Return function over iterator that groups key, value pairs into samples. + + :param keys: function that splits the key into key and extension (base_plus_ext) :param lcase: convert suffixes to + lower case (Default value = True) + """ + current_sample = None + for filesample in data: + assert isinstance(filesample, dict) + fname, value = filesample["fname"], filesample["data"] + prefix, suffix = keys(fname) + if prefix is None: + continue + if lcase: + suffix = suffix.lower() + # FIXME webdataset version throws if suffix in current_sample, but we have a potential for + # this happening in the current LAION400m dataset if a tar ends with same prefix as the next + # begins, rare, but can happen since prefix aren't unique across tar files in that dataset + if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample: + if valid_sample(current_sample): + yield current_sample + current_sample = {"__key__": prefix, "__url__": filesample["__url__"]} + if suffixes is None or suffix in suffixes: + current_sample[suffix] = value + if valid_sample(current_sample): + yield current_sample + + +def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue): + # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw + streams = url_opener(src, handler=handler) + files = tar_file_expander(streams, handler=handler) + samples = group_by_keys_nothrow(files, handler=handler) + return samples + + +class WebdatasetFilter: + def __init__(self, min_size=1024, max_pwatermark=0.5): + self.min_size = min_size + self.max_pwatermark = max_pwatermark + + def __call__(self, x): + try: + if "json" in x: + x_json = json.loads(x["json"]) + filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get( + "original_height", 0 + ) >= self.min_size + filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark + return filter_size and filter_watermark + else: + return False + except Exception: + return False + + +class SDText2ImageDataset: + def __init__( + self, + train_shards_path_or_url: Union[str, List[str]], + num_train_examples: int, + per_gpu_batch_size: int, + global_batch_size: int, + num_workers: int, + resolution: int = 512, + shuffle_buffer_size: int = 1000, + pin_memory: bool = False, + persistent_workers: bool = False, + ): + if not isinstance(train_shards_path_or_url, str): + train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url] + # flatten list using itertools + train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url)) + + def transform(example): + # resize image + image = example["image"] + image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) + + # get crop coordinates and crop image + c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) + image = TF.crop(image, c_top, c_left, resolution, resolution) + image = TF.to_tensor(image) + image = TF.normalize(image, [0.5], [0.5]) + + example["image"] = image + return example + + processing_pipeline = [ + wds.decode("pil", handler=wds.ignore_and_continue), + wds.rename(image="jpg;png;jpeg;webp", text="text;txt;caption", handler=wds.warn_and_continue), + wds.map(filter_keys({"image", "text"})), + wds.map(transform), + wds.to_tuple("image", "text"), + ] + + # Create train dataset and loader + pipeline = [ + wds.ResampledShards(train_shards_path_or_url), + tarfile_to_samples_nothrow, + wds.shuffle(shuffle_buffer_size), + *processing_pipeline, + wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate), + ] + + num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + + # each worker is iterating over this + self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches) + self._train_dataloader = wds.WebLoader( + self._train_dataset, + batch_size=None, + shuffle=False, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) + # add meta-data to dataloader instance for convenience + self._train_dataloader.num_batches = num_batches + self._train_dataloader.num_samples = num_samples + + @property + def train_dataset(self): + return self._train_dataset + + @property + def train_dataloader(self): + return self._train_dataloader + + +class Denoiser: + def __init__(self, alphas, sigmas, prediction_type="epsilon"): + self.alphas = alphas + self.sigmas = sigmas + self.prediction_type = prediction_type + + def to(self, device): + self.alphas = self.alphas.to(device) + self.sigmas = self.sigmas.to(device) + return self + + def __call__(self, model_output, timesteps, sample): + alphas = extract_into_tensor(self.alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(self.sigmas, timesteps, sample.shape) + if self.prediction_type == "epsilon": + pred_x_0 = (sample - sigmas * model_output) / alphas + elif self.prediction_type == "sample": + pred_x_0 = model_output + elif self.prediction_type == "v_prediction": + pred_x_0 = alphas * sample - sigmas * model_output + else: + raise ValueError( + f"Prediction type {self.prediction_type} is not supported; currently, `epsilon`, `sample`, and" + f" `v_prediction` are supported." + ) + + return pred_x_0 + + +# Based on SpectralConv1d from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L29 +class SpectralConv1d(torch.nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + torch.nn.utils.SpectralNorm.apply(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12) + + +# Based on ResidualBlock from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/shared.py#L20 +class ResidualBlock(torch.nn.Module): + def __init__(self, fn: Callable): + super().__init__() + self.fn = fn + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return (self.fn(x) + x) / np.sqrt(2) + + +# Based on make_block from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L64 +class DiscHeadBlock(torch.nn.Module): + """ + StyleGAN-T block: SpectralConv1d => BatchNormLocal => LeakyReLU + """ + def __init__( + self, + channels: int, + kernel_size: int, + num_groups: int = 8, + leaky_relu_neg_slope: float = 0.2, + ): + super().__init__() + self.channels = channels + + self.conv = SpectralConv1d( + channels, + channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + padding_mode="circular", + ) + self.batch_norm = torch.nn.GroupNorm(num_groups, channels) + self.act_fn = torch.nn.LeakyReLU(leaky_relu_neg_slope, inplace=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) + x = self.batch_norm(x) + x = self.act_fn(x) + return x + + +# Based on DiscHead in the official StyleGAN-T implementation +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L78 +# TODO: implement image conditioning (see Section 3.2 of paper) +class DiscriminatorHead(torch.nn.Module): + """ + Implements a StyleGAN-T-style discriminator head. + """ + def __init__( + self, + channels: int, + cond_embedding_dim: int, + cond_map_dim: int = 64, + ): + super().__init__() + self.channels = channels + self.cond_embedding_dim = cond_embedding_dim + self.cond_map_dim = cond_map_dim + + self.input_block = DiscHeadBlock(channels, kernel_size=1) + self.resblock = ResidualBlock(DiscHeadBlock(channels, kernel_size=9)) + + if self.cond_embedding_dim > 0: + self.conditioning_map = torch.nn.Linear(self.cond_embedding_dim, cond_map_dim) + self.cls = SpectralConv1d(channels, cond_map_dim, kernel_size=1, padding=0) + else: + self.cls = SpectralConv1d(channels, 1, kernel_size=1, padding=0) + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + hidden_states = self.input_block(x) + hidden_states = self.resblock(hidden_states) + out = self.cls(hidden_states) + + if self.cond_embedding_dim > 0: + c = self.conditioning_map(c).squeeze(-1) + out = (out * c).sum(1, keepdim=True) * (1 / np.sqrt(self.cond_map_dim)) + + return out + + +activations = {} + + +# Based on get_activation from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L111 +def get_activation(name: str) -> Callable: + def hook(model, input, output): + activations[name] = output + return hook + + +# Based on _resize_pos_embed from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L66 +def _resize_pos_embed(self, posemb: torch.Tensor, gs_h: int, gs_w: int) -> torch.Tensor: + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear", align_corners=False) + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +# Based on forward_flex from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L83 +def forward_flex(self, x: torch.Tensor) -> torch.Tensor: + # patch proj and dynamically resize + B, C, H, W = x.size() + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + pos_embed = self._resize_pos_embed( + self.pos_embed, H // self.patch_size[1], W // self.patch_size[0] + ) + + # add cls token + cls_tokens = self.cls_token.expand( + x.size(0), -1, -1 + ) + x = torch.cat((cls_tokens, x), dim=1) + + # forward pass + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + return x + + +# Based on forward_vit from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L60 +def forward_vit(pretrained: torch.nn.Module, x: torch.Tensor) -> dict: + _ = pretrained.model.forward_flex(x) + return {k: pretrained.rearrange(v) for k, v in activations.items()} + + +# Based on AddReadout from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L36 +class AddReadout(torch.nn.Module): + def __init__(self, start_index: int = 1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index:] + readout.unsqueeze(1) + + +# Based on Transpose from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L49 +class Transpose(torch.nn.Module): + def __init__(self, dim0: int, dim1: int): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.transpose(self.dim0, self.dim1) + return x.contiguous() + + +# Based on DINO from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L107 +class FeatureNetwork(torch.nn.Module): + """ + DINO ViT model to act as feature extractor for the discriminator. + """ + def __init__( + self, + pretrained_feature_network: str = "vit_small_patch16_224_dino", + patch_size: List[int] = [16, 16], + hooks: List[int] = [2,5,8,11], + start_index: int = 1, + ): + super().__init__() + self.num_hooks = len(hooks) + 1 + + pretrained_model = timm.create_model(pretrained_feature_network, pretrained=True) + + # Based on make_vit_backbone from the official StyleGAN-T code + # https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L117 + # which I believe is itself based on https://github.com/isl-org/DPT + model_with_hooks = torch.nn.Module() + model_with_hooks.model = pretrained_model + + # Add hooks + model_with_hooks.model.blocks[hooks[0]].register_forward_hook(get_activation('0')) + model_with_hooks.model.blocks[hooks[1]].register_forward_hook(get_activation('1')) + model_with_hooks.model.blocks[hooks[2]].register_forward_hook(get_activation('2')) + model_with_hooks.model.blocks[hooks[3]].register_forward_hook(get_activation('3')) + model_with_hooks.model.pos_drop.register_forward_hook(get_activation('4')) + + # Configure readout + model_with_hooks.rearrange = torch.nn.Sequential(AddReadout(start_index), Transpose(1, 2)) + model_with_hooks.model.start_index = start_index + model_with_hooks.model.patch_size = patch_size + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + model_with_hooks.model.forward_flex = types.MethodType(forward_flex, model_with_hooks.model) + model_with_hooks.model._resize_pos_embed = types.MethodType(_resize_pos_embed, model_with_hooks.model) + + self.model = model_with_hooks + # Freeze pretrained model with hooks + self.model = self.model.eval().requires_grad_(False) + + self.img_resolution = self.model.model.patch_embed.img_size[0] + self.embed_dim = self.model.model.embed_dim + self.norm = transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + + def forward(self, x: torch.Tensor): + """ + Forward pass consisting of interpolation, ImageNet normalization, and a forward pass of self.model. + + Args: + x (`torch.Tensor`): + Image with pixel values in [0, 1]. + + Returns: + `Dict[Any]`: dict of activations + """ + x = F.interpolate(x, self.img_resolution, mode="area") + x = self.norm(x) + + activation_dict = forward_vit(self.model, x) + return activation_dict + + +class DiscriminatorOutput(BaseOutput): + """ + Output class for the Discriminator module. + """ + logits: torch.FloatTensor + + +# Based on ProjectedDiscriminator from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L130 +# TODO: implement image conditioning (see Section 3.2 of paper) +class Discriminator(torch.nn.Module): + """ + StyleGAN-T-style discriminator for adversarial diffusion distillation (ADD). + """ + def __init__( + self, + pretrained_feature_network: str = "vit_small_patch16_224_dino", + cond_embedding_dim: int = 512, + patch_size: List[int] = [16, 16], + hooks: List[int] = [2,5,8,11], + start_index: int = 1, + ): + super().__init__() + self.cond_embedding_dim = cond_embedding_dim + + # Frozen feature network, e.g. DINO + self.feature_network = FeatureNetwork( + pretrained_feature_network=pretrained_feature_network, + patch_size=patch_size, + hooks=hooks, + start_index=start_index, + ) + + # Trainable discriminator heads + heads = [] + for i in range(self.feature_network.num_hooks): + heads += [str(i), DiscriminatorHead(self.feature_network.embed_dim, cond_embedding_dim)], + self.heads = torch.nn.ModuleDict(heads) + + def train(self, mode: bool = True): + self.feature_network = self.feature_network.train(False) + self.heads = self.heads.train(mode) + return self + + def eval(self): + return self.train(False) + + def forward( + self, + x: torch.Tensor, + c: torch.Tensor, + transform_positive = True, + return_dict: bool = True, + ): + # TODO: do we need the augmentations from the original StyleGAN-T code? + if transform_positive: + # Transform to [0, 1]. + x = x.add(1).div(2) + + # Forward pass through feature network. + features = self.feature_network(x) + + # Apply discriminator heads. + logits = [] + for k, head in self.heads.items(): + logits.append(head(features[k], c).view(x.size(0), -1)) + logits = torch.cat(logits, dim=1) + + if not return_dict: + return (logits,) + + return DiscriminatorOutput(logits=logits) + + +def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="student"): + logger.info("Running validation... ") + + unet = accelerator.unwrap_model(unet) + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_teacher_model, + vae=vae, + unet=unet, + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + validation_prompts = [ + "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography", + "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece", + ] + + image_logs = [] + + for _, prompt in enumerate(validation_prompts): + images = [] + with torch.autocast("cuda"): + images = pipeline( + prompt=prompt, + num_inference_steps=4, + num_images_per_prompt=4, + generator=generator, + ).images + image_logs.append({"validation_prompt": prompt, "images": images}) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + formatted_images = [] + for image in images: + formatted_images.append(np.asarray(image)) + + formatted_images = np.stack(formatted_images) + + tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") + elif tracker.name == "wandb": + formatted_images = [] + + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + for image in images: + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + tracker.log({f"validation/{name}": formatted_images}) + else: + logger.warn(f"image logging not implemented for {tracker.name}") + + del pipeline + gc.collect() + torch.cuda.empty_cache() + + return image_logs + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + return x[(...,) + (None,) * dims_to_append] + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +@torch.no_grad() +def update_ema(target_params, source_params, rate=0.99): + """ + Update target parameters to be closer to those of source parameters using + an exponential moving average. + + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src, alpha=1 - rate) + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + # ----------Model Checkpoint Loading Arguments---------- + parser.add_argument( + "--pretrained_teacher_model", + type=str, + default=None, + required=True, + help="Path to pretrained LDM teacher model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--teacher_revision", + type=str, + default=None, + required=False, + help="Revision of pretrained LDM teacher model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained LDM model identifier from huggingface.co/models.", + ) + # ----------Training Arguments---------- + # ----General Training Arguments---- + parser.add_argument( + "--output_dir", + type=str, + default="lcm-xl-distilled", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + # ----Logging---- + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + # ----Checkpointing---- + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + # ----Image Processing---- + parser.add_argument( + "--train_shards_path_or_url", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + # ----Dataloader---- + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + # ----Batch Size and Training Steps---- + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + # ----Learning Rate---- + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + # ----Optimizer (Adam)---- + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--discriminator_adam_beta1", type=float, default=0.0, help="The beta1 parameter for the Adam optimizer." + ) + parser.add_argument( + "--discriminator_adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer." + ) + parser.add_argument("--discriminator_adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument( + "--discriminator_adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer" + ) + # ----Diffusion Training Arguments---- + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + # ----Adversarial Diffusion Distillation (ADD) Specific Arguments---- + parser.add_argument( + "--pretrained_feature_network", + type=str, + default="vit_small_patch16_224_dino", + help=( + "The pretrained feature network used in the discriminator, typically a vision transformer (ViT) trained" + " the DINO objective. The given identifier should be compatible with `timm.create_model`." + ), + ) + parser.add_argument( + "--weight_schedule", + type=str, + default="exponential", + help=( + "The time-dependent weighting function gamma used for scaling the distillation loss Choose between" + " `uniform`, `exponential`, `sds`, and `nfsd`." + ), + ) + parser.add_argument( + "--student_distillation_steps", + type=int, + default=4, + help="The number of student timesteps N used during distillation.", + ) + parser.add_argument( + "--student_timestep_schedule", + type=str, + default="uniform", + help="The method by which the student timestep schedule is determined. Currently, only `uniform` is implemented.", + ) + parser.add_argument( + "--student_custom_timesteps", + type=str, + default=None, + help=( + "A comma-separated list of timesteps which will override the timestep schedule set in" + " `student_timestep_schedule` if set." + ), + ) + parser.add_argument( + "--discriminator_r1_strength", + type=float, + default=1e-05, + help="The discriminator R1 gradient penalty strength gamma.", + ) + parser.add_argument( + "--distillation_weight_factor", + type=float, + default=2.5, + help="Multiplicative weight factor lambda for the distillation loss on the student generator U-Net.", + ) + # ----Exponential Moving Average (EMA)---- + parser.add_argument( + "--ema_decay", + type=float, + default=0.95, + required=False, + help="The exponential moving average (EMA) rate or decay factor.", + ) + # ----Mixed Precision---- + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--cast_teacher_unet", + action="store_true", + help="Whether to cast the teacher U-Net to the precision specified by `--mixed_precision`.", + ) + # ----Training Optimizations---- + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + # ----Distributed Training---- + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + # ----------Validation Arguments---------- + parser.add_argument( + "--validation_steps", + type=int, + default=200, + help="Run validation every X steps.", + ) + # ----------Huggingface Hub Arguments----------- + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + # ----------Accelerate Arguments---------- + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + parser.add_argument( + "--train_shards_path_or_url", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") + + return args + + +# Adapted from pipelines.StableDiffusionPipeline.encode_prompt +def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True): + captions = [] + for caption in prompt_batch: + if random.random() < proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + + with torch.no_grad(): + text_inputs = tokenizer( + captions, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device))[0] + + return prompt_embeds + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + token=args.hub_token, + private=True, + ).repo_id + + # 1. Create the noise scheduler and the desired noise schedule. + # Enforce zero terminal SNR (see section 3.1 of ADD paper) + # TODO: is there a better way to implement this? + teacher_scheduler = DDIMScheduler.from_pretrained( + args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision + ) + if teacher_scheduler.config.rescale_betas_zero_snr == False: + teacher_scheduler.config["rescale_betas_zero_snr"] = True + noise_scheduler = DDIMScheduler(**teacher_scheduler.config) + + # DDIMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us + alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) + sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) + # denoiser gets predicted original sample x_0 from prediction_type using alpha and sigma noise schedules + denoiser = Denoiser(alpha_schedule, sigma_schedule) + + # Create time-dependent weighting schedule c(t) for scaling the GAN generator reconstruction loss term. + if args.weight_schedule == "uniform": + train_weight_schedule = torch.ones_like(noise_scheduler.alphas_cumprod) + elif args.weight_schedule == "exponential": + # Set gamma schedule equal to alpha_bar (`alphas_cumprod`) schedule. Higher timesteps have less weight. + train_weight_schedule = noise_scheduler.alphas_cumprod + elif args.weight_schedule == "sds": + # Score distillation sampling weighting + # Introduced in the DreamFusion paper: https://arxiv.org/pdf/2209.14988.pdf. + raise NotImplementedError("SDS distillation weighting is not yet implemented.") + elif args.weight_schedule == "nfsd": + # Noise-free score distillation weighting + # Introduced in "Noise-Free Score Distillation": https://arxiv.org/pdf/2310.17590.pdf. + raise NotImplementedError("NFSD distillation weighting is not yet implemented.") + else: + raise ValueError( + f"Weight schedule {args.weight_schedule} is not currently supported. Supported schedules are `uniform`," + f" `exponential`, `sds`, and `nfsd`." + ) + + # Create student timestep schedule tau_1, ..., tau_N. + if args.student_custom_timesteps is not None: + student_timestep_schedule = np.asarray( + sorted([int(timestep.strip()) for timestep in args.student_custom_timesteps.split(',')]) + ) + elif args.student_timestep_schedule == "uniform": + student_timestep_schedule = np.linspace( + 0, noise_scheduler.config.num_train_timesteps - 1, args.student_distillation_steps + ).round() + else: + raise ValueError( + f"Student timestep schedule {args.student_timestep_schedule} was not recognized and custom student" + f" timesteps have not been provided. Either use one of `uniform` for `student_timestep_schedule` or" + f" provide custom timesteps via `student_custom_timesteps`." + ) + student_distillation_steps = student_timestep_schedule.shape[0] + + # 2. Load tokenizers from SD 1.X/2.X checkpoint. + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False + ) + + # 3. Load text encoders from SD 1.X/2.X checkpoint. + # import correct text encoder classes + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision + ) + + # 4. Load VAE from SD 1.X/2.X checkpoint + vae = AutoencoderKL.from_pretrained( + args.pretrained_teacher_model, + subfolder="vae", + revision=args.teacher_revision, + ) + + # 5. Load teacher U-Net from SD 1.X/2.X checkpoint + teacher_unet = UNet2DConditionModel.from_pretrained( + args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision + ) + + # 6. Initialize GAN generator U-Net from SD 1.X/2.X checkpoint with the teacher U-Net's pretrained weights + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision + ) + + # 7. Initialize GAN discriminator. + discriminator = Discriminator( + pretrained_feature_network=args.pretrained_feature_network, + cond_embedding_dim=text_encoder.config.projection_dim, + ) + + # 8. Freeze teacher vae, text_encoder, and teacher_unet + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + teacher_unet.requires_grad_(False) + + unet.train() + + # Check that all trainable models are in full precision + low_precision_error_string = ( + " Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training, copy of the weights should still be float32." + ) + + if accelerator.unwrap_model(unet).dtype != torch.float32: + raise ValueError( + f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" + ) + + # 9. Handle mixed precision and device placement + # For mixed precision training we cast all non-trainable weigths to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move vae and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + vae.to(accelerator.device) + if args.pretrained_vae_model_name_or_path is not None: + vae.to(dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # Move teacher_unet to device, optionally cast to weight_dtype + teacher_unet.to(accelerator.device) + if args.cast_teacher_unet: + teacher_unet.to(dtype=weight_dtype) + + # Also move the denoiser and schedules to accelerator.device + denoiser.to(accelerator.device) + train_weight_schedule = train_weight_schedule.to(accelerator.device) + student_timestep_schedule = torch.from_numpy(student_timestep_schedule).to(accelerator.device) + + # 10. Handle saving and loading of checkpoints + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + # target_unet.save_pretrained(os.path.join(output_dir, "unet_target")) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "unet")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + load_model = UNet2DConditionModel.from_pretrained(os.path.join(input_dir, "unet_target")) + # target_unet.load_state_dict(load_model.state_dict()) + # target_unet.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # 11. Enable optimizations + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # 12. Optimizer creation for generator and discriminator + optimizer = optimizer_class( + unet.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + discriminator_optimizer = optimizer_class( + discriminator.parameters(), + lr=args.discriminator_learning_rate, + betas=(args.discriminator_adam_beta1, args.discriminator_adam_beta2), + weight_decay=args.discriminator_adam_weight_decay, + eps=args.discriminator_adam_epsilon, + ) + + # 13. Dataset creation and data processing + # Here, we compute not just the text embeddings but also the additional embeddings + # needed for the SD XL UNet to operate. + def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True): + prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train) + return {"prompt_embeds": prompt_embeds} + + dataset = SDText2ImageDataset( + train_shards_path_or_url=args.train_shards_path_or_url, + num_train_examples=args.max_train_samples, + per_gpu_batch_size=args.train_batch_size, + global_batch_size=args.train_batch_size * accelerator.num_processes, + num_workers=args.dataloader_num_workers, + resolution=args.resolution, + shuffle_buffer_size=1000, + pin_memory=True, + persistent_workers=True, + ) + train_dataloader = dataset.train_dataloader + + compute_embeddings_fn = functools.partial( + compute_embeddings, + proportion_empty_prompts=0, + text_encoder=text_encoder, + tokenizer=tokenizer, + ) + + # 14. Create learning rate scheduler for generator and discriminator + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps, + ) + + discriminator_lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=discriminator_optimizer, + num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps, + ) + + # 15. Prepare for training + # Prepare everything with our `accelerator`. + ( + unet, + discriminator, + optimizer, + discriminator_optimizer, + lr_scheduler, + discriminator_lr_scheduler, + ) = accelerator.prepare( + unet, + discriminator, + optimizer, + discriminator_optimizer, + lr_scheduler, + discriminator_lr_scheduler, + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + accelerator.init_trackers(args.tracker_project_name, config=tracker_config) + + # 16. Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num batches each epoch = {train_dataloader.num_batches}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + # 1. Load and process the image and text conditioning + image, text = batch + + image = image.to(accelerator.device, non_blocking=True) + encoded_text = compute_embeddings_fn(text) + + pixel_values = image.to(dtype=weight_dtype) + if vae.dtype != weight_dtype: + vae.to(dtype=weight_dtype) + + # encode pixel values with batch size of at most 32 + latents = [] + for i in range(0, pixel_values.shape[0], 32): + latents.append(vae.encode(pixel_values[i : i + 32]).latent_dist.sample()) + latents = torch.cat(latents, dim=0) + + latents = latents * vae.config.scaling_factor + latents = latents.to(weight_dtype) + bsz = latents.shape[0] + + # 2. Sample random student timesteps s uniformly in `student_timestep_schedule` and sample random + # teacher timesteps t uniformly in [0, ..., noise_scheduler.config.num_train_timesteps - 1]. + student_index = torch.randint(0, student_distillation_steps, (bsz,), device=latents.device).long() + student_timesteps = student_timestep_schedule[student_index] + teacher_timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ).long() + + # 3. Sample noise and add it to the latents according to the noise magnitude at each student timestep + # (that is, run the forward process on the student model) + student_noise = torch.randn_like(latents) + noisy_student_input = noise_scheduler.add_noise(latents, student_noise, student_timesteps) + + # 4. Prepare prompt embeds and unet_added_conditions + prompt_embeds = encoded_text.pop("prompt_embeds") + + # 5. Get the student model predicted original sample `student_x_0`. + student_noise_pred = unet( + noisy_student_input, + student_timesteps, + encoder_hidden_states=prompt_embeds.float(), + ).sample + student_x_0 = denoiser(student_noise_pred, student_timesteps, noisy_student_input) + + # 6. Sample noise and add it to the student's predicted original sample according to the noise + # magnitude at each teacher timestep (that is, run the forward process on the teacher model, but + # using `student_x_0` instead of latents sampled from the prior). + teacher_noise = torch.randn_like(student_x_0) + noisy_teacher_input = noise_scheduler.add_noise(student_x_0, teacher_noise, teacher_timesteps) + + # 7. Get teacher model predicted original sample `teacher_x_0`. + with torch.no_grad(): + with torch.autocast("cuda"): + teacher_noise_pred = teacher_unet( + noisy_teacher_input.detach(), + teacher_timesteps, + encoder_hidden_states=prompt_embeds.float(), + ).sample + teacher_x_0 = denoiser(teacher_noise_pred, teacher_timesteps, noisy_teacher_input) + + ############################ + # 8. Discriminator Loss + ############################ + discriminator_optimizer.zero_grad(set_to_none=True) + + # 1. Decode real and fake (generated) latents back to pixel space. + # NOTE: the paper doesn't mention this explicitly AFAIK but I think this makes sense since the + # pretrained feature network for the discriminator operates in pixel space rather than latent space. + real_image = vae.decode(latents).sample + student_gen_image = vae.decode(student_x_0).sample + + # 2. Get discriminator real/fake outputs on the real and fake (generated) images respectively. + disc_output_real = discriminator(real_image, prompt_embeds) + disc_output_fake = discriminator(student_gen_image.detach(), prompt_embeds) + + # 3. Calculate the discriminator real adversarial loss terms. + d_logits_real = disc_output_real.logits + # Use hinge loss (see section 3.2, Equation 3 of paper) + d_adv_loss_real = torch.mean(F.relu(torch.ones_like(d_logits_real) - d_logits_real)) + + # 4. Calculate the discriminator R1 gradient penalty term with respect to the gradients from the real + # data. + d_r1_regularizer = 0 + for k, head in discriminator.heads.items(): + head_grad_params = torch.autograd.grad( + outputs=d_adv_loss_real, inputs=head.params(), create_graph=True + ) + head_grad_norm = 0 + for grad in head_grad_params: + head_grad_norm += grad.abs().sum() + d_r1_regularizer += head_grad_norm + + d_loss_real = d_adv_loss_real + args.discriminator_r1_strength * d_r1_regularizer + accelerator.backward(d_loss_real, retain_graph=True) + + # 5. Calculate the discriminator fake adversarial loss terms. + d_logits_fake = disc_output_fake.logits + # Use hinge loss (see section 3.2, Equation 3 of paper) + d_adv_loss_fake = torch.mean(F.relu(torch.ones_like(d_logits_fake) + d_logits_fake)) + accelerator.backward(d_adv_loss_fake) + + d_total_loss = d_loss_real + d_adv_loss_fake + + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(discriminator.parameters(), args.max_grad_norm) + discriminator_optimizer.step() + discriminator_lr_scheduler.step() + + ############################ + # 9. Generator Loss + ############################ + optimizer.zero_grad(set_to_none=True) + + # 1. Rerun the disc on generated image, but this time allow gradients to flow through the generator + disc_output_fake = discriminator(student_gen_image, prompt_embeds) + + # 2. Calculate generator adversarial loss term + g_logits_fake = disc_output_fake.logits + g_adv_loss = torch.mean(-g_logits_fake) + + ############################ + # 10. Distillation Loss + ############################ + # Calculate distillation loss in pixel space rather than latent space (see section 3.1) + teacher_gen_image = vae.decode(teacher_x_0).sample + per_instance_distillation_loss = F.mse_loss( + student_gen_image.float(), teacher_gen_image.float(), reduction="none" + ) + # Note that we use the teacher timesteps t when getting the loss weights. + c_t = extract_into_tensor( + train_weight_schedule, teacher_timesteps, per_instance_distillation_loss.shape + ) + g_distillation_loss = torch.mean(c_t * per_instance_distillation_loss) + + g_total_loss = g_adv_loss + args.distillation_weight_factor * g_distillation_loss + + # Backprop on the generator total loss + accelerator.backward(g_total_loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if global_step % args.validation_steps == 0: + # log_validation(vae, target_unet, args, accelerator, weight_dtype, global_step, "target") + log_validation(vae, unet, args, accelerator, weight_dtype, global_step, "student") + + logs = { + "d_total_loss": d_total_loss.detach().item(), + "g_total_loss": g_total_loss.detach().item(), + "g_adv_loss": g_adv_loss.detach().item(), + "g_distill_loss": g_distillation_loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + } + progress_bar.set_postfix(**logs) + + # Write out additional values for accelerator to report. + logs["d_adv_loss_fake"] = d_adv_loss_fake.detach().item() + logs["d_adv_loss_real"] = d_adv_loss_real.detach().item() + logs["d_r1_regularizer"] = d_r1_regularizer.detach().item() + logs["d_loss_real"] = d_loss_real.detach().item() + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) + unet.save_pretrained(os.path.join(args.output_dir, "unet")) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) From d2316dac36671a71b8029323e6f0a8ce496382f7 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 23 Dec 2023 15:18:00 -0800 Subject: [PATCH 02/53] make style --- examples/add/train_add_distill_sd_wds.py | 68 ++++++++++++------------ 1 file changed, 35 insertions(+), 33 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 0f49b271dd19..b034ed862781 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -29,11 +29,11 @@ import accelerate import numpy as np +import timm import torch import torch.nn.functional as F import torch.utils.checkpoint import torchvision.transforms.functional as TF -import timm import transformers import webdataset as wds from accelerate import Accelerator @@ -221,12 +221,12 @@ def __init__(self, alphas, sigmas, prediction_type="epsilon"): self.alphas = alphas self.sigmas = sigmas self.prediction_type = prediction_type - + def to(self, device): self.alphas = self.alphas.to(device) self.sigmas = self.sigmas.to(device) return self - + def __call__(self, model_output, timesteps, sample): alphas = extract_into_tensor(self.alphas, timesteps, sample.shape) sigmas = extract_into_tensor(self.sigmas, timesteps, sample.shape) @@ -250,7 +250,7 @@ def __call__(self, model_output, timesteps, sample): class SpectralConv1d(torch.nn.Conv1d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - torch.nn.utils.SpectralNorm.apply(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12) + torch.nn.utils.SpectralNorm.apply(self, name="weight", n_power_iterations=1, dim=0, eps=1e-12) # Based on ResidualBlock from the official StyleGAN-T code @@ -270,6 +270,7 @@ class DiscHeadBlock(torch.nn.Module): """ StyleGAN-T block: SpectralConv1d => BatchNormLocal => LeakyReLU """ + def __init__( self, channels: int, @@ -289,7 +290,7 @@ def __init__( ) self.batch_norm = torch.nn.GroupNorm(num_groups, channels) self.act_fn = torch.nn.LeakyReLU(leaky_relu_neg_slope, inplace=True) - + def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) x = self.batch_norm(x) @@ -304,6 +305,7 @@ class DiscriminatorHead(torch.nn.Module): """ Implements a StyleGAN-T-style discriminator head. """ + def __init__( self, channels: int, @@ -323,7 +325,7 @@ def __init__( self.cls = SpectralConv1d(channels, cond_map_dim, kernel_size=1, padding=0) else: self.cls = SpectralConv1d(channels, 1, kernel_size=1, padding=0) - + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: hidden_states = self.input_block(x) hidden_states = self.resblock(hidden_states) @@ -332,7 +334,7 @@ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: if self.cond_embedding_dim > 0: c = self.conditioning_map(c).squeeze(-1) out = (out * c).sum(1, keepdim=True) * (1 / np.sqrt(self.cond_map_dim)) - + return out @@ -344,6 +346,7 @@ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: def get_activation(name: str) -> Callable: def hook(model, input, output): activations[name] = output + return hook @@ -372,14 +375,10 @@ def forward_flex(self, x: torch.Tensor) -> torch.Tensor: # patch proj and dynamically resize B, C, H, W = x.size() x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) - pos_embed = self._resize_pos_embed( - self.pos_embed, H // self.patch_size[1], W // self.patch_size[0] - ) + pos_embed = self._resize_pos_embed(self.pos_embed, H // self.patch_size[1], W // self.patch_size[0]) # add cls token - cls_tokens = self.cls_token.expand( - x.size(0), -1, -1 - ) + cls_tokens = self.cls_token.expand(x.size(0), -1, -1) x = torch.cat((cls_tokens, x), dim=1) # forward pass @@ -412,7 +411,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: readout = (x[:, 0] + x[:, 1]) / 2 else: readout = x[:, 0] - return x[:, self.start_index:] + readout.unsqueeze(1) + return x[:, self.start_index :] + readout.unsqueeze(1) # Based on Transpose from the official StyleGAN-T code @@ -434,16 +433,17 @@ class FeatureNetwork(torch.nn.Module): """ DINO ViT model to act as feature extractor for the discriminator. """ + def __init__( self, pretrained_feature_network: str = "vit_small_patch16_224_dino", patch_size: List[int] = [16, 16], - hooks: List[int] = [2,5,8,11], + hooks: List[int] = [2, 5, 8, 11], start_index: int = 1, ): super().__init__() self.num_hooks = len(hooks) + 1 - + pretrained_model = timm.create_model(pretrained_feature_network, pretrained=True) # Based on make_vit_backbone from the official StyleGAN-T code @@ -453,11 +453,11 @@ def __init__( model_with_hooks.model = pretrained_model # Add hooks - model_with_hooks.model.blocks[hooks[0]].register_forward_hook(get_activation('0')) - model_with_hooks.model.blocks[hooks[1]].register_forward_hook(get_activation('1')) - model_with_hooks.model.blocks[hooks[2]].register_forward_hook(get_activation('2')) - model_with_hooks.model.blocks[hooks[3]].register_forward_hook(get_activation('3')) - model_with_hooks.model.pos_drop.register_forward_hook(get_activation('4')) + model_with_hooks.model.blocks[hooks[0]].register_forward_hook(get_activation("0")) + model_with_hooks.model.blocks[hooks[1]].register_forward_hook(get_activation("1")) + model_with_hooks.model.blocks[hooks[2]].register_forward_hook(get_activation("2")) + model_with_hooks.model.blocks[hooks[3]].register_forward_hook(get_activation("3")) + model_with_hooks.model.pos_drop.register_forward_hook(get_activation("4")) # Configure readout model_with_hooks.rearrange = torch.nn.Sequential(AddReadout(start_index), Transpose(1, 2)) @@ -476,7 +476,7 @@ def __init__( self.img_resolution = self.model.model.patch_embed.img_size[0] self.embed_dim = self.model.model.embed_dim self.norm = transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) - + def forward(self, x: torch.Tensor): """ Forward pass consisting of interpolation, ImageNet normalization, and a forward pass of self.model. @@ -484,13 +484,13 @@ def forward(self, x: torch.Tensor): Args: x (`torch.Tensor`): Image with pixel values in [0, 1]. - + Returns: `Dict[Any]`: dict of activations """ x = F.interpolate(x, self.img_resolution, mode="area") x = self.norm(x) - + activation_dict = forward_vit(self.model, x) return activation_dict @@ -499,6 +499,7 @@ class DiscriminatorOutput(BaseOutput): """ Output class for the Discriminator module. """ + logits: torch.FloatTensor @@ -509,17 +510,18 @@ class Discriminator(torch.nn.Module): """ StyleGAN-T-style discriminator for adversarial diffusion distillation (ADD). """ + def __init__( self, pretrained_feature_network: str = "vit_small_patch16_224_dino", cond_embedding_dim: int = 512, patch_size: List[int] = [16, 16], - hooks: List[int] = [2,5,8,11], + hooks: List[int] = [2, 5, 8, 11], start_index: int = 1, ): super().__init__() self.cond_embedding_dim = cond_embedding_dim - + # Frozen feature network, e.g. DINO self.feature_network = FeatureNetwork( pretrained_feature_network=pretrained_feature_network, @@ -531,9 +533,9 @@ def __init__( # Trainable discriminator heads heads = [] for i in range(self.feature_network.num_hooks): - heads += [str(i), DiscriminatorHead(self.feature_network.embed_dim, cond_embedding_dim)], + heads += [str(i), DiscriminatorHead(self.feature_network.embed_dim, cond_embedding_dim)] self.heads = torch.nn.ModuleDict(heads) - + def train(self, mode: bool = True): self.feature_network = self.feature_network.train(False) self.heads = self.heads.train(mode) @@ -541,12 +543,12 @@ def train(self, mode: bool = True): def eval(self): return self.train(False) - + def forward( self, x: torch.Tensor, c: torch.Tensor, - transform_positive = True, + transform_positive=True, return_dict: bool = True, ): # TODO: do we need the augmentations from the original StyleGAN-T code? @@ -1121,7 +1123,7 @@ def main(args): teacher_scheduler = DDIMScheduler.from_pretrained( args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) - if teacher_scheduler.config.rescale_betas_zero_snr == False: + if not teacher_scheduler.config.rescale_betas_zero_snr: teacher_scheduler.config["rescale_betas_zero_snr"] = True noise_scheduler = DDIMScheduler(**teacher_scheduler.config) @@ -1150,11 +1152,11 @@ def main(args): f"Weight schedule {args.weight_schedule} is not currently supported. Supported schedules are `uniform`," f" `exponential`, `sds`, and `nfsd`." ) - + # Create student timestep schedule tau_1, ..., tau_N. if args.student_custom_timesteps is not None: student_timestep_schedule = np.asarray( - sorted([int(timestep.strip()) for timestep in args.student_custom_timesteps.split(',')]) + sorted([int(timestep.strip()) for timestep in args.student_custom_timesteps.split(",")]) ) elif args.student_timestep_schedule == "uniform": student_timestep_schedule = np.linspace( From ca90b96dbe9e5068693b49ffd2a5713ef36d67fa Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 23 Dec 2023 17:49:10 -0800 Subject: [PATCH 03/53] Fix bug where train_shards_path_or_url arg is duplicated. --- examples/add/train_add_distill_sd_wds.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index b034ed862781..9cd571490334 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -1025,16 +1025,6 @@ def parse_args(): " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" ), ) - parser.add_argument( - "--train_shards_path_or_url", - type=str, - default=None, - help=( - "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," - " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," - " or to a folder containing files that 🤗 Datasets can understand." - ), - ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) From 29706fd78f3ff9aeeaff4777749ea475e8e50783 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sun, 24 Dec 2023 09:15:25 -0800 Subject: [PATCH 04/53] Add SD-XL version of ADD script. --- examples/add/train_add_distill_sdxl_wds.py | 1774 ++++++++++++++++++++ 1 file changed, 1774 insertions(+) create mode 100644 examples/add/train_add_distill_sdxl_wds.py diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py new file mode 100644 index 000000000000..f25b0217f909 --- /dev/null +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -0,0 +1,1774 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import copy +import functools +import gc +import itertools +import json +import logging +import math +import os +import random +import shutil +import types +from pathlib import Path +from typing import Callable, List, Union + +import accelerate +import numpy as np +import timm +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import torchvision.transforms.functional as TF +import transformers +import webdataset as wds +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from braceexpand import braceexpand +from huggingface_hub import create_repo +from packaging import version +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torch.utils.data import default_collate +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig +from webdataset.tariterators import ( + base_plus_ext, + tar_file_expander, + url_opener, + valid_sample, +) + +import diffusers +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + StableDiffusionXLPipeline, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from diffusers.utils import BaseOutput, check_min_version, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available + + +MAX_SEQ_LENGTH = 77 + +# Adjust for your dataset +WDS_JSON_WIDTH = "width" # original_width for LAION +WDS_JSON_HEIGHT = "height" # original_height for LAION +MIN_SIZE = 700 # ~960 for LAION, ideal: 1024 if the dataset contains large images + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.25.0.dev0") + +logger = get_logger(__name__) + + +def filter_keys(key_set): + def _f(dictionary): + return {k: v for k, v in dictionary.items() if k in key_set} + + return _f + + +def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): + """Return function over iterator that groups key, value pairs into samples. + + :param keys: function that splits the key into key and extension (base_plus_ext) :param lcase: convert suffixes to + lower case (Default value = True) + """ + current_sample = None + for filesample in data: + assert isinstance(filesample, dict) + fname, value = filesample["fname"], filesample["data"] + prefix, suffix = keys(fname) + if prefix is None: + continue + if lcase: + suffix = suffix.lower() + # FIXME webdataset version throws if suffix in current_sample, but we have a potential for + # this happening in the current LAION400m dataset if a tar ends with same prefix as the next + # begins, rare, but can happen since prefix aren't unique across tar files in that dataset + if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample: + if valid_sample(current_sample): + yield current_sample + current_sample = {"__key__": prefix, "__url__": filesample["__url__"]} + if suffixes is None or suffix in suffixes: + current_sample[suffix] = value + if valid_sample(current_sample): + yield current_sample + + +def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue): + # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw + streams = url_opener(src, handler=handler) + files = tar_file_expander(streams, handler=handler) + samples = group_by_keys_nothrow(files, handler=handler) + return samples + + +class WebdatasetFilter: + def __init__(self, min_size=1024, max_pwatermark=0.5): + self.min_size = min_size + self.max_pwatermark = max_pwatermark + + def __call__(self, x): + try: + if "json" in x: + x_json = json.loads(x["json"]) + filter_size = (x_json.get(WDS_JSON_WIDTH, 0.0) or 0.0) >= self.min_size and x_json.get( + WDS_JSON_HEIGHT, 0 + ) >= self.min_size + filter_watermark = (x_json.get("pwatermark", 0.0) or 0.0) <= self.max_pwatermark + return filter_size and filter_watermark + else: + return False + except Exception: + return False + + +class SDXLText2ImageDataset: + def __init__( + self, + train_shards_path_or_url: Union[str, List[str]], + num_train_examples: int, + per_gpu_batch_size: int, + global_batch_size: int, + num_workers: int, + resolution: int = 1024, + shuffle_buffer_size: int = 1000, + pin_memory: bool = False, + persistent_workers: bool = False, + use_fix_crop_and_size: bool = False, + ): + if not isinstance(train_shards_path_or_url, str): + train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url] + # flatten list using itertools + train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url)) + + def get_orig_size(json): + if use_fix_crop_and_size: + return (resolution, resolution) + else: + return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0))) + + def transform(example): + # resize image + image = example["image"] + image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) + + # get crop coordinates and crop image + c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) + image = TF.crop(image, c_top, c_left, resolution, resolution) + image = TF.to_tensor(image) + image = TF.normalize(image, [0.5], [0.5]) + + example["image"] = image + example["crop_coords"] = (c_top, c_left) if not use_fix_crop_and_size else (0, 0) + return example + + processing_pipeline = [ + wds.decode("pil", handler=wds.ignore_and_continue), + wds.rename( + image="jpg;png;jpeg;webp", text="text;txt;caption", orig_size="json", handler=wds.warn_and_continue + ), + wds.map(filter_keys({"image", "text", "orig_size"})), + wds.map_dict(orig_size=get_orig_size), + wds.map(transform), + wds.to_tuple("image", "text", "orig_size", "crop_coords"), + ] + + # Create train dataset and loader + pipeline = [ + wds.ResampledShards(train_shards_path_or_url), + tarfile_to_samples_nothrow, + wds.select(WebdatasetFilter(min_size=MIN_SIZE)), + wds.shuffle(shuffle_buffer_size), + *processing_pipeline, + wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate), + ] + + num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + + # each worker is iterating over this + self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches) + self._train_dataloader = wds.WebLoader( + self._train_dataset, + batch_size=None, + shuffle=False, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) + # add meta-data to dataloader instance for convenience + self._train_dataloader.num_batches = num_batches + self._train_dataloader.num_samples = num_samples + + @property + def train_dataset(self): + return self._train_dataset + + @property + def train_dataloader(self): + return self._train_dataloader + + +class Denoiser: + def __init__(self, alphas, sigmas, prediction_type="epsilon"): + self.alphas = alphas + self.sigmas = sigmas + self.prediction_type = prediction_type + + def to(self, device): + self.alphas = self.alphas.to(device) + self.sigmas = self.sigmas.to(device) + return self + + def __call__(self, model_output, timesteps, sample): + alphas = extract_into_tensor(self.alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(self.sigmas, timesteps, sample.shape) + if self.prediction_type == "epsilon": + pred_x_0 = (sample - sigmas * model_output) / alphas + elif self.prediction_type == "sample": + pred_x_0 = model_output + elif self.prediction_type == "v_prediction": + pred_x_0 = alphas * sample - sigmas * model_output + else: + raise ValueError( + f"Prediction type {self.prediction_type} is not supported; currently, `epsilon`, `sample`, and" + f" `v_prediction` are supported." + ) + + return pred_x_0 + + +# Based on SpectralConv1d from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L29 +class SpectralConv1d(torch.nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + torch.nn.utils.SpectralNorm.apply(self, name="weight", n_power_iterations=1, dim=0, eps=1e-12) + + +# Based on ResidualBlock from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/shared.py#L20 +class ResidualBlock(torch.nn.Module): + def __init__(self, fn: Callable): + super().__init__() + self.fn = fn + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return (self.fn(x) + x) / np.sqrt(2) + + +# Based on make_block from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L64 +class DiscHeadBlock(torch.nn.Module): + """ + StyleGAN-T block: SpectralConv1d => BatchNormLocal => LeakyReLU + """ + + def __init__( + self, + channels: int, + kernel_size: int, + num_groups: int = 8, + leaky_relu_neg_slope: float = 0.2, + ): + super().__init__() + self.channels = channels + + self.conv = SpectralConv1d( + channels, + channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + padding_mode="circular", + ) + self.batch_norm = torch.nn.GroupNorm(num_groups, channels) + self.act_fn = torch.nn.LeakyReLU(leaky_relu_neg_slope, inplace=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) + x = self.batch_norm(x) + x = self.act_fn(x) + return x + + +# Based on DiscHead in the official StyleGAN-T implementation +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L78 +# TODO: implement image conditioning (see Section 3.2 of paper) +class DiscriminatorHead(torch.nn.Module): + """ + Implements a StyleGAN-T-style discriminator head. + """ + + def __init__( + self, + channels: int, + cond_embedding_dim: int, + cond_map_dim: int = 64, + ): + super().__init__() + self.channels = channels + self.cond_embedding_dim = cond_embedding_dim + self.cond_map_dim = cond_map_dim + + self.input_block = DiscHeadBlock(channels, kernel_size=1) + self.resblock = ResidualBlock(DiscHeadBlock(channels, kernel_size=9)) + + if self.cond_embedding_dim > 0: + self.conditioning_map = torch.nn.Linear(self.cond_embedding_dim, cond_map_dim) + self.cls = SpectralConv1d(channels, cond_map_dim, kernel_size=1, padding=0) + else: + self.cls = SpectralConv1d(channels, 1, kernel_size=1, padding=0) + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + hidden_states = self.input_block(x) + hidden_states = self.resblock(hidden_states) + out = self.cls(hidden_states) + + if self.cond_embedding_dim > 0: + c = self.conditioning_map(c).squeeze(-1) + out = (out * c).sum(1, keepdim=True) * (1 / np.sqrt(self.cond_map_dim)) + + return out + + +activations = {} + + +# Based on get_activation from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L111 +def get_activation(name: str) -> Callable: + def hook(model, input, output): + activations[name] = output + + return hook + + +# Based on _resize_pos_embed from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L66 +def _resize_pos_embed(self, posemb: torch.Tensor, gs_h: int, gs_w: int) -> torch.Tensor: + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear", align_corners=False) + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +# Based on forward_flex from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L83 +def forward_flex(self, x: torch.Tensor) -> torch.Tensor: + # patch proj and dynamically resize + B, C, H, W = x.size() + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + pos_embed = self._resize_pos_embed(self.pos_embed, H // self.patch_size[1], W // self.patch_size[0]) + + # add cls token + cls_tokens = self.cls_token.expand(x.size(0), -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # forward pass + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + return x + + +# Based on forward_vit from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L60 +def forward_vit(pretrained: torch.nn.Module, x: torch.Tensor) -> dict: + _ = pretrained.model.forward_flex(x) + return {k: pretrained.rearrange(v) for k, v in activations.items()} + + +# Based on AddReadout from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L36 +class AddReadout(torch.nn.Module): + def __init__(self, start_index: int = 1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +# Based on Transpose from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L49 +class Transpose(torch.nn.Module): + def __init__(self, dim0: int, dim1: int): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.transpose(self.dim0, self.dim1) + return x.contiguous() + + +# Based on DINO from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L107 +class FeatureNetwork(torch.nn.Module): + """ + DINO ViT model to act as feature extractor for the discriminator. + """ + + def __init__( + self, + pretrained_feature_network: str = "vit_small_patch16_224_dino", + patch_size: List[int] = [16, 16], + hooks: List[int] = [2, 5, 8, 11], + start_index: int = 1, + ): + super().__init__() + self.num_hooks = len(hooks) + 1 + + pretrained_model = timm.create_model(pretrained_feature_network, pretrained=True) + + # Based on make_vit_backbone from the official StyleGAN-T code + # https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L117 + # which I believe is itself based on https://github.com/isl-org/DPT + model_with_hooks = torch.nn.Module() + model_with_hooks.model = pretrained_model + + # Add hooks + model_with_hooks.model.blocks[hooks[0]].register_forward_hook(get_activation("0")) + model_with_hooks.model.blocks[hooks[1]].register_forward_hook(get_activation("1")) + model_with_hooks.model.blocks[hooks[2]].register_forward_hook(get_activation("2")) + model_with_hooks.model.blocks[hooks[3]].register_forward_hook(get_activation("3")) + model_with_hooks.model.pos_drop.register_forward_hook(get_activation("4")) + + # Configure readout + model_with_hooks.rearrange = torch.nn.Sequential(AddReadout(start_index), Transpose(1, 2)) + model_with_hooks.model.start_index = start_index + model_with_hooks.model.patch_size = patch_size + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + model_with_hooks.model.forward_flex = types.MethodType(forward_flex, model_with_hooks.model) + model_with_hooks.model._resize_pos_embed = types.MethodType(_resize_pos_embed, model_with_hooks.model) + + self.model = model_with_hooks + # Freeze pretrained model with hooks + self.model = self.model.eval().requires_grad_(False) + + self.img_resolution = self.model.model.patch_embed.img_size[0] + self.embed_dim = self.model.model.embed_dim + self.norm = transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + + def forward(self, x: torch.Tensor): + """ + Forward pass consisting of interpolation, ImageNet normalization, and a forward pass of self.model. + + Args: + x (`torch.Tensor`): + Image with pixel values in [0, 1]. + + Returns: + `Dict[Any]`: dict of activations + """ + x = F.interpolate(x, self.img_resolution, mode="area") + x = self.norm(x) + + activation_dict = forward_vit(self.model, x) + return activation_dict + + +class DiscriminatorOutput(BaseOutput): + """ + Output class for the Discriminator module. + """ + + logits: torch.FloatTensor + + +# Based on ProjectedDiscriminator from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L130 +# TODO: implement image conditioning (see Section 3.2 of paper) +class Discriminator(torch.nn.Module): + """ + StyleGAN-T-style discriminator for adversarial diffusion distillation (ADD). + """ + + def __init__( + self, + pretrained_feature_network: str = "vit_small_patch16_224_dino", + cond_embedding_dim: int = 512, + patch_size: List[int] = [16, 16], + hooks: List[int] = [2, 5, 8, 11], + start_index: int = 1, + ): + super().__init__() + self.cond_embedding_dim = cond_embedding_dim + + # Frozen feature network, e.g. DINO + self.feature_network = FeatureNetwork( + pretrained_feature_network=pretrained_feature_network, + patch_size=patch_size, + hooks=hooks, + start_index=start_index, + ) + + # Trainable discriminator heads + heads = [] + for i in range(self.feature_network.num_hooks): + heads += [str(i), DiscriminatorHead(self.feature_network.embed_dim, cond_embedding_dim)] + self.heads = torch.nn.ModuleDict(heads) + + def train(self, mode: bool = True): + self.feature_network = self.feature_network.train(False) + self.heads = self.heads.train(mode) + return self + + def eval(self): + return self.train(False) + + def forward( + self, + x: torch.Tensor, + c: torch.Tensor, + transform_positive=True, + return_dict: bool = True, + ): + # TODO: do we need the augmentations from the original StyleGAN-T code? + if transform_positive: + # Transform to [0, 1]. + x = x.add(1).div(2) + + # Forward pass through feature network. + features = self.feature_network(x) + + # Apply discriminator heads. + logits = [] + for k, head in self.heads.items(): + logits.append(head(features[k], c).view(x.size(0), -1)) + logits = torch.cat(logits, dim=1) + + if not return_dict: + return (logits,) + + return DiscriminatorOutput(logits=logits) + + +def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="student"): + logger.info("Running validation... ") + + unet = accelerator.unwrap_model(unet) + pipeline = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_teacher_model, + vae=vae, + unet=unet, + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + validation_prompts = [ + "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography", + "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", + "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece", + ] + + image_logs = [] + + for _, prompt in enumerate(validation_prompts): + images = [] + with torch.autocast("cuda"): + images = pipeline( + prompt=prompt, + num_inference_steps=4, + num_images_per_prompt=4, + generator=generator, + ).images + image_logs.append({"validation_prompt": prompt, "images": images}) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + formatted_images = [] + for image in images: + formatted_images.append(np.asarray(image)) + + formatted_images = np.stack(formatted_images) + + tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") + elif tracker.name == "wandb": + formatted_images = [] + + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + for image in images: + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + tracker.log({f"validation/{name}": formatted_images}) + else: + logger.warn(f"image logging not implemented for {tracker.name}") + + del pipeline + gc.collect() + torch.cuda.empty_cache() + + return image_logs + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + return x[(...,) + (None,) * dims_to_append] + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +@torch.no_grad() +def update_ema(target_params, source_params, rate=0.99): + """ + Update target parameters to be closer to those of source parameters using + an exponential moving average. + + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src, alpha=1 - rate) + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + # ----------Model Checkpoint Loading Arguments---------- + parser.add_argument( + "--pretrained_teacher_model", + type=str, + default=None, + required=True, + help="Path to pretrained LDM teacher model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--teacher_revision", + type=str, + default=None, + required=False, + help="Revision of pretrained LDM teacher model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained LDM model identifier from huggingface.co/models.", + ) + # ----------Training Arguments---------- + # ----General Training Arguments---- + parser.add_argument( + "--output_dir", + type=str, + default="lcm-xl-distilled", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + # ----Logging---- + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + # ----Checkpointing---- + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + # ----Image Processing---- + parser.add_argument( + "--train_shards_path_or_url", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--resolution", + type=int, + default=1024, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--use_fix_crop_and_size", + action="store_true", + help="Whether or not to use the fixed crop and size for the teacher model.", + default=False, + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + # ----Dataloader---- + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + # ----Batch Size and Training Steps---- + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + # ----Learning Rate---- + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--discriminator_learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + # ----Optimizer (Adam)---- + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--discriminator_adam_beta1", type=float, default=0.0, help="The beta1 parameter for the Adam optimizer." + ) + parser.add_argument( + "--discriminator_adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer." + ) + parser.add_argument("--discriminator_adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument( + "--discriminator_adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer" + ) + # ----Diffusion Training Arguments---- + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + # ----Adversarial Diffusion Distillation (ADD) Specific Arguments---- + parser.add_argument( + "--pretrained_feature_network", + type=str, + default="vit_small_patch16_224_dino", + help=( + "The pretrained feature network used in the discriminator, typically a vision transformer (ViT) trained" + " the DINO objective. The given identifier should be compatible with `timm.create_model`." + ), + ) + parser.add_argument( + "--weight_schedule", + type=str, + default="exponential", + help=( + "The time-dependent weighting function gamma used for scaling the distillation loss Choose between" + " `uniform`, `exponential`, `sds`, and `nfsd`." + ), + ) + parser.add_argument( + "--student_distillation_steps", + type=int, + default=4, + help="The number of student timesteps N used during distillation.", + ) + parser.add_argument( + "--student_timestep_schedule", + type=str, + default="uniform", + help="The method by which the student timestep schedule is determined. Currently, only `uniform` is implemented.", + ) + parser.add_argument( + "--student_custom_timesteps", + type=str, + default=None, + help=( + "A comma-separated list of timesteps which will override the timestep schedule set in" + " `student_timestep_schedule` if set." + ), + ) + parser.add_argument( + "--discriminator_r1_strength", + type=float, + default=1e-05, + help="The discriminator R1 gradient penalty strength gamma.", + ) + parser.add_argument( + "--distillation_weight_factor", + type=float, + default=2.5, + help="Multiplicative weight factor lambda for the distillation loss on the student generator U-Net.", + ) + # ----Exponential Moving Average (EMA)---- + parser.add_argument( + "--ema_decay", + type=float, + default=0.95, + required=False, + help="The exponential moving average (EMA) rate or decay factor.", + ) + # ----Mixed Precision---- + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--cast_teacher_unet", + action="store_true", + help="Whether to cast the teacher U-Net to the precision specified by `--mixed_precision`.", + ) + # ----Training Optimizations---- + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + # ----Distributed Training---- + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + # ----------Validation Arguments---------- + parser.add_argument( + "--validation_steps", + type=int, + default=200, + help="Run validation every X steps.", + ) + # ----------Huggingface Hub Arguments----------- + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + # ----------Accelerate Arguments---------- + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") + + return args + + +# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt +def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True): + prompt_embeds_list = [] + + captions = [] + for caption in prompt_batch: + if random.random() < proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + + with torch.no_grad(): + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + captions, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return prompt_embeds, pooled_prompt_embeds + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + token=args.hub_token, + private=True, + ).repo_id + + # 1. Create the noise scheduler and the desired noise schedule. + # Enforce zero terminal SNR (see section 3.1 of ADD paper) + # TODO: is there a better way to implement this? + teacher_scheduler = DDIMScheduler.from_pretrained( + args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision + ) + if not teacher_scheduler.config.rescale_betas_zero_snr: + teacher_scheduler.config["rescale_betas_zero_snr"] = True + noise_scheduler = DDIMScheduler(**teacher_scheduler.config) + + # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us + alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) + sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) + # denoiser gets predicted original sample x_0 from prediction_type using alpha and sigma noise schedules + denoiser = Denoiser(alpha_schedule, sigma_schedule) + + # Create time-dependent weighting schedule c(t) for scaling the GAN generator reconstruction loss term. + if args.weight_schedule == "uniform": + train_weight_schedule = torch.ones_like(noise_scheduler.alphas_cumprod) + elif args.weight_schedule == "exponential": + # Set gamma schedule equal to alpha_bar (`alphas_cumprod`) schedule. Higher timesteps have less weight. + train_weight_schedule = noise_scheduler.alphas_cumprod + elif args.weight_schedule == "sds": + # Score distillation sampling weighting + # Introduced in the DreamFusion paper: https://arxiv.org/pdf/2209.14988.pdf. + raise NotImplementedError("SDS distillation weighting is not yet implemented.") + elif args.weight_schedule == "nfsd": + # Noise-free score distillation weighting + # Introduced in "Noise-Free Score Distillation": https://arxiv.org/pdf/2310.17590.pdf. + raise NotImplementedError("NFSD distillation weighting is not yet implemented.") + else: + raise ValueError( + f"Weight schedule {args.weight_schedule} is not currently supported. Supported schedules are `uniform`," + f" `exponential`, `sds`, and `nfsd`." + ) + + # Create student timestep schedule tau_1, ..., tau_N. + if args.student_custom_timesteps is not None: + student_timestep_schedule = np.asarray( + sorted([int(timestep.strip()) for timestep in args.student_custom_timesteps.split(",")]) + ) + elif args.student_timestep_schedule == "uniform": + student_timestep_schedule = np.linspace( + 0, noise_scheduler.config.num_train_timesteps - 1, args.student_distillation_steps + ).round() + else: + raise ValueError( + f"Student timestep schedule {args.student_timestep_schedule} was not recognized and custom student" + f" timesteps have not been provided. Either use one of `uniform` for `student_timestep_schedule` or" + f" provide custom timesteps via `student_custom_timesteps`." + ) + student_distillation_steps = student_timestep_schedule.shape[0] + + # 2. Load tokenizers from SD-XL checkpoint. + tokenizer_one = AutoTokenizer.from_pretrained( + args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False + ) + tokenizer_two = AutoTokenizer.from_pretrained( + args.pretrained_teacher_model, subfolder="tokenizer_2", revision=args.teacher_revision, use_fast=False + ) + + # 3. Load text encoders from SD-XL checkpoint. + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_teacher_model, args.teacher_revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_teacher_model, args.teacher_revision, subfolder="text_encoder_2" + ) + + text_encoder_one = text_encoder_cls_one.from_pretrained( + args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision + ) + text_encoder_two = text_encoder_cls_two.from_pretrained( + args.pretrained_teacher_model, subfolder="text_encoder_2", revision=args.teacher_revision + ) + + # 4. Load VAE from SD-XL checkpoint (or more stable VAE) + vae_path = ( + args.pretrained_teacher_model + if args.pretrained_vae_model_name_or_path is None + else args.pretrained_vae_model_name_or_path + ) + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.teacher_revision, + ) + + # 5. Load teacher U-Net from SD-XL checkpoint + teacher_unet = UNet2DConditionModel.from_pretrained( + args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision + ) + + # 6. Initialize GAN generator U-Net from SD-XL checkpoint with the teacher U-Net's pretrained weights + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision + ) + + # 7. Initialize GAN discriminator. + # TODO: Confirm that using text_encoder_one here is correct + discriminator = Discriminator( + pretrained_feature_network=args.pretrained_feature_network, + cond_embedding_dim=text_encoder_one.config.projection_dim, + ) + + # 8. Freeze teacher vae, text_encoders, and teacher_unet + vae.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + teacher_unet.requires_grad_(False) + + unet.train() + + # Check that all trainable models are in full precision + low_precision_error_string = ( + " Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training, copy of the weights should still be float32." + ) + + if accelerator.unwrap_model(unet).dtype != torch.float32: + raise ValueError( + f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" + ) + + # 9. Handle mixed precision and device placement + # For mixed precision training we cast all non-trainable weigths to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + vae.to(accelerator.device) + if args.pretrained_vae_model_name_or_path is not None: + vae.to(dtype=weight_dtype) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + + # Move teacher_unet to device, optionally cast to weight_dtype + teacher_unet.to(accelerator.device) + if args.cast_teacher_unet: + teacher_unet.to(dtype=weight_dtype) + + # Also move the denoiser and schedules to accelerator.device + denoiser.to(accelerator.device) + train_weight_schedule = train_weight_schedule.to(accelerator.device) + student_timestep_schedule = torch.from_numpy(student_timestep_schedule).to(accelerator.device) + + # 10. Handle saving and loading of checkpoints + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + # target_unet.save_pretrained(os.path.join(output_dir, "unet_target")) + + for i, model in enumerate(models): + model.save_pretrained(os.path.join(output_dir, "unet")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + # load_model = UNet2DConditionModel.from_pretrained(os.path.join(input_dir, "unet_target")) + # target_unet.load_state_dict(load_model.state_dict()) + # target_unet.to(accelerator.device) + del load_model + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # 11. Enable optimizations + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + teacher_unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # 12. Optimizer creation for generator and discriminator + optimizer = optimizer_class( + unet.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + discriminator_optimizer = optimizer_class( + discriminator.parameters(), + lr=args.discriminator_learning_rate, + betas=(args.discriminator_adam_beta1, args.discriminator_adam_beta2), + weight_decay=args.discriminator_adam_weight_decay, + eps=args.discriminator_adam_epsilon, + ) + + # 13. Dataset creation and data processing + # Here, we compute not just the text embeddings but also the additional embeddings + # needed for the SD XL UNet to operate. + def compute_embeddings( + prompt_batch, original_sizes, crop_coords, proportion_empty_prompts, text_encoders, tokenizers, is_train=True + ): + target_size = (args.resolution, args.resolution) + original_sizes = list(map(list, zip(*original_sizes))) + crops_coords_top_left = list(map(list, zip(*crop_coords))) + + original_sizes = torch.tensor(original_sizes, dtype=torch.long) + crops_coords_top_left = torch.tensor(crops_coords_top_left, dtype=torch.long) + + prompt_embeds, pooled_prompt_embeds = encode_prompt( + prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train + ) + add_text_embeds = pooled_prompt_embeds + + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids + add_time_ids = list(target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.repeat(len(prompt_batch), 1) + add_time_ids = torch.cat([original_sizes, crops_coords_top_left, add_time_ids], dim=-1) + add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype) + + prompt_embeds = prompt_embeds.to(accelerator.device) + add_text_embeds = add_text_embeds.to(accelerator.device) + unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} + + dataset = SDXLText2ImageDataset( + train_shards_path_or_url=args.train_shards_path_or_url, + num_train_examples=args.max_train_samples, + per_gpu_batch_size=args.train_batch_size, + global_batch_size=args.train_batch_size * accelerator.num_processes, + num_workers=args.dataloader_num_workers, + resolution=args.resolution, + shuffle_buffer_size=1000, + pin_memory=True, + persistent_workers=True, + use_fix_crop_and_size=args.use_fix_crop_and_size, + ) + train_dataloader = dataset.train_dataloader + + # Let's first compute all the embeddings so that we can free up the text encoders + # from memory. + text_encoders = [text_encoder_one, text_encoder_two] + tokenizers = [tokenizer_one, tokenizer_two] + + compute_embeddings_fn = functools.partial( + compute_embeddings, + proportion_empty_prompts=0, + text_encoders=text_encoders, + tokenizers=tokenizers, + ) + + # 14. Create learning rate scheduler for generator and discriminator + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps, + ) + + discriminator_lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=discriminator_optimizer, + num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps, + ) + + # 15. Prepare for training + # Prepare everything with our `accelerator`. + ( + unet, + discriminator, + optimizer, + discriminator_optimizer, + lr_scheduler, + discriminator_lr_scheduler, + ) = accelerator.prepare( + unet, + discriminator, + optimizer, + discriminator_optimizer, + lr_scheduler, + discriminator_lr_scheduler, + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + accelerator.init_trackers(args.tracker_project_name, config=tracker_config) + + # 16. Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num batches each epoch = {train_dataloader.num_batches}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + # 1. Load and process the image, text, and micro-conditioning (original image size, crop coordinates) + image, text, orig_size, crop_coords = batch + + image = image.to(accelerator.device, non_blocking=True) + encoded_text = compute_embeddings_fn(text, orig_size, crop_coords) + + if args.pretrained_vae_model_name_or_path is not None: + pixel_values = image.to(dtype=weight_dtype) + if vae.dtype != weight_dtype: + vae.to(dtype=weight_dtype) + else: + pixel_values = image + + # encode pixel values with batch size of at most 8 + latents = [] + for i in range(0, pixel_values.shape[0], 8): + latents.append(vae.encode(pixel_values[i : i + 8]).latent_dist.sample()) + latents = torch.cat(latents, dim=0) + + latents = latents * vae.config.scaling_factor + if args.pretrained_vae_model_name_or_path is None: + latents = latents.to(weight_dtype) + bsz = latents.shape[0] + + # 2. Sample random student timesteps s uniformly in `student_timestep_schedule` and sample random + # teacher timesteps t uniformly in [0, ..., noise_scheduler.config.num_train_timesteps - 1]. + student_index = torch.randint(0, student_distillation_steps, (bsz,), device=latents.device).long() + student_timesteps = student_timestep_schedule[student_index] + teacher_timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ).long() + + # 3. Sample noise and add it to the latents according to the noise magnitude at each student timestep + # (that is, run the forward process on the student model) + student_noise = torch.randn_like(latents) + noisy_student_input = noise_scheduler.add_noise(latents, student_noise, student_timesteps) + + # 4. Prepare prompt embeds and unet_added_conditions + prompt_embeds = encoded_text.pop("prompt_embeds") + + # 5. Get the student model predicted original sample `student_x_0`. + student_noise_pred = unet( + noisy_student_input, + student_timesteps, + encoder_hidden_states=prompt_embeds.float(), + added_cond_kwargs=encoded_text, + ).sample + student_x_0 = denoiser(student_noise_pred, student_timesteps, noisy_student_input) + + # 6. Sample noise and add it to the student's predicted original sample according to the noise + # magnitude at each teacher timestep (that is, run the forward process on the teacher model, but + # using `student_x_0` instead of latents sampled from the prior). + teacher_noise = torch.randn_like(student_x_0) + noisy_teacher_input = noise_scheduler.add_noise(student_x_0, teacher_noise, teacher_timesteps) + + # 7. Get teacher model predicted original sample `teacher_x_0`. + with torch.no_grad(): + with torch.autocast("cuda"): + teacher_noise_pred = teacher_unet( + noisy_teacher_input.detach(), + teacher_timesteps, + encoder_hidden_states=prompt_embeds.float(), + added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()}, + ).sample + teacher_x_0 = denoiser(teacher_noise_pred, teacher_timesteps, noisy_teacher_input) + + ############################ + # 8. Discriminator Loss + ############################ + discriminator_optimizer.zero_grad(set_to_none=True) + + # 1. Decode real and fake (generated) latents back to pixel space. + # NOTE: the paper doesn't mention this explicitly AFAIK but I think this makes sense since the + # pretrained feature network for the discriminator operates in pixel space rather than latent space. + real_image = vae.decode(latents).sample + student_gen_image = vae.decode(student_x_0).sample + + # 2. Get discriminator real/fake outputs on the real and fake (generated) images respectively. + disc_output_real = discriminator(real_image, prompt_embeds) + disc_output_fake = discriminator(student_gen_image.detach(), prompt_embeds) + + # 3. Calculate the discriminator real adversarial loss terms. + d_logits_real = disc_output_real.logits + # Use hinge loss (see section 3.2, Equation 3 of paper) + d_adv_loss_real = torch.mean(F.relu(torch.ones_like(d_logits_real) - d_logits_real)) + + # 4. Calculate the discriminator R1 gradient penalty term with respect to the gradients from the real + # data. + d_r1_regularizer = 0 + for k, head in discriminator.heads.items(): + head_grad_params = torch.autograd.grad( + outputs=d_adv_loss_real, inputs=head.params(), create_graph=True + ) + head_grad_norm = 0 + for grad in head_grad_params: + head_grad_norm += grad.abs().sum() + d_r1_regularizer += head_grad_norm + + d_loss_real = d_adv_loss_real + args.discriminator_r1_strength * d_r1_regularizer + accelerator.backward(d_loss_real, retain_graph=True) + + # 5. Calculate the discriminator fake adversarial loss terms. + d_logits_fake = disc_output_fake.logits + # Use hinge loss (see section 3.2, Equation 3 of paper) + d_adv_loss_fake = torch.mean(F.relu(torch.ones_like(d_logits_fake) + d_logits_fake)) + accelerator.backward(d_adv_loss_fake) + + d_total_loss = d_loss_real + d_adv_loss_fake + + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(discriminator.parameters(), args.max_grad_norm) + discriminator_optimizer.step() + discriminator_lr_scheduler.step() + + ############################ + # 9. Generator Loss + ############################ + optimizer.zero_grad(set_to_none=True) + + # 1. Rerun the disc on generated image, but this time allow gradients to flow through the generator + disc_output_fake = discriminator(student_gen_image, prompt_embeds) + + # 2. Calculate generator adversarial loss term + g_logits_fake = disc_output_fake.logits + g_adv_loss = torch.mean(-g_logits_fake) + + ############################ + # 10. Distillation Loss + ############################ + # Calculate distillation loss in pixel space rather than latent space (see section 3.1) + teacher_gen_image = vae.decode(teacher_x_0).sample + per_instance_distillation_loss = F.mse_loss( + student_gen_image.float(), teacher_gen_image.float(), reduction="none" + ) + # Note that we use the teacher timesteps t when getting the loss weights. + c_t = extract_into_tensor( + train_weight_schedule, teacher_timesteps, per_instance_distillation_loss.shape + ) + g_distillation_loss = torch.mean(c_t * per_instance_distillation_loss) + + g_total_loss = g_adv_loss + args.distillation_weight_factor * g_distillation_loss + + # Backprop on the generator total loss + accelerator.backward(g_total_loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if global_step % args.validation_steps == 0: + # log_validation(vae, target_unet, args, accelerator, weight_dtype, global_step, "target") + log_validation(vae, unet, args, accelerator, weight_dtype, global_step, "student") + + logs = { + "d_total_loss": d_total_loss.detach().item(), + "g_total_loss": g_total_loss.detach().item(), + "g_adv_loss": g_adv_loss.detach().item(), + "g_distill_loss": g_distillation_loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + } + progress_bar.set_postfix(**logs) + + # Write out additional values for accelerator to report. + logs["d_adv_loss_fake"] = d_adv_loss_fake.detach().item() + logs["d_adv_loss_real"] = d_adv_loss_real.detach().item() + logs["d_r1_regularizer"] = d_r1_regularizer.detach().item() + logs["d_loss_real"] = d_loss_real.detach().item() + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) + unet.save_pretrained(os.path.join(args.output_dir, "unet")) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) From 1944a1c2119ba8f12fd9450f1ca6e342507e53e5 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sun, 24 Dec 2023 09:16:03 -0800 Subject: [PATCH 05/53] Fix bugs in ADD SD distill script. --- examples/add/train_add_distill_sd_wds.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 9cd571490334..a72a2df25790 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -853,6 +853,12 @@ def parse_args(): default=1e-4, help="Initial learning rate (after the potential warmup period) to use.", ) + parser.add_argument( + "--discriminator_learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) parser.add_argument( "--scale_lr", action="store_true", @@ -1253,7 +1259,7 @@ def save_model_hook(models, weights, output_dir): weights.pop() def load_model_hook(models, input_dir): - load_model = UNet2DConditionModel.from_pretrained(os.path.join(input_dir, "unet_target")) + # load_model = UNet2DConditionModel.from_pretrained(os.path.join(input_dir, "unet_target")) # target_unet.load_state_dict(load_model.state_dict()) # target_unet.to(accelerator.device) del load_model @@ -1283,6 +1289,7 @@ def load_model_hook(models, input_dir): "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." ) unet.enable_xformers_memory_efficient_attention() + teacher_unet.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Make sure it is installed correctly") From 281e72a9dad32fed7285ff7e5ddd9477d09b4ec0 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sun, 24 Dec 2023 09:18:49 -0800 Subject: [PATCH 06/53] make style --- examples/add/train_add_distill_sd_wds.py | 2 +- examples/add/train_add_distill_sdxl_wds.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index a72a2df25790..53ac741c8f12 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -1262,7 +1262,7 @@ def load_model_hook(models, input_dir): # load_model = UNet2DConditionModel.from_pretrained(os.path.join(input_dir, "unet_target")) # target_unet.load_state_dict(load_model.state_dict()) # target_unet.to(accelerator.device) - del load_model + # del load_model for i in range(len(models)): # pop models so that they are not loaded again diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index f25b0217f909..9c15ea0d1378 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and import argparse -import copy import functools import gc import itertools @@ -1322,7 +1321,7 @@ def load_model_hook(models, input_dir): # load_model = UNet2DConditionModel.from_pretrained(os.path.join(input_dir, "unet_target")) # target_unet.load_state_dict(load_model.state_dict()) # target_unet.to(accelerator.device) - del load_model + # del load_model for i in range(len(models)): # pop models so that they are not loaded again @@ -1622,7 +1621,7 @@ def compute_embeddings( added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()}, ).sample teacher_x_0 = denoiser(teacher_noise_pred, teacher_timesteps, noisy_teacher_input) - + ############################ # 8. Discriminator Loss ############################ From a925613ce11cdfcc62d4e8a9ffbecf3dcbee05d7 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 25 Dec 2023 15:44:37 -0800 Subject: [PATCH 07/53] apply suggestions from review --- examples/add/train_add_distill_sd_wds.py | 78 +++++++++++----------- examples/add/train_add_distill_sdxl_wds.py | 60 +++++++++++------ 2 files changed, 81 insertions(+), 57 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 53ac741c8f12..c3e59609e612 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -149,6 +149,7 @@ def __init__( global_batch_size: int, num_workers: int, resolution: int = 512, + interpolation_type: str = "bilinear", shuffle_buffer_size: int = 1000, pin_memory: bool = False, persistent_workers: bool = False, @@ -158,10 +159,24 @@ def __init__( # flatten list using itertools train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url)) + if interpolation_type == "bilinear": + self.interpolation_mode = TF.InterpolationMode.BILINEAR + elif interpolation_type == "bicubic": + self.interpolation_mode = TF.InterpolationMode.BICUBIC + elif interpolation_type == "nearest": + self.interpolation_mode = TF.InterpolationMode.NEAREST + elif interpolation_type == "lanczos": + self.interpolation_mode = TF.InterpolationMode.LANCZOS + else: + raise ValueError( + f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" + f" modes are `bilinear`, `bicubic`, `lanczos`, and `nearest`." + ) + def transform(example): # resize image image = example["image"] - image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) + image = TF.resize(image, resolution, interpolation=self.interpolation_mode) # get crop coordinates and crop image c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) @@ -268,7 +283,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L64 class DiscHeadBlock(torch.nn.Module): """ - StyleGAN-T block: SpectralConv1d => BatchNormLocal => LeakyReLU + StyleGAN-T block: SpectralConv1d => GroupNorm => LeakyReLU """ def __init__( @@ -288,12 +303,12 @@ def __init__( padding=kernel_size // 2, padding_mode="circular", ) - self.batch_norm = torch.nn.GroupNorm(num_groups, channels) + self.norm = torch.nn.GroupNorm(num_groups, channels) self.act_fn = torch.nn.LeakyReLU(leaky_relu_neg_slope, inplace=True) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) - x = self.batch_norm(x) + x = self.norm(x) x = self.act_fn(x) return x @@ -607,7 +622,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="stude with torch.autocast("cuda"): images = pipeline( prompt=prompt, - num_inference_steps=4, + num_inference_steps=1, num_images_per_prompt=4, generator=generator, ).images @@ -674,26 +689,6 @@ def update_ema(target_params, source_params, rate=0.99): targ.detach().mul_(rate).add_(src, alpha=1 - rate) -def import_model_class_from_model_name_or_path( - pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" -): - text_encoder_config = PretrainedConfig.from_pretrained( - pretrained_model_name_or_path, subfolder=subfolder, revision=revision - ) - model_class = text_encoder_config.architectures[0] - - if model_class == "CLIPTextModel": - from transformers import CLIPTextModel - - return CLIPTextModel - elif model_class == "CLIPTextModelWithProjection": - from transformers import CLIPTextModelWithProjection - - return CLIPTextModelWithProjection - else: - raise ValueError(f"{model_class} is not supported.") - - def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") # ----------Model Checkpoint Loading Arguments---------- @@ -803,6 +798,15 @@ def parse_args(): " resolution" ), ) + parser.add_argument( + "--interpolation_type", + type=str, + default="bilinear", + help=( + "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," + " `bicubic`, `lanczos`, and `nearest`." + ), + ) parser.add_argument( "--center_crop", default=False, @@ -1115,7 +1119,6 @@ def main(args): # 1. Create the noise scheduler and the desired noise schedule. # Enforce zero terminal SNR (see section 3.1 of ADD paper) - # TODO: is there a better way to implement this? teacher_scheduler = DDIMScheduler.from_pretrained( args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) @@ -1227,17 +1230,13 @@ def main(args): elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - # Move vae and text_encoder to device and cast to weight_dtype + # Move vae, text_encoder, and teacher_unet to device and cast to weight_dtype # The VAE is in float32 to avoid NaN losses. vae.to(accelerator.device) if args.pretrained_vae_model_name_or_path is not None: vae.to(dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) - - # Move teacher_unet to device, optionally cast to weight_dtype - teacher_unet.to(accelerator.device) - if args.cast_teacher_unet: - teacher_unet.to(dtype=weight_dtype) + teacher_unet.to(accelerator.device, dtype=weight_dtype) # Also move the denoiser and schedules to accelerator.device denoiser.to(accelerator.device) @@ -1296,6 +1295,7 @@ def load_model_hook(models, input_dir): # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: + torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True if args.gradient_checkpointing: @@ -1513,12 +1513,11 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok noisy_teacher_input = noise_scheduler.add_noise(student_x_0, teacher_noise, teacher_timesteps) # 7. Get teacher model predicted original sample `teacher_x_0`. - with torch.no_grad(): - with torch.autocast("cuda"): + with torch.no_grad(), torch.autocast("cuda", dtype=weight_dtype): teacher_noise_pred = teacher_unet( noisy_teacher_input.detach(), teacher_timesteps, - encoder_hidden_states=prompt_embeds.float(), + encoder_hidden_states=prompt_embeds, ).sample teacher_x_0 = denoiser(teacher_noise_pred, teacher_timesteps, noisy_teacher_input) @@ -1530,8 +1529,10 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok # 1. Decode real and fake (generated) latents back to pixel space. # NOTE: the paper doesn't mention this explicitly AFAIK but I think this makes sense since the # pretrained feature network for the discriminator operates in pixel space rather than latent space. - real_image = vae.decode(latents).sample - student_gen_image = vae.decode(student_x_0).sample + unscaled_latents = (1 / vae.config.scaling_factor) * latents + unscaled_student_x_0 = (1 / vae.config.scaling_factor) * student_x_0 + real_image = vae.decode(unscaled_latents).sample + student_gen_image = vae.decode(unscaled_student_x_0).sample # 2. Get discriminator real/fake outputs on the real and fake (generated) images respectively. disc_output_real = discriminator(real_image, prompt_embeds) @@ -1586,7 +1587,8 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok # 10. Distillation Loss ############################ # Calculate distillation loss in pixel space rather than latent space (see section 3.1) - teacher_gen_image = vae.decode(teacher_x_0).sample + unscaled_teacher_x_0 = (1 / vae.config.scaling_factor) * teacher_x_0 + teacher_gen_image = vae.decode(unscaled_teacher_x_0).sample per_instance_distillation_loss = F.mse_loss( student_gen_image.float(), teacher_gen_image.float(), reduction="none" ) diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index 9c15ea0d1378..937625e320df 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -154,6 +154,7 @@ def __init__( global_batch_size: int, num_workers: int, resolution: int = 1024, + interpolation_type: str = "bilinear", shuffle_buffer_size: int = 1000, pin_memory: bool = False, persistent_workers: bool = False, @@ -170,10 +171,24 @@ def get_orig_size(json): else: return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0))) + if interpolation_type == "bilinear": + self.interpolation_mode = TF.InterpolationMode.BILINEAR + elif interpolation_type == "bicubic": + self.interpolation_mode = TF.InterpolationMode.BICUBIC + elif interpolation_type == "nearest": + self.interpolation_mode = TF.InterpolationMode.NEAREST + elif interpolation_type == "lanczos": + self.interpolation_mode = TF.InterpolationMode.LANCZOS + else: + raise ValueError( + f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" + f" modes are `bilinear`, `bicubic`, `lanczos`, and `nearest`." + ) + def transform(example): # resize image image = example["image"] - image = TF.resize(image, resolution, interpolation=transforms.InterpolationMode.BILINEAR) + image = TF.resize(image, resolution, interpolation=self.interpolation_mode) # get crop coordinates and crop image c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) @@ -285,7 +300,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L64 class DiscHeadBlock(torch.nn.Module): """ - StyleGAN-T block: SpectralConv1d => BatchNormLocal => LeakyReLU + StyleGAN-T block: SpectralConv1d => GroupNorm => LeakyReLU """ def __init__( @@ -305,12 +320,12 @@ def __init__( padding=kernel_size // 2, padding_mode="circular", ) - self.batch_norm = torch.nn.GroupNorm(num_groups, channels) + self.norm = torch.nn.GroupNorm(num_groups, channels) self.act_fn = torch.nn.LeakyReLU(leaky_relu_neg_slope, inplace=True) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) - x = self.batch_norm(x) + x = self.norm(x) x = self.act_fn(x) return x @@ -624,7 +639,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="stude with torch.autocast("cuda"): images = pipeline( prompt=prompt, - num_inference_steps=4, + num_inference_steps=1, num_images_per_prompt=4, generator=generator, ).images @@ -820,6 +835,15 @@ def parse_args(): " resolution" ), ) + parser.add_argument( + "--interpolation_type", + type=str, + default="bilinear", + help=( + "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," + " `bicubic`, `lanczos`, and `nearest`." + ), + ) parser.add_argument( "--use_fix_crop_and_size", action="store_true", @@ -1153,7 +1177,6 @@ def main(args): # 1. Create the noise scheduler and the desired noise schedule. # Enforce zero terminal SNR (see section 3.1 of ADD paper) - # TODO: is there a better way to implement this? teacher_scheduler = DDIMScheduler.from_pretrained( args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) @@ -1285,18 +1308,14 @@ def main(args): elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - # Move unet, vae and text_encoder to device and cast to weight_dtype + # Move vae, text_encoders, and teacher_unet to device and cast to weight_dtype # The VAE is in float32 to avoid NaN losses. vae.to(accelerator.device) if args.pretrained_vae_model_name_or_path is not None: vae.to(dtype=weight_dtype) text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype) - - # Move teacher_unet to device, optionally cast to weight_dtype - teacher_unet.to(accelerator.device) - if args.cast_teacher_unet: - teacher_unet.to(dtype=weight_dtype) + teacher_unet.to(accelerator.device, dtype=weight_dtype) # Also move the denoiser and schedules to accelerator.device denoiser.to(accelerator.device) @@ -1355,6 +1374,7 @@ def load_model_hook(models, input_dir): # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: + torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.matmul.allow_tf32 = True if args.gradient_checkpointing: @@ -1612,13 +1632,12 @@ def compute_embeddings( noisy_teacher_input = noise_scheduler.add_noise(student_x_0, teacher_noise, teacher_timesteps) # 7. Get teacher model predicted original sample `teacher_x_0`. - with torch.no_grad(): - with torch.autocast("cuda"): + with torch.no_grad(), torch.autocast("cuda", dtype=weight_dtype): teacher_noise_pred = teacher_unet( noisy_teacher_input.detach(), teacher_timesteps, - encoder_hidden_states=prompt_embeds.float(), - added_cond_kwargs={k: v.to(weight_dtype) for k, v in encoded_text.items()}, + encoder_hidden_states=prompt_embeds, + added_cond_kwargs=encoded_text, ).sample teacher_x_0 = denoiser(teacher_noise_pred, teacher_timesteps, noisy_teacher_input) @@ -1630,8 +1649,10 @@ def compute_embeddings( # 1. Decode real and fake (generated) latents back to pixel space. # NOTE: the paper doesn't mention this explicitly AFAIK but I think this makes sense since the # pretrained feature network for the discriminator operates in pixel space rather than latent space. - real_image = vae.decode(latents).sample - student_gen_image = vae.decode(student_x_0).sample + unscaled_latents = (1 / vae.config.scaling_factor) * latents + unscaled_student_x_0 = (1 / vae.config.scaling_factor) * student_x_0 + real_image = vae.decode(unscaled_latents).sample + student_gen_image = vae.decode(unscaled_student_x_0).sample # 2. Get discriminator real/fake outputs on the real and fake (generated) images respectively. disc_output_real = discriminator(real_image, prompt_embeds) @@ -1686,7 +1707,8 @@ def compute_embeddings( # 10. Distillation Loss ############################ # Calculate distillation loss in pixel space rather than latent space (see section 3.1) - teacher_gen_image = vae.decode(teacher_x_0).sample + unscaled_teacher_x_0 = (1 / vae.config.scaling_factor) * teacher_x_0 + teacher_gen_image = vae.decode(unscaled_teacher_x_0).sample per_instance_distillation_loss = F.mse_loss( student_gen_image.float(), teacher_gen_image.float(), reduction="none" ) From 6d028376c2f8c4f2dbe118084be6c59683525522 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 25 Dec 2023 15:47:40 -0800 Subject: [PATCH 08/53] make style --- examples/add/train_add_distill_sd_wds.py | 14 +++++++------- examples/add/train_add_distill_sdxl_wds.py | 14 +++++++------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index c3e59609e612..0bcd3825a981 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -46,7 +46,7 @@ from torch.utils.data import default_collate from torchvision import transforms from tqdm.auto import tqdm -from transformers import AutoTokenizer, CLIPTextModel, PretrainedConfig +from transformers import AutoTokenizer, CLIPTextModel from webdataset.tariterators import ( base_plus_ext, tar_file_expander, @@ -1514,12 +1514,12 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok # 7. Get teacher model predicted original sample `teacher_x_0`. with torch.no_grad(), torch.autocast("cuda", dtype=weight_dtype): - teacher_noise_pred = teacher_unet( - noisy_teacher_input.detach(), - teacher_timesteps, - encoder_hidden_states=prompt_embeds, - ).sample - teacher_x_0 = denoiser(teacher_noise_pred, teacher_timesteps, noisy_teacher_input) + teacher_noise_pred = teacher_unet( + noisy_teacher_input.detach(), + teacher_timesteps, + encoder_hidden_states=prompt_embeds, + ).sample + teacher_x_0 = denoiser(teacher_noise_pred, teacher_timesteps, noisy_teacher_input) ############################ # 8. Discriminator Loss diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index 937625e320df..3d8a4bf04724 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -1633,13 +1633,13 @@ def compute_embeddings( # 7. Get teacher model predicted original sample `teacher_x_0`. with torch.no_grad(), torch.autocast("cuda", dtype=weight_dtype): - teacher_noise_pred = teacher_unet( - noisy_teacher_input.detach(), - teacher_timesteps, - encoder_hidden_states=prompt_embeds, - added_cond_kwargs=encoded_text, - ).sample - teacher_x_0 = denoiser(teacher_noise_pred, teacher_timesteps, noisy_teacher_input) + teacher_noise_pred = teacher_unet( + noisy_teacher_input.detach(), + teacher_timesteps, + encoder_hidden_states=prompt_embeds, + added_cond_kwargs=encoded_text, + ).sample + teacher_x_0 = denoiser(teacher_noise_pred, teacher_timesteps, noisy_teacher_input) ############################ # 8. Discriminator Loss From 60a9ea76f4c5b524e21c18d05151c4167149a31c Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 26 Dec 2023 13:17:30 -0800 Subject: [PATCH 09/53] Use DDPMScheduler instead of DDIMScheduler for noise_scheduler since PR #6305 has been merged. --- examples/add/train_add_distill_sd_wds.py | 6 +++--- examples/add/train_add_distill_sdxl_wds.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 0bcd3825a981..4927b79b2d2b 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -57,7 +57,7 @@ import diffusers from diffusers import ( AutoencoderKL, - DDIMScheduler, + DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel, ) @@ -1119,12 +1119,12 @@ def main(args): # 1. Create the noise scheduler and the desired noise schedule. # Enforce zero terminal SNR (see section 3.1 of ADD paper) - teacher_scheduler = DDIMScheduler.from_pretrained( + teacher_scheduler = DDPMScheduler.from_pretrained( args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) if not teacher_scheduler.config.rescale_betas_zero_snr: teacher_scheduler.config["rescale_betas_zero_snr"] = True - noise_scheduler = DDIMScheduler(**teacher_scheduler.config) + noise_scheduler = DDPMScheduler(**teacher_scheduler.config) # DDIMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index 3d8a4bf04724..8eda39ce47a2 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -57,7 +57,7 @@ import diffusers from diffusers import ( AutoencoderKL, - DDIMScheduler, + DDPMScheduler, StableDiffusionXLPipeline, UNet2DConditionModel, ) @@ -1177,12 +1177,12 @@ def main(args): # 1. Create the noise scheduler and the desired noise schedule. # Enforce zero terminal SNR (see section 3.1 of ADD paper) - teacher_scheduler = DDIMScheduler.from_pretrained( + teacher_scheduler = DDPMScheduler.from_pretrained( args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) if not teacher_scheduler.config.rescale_betas_zero_snr: teacher_scheduler.config["rescale_betas_zero_snr"] = True - noise_scheduler = DDIMScheduler(**teacher_scheduler.config) + noise_scheduler = DDPMScheduler(**teacher_scheduler.config) # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) From c7062440bb0b33706ce3f0a6159fff76a81885fe Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 26 Dec 2023 14:45:04 -0800 Subject: [PATCH 10/53] Implement SDS weighting and fix bug in exponential weighting function. --- examples/add/train_add_distill_sd_wds.py | 12 +++++++----- examples/add/train_add_distill_sdxl_wds.py | 10 ++++++---- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 4927b79b2d2b..c1f1d4c7891f 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -1126,7 +1126,8 @@ def main(args): teacher_scheduler.config["rescale_betas_zero_snr"] = True noise_scheduler = DDPMScheduler(**teacher_scheduler.config) - # DDIMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us + # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us + # Note that the ADD paper parameterizes alpha and sigma as x_t = alpha_t * x_0 + sigma_t * eps alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) # denoiser gets predicted original sample x_0 from prediction_type using alpha and sigma noise schedules @@ -1136,12 +1137,13 @@ def main(args): if args.weight_schedule == "uniform": train_weight_schedule = torch.ones_like(noise_scheduler.alphas_cumprod) elif args.weight_schedule == "exponential": - # Set gamma schedule equal to alpha_bar (`alphas_cumprod`) schedule. Higher timesteps have less weight. - train_weight_schedule = noise_scheduler.alphas_cumprod + # Set weight schedule equal to alpha_schedule. Higher timesteps have less weight. + train_weight_schedule = alpha_schedule elif args.weight_schedule == "sds": - # Score distillation sampling weighting + # Score distillation sampling weighting: alpha_t / (2 * sigma_t) * w(t) + # NOTE: choose w(t) = 1 # Introduced in the DreamFusion paper: https://arxiv.org/pdf/2209.14988.pdf. - raise NotImplementedError("SDS distillation weighting is not yet implemented.") + train_weight_schedule = alpha_schedule / (2 * sigma_schedule) elif args.weight_schedule == "nfsd": # Noise-free score distillation weighting # Introduced in "Noise-Free Score Distillation": https://arxiv.org/pdf/2310.17590.pdf. diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index 8eda39ce47a2..69f66425ba5f 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -1185,6 +1185,7 @@ def main(args): noise_scheduler = DDPMScheduler(**teacher_scheduler.config) # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us + # Note that the ADD paper parameterizes alpha and sigma as x_t = alpha_t * x_0 + sigma_t * eps alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) # denoiser gets predicted original sample x_0 from prediction_type using alpha and sigma noise schedules @@ -1194,12 +1195,13 @@ def main(args): if args.weight_schedule == "uniform": train_weight_schedule = torch.ones_like(noise_scheduler.alphas_cumprod) elif args.weight_schedule == "exponential": - # Set gamma schedule equal to alpha_bar (`alphas_cumprod`) schedule. Higher timesteps have less weight. - train_weight_schedule = noise_scheduler.alphas_cumprod + # Set weight schedule equal to alpha_schedule. Higher timesteps have less weight. + train_weight_schedule = alpha_schedule elif args.weight_schedule == "sds": - # Score distillation sampling weighting + # Score distillation sampling weighting: alpha_t / (2 * sigma_t) * w(t) + # NOTE: choose w(t) = 1 # Introduced in the DreamFusion paper: https://arxiv.org/pdf/2209.14988.pdf. - raise NotImplementedError("SDS distillation weighting is not yet implemented.") + train_weight_schedule = alpha_schedule / (2 * sigma_schedule) elif args.weight_schedule == "nfsd": # Noise-free score distillation weighting # Introduced in "Noise-Free Score Distillation": https://arxiv.org/pdf/2310.17590.pdf. From 76fa403e337596df24eb3964de7f19d5bb8d30d8 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 26 Dec 2023 18:51:55 -0800 Subject: [PATCH 11/53] Use teacher CFG estimate as distillation target. --- examples/add/train_add_distill_sd_wds.py | 54 +++++++++++++++++--- examples/add/train_add_distill_sdxl_wds.py | 58 +++++++++++++++++++--- 2 files changed, 100 insertions(+), 12 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index c1f1d4c7891f..a276528fa9c8 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -965,6 +965,27 @@ def parse_args(): default=2.5, help="Multiplicative weight factor lambda for the distillation loss on the student generator U-Net.", ) + parser.add_argument( + "--w_min", + type=float, + default=1.0, + required=False, + help=( + "The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG" + " formulation." + ), + ) + parser.add_argument( + "--w_max", + type=float, + default=15.0, + required=False, + help=( + "The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG" + " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as" + " compared to the original paper." + ), + ) # ----Exponential Moving Average (EMA)---- parser.add_argument( "--ema_decay", @@ -1413,6 +1434,12 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok tracker_config = dict(vars(args)) accelerator.init_trackers(args.tracker_project_name, config=tracker_config) + # Prepare unconditional text embedding for CFG. + uncond_input_ids = tokenizer( + [""] * args.train_batch_size, return_tensors="pt", padding="max_length", max_length=MAX_SEQ_LENGTH + ).input_ids.to(accelerator.device) + uncond_prompt_embeds = text_encoder(uncond_input_ids)[0] + # 16. Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -1514,17 +1541,32 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok teacher_noise = torch.randn_like(student_x_0) noisy_teacher_input = noise_scheduler.add_noise(student_x_0, teacher_noise, teacher_timesteps) - # 7. Get teacher model predicted original sample `teacher_x_0`. + # 7. Sample random guidance scales w ~ U[w_min, w_max] for CFG. + w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min + w = w.reshape(bsz, 1, 1, 1) + # Move to U-Net device and dtype + w = w.to(device=latents.device, dtype=latents.dtype) + + # 8. Get teacher model predicted original sample `teacher_x_0`. with torch.no_grad(), torch.autocast("cuda", dtype=weight_dtype): - teacher_noise_pred = teacher_unet( + teacher_cond_noise_pred = teacher_unet( noisy_teacher_input.detach(), teacher_timesteps, encoder_hidden_states=prompt_embeds, ).sample - teacher_x_0 = denoiser(teacher_noise_pred, teacher_timesteps, noisy_teacher_input) + + teacher_uncond_noise_pred = teacher_unet( + noisy_teacher_input.detach(), + teacher_timesteps, + encoder_hidden_states=uncond_prompt_embeds, + ).sample + + # Get the teacher's CFG estimate of x_0. + teacher_cfg_noise_pred = w * teacher_cond_noise_pred + (1 - w) * teacher_uncond_noise_pred + teacher_x_0 = denoiser(teacher_cfg_noise_pred, teacher_timesteps, noisy_teacher_input) ############################ - # 8. Discriminator Loss + # 9. Discriminator Loss ############################ discriminator_optimizer.zero_grad(set_to_none=True) @@ -1574,7 +1616,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok discriminator_lr_scheduler.step() ############################ - # 9. Generator Loss + # 10. Generator Loss ############################ optimizer.zero_grad(set_to_none=True) @@ -1586,7 +1628,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok g_adv_loss = torch.mean(-g_logits_fake) ############################ - # 10. Distillation Loss + # 11. Distillation Loss ############################ # Calculate distillation loss in pixel space rather than latent space (see section 3.1) unscaled_teacher_x_0 = (1 / vae.config.scaling_factor) * teacher_x_0 diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index 69f66425ba5f..0462bef787f6 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and import argparse +import copy import functools import gc import itertools @@ -1008,6 +1009,27 @@ def parse_args(): default=2.5, help="Multiplicative weight factor lambda for the distillation loss on the student generator U-Net.", ) + parser.add_argument( + "--w_min", + type=float, + default=1.0, + required=False, + help=( + "The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG" + " formulation." + ), + ) + parser.add_argument( + "--w_max", + type=float, + default=8.0, + required=False, + help=( + "The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG" + " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as" + " compared to the original paper." + ), + ) # ----Exponential Moving Average (EMA)---- parser.add_argument( "--ema_decay", @@ -1527,6 +1549,12 @@ def compute_embeddings( tracker_config = dict(vars(args)) accelerator.init_trackers(args.tracker_project_name, config=tracker_config) + # Create uncond embeds for classifier free guidance + uncond_prompt_embeds = torch.zeros(args.train_batch_size, MAX_SEQ_LENGTH, 2048).to(accelerator.device) + uncond_pooled_prompt_embeds = torch.zeros( + args.train_batch_size, text_encoder_two.config.projection_dim + ).to(accelerator.device) + # 16. Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -1633,18 +1661,36 @@ def compute_embeddings( teacher_noise = torch.randn_like(student_x_0) noisy_teacher_input = noise_scheduler.add_noise(student_x_0, teacher_noise, teacher_timesteps) - # 7. Get teacher model predicted original sample `teacher_x_0`. + # 7. Sample random guidance scales w ~ U[w_min, w_max] for CFG. + w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min + w = w.reshape(bsz, 1, 1, 1) + # Move to U-Net device and dtype + w = w.to(device=latents.device, dtype=latents.dtype) + + # 8. Get teacher model predicted original sample `teacher_x_0`. with torch.no_grad(), torch.autocast("cuda", dtype=weight_dtype): - teacher_noise_pred = teacher_unet( + teacher_cond_noise_pred = teacher_unet( noisy_teacher_input.detach(), teacher_timesteps, encoder_hidden_states=prompt_embeds, added_cond_kwargs=encoded_text, ).sample - teacher_x_0 = denoiser(teacher_noise_pred, teacher_timesteps, noisy_teacher_input) + + uncond_added_conditions = copy.deepcopy(encoded_text) + uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds + teacher_uncond_noise_pred = teacher_unet( + noisy_teacher_input.detach(), + teacher_timesteps, + encoder_hidden_states=uncond_prompt_embeds, + added_cond_kwargs=uncond_added_conditions, + ).sample + + # Get the teacher's CFG estimate of x_0. + teacher_cfg_noise_pred = w * teacher_cond_noise_pred + (1 - w) * teacher_uncond_noise_pred + teacher_x_0 = denoiser(teacher_cfg_noise_pred, teacher_timesteps, noisy_teacher_input) ############################ - # 8. Discriminator Loss + # 9. Discriminator Loss ############################ discriminator_optimizer.zero_grad(set_to_none=True) @@ -1694,7 +1740,7 @@ def compute_embeddings( discriminator_lr_scheduler.step() ############################ - # 9. Generator Loss + # 10. Generator Loss ############################ optimizer.zero_grad(set_to_none=True) @@ -1706,7 +1752,7 @@ def compute_embeddings( g_adv_loss = torch.mean(-g_logits_fake) ############################ - # 10. Distillation Loss + # 11. Distillation Loss ############################ # Calculate distillation loss in pixel space rather than latent space (see section 3.1) unscaled_teacher_x_0 = (1 / vae.config.scaling_factor) * teacher_x_0 From 840f7ddc7aedcdcb1a8bd7248a3b0919784eb920 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 26 Dec 2023 18:58:04 -0800 Subject: [PATCH 12/53] make style --- examples/add/train_add_distill_sdxl_wds.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index 0462bef787f6..3b7cab5d6ae8 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -1551,9 +1551,9 @@ def compute_embeddings( # Create uncond embeds for classifier free guidance uncond_prompt_embeds = torch.zeros(args.train_batch_size, MAX_SEQ_LENGTH, 2048).to(accelerator.device) - uncond_pooled_prompt_embeds = torch.zeros( - args.train_batch_size, text_encoder_two.config.projection_dim - ).to(accelerator.device) + uncond_pooled_prompt_embeds = torch.zeros(args.train_batch_size, text_encoder_two.config.projection_dim).to( + accelerator.device + ) # 16. Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps From 1a6569f931878482b394f0b09180fda2b40e215c Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 26 Dec 2023 19:45:04 -0800 Subject: [PATCH 13/53] Add option to also maintain EMA version of student U-Net parameters. --- examples/add/train_add_distill_sd_wds.py | 57 +++++++++++++++++++--- examples/add/train_add_distill_sdxl_wds.py | 57 +++++++++++++++++++--- 2 files changed, 102 insertions(+), 12 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index a276528fa9c8..866b3fce6d6e 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -62,6 +62,7 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel from diffusers.utils import BaseOutput, check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -987,6 +988,11 @@ def parse_args(): ), ) # ----Exponential Moving Average (EMA)---- + parser.add_argument( + "--use_ema", + action="store_true", + help="Whether to also maintain an EMA version of the student U-Net weights." + ) parser.add_argument( "--ema_decay", type=float, @@ -1220,6 +1226,18 @@ def main(args): args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision ) + # Make exponential moving average (EMA) version of the student unet weights, if using. + if args.use_ema: + ema_unet = UNet2DConditionModel.from_pretrained( + args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision + ) + ema_unet = EMAModel( + ema_unet.parameters(), + decay=args.ema_decay, + model_cls=UNet2DConditionModel, + model_config=ema_unet.config, + ) + # 7. Initialize GAN discriminator. discriminator = Discriminator( pretrained_feature_network=args.pretrained_feature_network, @@ -1261,6 +1279,10 @@ def main(args): text_encoder.to(accelerator.device, dtype=weight_dtype) teacher_unet.to(accelerator.device, dtype=weight_dtype) + # Move target (EMA) unet to device but keep in full precision + if args.use_ema: + ema_unet.to(accelerator.device) + # Also move the denoiser and schedules to accelerator.device denoiser.to(accelerator.device) train_weight_schedule = train_weight_schedule.to(accelerator.device) @@ -1272,7 +1294,8 @@ def main(args): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: - # target_unet.save_pretrained(os.path.join(output_dir, "unet_target")) + if args.use_ema: + ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) for i, model in enumerate(models): model.save_pretrained(os.path.join(output_dir, "unet")) @@ -1281,10 +1304,11 @@ def save_model_hook(models, weights, output_dir): weights.pop() def load_model_hook(models, input_dir): - # load_model = UNet2DConditionModel.from_pretrained(os.path.join(input_dir, "unet_target")) - # target_unet.load_state_dict(load_model.state_dict()) - # target_unet.to(accelerator.device) - # del load_model + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) + ema_unet.load_state_dict(load_model.state_dict()) + ema_unet.to(accelerator.device) + del load_model for i in range(len(models)): # pop models so that they are not loaded again @@ -1653,6 +1677,9 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: + # 12. Perform an EMA update on the EMA version of the student U-Net weights. + if args.use_ema: + ema_unet.step(unet.parameters()) progress_bar.update(1) global_step += 1 @@ -1683,7 +1710,16 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok logger.info(f"Saved state to {save_path}") if global_step % args.validation_steps == 0: - # log_validation(vae, target_unet, args, accelerator, weight_dtype, global_step, "target") + if args.use_ema: + # Store the student unet weights and load the EMA weights. + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + + log_validation(vae, ema_unet, args, accelerator, weight_dtype, global_step, "ema_student") + + # Restore student unet weights + ema_unet.restore(unet.parameters()) + log_validation(vae, unet, args, accelerator, weight_dtype, global_step, "student") logs = { @@ -1711,6 +1747,15 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok unet = accelerator.unwrap_model(unet) unet.save_pretrained(os.path.join(args.output_dir, "unet")) + # If using EMA, save EMA weights as well. + if args.use_ema: + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + + unet.save_pretrained(os.path.join(args.output_dir, "ema_unet")) + + ema_unet.restore(unet.parameters()) + accelerator.end_training() diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index 3b7cab5d6ae8..c4b2329ab0fc 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -63,6 +63,7 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel from diffusers.utils import BaseOutput, check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -1031,6 +1032,11 @@ def parse_args(): ), ) # ----Exponential Moving Average (EMA)---- + parser.add_argument( + "--use_ema", + action="store_true", + help="Whether to also maintain an EMA version of the student U-Net weights." + ) parser.add_argument( "--ema_decay", type=float, @@ -1297,6 +1303,18 @@ def main(args): args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision ) + # Make exponential moving average (EMA) version of the student unet weights, if using. + if args.use_ema: + ema_unet = UNet2DConditionModel.from_pretrained( + args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision + ) + ema_unet = EMAModel( + ema_unet.parameters(), + decay=args.ema_decay, + model_cls=UNet2DConditionModel, + model_config=ema_unet.config, + ) + # 7. Initialize GAN discriminator. # TODO: Confirm that using text_encoder_one here is correct discriminator = Discriminator( @@ -1341,6 +1359,10 @@ def main(args): text_encoder_two.to(accelerator.device, dtype=weight_dtype) teacher_unet.to(accelerator.device, dtype=weight_dtype) + # Move target (EMA) unet to device but keep in full precision + if args.use_ema: + ema_unet.to(accelerator.device) + # Also move the denoiser and schedules to accelerator.device denoiser.to(accelerator.device) train_weight_schedule = train_weight_schedule.to(accelerator.device) @@ -1352,7 +1374,8 @@ def main(args): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: - # target_unet.save_pretrained(os.path.join(output_dir, "unet_target")) + if args.use_ema: + ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) for i, model in enumerate(models): model.save_pretrained(os.path.join(output_dir, "unet")) @@ -1361,10 +1384,11 @@ def save_model_hook(models, weights, output_dir): weights.pop() def load_model_hook(models, input_dir): - # load_model = UNet2DConditionModel.from_pretrained(os.path.join(input_dir, "unet_target")) - # target_unet.load_state_dict(load_model.state_dict()) - # target_unet.to(accelerator.device) - # del load_model + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) + ema_unet.load_state_dict(load_model.state_dict()) + ema_unet.to(accelerator.device) + del load_model for i in range(len(models)): # pop models so that they are not loaded again @@ -1777,6 +1801,9 @@ def compute_embeddings( # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: + # 12. Perform an EMA update on the EMA version of the student U-Net weights. + if args.use_ema: + ema_unet.step(unet.parameters()) progress_bar.update(1) global_step += 1 @@ -1807,7 +1834,16 @@ def compute_embeddings( logger.info(f"Saved state to {save_path}") if global_step % args.validation_steps == 0: - # log_validation(vae, target_unet, args, accelerator, weight_dtype, global_step, "target") + if args.use_ema: + # Store the student unet weights and load the EMA weights. + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + + log_validation(vae, ema_unet, args, accelerator, weight_dtype, global_step, "ema_student") + + # Restore student unet weights + ema_unet.restore(unet.parameters()) + log_validation(vae, unet, args, accelerator, weight_dtype, global_step, "student") logs = { @@ -1835,6 +1871,15 @@ def compute_embeddings( unet = accelerator.unwrap_model(unet) unet.save_pretrained(os.path.join(args.output_dir, "unet")) + # If using EMA, save EMA weights as well. + if args.use_ema: + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + + unet.save_pretrained(os.path.join(args.output_dir, "ema_unet")) + + ema_unet.restore(unet.parameters()) + accelerator.end_training() From af31aa7f10cee291174ee79303ac6e213f076fdf Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Tue, 26 Dec 2023 19:45:51 -0800 Subject: [PATCH 14/53] make style --- examples/add/train_add_distill_sd_wds.py | 4 +--- examples/add/train_add_distill_sdxl_wds.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 866b3fce6d6e..2bbabe6cfab9 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -989,9 +989,7 @@ def parse_args(): ) # ----Exponential Moving Average (EMA)---- parser.add_argument( - "--use_ema", - action="store_true", - help="Whether to also maintain an EMA version of the student U-Net weights." + "--use_ema", action="store_true", help="Whether to also maintain an EMA version of the student U-Net weights." ) parser.add_argument( "--ema_decay", diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index c4b2329ab0fc..1aa91f6de71a 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -1033,9 +1033,7 @@ def parse_args(): ) # ----Exponential Moving Average (EMA)---- parser.add_argument( - "--use_ema", - action="store_true", - help="Whether to also maintain an EMA version of the student U-Net weights." + "--use_ema", action="store_true", help="Whether to also maintain an EMA version of the student U-Net weights." ) parser.add_argument( "--ema_decay", From 969823f88786d5b2641bbb81584b192518db9249 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 27 Dec 2023 11:52:50 -0800 Subject: [PATCH 15/53] apply suggestions from review --- examples/add/train_add_distill_sd_wds.py | 17 ++++++----------- examples/add/train_add_distill_sdxl_wds.py | 17 ++++++----------- 2 files changed, 12 insertions(+), 22 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 2bbabe6cfab9..8e5a7174e53d 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -336,20 +336,17 @@ def __init__( self.input_block = DiscHeadBlock(channels, kernel_size=1) self.resblock = ResidualBlock(DiscHeadBlock(channels, kernel_size=9)) - if self.cond_embedding_dim > 0: - self.conditioning_map = torch.nn.Linear(self.cond_embedding_dim, cond_map_dim) - self.cls = SpectralConv1d(channels, cond_map_dim, kernel_size=1, padding=0) - else: - self.cls = SpectralConv1d(channels, 1, kernel_size=1, padding=0) + # Map the feature network token embeddings and conditioning embedding to a common dimension cond_map_dim. + self.conditioning_map = torch.nn.Linear(self.cond_embedding_dim, cond_map_dim) + self.cls = SpectralConv1d(channels, cond_map_dim, kernel_size=1, padding=0) def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: hidden_states = self.input_block(x) hidden_states = self.resblock(hidden_states) out = self.cls(hidden_states) - if self.cond_embedding_dim > 0: - c = self.conditioning_map(c).squeeze(-1) - out = (out * c).sum(1, keepdim=True) * (1 / np.sqrt(self.cond_map_dim)) + c = self.conditioning_map(c).squeeze(-1) + out = (out * c).sum(1, keepdim=True) * (1 / np.sqrt(self.cond_map_dim)) return out @@ -1713,7 +1710,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok ema_unet.store(unet.parameters()) ema_unet.copy_to(unet.parameters()) - log_validation(vae, ema_unet, args, accelerator, weight_dtype, global_step, "ema_student") + log_validation(vae, unet, args, accelerator, weight_dtype, global_step, "ema_student") # Restore student unet weights ema_unet.restore(unet.parameters()) @@ -1752,8 +1749,6 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok unet.save_pretrained(os.path.join(args.output_dir, "ema_unet")) - ema_unet.restore(unet.parameters()) - accelerator.end_training() diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index 1aa91f6de71a..a38e28fd2f73 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -354,20 +354,17 @@ def __init__( self.input_block = DiscHeadBlock(channels, kernel_size=1) self.resblock = ResidualBlock(DiscHeadBlock(channels, kernel_size=9)) - if self.cond_embedding_dim > 0: - self.conditioning_map = torch.nn.Linear(self.cond_embedding_dim, cond_map_dim) - self.cls = SpectralConv1d(channels, cond_map_dim, kernel_size=1, padding=0) - else: - self.cls = SpectralConv1d(channels, 1, kernel_size=1, padding=0) + # Map the feature network token embeddings and conditioning embedding to a common dimension cond_map_dim. + self.conditioning_map = torch.nn.Linear(self.cond_embedding_dim, cond_map_dim) + self.cls = SpectralConv1d(channels, cond_map_dim, kernel_size=1, padding=0) def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: hidden_states = self.input_block(x) hidden_states = self.resblock(hidden_states) out = self.cls(hidden_states) - if self.cond_embedding_dim > 0: - c = self.conditioning_map(c).squeeze(-1) - out = (out * c).sum(1, keepdim=True) * (1 / np.sqrt(self.cond_map_dim)) + c = self.conditioning_map(c).squeeze(-1) + out = (out * c).sum(1, keepdim=True) * (1 / np.sqrt(self.cond_map_dim)) return out @@ -1837,7 +1834,7 @@ def compute_embeddings( ema_unet.store(unet.parameters()) ema_unet.copy_to(unet.parameters()) - log_validation(vae, ema_unet, args, accelerator, weight_dtype, global_step, "ema_student") + log_validation(vae, unet, args, accelerator, weight_dtype, global_step, "ema_student") # Restore student unet weights ema_unet.restore(unet.parameters()) @@ -1876,8 +1873,6 @@ def compute_embeddings( unet.save_pretrained(os.path.join(args.output_dir, "ema_unet")) - ema_unet.restore(unet.parameters()) - accelerator.end_training() From ddbcd7dc501b866c132dd4f8d031385ae77565d2 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 27 Dec 2023 13:41:13 -0800 Subject: [PATCH 16/53] Fix some bugs. --- examples/add/train_add_distill_sd_wds.py | 24 ++++++++++------------ examples/add/train_add_distill_sdxl_wds.py | 24 ++++++++++------------ 2 files changed, 22 insertions(+), 26 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 8e5a7174e53d..3718e51343e8 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -266,7 +266,7 @@ def __call__(self, model_output, timesteps, sample): class SpectralConv1d(torch.nn.Conv1d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - torch.nn.utils.SpectralNorm.apply(self, name="weight", n_power_iterations=1, dim=0, eps=1e-12) + torch.nn.utils.parametrizations.spectral_norm(self, name="weight", n_power_iterations=1, eps=1e-12, dim=0) # Based on ResidualBlock from the official StyleGAN-T code @@ -449,7 +449,7 @@ class FeatureNetwork(torch.nn.Module): def __init__( self, - pretrained_feature_network: str = "vit_small_patch16_224_dino", + pretrained_feature_network: str = "vit_small_patch16_224.dino", patch_size: List[int] = [16, 16], hooks: List[int] = [2, 5, 8, 11], start_index: int = 1, @@ -546,7 +546,7 @@ def __init__( # Trainable discriminator heads heads = [] for i in range(self.feature_network.num_hooks): - heads += [str(i), DiscriminatorHead(self.feature_network.embed_dim, cond_embedding_dim)] + heads.append([str(i), DiscriminatorHead(self.feature_network.embed_dim, cond_embedding_dim)]) self.heads = torch.nn.ModuleDict(heads) def train(self, mode: bool = True): @@ -915,7 +915,7 @@ def parse_args(): parser.add_argument( "--pretrained_feature_network", type=str, - default="vit_small_patch16_224_dino", + default="vit_small_patch16_224.dino", help=( "The pretrained feature network used in the discriminator, typically a vision transformer (ViT) trained" " the DINO objective. The given identifier should be compatible with `timm.create_model`." @@ -1179,12 +1179,12 @@ def main(args): # Create student timestep schedule tau_1, ..., tau_N. if args.student_custom_timesteps is not None: student_timestep_schedule = np.asarray( - sorted([int(timestep.strip()) for timestep in args.student_custom_timesteps.split(",")]) + sorted([int(timestep.strip()) for timestep in args.student_custom_timesteps.split(",")]), dtype=np.int64 ) elif args.student_timestep_schedule == "uniform": student_timestep_schedule = np.linspace( 0, noise_scheduler.config.num_train_timesteps - 1, args.student_distillation_steps - ).round() + ).round().astype(np.int64) else: raise ValueError( f"Student timestep schedule {args.student_timestep_schedule} was not recognized and custom student" @@ -1236,7 +1236,7 @@ def main(args): # 7. Initialize GAN discriminator. discriminator = Discriminator( pretrained_feature_network=args.pretrained_feature_network, - cond_embedding_dim=text_encoder.config.projection_dim, + cond_embedding_dim=text_encoder.config.hidden_size, ) # 8. Freeze teacher vae, text_encoder, and teacher_unet @@ -1592,14 +1592,12 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok # 1. Decode real and fake (generated) latents back to pixel space. # NOTE: the paper doesn't mention this explicitly AFAIK but I think this makes sense since the # pretrained feature network for the discriminator operates in pixel space rather than latent space. - unscaled_latents = (1 / vae.config.scaling_factor) * latents unscaled_student_x_0 = (1 / vae.config.scaling_factor) * student_x_0 - real_image = vae.decode(unscaled_latents).sample - student_gen_image = vae.decode(unscaled_student_x_0).sample + student_gen_image = vae.decode(unscaled_student_x_0.to(dtype=weight_dtype)).sample # 2. Get discriminator real/fake outputs on the real and fake (generated) images respectively. - disc_output_real = discriminator(real_image, prompt_embeds) - disc_output_fake = discriminator(student_gen_image.detach(), prompt_embeds) + disc_output_real = discriminator(pixel_values.float(), prompt_embeds) + disc_output_fake = discriminator(student_gen_image.detach().float(), prompt_embeds) # 3. Calculate the discriminator real adversarial loss terms. d_logits_real = disc_output_real.logits @@ -1651,7 +1649,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok ############################ # Calculate distillation loss in pixel space rather than latent space (see section 3.1) unscaled_teacher_x_0 = (1 / vae.config.scaling_factor) * teacher_x_0 - teacher_gen_image = vae.decode(unscaled_teacher_x_0).sample + teacher_gen_image = vae.decode(unscaled_teacher_x_0.to(dtype=weight_dtype)).sample per_instance_distillation_loss = F.mse_loss( student_gen_image.float(), teacher_gen_image.float(), reduction="none" ) diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index a38e28fd2f73..b51bd0e88e9a 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -284,7 +284,7 @@ def __call__(self, model_output, timesteps, sample): class SpectralConv1d(torch.nn.Conv1d): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - torch.nn.utils.SpectralNorm.apply(self, name="weight", n_power_iterations=1, dim=0, eps=1e-12) + torch.nn.utils.parametrizations.spectral_norm(self, name="weight", n_power_iterations=1, eps=1e-12, dim=0) # Based on ResidualBlock from the official StyleGAN-T code @@ -467,7 +467,7 @@ class FeatureNetwork(torch.nn.Module): def __init__( self, - pretrained_feature_network: str = "vit_small_patch16_224_dino", + pretrained_feature_network: str = "vit_small_patch16_224.dino", patch_size: List[int] = [16, 16], hooks: List[int] = [2, 5, 8, 11], start_index: int = 1, @@ -564,7 +564,7 @@ def __init__( # Trainable discriminator heads heads = [] for i in range(self.feature_network.num_hooks): - heads += [str(i), DiscriminatorHead(self.feature_network.embed_dim, cond_embedding_dim)] + heads.append([str(i), DiscriminatorHead(self.feature_network.embed_dim, cond_embedding_dim)]) self.heads = torch.nn.ModuleDict(heads) def train(self, mode: bool = True): @@ -959,7 +959,7 @@ def parse_args(): parser.add_argument( "--pretrained_feature_network", type=str, - default="vit_small_patch16_224_dino", + default="vit_small_patch16_224.dino", help=( "The pretrained feature network used in the discriminator, typically a vision transformer (ViT) trained" " the DINO objective. The given identifier should be compatible with `timm.create_model`." @@ -1238,12 +1238,12 @@ def main(args): # Create student timestep schedule tau_1, ..., tau_N. if args.student_custom_timesteps is not None: student_timestep_schedule = np.asarray( - sorted([int(timestep.strip()) for timestep in args.student_custom_timesteps.split(",")]) + sorted([int(timestep.strip()) for timestep in args.student_custom_timesteps.split(",")]), dtype=np.int64 ) elif args.student_timestep_schedule == "uniform": student_timestep_schedule = np.linspace( 0, noise_scheduler.config.num_train_timesteps - 1, args.student_distillation_steps - ).round() + ).round().astype(np.int64) else: raise ValueError( f"Student timestep schedule {args.student_timestep_schedule} was not recognized and custom student" @@ -1314,7 +1314,7 @@ def main(args): # TODO: Confirm that using text_encoder_one here is correct discriminator = Discriminator( pretrained_feature_network=args.pretrained_feature_network, - cond_embedding_dim=text_encoder_one.config.projection_dim, + cond_embedding_dim=text_encoder_one.config.hidden_size, ) # 8. Freeze teacher vae, text_encoders, and teacher_unet @@ -1716,14 +1716,12 @@ def compute_embeddings( # 1. Decode real and fake (generated) latents back to pixel space. # NOTE: the paper doesn't mention this explicitly AFAIK but I think this makes sense since the # pretrained feature network for the discriminator operates in pixel space rather than latent space. - unscaled_latents = (1 / vae.config.scaling_factor) * latents unscaled_student_x_0 = (1 / vae.config.scaling_factor) * student_x_0 - real_image = vae.decode(unscaled_latents).sample - student_gen_image = vae.decode(unscaled_student_x_0).sample + student_gen_image = vae.decode(unscaled_student_x_0.to(dtype=weight_dtype)).sample # 2. Get discriminator real/fake outputs on the real and fake (generated) images respectively. - disc_output_real = discriminator(real_image, prompt_embeds) - disc_output_fake = discriminator(student_gen_image.detach(), prompt_embeds) + disc_output_real = discriminator(pixel_values.float(), prompt_embeds) + disc_output_fake = discriminator(student_gen_image.detach().float(), prompt_embeds) # 3. Calculate the discriminator real adversarial loss terms. d_logits_real = disc_output_real.logits @@ -1775,7 +1773,7 @@ def compute_embeddings( ############################ # Calculate distillation loss in pixel space rather than latent space (see section 3.1) unscaled_teacher_x_0 = (1 / vae.config.scaling_factor) * teacher_x_0 - teacher_gen_image = vae.decode(unscaled_teacher_x_0).sample + teacher_gen_image = vae.decode(unscaled_teacher_x_0.to(dtype=weight_dtype)).sample per_instance_distillation_loss = F.mse_loss( student_gen_image.float(), teacher_gen_image.float(), reduction="none" ) From 9fa6133111c0c69cbffffdb0d951c687f2ea2e65 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 27 Dec 2023 13:41:50 -0800 Subject: [PATCH 17/53] make style --- examples/add/train_add_distill_sd_wds.py | 8 +++++--- examples/add/train_add_distill_sdxl_wds.py | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 3718e51343e8..6367d1305202 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -1182,9 +1182,11 @@ def main(args): sorted([int(timestep.strip()) for timestep in args.student_custom_timesteps.split(",")]), dtype=np.int64 ) elif args.student_timestep_schedule == "uniform": - student_timestep_schedule = np.linspace( - 0, noise_scheduler.config.num_train_timesteps - 1, args.student_distillation_steps - ).round().astype(np.int64) + student_timestep_schedule = ( + np.linspace(0, noise_scheduler.config.num_train_timesteps - 1, args.student_distillation_steps) + .round() + .astype(np.int64) + ) else: raise ValueError( f"Student timestep schedule {args.student_timestep_schedule} was not recognized and custom student" diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index b51bd0e88e9a..1d0fc4607bcc 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -1241,9 +1241,11 @@ def main(args): sorted([int(timestep.strip()) for timestep in args.student_custom_timesteps.split(",")]), dtype=np.int64 ) elif args.student_timestep_schedule == "uniform": - student_timestep_schedule = np.linspace( - 0, noise_scheduler.config.num_train_timesteps - 1, args.student_distillation_steps - ).round().astype(np.int64) + student_timestep_schedule = ( + np.linspace(0, noise_scheduler.config.num_train_timesteps - 1, args.student_distillation_steps) + .round() + .astype(np.int64) + ) else: raise ValueError( f"Student timestep schedule {args.student_timestep_schedule} was not recognized and custom student" From 080e8578edd9d983a2e6053ebe8d9930bfce5831 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 27 Dec 2023 19:20:16 -0800 Subject: [PATCH 18/53] Add ema_min_decay argument, which defaults to ema_decay, resulting in a fixed EMA decay rate. --- examples/add/train_add_distill_sd_wds.py | 14 ++++++++++++++ examples/add/train_add_distill_sdxl_wds.py | 14 ++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 6367d1305202..2e30245ae762 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -995,6 +995,16 @@ def parse_args(): required=False, help="The exponential moving average (EMA) rate or decay factor.", ) + parser.add_argument( + "--ema_min_decay", + type=float, + default=None, + help=( + "The minimum EMA decay rate, which the effective EMA decay rate (e.g. if warmup is used) will never go" + " below. If not set, the value set for `ema_decay` will be used, which results in a fixed EMA decay rate" + " equal to that value." + ), + ) # ----Mixed Precision---- parser.add_argument( "--mixed_precision", @@ -1225,12 +1235,16 @@ def main(args): # Make exponential moving average (EMA) version of the student unet weights, if using. if args.use_ema: + if args.ema_min_decay is None: + # Default to `args.ema_decay`, which results in a fixed EMA decay rate throughout distillation. + args.ema_min_decay = args.ema_decay ema_unet = UNet2DConditionModel.from_pretrained( args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision ) ema_unet = EMAModel( ema_unet.parameters(), decay=args.ema_decay, + min_decay=args.ema_min_decay, model_cls=UNet2DConditionModel, model_config=ema_unet.config, ) diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index 1d0fc4607bcc..fa4616dab95d 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -1039,6 +1039,16 @@ def parse_args(): required=False, help="The exponential moving average (EMA) rate or decay factor.", ) + parser.add_argument( + "--ema_min_decay", + type=float, + default=None, + help=( + "The minimum EMA decay rate, which the effective EMA decay rate (e.g. if warmup is used) will never go" + " below. If not set, the value set for `ema_decay` will be used, which results in a fixed EMA decay rate" + " equal to that value." + ), + ) # ----Mixed Precision---- parser.add_argument( "--mixed_precision", @@ -1302,12 +1312,16 @@ def main(args): # Make exponential moving average (EMA) version of the student unet weights, if using. if args.use_ema: + if args.ema_min_decay is None: + # Default to `args.ema_decay`, which results in a fixed EMA decay rate throughout distillation. + args.ema_min_decay = args.ema_decay ema_unet = UNet2DConditionModel.from_pretrained( args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision ) ema_unet = EMAModel( ema_unet.parameters(), decay=args.ema_decay, + min_decay=args.ema_min_decay, model_cls=UNet2DConditionModel, model_config=ema_unet.config, ) From a1913d575d15e2c91bd150244cf4211428ab1ab4 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 28 Dec 2023 21:43:03 -0800 Subject: [PATCH 19/53] Fix shape bug in DiscriminatorHead. --- examples/add/train_add_distill_sd_wds.py | 26 ++++++++++++++++++---- examples/add/train_add_distill_sdxl_wds.py | 26 ++++++++++++++++++---- 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 2e30245ae762..9b2b4027a7f9 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -319,7 +319,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # TODO: implement image conditioning (see Section 3.2 of paper) class DiscriminatorHead(torch.nn.Module): """ - Implements a StyleGAN-T-style discriminator head. + Implements a StyleGAN-T-style discriminator head. The discriminator head takes in a (possibly intermediate) 1D + sequence of tokens from the feature network, processes it, and combines it with conditioning information to output + per-token logits. """ def __init__( @@ -341,11 +343,26 @@ def __init__( self.cls = SpectralConv1d(channels, cond_map_dim, kernel_size=1, padding=0) def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + """ + Maps a 1D sequence of tokens from a feature network (e.g. ViT trained with DINO) and a conditioning embedding + to per-token logits. + + Args: + x (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`): + A sequence of 1D tokens (possibly intermediate) from a ViT feature neetwork. Note that the channels dim + should be the same as the feature network's embedding dim. + c (`torch.Tensor` of shape `(batch_size, cond_embedding_dim)`): + A conditioning embedding representing conditioning (e.g. text) information. + + Returns: + `torch.Tensor` of shape `(batch_size, sequence_length)`: batched 1D sequence of per-token logits. + """ hidden_states = self.input_block(x) hidden_states = self.resblock(hidden_states) out = self.cls(hidden_states) - c = self.conditioning_map(c).squeeze(-1) + # Project conditioning embeddings to cond_map_dim and unsqueeze in the sequence length dimension. + c = self.conditioning_map(c).unsqueeze(-1) out = (out * c).sum(1, keepdim=True) * (1 / np.sqrt(self.cond_map_dim)) return out @@ -499,7 +516,8 @@ def forward(self, x: torch.Tensor): Image with pixel values in [0, 1]. Returns: - `Dict[Any]`: dict of activations + `Dict[Any]`: dict of activations which are intermediate features from the feature network. The dict values + (feature embeddings) have shape `(batch_size, embed_dim, sequence_length)`. """ x = F.interpolate(x, self.img_resolution, mode="area") x = self.norm(x) @@ -526,7 +544,7 @@ class Discriminator(torch.nn.Module): def __init__( self, - pretrained_feature_network: str = "vit_small_patch16_224_dino", + pretrained_feature_network: str = "vit_small_patch16_224.dino", cond_embedding_dim: int = 512, patch_size: List[int] = [16, 16], hooks: List[int] = [2, 5, 8, 11], diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index fa4616dab95d..617ce4c3654e 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -337,7 +337,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # TODO: implement image conditioning (see Section 3.2 of paper) class DiscriminatorHead(torch.nn.Module): """ - Implements a StyleGAN-T-style discriminator head. + Implements a StyleGAN-T-style discriminator head. The discriminator head takes in a (possibly intermediate) 1D + sequence of tokens from the feature network, processes it, and combines it with conditioning information to output + per-token logits. """ def __init__( @@ -359,11 +361,26 @@ def __init__( self.cls = SpectralConv1d(channels, cond_map_dim, kernel_size=1, padding=0) def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + """ + Maps a 1D sequence of tokens from a feature network (e.g. ViT trained with DINO) and a conditioning embedding + to per-token logits. + + Args: + x (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`): + A sequence of 1D tokens (possibly intermediate) from a ViT feature neetwork. Note that the channels dim + should be the same as the feature network's embedding dim. + c (`torch.Tensor` of shape `(batch_size, cond_embedding_dim)`): + A conditioning embedding representing conditioning (e.g. text) information. + + Returns: + `torch.Tensor` of shape `(batch_size, sequence_length)`: batched 1D sequence of per-token logits. + """ hidden_states = self.input_block(x) hidden_states = self.resblock(hidden_states) out = self.cls(hidden_states) - c = self.conditioning_map(c).squeeze(-1) + # Project conditioning embeddings to cond_map_dim and unsqueeze in the sequence length dimension. + c = self.conditioning_map(c).unsqueeze(-1) out = (out * c).sum(1, keepdim=True) * (1 / np.sqrt(self.cond_map_dim)) return out @@ -517,7 +534,8 @@ def forward(self, x: torch.Tensor): Image with pixel values in [0, 1]. Returns: - `Dict[Any]`: dict of activations + `Dict[Any]`: dict of activations which are intermediate features from the feature network. The dict values + (feature embeddings) have shape `(batch_size, embed_dim, sequence_length)`. """ x = F.interpolate(x, self.img_resolution, mode="area") x = self.norm(x) @@ -544,7 +562,7 @@ class Discriminator(torch.nn.Module): def __init__( self, - pretrained_feature_network: str = "vit_small_patch16_224_dino", + pretrained_feature_network: str = "vit_small_patch16_224.dino", cond_embedding_dim: int = 512, patch_size: List[int] = [16, 16], hooks: List[int] = [2, 5, 8, 11], From f2e5f6d5ec7312983032c231600222a3f3d581b3 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Fri, 29 Dec 2023 15:07:41 -0800 Subject: [PATCH 20/53] Use CLIP pooled outputs instead of last hidden states (prompt_embeds) in discriminator. --- examples/add/train_add_distill_sd_wds.py | 44 ++++++++++++++++------ examples/add/train_add_distill_sdxl_wds.py | 23 +++++++---- 2 files changed, 47 insertions(+), 20 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 9b2b4027a7f9..83dee3e0ec62 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -337,10 +337,11 @@ def __init__( self.input_block = DiscHeadBlock(channels, kernel_size=1) self.resblock = ResidualBlock(DiscHeadBlock(channels, kernel_size=9)) + # Project each token embedding from channels dimensions to cond_map_dim dimensions. + self.cls = SpectralConv1d(channels, cond_map_dim, kernel_size=1, padding=0) - # Map the feature network token embeddings and conditioning embedding to a common dimension cond_map_dim. + # Also project the feature network token embeddings to dimension cond_map_dim. self.conditioning_map = torch.nn.Linear(self.cond_embedding_dim, cond_map_dim) - self.cls = SpectralConv1d(channels, cond_map_dim, kernel_size=1, padding=0) def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: """ @@ -363,6 +364,8 @@ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: # Project conditioning embeddings to cond_map_dim and unsqueeze in the sequence length dimension. c = self.conditioning_map(c).unsqueeze(-1) + + # Combine image features with conditioning embedding via a product. out = (out * c).sum(1, keepdim=True) * (1 / np.sqrt(self.cond_map_dim)) return out @@ -1120,7 +1123,24 @@ def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompt text_input_ids = text_inputs.input_ids prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device))[0] - return prompt_embeds + # Get pooled output from prompt_embeds for use in the discriminator. + # https://github.com/huggingface/transformers/blob/3cefac1d974db5e2825a0cb2b842883a628be7a0/src/transformers/models/clip/modeling_clip.py#L715-L734 + if text_encoder.config.eos_token_id == 2: + pooled_output = prompt_embeds[ + torch.arange(prompt_embeds.shape[0], device=prompt_embeds.device), + text_input_ids.to(dtype=torch.int, device=prompt_embeds.device).argmax(dim=-1), + ] + else: + # The config gets updated `eos_token_id` from transformers PR #24773 (so the use of exta new tokens is possible) + pooled_output = prompt_embeds[ + torch.arange(prompt_embeds.shape[0], device=prompt_embeds.device), + # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`) + (text_input_ids.to(dtype=torch.int, device=prompt_embeds.device) == text_encoder.config.eos_token_id) + .int() + .argmax(dim=-1), + ] + + return prompt_embeds, pooled_output def main(args): @@ -1229,7 +1249,6 @@ def main(args): ) # 3. Load text encoders from SD 1.X/2.X checkpoint. - # import correct text encoder classes text_encoder = CLIPTextModel.from_pretrained( args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision ) @@ -1408,11 +1427,11 @@ def load_model_hook(models, input_dir): ) # 13. Dataset creation and data processing - # Here, we compute not just the text embeddings but also the additional embeddings - # needed for the SD XL UNet to operate. + # Compute the text encoder last hidden states `prompt_embeds` for use in the teacher/student U-Nets and pooled + # output `text_embedding` for use in the discriminator. def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True): - prompt_embeds = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train) - return {"prompt_embeds": prompt_embeds} + prompt_embeds, text_embedding = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train) + return {"prompt_embeds": prompt_embeds, "text_embedding": text_embedding} dataset = SDText2ImageDataset( train_shards_path_or_url=args.train_shards_path_or_url, @@ -1577,8 +1596,9 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok student_noise = torch.randn_like(latents) noisy_student_input = noise_scheduler.add_noise(latents, student_noise, student_timesteps) - # 4. Prepare prompt embeds and unet_added_conditions + # 4. Prepare prompt embeds (for teacher/student U-Net) and text embedding (for discriminator). prompt_embeds = encoded_text.pop("prompt_embeds") + text_embedding = encoded_text.pop("text_embedding") # 5. Get the student model predicted original sample `student_x_0`. student_noise_pred = unet( @@ -1630,8 +1650,8 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok student_gen_image = vae.decode(unscaled_student_x_0.to(dtype=weight_dtype)).sample # 2. Get discriminator real/fake outputs on the real and fake (generated) images respectively. - disc_output_real = discriminator(pixel_values.float(), prompt_embeds) - disc_output_fake = discriminator(student_gen_image.detach().float(), prompt_embeds) + disc_output_real = discriminator(pixel_values.float(), text_embedding) + disc_output_fake = discriminator(student_gen_image.detach().float(), text_embedding) # 3. Calculate the discriminator real adversarial loss terms. d_logits_real = disc_output_real.logits @@ -1672,7 +1692,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok optimizer.zero_grad(set_to_none=True) # 1. Rerun the disc on generated image, but this time allow gradients to flow through the generator - disc_output_fake = discriminator(student_gen_image, prompt_embeds) + disc_output_fake = discriminator(student_gen_image, text_embedding) # 2. Calculate generator adversarial loss term g_logits_fake = disc_output_fake.logits diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index 617ce4c3654e..d31778bf3f1b 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -355,10 +355,11 @@ def __init__( self.input_block = DiscHeadBlock(channels, kernel_size=1) self.resblock = ResidualBlock(DiscHeadBlock(channels, kernel_size=9)) + # Project each token embedding from channels dimensions to cond_map_dim dimensions. + self.cls = SpectralConv1d(channels, cond_map_dim, kernel_size=1, padding=0) - # Map the feature network token embeddings and conditioning embedding to a common dimension cond_map_dim. + # Also project the feature network token embeddings to dimension cond_map_dim. self.conditioning_map = torch.nn.Linear(self.cond_embedding_dim, cond_map_dim) - self.cls = SpectralConv1d(channels, cond_map_dim, kernel_size=1, padding=0) def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: """ @@ -381,6 +382,8 @@ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: # Project conditioning embeddings to cond_map_dim and unsqueeze in the sequence length dimension. c = self.conditioning_map(c).unsqueeze(-1) + + # Combine image features with conditioning embedding via a product. out = (out * c).sum(1, keepdim=True) * (1 / np.sqrt(self.cond_map_dim)) return out @@ -1345,10 +1348,13 @@ def main(args): ) # 7. Initialize GAN discriminator. - # TODO: Confirm that using text_encoder_one here is correct + # Use text_encoder_two here since it already projects the CLIP embedding to a fixed length vector (e.g. it's + # already a ClipTextModelWithProjection) + # TODO: what if there's no text_encoder_two? I think we already assume text_encoder_two exists in Step 3 above so + # it might be fine? discriminator = Discriminator( pretrained_feature_network=args.pretrained_feature_network, - cond_embedding_dim=text_encoder_one.config.hidden_size, + cond_embedding_dim=text_encoder_two.config.projection_dim, ) # 8. Freeze teacher vae, text_encoders, and teacher_unet @@ -1696,8 +1702,9 @@ def compute_embeddings( student_noise = torch.randn_like(latents) noisy_student_input = noise_scheduler.add_noise(latents, student_noise, student_timesteps) - # 4. Prepare prompt embeds and unet_added_conditions + # 4. Prepare prompt embeds (for teacher/student U-Net) and text embedding (for discriminator). prompt_embeds = encoded_text.pop("prompt_embeds") + text_embedding = encoded_text["text_embeds"] # 5. Get the student model predicted original sample `student_x_0`. student_noise_pred = unet( @@ -1754,8 +1761,8 @@ def compute_embeddings( student_gen_image = vae.decode(unscaled_student_x_0.to(dtype=weight_dtype)).sample # 2. Get discriminator real/fake outputs on the real and fake (generated) images respectively. - disc_output_real = discriminator(pixel_values.float(), prompt_embeds) - disc_output_fake = discriminator(student_gen_image.detach().float(), prompt_embeds) + disc_output_real = discriminator(pixel_values.float(), text_embedding) + disc_output_fake = discriminator(student_gen_image.detach().float(), text_embedding) # 3. Calculate the discriminator real adversarial loss terms. d_logits_real = disc_output_real.logits @@ -1796,7 +1803,7 @@ def compute_embeddings( optimizer.zero_grad(set_to_none=True) # 1. Rerun the disc on generated image, but this time allow gradients to flow through the generator - disc_output_fake = discriminator(student_gen_image, prompt_embeds) + disc_output_fake = discriminator(student_gen_image, text_embedding) # 2. Calculate generator adversarial loss term g_logits_fake = disc_output_fake.logits From 97223e1e3521d50e913bb203e18f54b289b27618 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Fri, 29 Dec 2023 15:11:20 -0800 Subject: [PATCH 21/53] Fix bug in discriminator R1 penalty implementation. --- examples/add/train_add_distill_sd_wds.py | 2 +- examples/add/train_add_distill_sdxl_wds.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 83dee3e0ec62..7fc93e3f7428 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -1663,7 +1663,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok d_r1_regularizer = 0 for k, head in discriminator.heads.items(): head_grad_params = torch.autograd.grad( - outputs=d_adv_loss_real, inputs=head.params(), create_graph=True + outputs=d_adv_loss_real, inputs=head.parameters(), create_graph=True ) head_grad_norm = 0 for grad in head_grad_params: diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index d31778bf3f1b..710f7a25d958 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -1774,7 +1774,7 @@ def compute_embeddings( d_r1_regularizer = 0 for k, head in discriminator.heads.items(): head_grad_params = torch.autograd.grad( - outputs=d_adv_loss_real, inputs=head.params(), create_graph=True + outputs=d_adv_loss_real, inputs=head.parameters(), create_graph=True ) head_grad_norm = 0 for grad in head_grad_params: From e9a1db8ea0de68632439b7f56bf49b5f6523c305 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Fri, 29 Dec 2023 15:12:05 -0800 Subject: [PATCH 22/53] make style --- examples/add/train_add_distill_sd_wds.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 7fc93e3f7428..030391117d2a 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -1430,7 +1430,9 @@ def load_model_hook(models, input_dir): # Compute the text encoder last hidden states `prompt_embeds` for use in the teacher/student U-Nets and pooled # output `text_embedding` for use in the discriminator. def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True): - prompt_embeds, text_embedding = encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train) + prompt_embeds, text_embedding = encode_prompt( + prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train + ) return {"prompt_embeds": prompt_embeds, "text_embedding": text_embedding} dataset = SDText2ImageDataset( From 9f85686e879b5d1dbe59c27d09702564cfe58c68 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Fri, 29 Dec 2023 15:43:31 -0800 Subject: [PATCH 23/53] Fix bugs in SD-XL ADD script. --- examples/add/train_add_distill_sdxl_wds.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index 710f7a25d958..7fb4e4c1edde 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -1608,12 +1608,6 @@ def compute_embeddings( tracker_config = dict(vars(args)) accelerator.init_trackers(args.tracker_project_name, config=tracker_config) - # Create uncond embeds for classifier free guidance - uncond_prompt_embeds = torch.zeros(args.train_batch_size, MAX_SEQ_LENGTH, 2048).to(accelerator.device) - uncond_pooled_prompt_embeds = torch.zeros(args.train_batch_size, text_encoder_two.config.projection_dim).to( - accelerator.device - ) - # 16. Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps @@ -1736,6 +1730,8 @@ def compute_embeddings( added_cond_kwargs=encoded_text, ).sample + uncond_prompt_embeds = torch.zeros_like(prompt_embeds) + uncond_pooled_prompt_embeds = torch.zeros_like(encoded_text["text_embeds"]) uncond_added_conditions = copy.deepcopy(encoded_text) uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds teacher_uncond_noise_pred = teacher_unet( @@ -1758,7 +1754,11 @@ def compute_embeddings( # NOTE: the paper doesn't mention this explicitly AFAIK but I think this makes sense since the # pretrained feature network for the discriminator operates in pixel space rather than latent space. unscaled_student_x_0 = (1 / vae.config.scaling_factor) * student_x_0 - student_gen_image = vae.decode(unscaled_student_x_0.to(dtype=weight_dtype)).sample + if args.pretrained_vae_model_name_or_path is not None: + student_gen_image = vae.decode(unscaled_student_x_0.to(dtype=weight_dtype)).sample + else: + # VAE is in full precision due to possible NaN issues + student_gen_image = vae.decode(unscaled_student_x_0).sample # 2. Get discriminator real/fake outputs on the real and fake (generated) images respectively. disc_output_real = discriminator(pixel_values.float(), text_embedding) @@ -1814,7 +1814,12 @@ def compute_embeddings( ############################ # Calculate distillation loss in pixel space rather than latent space (see section 3.1) unscaled_teacher_x_0 = (1 / vae.config.scaling_factor) * teacher_x_0 - teacher_gen_image = vae.decode(unscaled_teacher_x_0.to(dtype=weight_dtype)).sample + if args.pretrained_vae_model_name_or_path is not None: + teacher_gen_image = vae.decode(unscaled_teacher_x_0.to(dtype=weight_dtype)).sample + else: + # VAE is in full precision due to possible NaN issues + teacher_gen_image = vae.decode(unscaled_teacher_x_0).sample + per_instance_distillation_loss = F.mse_loss( student_gen_image.float(), teacher_gen_image.float(), reduction="none" ) From 9cda52ef0f2f181dbf663a253533eea408a91ad7 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 30 Dec 2023 11:40:34 -0800 Subject: [PATCH 24/53] Add option to use CLIPTextModelWithProjection model in SD ADD script. --- examples/add/train_add_distill_sd_wds.py | 65 +++++++++++++++--------- 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 030391117d2a..9635565872eb 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -46,7 +46,7 @@ from torch.utils.data import default_collate from torchvision import transforms from tqdm.auto import tqdm -from transformers import AutoTokenizer, CLIPTextModel +from transformers import AutoTokenizer, CLIPTextModel, CLIPTextModelWithProjection from webdataset.tariterators import ( base_plus_ext, tar_file_expander, @@ -932,6 +932,24 @@ def parse_args(): default=0, help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", ) + parser.add_argument( + "--use_pretrained_projection", + action="store_true", + help=( + "Whether to use a text encoder which projects the CLIP text embedding to a fixed length vector (that is, a" + " `CLIPTextModelWithProjection` rather than the `CLIPTextModel` usually used by SD 1.X/2.X.). If set, the" + " text encoder will be loaded from the model id or path given in `text_encoder_with_proj`." + ), + ) + parser.add_argument( + "--text_encoder_with_proj", + type=str, + default="openai/clip-vit-large-patch14", + help=( + "The text encoder with projection that will be used if `use_pretrained_projection` is set. Note that the" + " default value of `openai/clip-vit-large-patch14` is the CLIP model used by Stable Diffusion v1.5." + ), + ) # ----Adversarial Diffusion Distillation (ADD) Specific Arguments---- parser.add_argument( "--pretrained_feature_network", @@ -1101,7 +1119,7 @@ def parse_args(): # Adapted from pipelines.StableDiffusionPipeline.encode_prompt -def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True): +def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, has_projection, is_train=True): captions = [] for caption in prompt_batch: if random.random() < proportion_empty_prompts: @@ -1121,24 +1139,14 @@ def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompt return_tensors="pt", ) text_input_ids = text_inputs.input_ids - prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device))[0] - - # Get pooled output from prompt_embeds for use in the discriminator. - # https://github.com/huggingface/transformers/blob/3cefac1d974db5e2825a0cb2b842883a628be7a0/src/transformers/models/clip/modeling_clip.py#L715-L734 - if text_encoder.config.eos_token_id == 2: - pooled_output = prompt_embeds[ - torch.arange(prompt_embeds.shape[0], device=prompt_embeds.device), - text_input_ids.to(dtype=torch.int, device=prompt_embeds.device).argmax(dim=-1), - ] + text_model_output = text_encoder(text_input_ids.to(text_encoder.device)) + # Get last hidden states for use in conditioning the student and teacher U-Nets + prompt_embeds = text_model_output.last_hidden_state + # Get text embedding (if model has projection) or pooled output for use in conditioning the discriminator + if has_projection: + pooled_output = text_model_output.text_embeds else: - # The config gets updated `eos_token_id` from transformers PR #24773 (so the use of exta new tokens is possible) - pooled_output = prompt_embeds[ - torch.arange(prompt_embeds.shape[0], device=prompt_embeds.device), - # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`) - (text_input_ids.to(dtype=torch.int, device=prompt_embeds.device) == text_encoder.config.eos_token_id) - .int() - .argmax(dim=-1), - ] + pooled_output = text_model_output.pooler_output return prompt_embeds, pooled_output @@ -1249,9 +1257,12 @@ def main(args): ) # 3. Load text encoders from SD 1.X/2.X checkpoint. - text_encoder = CLIPTextModel.from_pretrained( - args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision - ) + if args.use_pretrained_projection: + text_encoder = CLIPTextModelWithProjection.from_pretrained(args.text_encoder_with_proj) + else: + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision + ) # 4. Load VAE from SD 1.X/2.X checkpoint vae = AutoencoderKL.from_pretrained( @@ -1287,9 +1298,10 @@ def main(args): ) # 7. Initialize GAN discriminator. + conditioning_dim = text_encoder.config.projection_dim if args.use_pretrained_projection else text_encoder.config.hidden_size discriminator = Discriminator( pretrained_feature_network=args.pretrained_feature_network, - cond_embedding_dim=text_encoder.config.hidden_size, + cond_embedding_dim=conditioning_dim, ) # 8. Freeze teacher vae, text_encoder, and teacher_unet @@ -1429,9 +1441,11 @@ def load_model_hook(models, input_dir): # 13. Dataset creation and data processing # Compute the text encoder last hidden states `prompt_embeds` for use in the teacher/student U-Nets and pooled # output `text_embedding` for use in the discriminator. - def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, is_train=True): + def compute_embeddings( + prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, has_projection, is_train=True + ): prompt_embeds, text_embedding = encode_prompt( - prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train + prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, has_projection, is_train ) return {"prompt_embeds": prompt_embeds, "text_embedding": text_embedding} @@ -1453,6 +1467,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok proportion_empty_prompts=0, text_encoder=text_encoder, tokenizer=tokenizer, + has_projection=args.use_pretrained_projection, ) # 14. Create learning rate scheduler for generator and discriminator From 7d7f98b8f86c9fe0d43782e656873ce6c0cf29e5 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 30 Dec 2023 11:45:58 -0800 Subject: [PATCH 25/53] make style --- examples/add/train_add_distill_sd_wds.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 9635565872eb..9f6dd0ab3803 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -1298,7 +1298,9 @@ def main(args): ) # 7. Initialize GAN discriminator. - conditioning_dim = text_encoder.config.projection_dim if args.use_pretrained_projection else text_encoder.config.hidden_size + conditioning_dim = ( + text_encoder.config.projection_dim if args.use_pretrained_projection else text_encoder.config.hidden_size + ) discriminator = Discriminator( pretrained_feature_network=args.pretrained_feature_network, cond_embedding_dim=conditioning_dim, From 8675850a530f9dff22479fbc22f10c52a8b5494a Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 30 Dec 2023 12:22:32 -0800 Subject: [PATCH 26/53] Allow Discriminator to optionally take in image conditioning information. --- examples/add/train_add_distill_sd_wds.py | 82 +++++++++++++++------- examples/add/train_add_distill_sdxl_wds.py | 70 ++++++++++++------ 2 files changed, 104 insertions(+), 48 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 9f6dd0ab3803..d108f3444424 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -25,7 +25,7 @@ import shutil import types from pathlib import Path -from typing import Callable, List, Union +from typing import Callable, List, Optional, Union import accelerate import numpy as np @@ -316,7 +316,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Based on DiscHead in the official StyleGAN-T implementation # https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L78 -# TODO: implement image conditioning (see Section 3.2 of paper) class DiscriminatorHead(torch.nn.Module): """ Implements a StyleGAN-T-style discriminator head. The discriminator head takes in a (possibly intermediate) 1D @@ -327,12 +326,14 @@ class DiscriminatorHead(torch.nn.Module): def __init__( self, channels: int, - cond_embedding_dim: int, + c_text_embedding_dim: int, + c_img_embedding_dim: Optional[int] = None, cond_map_dim: int = 64, ): super().__init__() self.channels = channels - self.cond_embedding_dim = cond_embedding_dim + self.c_text_embedding_dim = c_text_embedding_dim + self.c_img_embedding_dim = c_img_embedding_dim self.cond_map_dim = cond_map_dim self.input_block = DiscHeadBlock(channels, kernel_size=1) @@ -340,10 +341,12 @@ def __init__( # Project each token embedding from channels dimensions to cond_map_dim dimensions. self.cls = SpectralConv1d(channels, cond_map_dim, kernel_size=1, padding=0) - # Also project the feature network token embeddings to dimension cond_map_dim. - self.conditioning_map = torch.nn.Linear(self.cond_embedding_dim, cond_map_dim) + # Also project the text conditioning embeddings to dimension cond_map_dim. + self.c_text_map = torch.nn.Linear(self.c_text_embedding_dim, cond_map_dim) + if self.c_img_embedding_dim is not None: + self.c_img_map = torch.nn.Linear(self.c_img_embedding_dim, cond_map_dim) - def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, c_text: torch.Tensor, c_img: Optional[torch.Tensor] = None) -> torch.Tensor: """ Maps a 1D sequence of tokens from a feature network (e.g. ViT trained with DINO) and a conditioning embedding to per-token logits. @@ -352,8 +355,10 @@ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: x (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`): A sequence of 1D tokens (possibly intermediate) from a ViT feature neetwork. Note that the channels dim should be the same as the feature network's embedding dim. - c (`torch.Tensor` of shape `(batch_size, cond_embedding_dim)`): - A conditioning embedding representing conditioning (e.g. text) information. + c_text (`torch.Tensor` of shape `(batch_size, c_text_embedding_dim)`): + A conditioning embedding representing text conditioning information. + c_img (`torch.Tensor` of shape `(batch_size, c_img_embedding_dim)`): + A conditioning embedding representing image conditioning information. Returns: `torch.Tensor` of shape `(batch_size, sequence_length)`: batched 1D sequence of per-token logits. @@ -362,11 +367,15 @@ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: hidden_states = self.resblock(hidden_states) out = self.cls(hidden_states) - # Project conditioning embeddings to cond_map_dim and unsqueeze in the sequence length dimension. - c = self.conditioning_map(c).unsqueeze(-1) + # Project text conditioning embedding to cond_map_dim and unsqueeze in the sequence length dimension. + c_text = self.c_text_map(c_text).unsqueeze(-1) - # Combine image features with conditioning embedding via a product. - out = (out * c).sum(1, keepdim=True) * (1 / np.sqrt(self.cond_map_dim)) + # Combine image features with text conditioning embedding via a product. + out = (out * c_text).sum(1, keepdim=True) * (1 / np.sqrt(self.cond_map_dim)) + + if self.c_img_embedding_dim is not None: + c_img = self.c_img_map(c_img).unsqueeze(-1) + out = (out * c_img).sum(1, keepdim=True) * (1 / np.sqrt(self.cond_map_dim)) return out @@ -539,7 +548,6 @@ class DiscriminatorOutput(BaseOutput): # Based on ProjectedDiscriminator from the official StyleGAN-T code # https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L130 -# TODO: implement image conditioning (see Section 3.2 of paper) class Discriminator(torch.nn.Module): """ StyleGAN-T-style discriminator for adversarial diffusion distillation (ADD). @@ -548,13 +556,17 @@ class Discriminator(torch.nn.Module): def __init__( self, pretrained_feature_network: str = "vit_small_patch16_224.dino", - cond_embedding_dim: int = 512, + c_text_embedding_dim: int = 768, + c_img_embedding_dim: Optional[int] = None, + cond_map_dim: int = 64, patch_size: List[int] = [16, 16], hooks: List[int] = [2, 5, 8, 11], start_index: int = 1, ): super().__init__() - self.cond_embedding_dim = cond_embedding_dim + self.c_text_embedding_dim = c_text_embedding_dim + self.c_img_embedding_dim = c_img_embedding_dim + self.cond_map_dim = cond_map_dim # Frozen feature network, e.g. DINO self.feature_network = FeatureNetwork( @@ -567,7 +579,12 @@ def __init__( # Trainable discriminator heads heads = [] for i in range(self.feature_network.num_hooks): - heads.append([str(i), DiscriminatorHead(self.feature_network.embed_dim, cond_embedding_dim)]) + heads.append([ + str(i), + DiscriminatorHead( + self.feature_network.embed_dim, c_text_embedding_dim, c_img_embedding_dim, cond_map_dim + ) + ]) self.heads = torch.nn.ModuleDict(heads) def train(self, mode: bool = True): @@ -581,8 +598,9 @@ def eval(self): def forward( self, x: torch.Tensor, - c: torch.Tensor, - transform_positive=True, + c_text: torch.Tensor, + c_img: Optional[torch.Tensor] = None, + transform_positive: bool = True, return_dict: bool = True, ): # TODO: do we need the augmentations from the original StyleGAN-T code? @@ -596,7 +614,7 @@ def forward( # Apply discriminator heads. logits = [] for k, head in self.heads.items(): - logits.append(head(features[k], c).view(x.size(0), -1)) + logits.append(head(features[k], c_text, c_img).view(x.size(0), -1)) logits = torch.cat(logits, dim=1) if not return_dict: @@ -932,6 +950,16 @@ def parse_args(): default=0, help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", ) + # ----Adversarial Diffusion Distillation (ADD) Specific Arguments---- + parser.add_argument( + "--pretrained_feature_network", + type=str, + default="vit_small_patch16_224.dino", + help=( + "The pretrained feature network used in the discriminator, typically a vision transformer (ViT) trained" + " the DINO objective. The given identifier should be compatible with `timm.create_model`." + ), + ) parser.add_argument( "--use_pretrained_projection", action="store_true", @@ -950,14 +978,13 @@ def parse_args(): " default value of `openai/clip-vit-large-patch14` is the CLIP model used by Stable Diffusion v1.5." ), ) - # ----Adversarial Diffusion Distillation (ADD) Specific Arguments---- parser.add_argument( - "--pretrained_feature_network", - type=str, - default="vit_small_patch16_224.dino", + "--cond_map_dim", + type=int, + default=64, help=( - "The pretrained feature network used in the discriminator, typically a vision transformer (ViT) trained" - " the DINO objective. The given identifier should be compatible with `timm.create_model`." + "The common dimension to which the discriminator feature network features and conditioning embeddings will" + " be projected to in the discriminator heads." ), ) parser.add_argument( @@ -1303,7 +1330,8 @@ def main(args): ) discriminator = Discriminator( pretrained_feature_network=args.pretrained_feature_network, - cond_embedding_dim=conditioning_dim, + c_text_embedding_dim=conditioning_dim, + cond_map_dim=args.cond_map_dim, ) # 8. Freeze teacher vae, text_encoder, and teacher_unet diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index 7fb4e4c1edde..c2fc25818ac0 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -26,7 +26,7 @@ import shutil import types from pathlib import Path -from typing import Callable, List, Union +from typing import Callable, List, Optional, Union import accelerate import numpy as np @@ -334,7 +334,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Based on DiscHead in the official StyleGAN-T implementation # https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L78 -# TODO: implement image conditioning (see Section 3.2 of paper) class DiscriminatorHead(torch.nn.Module): """ Implements a StyleGAN-T-style discriminator head. The discriminator head takes in a (possibly intermediate) 1D @@ -345,12 +344,14 @@ class DiscriminatorHead(torch.nn.Module): def __init__( self, channels: int, - cond_embedding_dim: int, + c_text_embedding_dim: int, + c_img_embedding_dim: Optional[int] = None, cond_map_dim: int = 64, ): super().__init__() self.channels = channels - self.cond_embedding_dim = cond_embedding_dim + self.c_text_embedding_dim = c_text_embedding_dim + self.c_img_embedding_dim = c_img_embedding_dim self.cond_map_dim = cond_map_dim self.input_block = DiscHeadBlock(channels, kernel_size=1) @@ -358,10 +359,12 @@ def __init__( # Project each token embedding from channels dimensions to cond_map_dim dimensions. self.cls = SpectralConv1d(channels, cond_map_dim, kernel_size=1, padding=0) - # Also project the feature network token embeddings to dimension cond_map_dim. - self.conditioning_map = torch.nn.Linear(self.cond_embedding_dim, cond_map_dim) + # Also project the text conditioning embeddings to dimension cond_map_dim. + self.c_text_map = torch.nn.Linear(self.c_text_embedding_dim, cond_map_dim) + if self.c_img_embedding_dim is not None: + self.c_img_map = torch.nn.Linear(self.c_img_embedding_dim, cond_map_dim) - def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, c_text: torch.Tensor, c_img: Optional[torch.Tensor] = None) -> torch.Tensor: """ Maps a 1D sequence of tokens from a feature network (e.g. ViT trained with DINO) and a conditioning embedding to per-token logits. @@ -370,8 +373,10 @@ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: x (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`): A sequence of 1D tokens (possibly intermediate) from a ViT feature neetwork. Note that the channels dim should be the same as the feature network's embedding dim. - c (`torch.Tensor` of shape `(batch_size, cond_embedding_dim)`): - A conditioning embedding representing conditioning (e.g. text) information. + c_text (`torch.Tensor` of shape `(batch_size, c_text_embedding_dim)`): + A conditioning embedding representing text conditioning information. + c_img (`torch.Tensor` of shape `(batch_size, c_img_embedding_dim)`): + A conditioning embedding representing image conditioning information. Returns: `torch.Tensor` of shape `(batch_size, sequence_length)`: batched 1D sequence of per-token logits. @@ -380,11 +385,15 @@ def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: hidden_states = self.resblock(hidden_states) out = self.cls(hidden_states) - # Project conditioning embeddings to cond_map_dim and unsqueeze in the sequence length dimension. - c = self.conditioning_map(c).unsqueeze(-1) + # Project text conditioning embedding to cond_map_dim and unsqueeze in the sequence length dimension. + c_text = self.c_text_map(c_text).unsqueeze(-1) - # Combine image features with conditioning embedding via a product. - out = (out * c).sum(1, keepdim=True) * (1 / np.sqrt(self.cond_map_dim)) + # Combine image features with text conditioning embedding via a product. + out = (out * c_text).sum(1, keepdim=True) * (1 / np.sqrt(self.cond_map_dim)) + + if self.c_img_embedding_dim is not None: + c_img = self.c_img_map(c_img).unsqueeze(-1) + out = (out * c_img).sum(1, keepdim=True) * (1 / np.sqrt(self.cond_map_dim)) return out @@ -557,7 +566,6 @@ class DiscriminatorOutput(BaseOutput): # Based on ProjectedDiscriminator from the official StyleGAN-T code # https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L130 -# TODO: implement image conditioning (see Section 3.2 of paper) class Discriminator(torch.nn.Module): """ StyleGAN-T-style discriminator for adversarial diffusion distillation (ADD). @@ -566,13 +574,17 @@ class Discriminator(torch.nn.Module): def __init__( self, pretrained_feature_network: str = "vit_small_patch16_224.dino", - cond_embedding_dim: int = 512, + c_text_embedding_dim: int = 768, + c_img_embedding_dim: Optional[int] = None, + cond_map_dim: int = 64, patch_size: List[int] = [16, 16], hooks: List[int] = [2, 5, 8, 11], start_index: int = 1, ): super().__init__() - self.cond_embedding_dim = cond_embedding_dim + self.c_text_embedding_dim = c_text_embedding_dim + self.c_img_embedding_dim = c_img_embedding_dim + self.cond_map_dim = cond_map_dim # Frozen feature network, e.g. DINO self.feature_network = FeatureNetwork( @@ -585,7 +597,12 @@ def __init__( # Trainable discriminator heads heads = [] for i in range(self.feature_network.num_hooks): - heads.append([str(i), DiscriminatorHead(self.feature_network.embed_dim, cond_embedding_dim)]) + heads.append([ + str(i), + DiscriminatorHead( + self.feature_network.embed_dim, c_text_embedding_dim, c_img_embedding_dim, cond_map_dim + ) + ]) self.heads = torch.nn.ModuleDict(heads) def train(self, mode: bool = True): @@ -599,8 +616,9 @@ def eval(self): def forward( self, x: torch.Tensor, - c: torch.Tensor, - transform_positive=True, + c_text: torch.Tensor, + c_img: Optional[torch.Tensor] = None, + transform_positive: bool = True, return_dict: bool = True, ): # TODO: do we need the augmentations from the original StyleGAN-T code? @@ -614,7 +632,7 @@ def forward( # Apply discriminator heads. logits = [] for k, head in self.heads.items(): - logits.append(head(features[k], c).view(x.size(0), -1)) + logits.append(head(features[k], c_text, c_img).view(x.size(0), -1)) logits = torch.cat(logits, dim=1) if not return_dict: @@ -986,6 +1004,15 @@ def parse_args(): " the DINO objective. The given identifier should be compatible with `timm.create_model`." ), ) + parser.add_argument( + "--cond_map_dim", + type=int, + default=64, + help=( + "The common dimension to which the discriminator feature network features and conditioning embeddings will" + " be projected to in the discriminator heads." + ), + ) parser.add_argument( "--weight_schedule", type=str, @@ -1354,7 +1381,8 @@ def main(args): # it might be fine? discriminator = Discriminator( pretrained_feature_network=args.pretrained_feature_network, - cond_embedding_dim=text_encoder_two.config.projection_dim, + c_text_embedding_dim=text_encoder_two.config.projection_dim, + cond_map_dim=args.cond_map_dim, ) # 8. Freeze teacher vae, text_encoders, and teacher_unet From fd39891c047a565a17b4a6a54963c1a746804dc8 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 30 Dec 2023 12:23:15 -0800 Subject: [PATCH 27/53] make style --- examples/add/train_add_distill_sd_wds.py | 14 ++++++++------ examples/add/train_add_distill_sdxl_wds.py | 14 ++++++++------ 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index d108f3444424..e9a7bff1c80c 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -579,12 +579,14 @@ def __init__( # Trainable discriminator heads heads = [] for i in range(self.feature_network.num_hooks): - heads.append([ - str(i), - DiscriminatorHead( - self.feature_network.embed_dim, c_text_embedding_dim, c_img_embedding_dim, cond_map_dim - ) - ]) + heads.append( + [ + str(i), + DiscriminatorHead( + self.feature_network.embed_dim, c_text_embedding_dim, c_img_embedding_dim, cond_map_dim + ), + ] + ) self.heads = torch.nn.ModuleDict(heads) def train(self, mode: bool = True): diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index c2fc25818ac0..b42e499ca727 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -597,12 +597,14 @@ def __init__( # Trainable discriminator heads heads = [] for i in range(self.feature_network.num_hooks): - heads.append([ - str(i), - DiscriminatorHead( - self.feature_network.embed_dim, c_text_embedding_dim, c_img_embedding_dim, cond_map_dim - ) - ]) + heads.append( + [ + str(i), + DiscriminatorHead( + self.feature_network.embed_dim, c_text_embedding_dim, c_img_embedding_dim, cond_map_dim + ), + ] + ) self.heads = torch.nn.ModuleDict(heads) def train(self, mode: bool = True): From 35e45e64a23f8b87ee0a0d3471421592d594c2a8 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 30 Dec 2023 14:19:40 -0800 Subject: [PATCH 28/53] Change default feature network to DINOv2 ViT-S (see section 4.1/Table 1 of paper) and make feature network patch size configurable. --- examples/add/train_add_distill_sd_wds.py | 17 ++++++++++++----- examples/add/train_add_distill_sdxl_wds.py | 17 ++++++++++++----- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index e9a7bff1c80c..b672e79a1ff8 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -478,8 +478,8 @@ class FeatureNetwork(torch.nn.Module): def __init__( self, - pretrained_feature_network: str = "vit_small_patch16_224.dino", - patch_size: List[int] = [16, 16], + pretrained_feature_network: str = "vit_small_patch14_dinov2.lvd142m", + patch_size: List[int] = [14, 14], hooks: List[int] = [2, 5, 8, 11], start_index: int = 1, ): @@ -555,11 +555,11 @@ class Discriminator(torch.nn.Module): def __init__( self, - pretrained_feature_network: str = "vit_small_patch16_224.dino", + pretrained_feature_network: str = "vit_small_patch14_dinov2.lvd142m", c_text_embedding_dim: int = 768, c_img_embedding_dim: Optional[int] = None, cond_map_dim: int = 64, - patch_size: List[int] = [16, 16], + patch_size: List[int] = [14, 14], hooks: List[int] = [2, 5, 8, 11], start_index: int = 1, ): @@ -956,12 +956,18 @@ def parse_args(): parser.add_argument( "--pretrained_feature_network", type=str, - default="vit_small_patch16_224.dino", + default="vit_small_patch14_dinov2.lvd142m", help=( "The pretrained feature network used in the discriminator, typically a vision transformer (ViT) trained" " the DINO objective. The given identifier should be compatible with `timm.create_model`." ), ) + parser.add_argument( + "--feature_network_patch_size", + type=int, + default=14, + help="The patch size of the `pretrained_feature_network`." + ) parser.add_argument( "--use_pretrained_projection", action="store_true", @@ -1334,6 +1340,7 @@ def main(args): pretrained_feature_network=args.pretrained_feature_network, c_text_embedding_dim=conditioning_dim, cond_map_dim=args.cond_map_dim, + patch_size=[args.feature_network_patch_size, args.feature_network_patch_size], ) # 8. Freeze teacher vae, text_encoder, and teacher_unet diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index b42e499ca727..69ff196bcf3f 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -496,8 +496,8 @@ class FeatureNetwork(torch.nn.Module): def __init__( self, - pretrained_feature_network: str = "vit_small_patch16_224.dino", - patch_size: List[int] = [16, 16], + pretrained_feature_network: str = "vit_small_patch14_dinov2.lvd142m", + patch_size: List[int] = [14, 14], hooks: List[int] = [2, 5, 8, 11], start_index: int = 1, ): @@ -573,11 +573,11 @@ class Discriminator(torch.nn.Module): def __init__( self, - pretrained_feature_network: str = "vit_small_patch16_224.dino", + pretrained_feature_network: str = "vit_small_patch14_dinov2.lvd142m", c_text_embedding_dim: int = 768, c_img_embedding_dim: Optional[int] = None, cond_map_dim: int = 64, - patch_size: List[int] = [16, 16], + patch_size: List[int] = [14, 14], hooks: List[int] = [2, 5, 8, 11], start_index: int = 1, ): @@ -1000,12 +1000,18 @@ def parse_args(): parser.add_argument( "--pretrained_feature_network", type=str, - default="vit_small_patch16_224.dino", + default="vit_small_patch14_dinov2.lvd142m", help=( "The pretrained feature network used in the discriminator, typically a vision transformer (ViT) trained" " the DINO objective. The given identifier should be compatible with `timm.create_model`." ), ) + parser.add_argument( + "--feature_network_patch_size", + type=int, + default=14, + help="The patch size of the `pretrained_feature_network`." + ) parser.add_argument( "--cond_map_dim", type=int, @@ -1385,6 +1391,7 @@ def main(args): pretrained_feature_network=args.pretrained_feature_network, c_text_embedding_dim=text_encoder_two.config.projection_dim, cond_map_dim=args.cond_map_dim, + patch_size=[args.feature_network_patch_size, args.feature_network_patch_size], ) # 8. Freeze teacher vae, text_encoders, and teacher_unet From d7459b07c432f4767498dc7f4dc04f0ff2a2361e Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 30 Dec 2023 14:20:12 -0800 Subject: [PATCH 29/53] make style --- examples/add/train_add_distill_sd_wds.py | 2 +- examples/add/train_add_distill_sdxl_wds.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index b672e79a1ff8..ba2be711890b 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -966,7 +966,7 @@ def parse_args(): "--feature_network_patch_size", type=int, default=14, - help="The patch size of the `pretrained_feature_network`." + help="The patch size of the `pretrained_feature_network`.", ) parser.add_argument( "--use_pretrained_projection", diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index 69ff196bcf3f..db3e6fd99e91 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -1010,7 +1010,7 @@ def parse_args(): "--feature_network_patch_size", type=int, default=14, - help="The patch size of the `pretrained_feature_network`." + help="The patch size of the `pretrained_feature_network`.", ) parser.add_argument( "--cond_map_dim", From 266474e8c58fcde585971881787e0461370cc374 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 30 Dec 2023 18:37:39 -0800 Subject: [PATCH 30/53] Add option to use image conditioning in discriminator. --- examples/add/train_add_distill_sd_wds.py | 155 +++++++++++++++++---- examples/add/train_add_distill_sdxl_wds.py | 146 ++++++++++++++++--- 2 files changed, 255 insertions(+), 46 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index ba2be711890b..b28b77085a21 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and import argparse +import copy import functools import gc import itertools @@ -121,6 +122,24 @@ def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue): return samples +def resolve_interpolation_mode(interpolation_type): + if interpolation_type == "bilinear": + interpolation_mode = TF.InterpolationMode.BILINEAR + elif interpolation_type == "bicubic": + interpolation_mode = TF.InterpolationMode.BICUBIC + elif interpolation_type == "nearest": + interpolation_mode = TF.InterpolationMode.NEAREST + elif interpolation_type == "lanczos": + interpolation_mode = TF.InterpolationMode.LANCZOS + else: + raise ValueError( + f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" + f" modes are `bilinear`, `bicubic`, `lanczos`, and `nearest`." + ) + + return interpolation_mode + + class WebdatasetFilter: def __init__(self, min_size=1024, max_pwatermark=0.5): self.min_size = min_size @@ -154,30 +173,26 @@ def __init__( shuffle_buffer_size: int = 1000, pin_memory: bool = False, persistent_workers: bool = False, + use_image_conditioning: bool = False, + cond_resolution: Optional[int] = None, + cond_interpolation_type: Optional[str] = None, ): if not isinstance(train_shards_path_or_url, str): train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url] # flatten list using itertools train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url)) - if interpolation_type == "bilinear": - self.interpolation_mode = TF.InterpolationMode.BILINEAR - elif interpolation_type == "bicubic": - self.interpolation_mode = TF.InterpolationMode.BICUBIC - elif interpolation_type == "nearest": - self.interpolation_mode = TF.InterpolationMode.NEAREST - elif interpolation_type == "lanczos": - self.interpolation_mode = TF.InterpolationMode.LANCZOS - else: - raise ValueError( - f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" - f" modes are `bilinear`, `bicubic`, `lanczos`, and `nearest`." - ) + interpolation_mode = resolve_interpolation_mode(interpolation_type) + if use_image_conditioning: + cond_interpolation_mode = resolve_interpolation_mode(cond_interpolation_type) def transform(example): # resize image image = example["image"] - image = TF.resize(image, resolution, interpolation=self.interpolation_mode) + if use_image_conditioning: + cond_image = copy.deepcopy(image) + + image = TF.resize(image, resolution, interpolation=interpolation_mode) # get crop coordinates and crop image c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) @@ -186,6 +201,14 @@ def transform(example): image = TF.normalize(image, [0.5], [0.5]) example["image"] = image + + if use_image_conditioning: + # Prepare a separate image for image conditioning since the preprocessing pipelines are different. + cond_image = TF.resize(cond_image, cond_resolution, interpolation=cond_interpolation_mode) + cond_image = TF.center_crop(cond_image, output_size=(cond_resolution, cond_resolution)) + cond_image = TF.normalize(cond_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + example["cond_image"] = cond_image + return example processing_pipeline = [ @@ -193,9 +216,13 @@ def transform(example): wds.rename(image="jpg;png;jpeg;webp", text="text;txt;caption", handler=wds.warn_and_continue), wds.map(filter_keys({"image", "text"})), wds.map(transform), - wds.to_tuple("image", "text"), ] + if use_image_conditioning: + processing_pipeline.append(wds.to_tuple("image", "text", "cond_image")) + else: + processing_pipeline.append(wds.to_tuple("image", "text")) + # Create train dataset and loader pipeline = [ wds.ResampledShards(train_shards_path_or_url), @@ -968,6 +995,15 @@ def parse_args(): default=14, help="The patch size of the `pretrained_feature_network`.", ) + parser.add_argument( + "--cond_map_dim", + type=int, + default=64, + help=( + "The common dimension to which the discriminator feature network features and conditioning embeddings will" + " be projected to in the discriminator heads." + ), + ) parser.add_argument( "--use_pretrained_projection", action="store_true", @@ -987,12 +1023,36 @@ def parse_args(): ), ) parser.add_argument( - "--cond_map_dim", + "--use_image_conditioning", + action="store_true", + help=( + "Whether to also use an image encoder to calculate image conditioning embeddings for the discriminator. If" + " set, the model at the timm model id given in `image_encoder_with_proj` will be used." + ), + ) + parser.add_argument( + "--pretrained_image_encoder", + type=str, + default="vit_large_patch14_dinov2.lvd142m", + help=( + "An optional image encoder to add image conditioning information to the discriminator. Is used if" + " `use_image_conditioning` is set. The model id should be loadable by `timm.create_model`. Note that ADD" + " uses a DINOv2 ViT-L encoder (see section 4 of the paper)." + ), + ) + parser.add_argument( + "--cond_resolution", type=int, - default=64, + default=518, + help="Resolution to resize the original images to for image conditioning." + ) + parser.add_argument( + "--cond_interpolation_type", + type=str, + default="bicubic", help=( - "The common dimension to which the discriminator feature network features and conditioning embeddings will" - " be projected to in the discriminator heads." + "The interpolation function used when resizing the image for conditioning. Choose between `bilinear`," + " `bicubic`, `lanczos`, and `nearest`." ), ) parser.add_argument( @@ -1186,6 +1246,12 @@ def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompt return prompt_embeds, pooled_output +def encode_images(image_batch, image_encoder): + # image_encoder pre-processing is done in SDText2ImageDataset + image_embeds = image_encoder(image_batch) + return image_embeds + + def main(args): logging_dir = Path(args.output_dir, args.logging_dir) @@ -1299,6 +1365,11 @@ def main(args): args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision ) + # Optionally load a image encoder model for image conditioning of the discriminator. + if args.use_image_conditioning: + # Set num_classes=0 so that we get image embeddings from image_encoder forward pass + image_encoder = timm.create_model(args.pretrained_image_encoder, pretrained=True, num_classes=0) + # 4. Load VAE from SD 1.X/2.X checkpoint vae = AutoencoderKL.from_pretrained( args.pretrained_teacher_model, @@ -1333,12 +1404,14 @@ def main(args): ) # 7. Initialize GAN discriminator. - conditioning_dim = ( + text_conditioning_dim = ( text_encoder.config.projection_dim if args.use_pretrained_projection else text_encoder.config.hidden_size ) + img_conditioning_dim = image_encoder.num_features if args.use_image_conditioning else None discriminator = Discriminator( pretrained_feature_network=args.pretrained_feature_network, - c_text_embedding_dim=conditioning_dim, + c_text_embedding_dim=text_conditioning_dim, + c_img_embedding_dim=img_conditioning_dim, cond_map_dim=args.cond_map_dim, patch_size=[args.feature_network_patch_size, args.feature_network_patch_size], ) @@ -1347,6 +1420,9 @@ def main(args): vae.requires_grad_(False) text_encoder.requires_grad_(False) teacher_unet.requires_grad_(False) + if args.use_image_conditioning: + image_encoder.eval() + image_encoder.requires_grad_(False) unet.train() @@ -1377,6 +1453,8 @@ def main(args): vae.to(dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) teacher_unet.to(accelerator.device, dtype=weight_dtype) + if args.use_image_conditioning: + image_encoder.to(accelerator.device, dtype=weight_dtype) # Move target (EMA) unet to device but keep in full precision if args.use_ema: @@ -1488,6 +1566,10 @@ def compute_embeddings( ) return {"prompt_embeds": prompt_embeds, "text_embedding": text_embedding} + def compute_image_embeddings(image_batch, image_encoder): + image_embeds = encode_images(image_batch, image_encoder) + return {"image_embeds": image_embeds} + dataset = SDText2ImageDataset( train_shards_path_or_url=args.train_shards_path_or_url, num_train_examples=args.max_train_samples, @@ -1498,6 +1580,9 @@ def compute_embeddings( shuffle_buffer_size=1000, pin_memory=True, persistent_workers=True, + use_image_conditioning=args.use_image_conditioning, + cond_resolution=args.cond_resolution, + cond_interpolation_type=args.cond_interpolation_type, ) train_dataloader = dataset.train_dataloader @@ -1509,6 +1594,12 @@ def compute_embeddings( has_projection=args.use_pretrained_projection, ) + if args.use_image_conditioning: + compute_image_embeddings_fn = functools.partial( + compute_image_embeddings, + image_encoder=image_encoder, + ) + # 14. Create learning rate scheduler for generator and discriminator # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -1620,10 +1711,15 @@ def compute_embeddings( for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # 1. Load and process the image and text conditioning - image, text = batch + if args.use_image_conditioning: + image, text, cond_image = batch + else: + image, text = batch image = image.to(accelerator.device, non_blocking=True) encoded_text = compute_embeddings_fn(text) + if args.use_image_conditioning: + encoded_image = compute_image_embeddings_fn(cond_image) pixel_values = image.to(dtype=weight_dtype) if vae.dtype != weight_dtype: @@ -1655,6 +1751,15 @@ def compute_embeddings( # 4. Prepare prompt embeds (for teacher/student U-Net) and text embedding (for discriminator). prompt_embeds = encoded_text.pop("prompt_embeds") text_embedding = encoded_text.pop("text_embedding") + image_embedding = None + if args.use_image_conditioning: + image_embedding = encoded_image.pop("image_embeds") + # Only supply image conditioning when student timestep is not last training timestep T. + image_embedding = torch.where( + student_timesteps.unsqueeze(1) < noise_scheduler.config.num_train_timesteps - 1, + torch.zeros_like(image_embedding), + image_embedding, + ) # 5. Get the student model predicted original sample `student_x_0`. student_noise_pred = unet( @@ -1706,8 +1811,8 @@ def compute_embeddings( student_gen_image = vae.decode(unscaled_student_x_0.to(dtype=weight_dtype)).sample # 2. Get discriminator real/fake outputs on the real and fake (generated) images respectively. - disc_output_real = discriminator(pixel_values.float(), text_embedding) - disc_output_fake = discriminator(student_gen_image.detach().float(), text_embedding) + disc_output_real = discriminator(pixel_values.float(), text_embedding, image_embedding) + disc_output_fake = discriminator(student_gen_image.detach().float(), text_embedding, image_embedding) # 3. Calculate the discriminator real adversarial loss terms. d_logits_real = disc_output_real.logits @@ -1748,7 +1853,7 @@ def compute_embeddings( optimizer.zero_grad(set_to_none=True) # 1. Rerun the disc on generated image, but this time allow gradients to flow through the generator - disc_output_fake = discriminator(student_gen_image, text_embedding) + disc_output_fake = discriminator(student_gen_image, text_embedding, image_embedding) # 2. Calculate generator adversarial loss term g_logits_fake = disc_output_fake.logits diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index db3e6fd99e91..29245f669e94 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -127,6 +127,24 @@ def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue): return samples +def resolve_interpolation_mode(interpolation_type): + if interpolation_type == "bilinear": + interpolation_mode = TF.InterpolationMode.BILINEAR + elif interpolation_type == "bicubic": + interpolation_mode = TF.InterpolationMode.BICUBIC + elif interpolation_type == "nearest": + interpolation_mode = TF.InterpolationMode.NEAREST + elif interpolation_type == "lanczos": + interpolation_mode = TF.InterpolationMode.LANCZOS + else: + raise ValueError( + f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" + f" modes are `bilinear`, `bicubic`, `lanczos`, and `nearest`." + ) + + return interpolation_mode + + class WebdatasetFilter: def __init__(self, min_size=1024, max_pwatermark=0.5): self.min_size = min_size @@ -161,6 +179,9 @@ def __init__( pin_memory: bool = False, persistent_workers: bool = False, use_fix_crop_and_size: bool = False, + use_image_conditioning: bool = False, + cond_resolution: Optional[int] = None, + cond_interpolation_type: Optional[str] = None, ): if not isinstance(train_shards_path_or_url, str): train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url] @@ -173,24 +194,17 @@ def get_orig_size(json): else: return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0))) - if interpolation_type == "bilinear": - self.interpolation_mode = TF.InterpolationMode.BILINEAR - elif interpolation_type == "bicubic": - self.interpolation_mode = TF.InterpolationMode.BICUBIC - elif interpolation_type == "nearest": - self.interpolation_mode = TF.InterpolationMode.NEAREST - elif interpolation_type == "lanczos": - self.interpolation_mode = TF.InterpolationMode.LANCZOS - else: - raise ValueError( - f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" - f" modes are `bilinear`, `bicubic`, `lanczos`, and `nearest`." - ) + interpolation_mode = resolve_interpolation_mode(interpolation_type) + if use_image_conditioning: + cond_interpolation_mode = resolve_interpolation_mode(cond_interpolation_type) def transform(example): # resize image image = example["image"] - image = TF.resize(image, resolution, interpolation=self.interpolation_mode) + if use_image_conditioning: + cond_image = copy.deepcopy(image) + + image = TF.resize(image, resolution, interpolation=interpolation_mode) # get crop coordinates and crop image c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) @@ -200,6 +214,14 @@ def transform(example): example["image"] = image example["crop_coords"] = (c_top, c_left) if not use_fix_crop_and_size else (0, 0) + + if use_image_conditioning: + # Prepare a separate image for image conditioning since the preprocessing pipelines are different. + cond_image = TF.resize(cond_image, cond_resolution, interpolation=cond_interpolation_mode) + cond_image = TF.center_crop(cond_image, output_size=(cond_resolution, cond_resolution)) + cond_image = TF.normalize(cond_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + example["cond_image"] = cond_image + return example processing_pipeline = [ @@ -210,9 +232,13 @@ def transform(example): wds.map(filter_keys({"image", "text", "orig_size"})), wds.map_dict(orig_size=get_orig_size), wds.map(transform), - wds.to_tuple("image", "text", "orig_size", "crop_coords"), ] + if use_image_conditioning: + processing_pipeline.append(wds.to_tuple("image", "text", "orig_size", "crop_coords", "cond_image")) + else: + processing_pipeline.append(wds.to_tuple("image", "text", "orig_size", "crop_coords")) + # Create train dataset and loader pipeline = [ wds.ResampledShards(train_shards_path_or_url), @@ -1021,6 +1047,39 @@ def parse_args(): " be projected to in the discriminator heads." ), ) + parser.add_argument( + "--use_image_conditioning", + action="store_true", + help=( + "Whether to also use an image encoder to calculate image conditioning embeddings for the discriminator. If" + " set, the model at the timm model id given in `image_encoder_with_proj` will be used." + ), + ) + parser.add_argument( + "--pretrained_image_encoder", + type=str, + default="vit_large_patch14_dinov2.lvd142m", + help=( + "An optional image encoder to add image conditioning information to the discriminator. Is used if" + " `use_image_conditioning` is set. The model id should be loadable by `timm.create_model`. Note that ADD" + " uses a DINOv2 ViT-L encoder (see section 4 of the paper)." + ), + ) + parser.add_argument( + "--cond_resolution", + type=int, + default=518, + help="Resolution to resize the original images to for image conditioning." + ) + parser.add_argument( + "--cond_interpolation_type", + type=str, + default="bicubic", + help=( + "The interpolation function used when resizing the image for conditioning. Choose between `bilinear`," + " `bicubic`, `lanczos`, and `nearest`." + ), + ) parser.add_argument( "--weight_schedule", type=str, @@ -1220,6 +1279,12 @@ def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prom return prompt_embeds, pooled_prompt_embeds +def encode_images(image_batch, image_encoder): + # image_encoder pre-processing is done in SDText2ImageDataset + image_embeds = image_encoder(image_batch) + return image_embeds + + def main(args): logging_dir = Path(args.output_dir, args.logging_dir) @@ -1344,6 +1409,11 @@ def main(args): args.pretrained_teacher_model, subfolder="text_encoder_2", revision=args.teacher_revision ) + # Optionally load a image encoder model for image conditioning of the discriminator. + if args.use_image_conditioning: + # Set num_classes=0 so that we get image embeddings from image_encoder forward pass + image_encoder = timm.create_model(args.pretrained_image_encoder, pretrained=True, num_classes=0) + # 4. Load VAE from SD-XL checkpoint (or more stable VAE) vae_path = ( args.pretrained_teacher_model @@ -1387,10 +1457,12 @@ def main(args): # already a ClipTextModelWithProjection) # TODO: what if there's no text_encoder_two? I think we already assume text_encoder_two exists in Step 3 above so # it might be fine? + text_conditioning_dim = text_encoder_two.config.projection_dim + img_conditioning_dim = image_encoder.num_features if args.use_image_conditioning else None discriminator = Discriminator( pretrained_feature_network=args.pretrained_feature_network, - c_text_embedding_dim=text_encoder_two.config.projection_dim, - cond_map_dim=args.cond_map_dim, + c_text_embedding_dim=text_conditioning_dim, + c_img_embedding_dim=img_conditioning_dim, patch_size=[args.feature_network_patch_size, args.feature_network_patch_size], ) @@ -1399,6 +1471,9 @@ def main(args): text_encoder_one.requires_grad_(False) text_encoder_two.requires_grad_(False) teacher_unet.requires_grad_(False) + if args.use_image_conditioning: + image_encoder.eval() + image_encoder.requires_grad_(False) unet.train() @@ -1430,6 +1505,8 @@ def main(args): text_encoder_one.to(accelerator.device, dtype=weight_dtype) text_encoder_two.to(accelerator.device, dtype=weight_dtype) teacher_unet.to(accelerator.device, dtype=weight_dtype) + if args.use_image_conditioning: + image_encoder.to(accelerator.device, dtype=weight_dtype) # Move target (EMA) unet to device but keep in full precision if args.use_ema: @@ -1561,6 +1638,10 @@ def compute_embeddings( return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} + def compute_image_embeddings(image_batch, image_encoder): + image_embeds = encode_images(image_batch, image_encoder) + return {"image_embeds": image_embeds} + dataset = SDXLText2ImageDataset( train_shards_path_or_url=args.train_shards_path_or_url, num_train_examples=args.max_train_samples, @@ -1572,6 +1653,9 @@ def compute_embeddings( pin_memory=True, persistent_workers=True, use_fix_crop_and_size=args.use_fix_crop_and_size, + use_image_conditioning=args.use_image_conditioning, + cond_resolution=args.cond_resolution, + cond_interpolation_type=args.cond_interpolation_type, ) train_dataloader = dataset.train_dataloader @@ -1587,6 +1671,12 @@ def compute_embeddings( tokenizers=tokenizers, ) + if args.use_image_conditioning: + compute_image_embeddings_fn = functools.partial( + compute_image_embeddings, + image_encoder=image_encoder, + ) + # 14. Create learning rate scheduler for generator and discriminator # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -1697,10 +1787,15 @@ def compute_embeddings( for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # 1. Load and process the image, text, and micro-conditioning (original image size, crop coordinates) - image, text, orig_size, crop_coords = batch + if args.use_image_conditioning: + image, text, orig_size, crop_coords, cond_image = batch + else: + image, text, orig_size, crop_coords = batch image = image.to(accelerator.device, non_blocking=True) encoded_text = compute_embeddings_fn(text, orig_size, crop_coords) + if args.use_image_conditioning: + encoded_image = compute_image_embeddings_fn(cond_image) if args.pretrained_vae_model_name_or_path is not None: pixel_values = image.to(dtype=weight_dtype) @@ -1736,6 +1831,15 @@ def compute_embeddings( # 4. Prepare prompt embeds (for teacher/student U-Net) and text embedding (for discriminator). prompt_embeds = encoded_text.pop("prompt_embeds") text_embedding = encoded_text["text_embeds"] + image_embedding = None + if args.use_image_conditioning: + image_embedding = encoded_image.pop("image_embeds") + # Only supply image conditioning when student timestep is not last training timestep T. + image_embedding = torch.where( + student_timesteps.unsqueeze(1) < noise_scheduler.config.num_train_timesteps - 1, + torch.zeros_like(image_embedding), + image_embedding, + ) # 5. Get the student model predicted original sample `student_x_0`. student_noise_pred = unet( @@ -1798,8 +1902,8 @@ def compute_embeddings( student_gen_image = vae.decode(unscaled_student_x_0).sample # 2. Get discriminator real/fake outputs on the real and fake (generated) images respectively. - disc_output_real = discriminator(pixel_values.float(), text_embedding) - disc_output_fake = discriminator(student_gen_image.detach().float(), text_embedding) + disc_output_real = discriminator(pixel_values.float(), text_embedding, image_embedding) + disc_output_fake = discriminator(student_gen_image.detach().float(), text_embedding, image_embedding) # 3. Calculate the discriminator real adversarial loss terms. d_logits_real = disc_output_real.logits @@ -1840,7 +1944,7 @@ def compute_embeddings( optimizer.zero_grad(set_to_none=True) # 1. Rerun the disc on generated image, but this time allow gradients to flow through the generator - disc_output_fake = discriminator(student_gen_image, text_embedding) + disc_output_fake = discriminator(student_gen_image, text_embedding, image_embedding) # 2. Calculate generator adversarial loss term g_logits_fake = disc_output_fake.logits From 1773de31e25adb8265043f53e4e331001182e45a Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 30 Dec 2023 18:38:07 -0800 Subject: [PATCH 31/53] make style --- examples/add/train_add_distill_sd_wds.py | 2 +- examples/add/train_add_distill_sdxl_wds.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index b28b77085a21..2b65b74cb7de 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -1044,7 +1044,7 @@ def parse_args(): "--cond_resolution", type=int, default=518, - help="Resolution to resize the original images to for image conditioning." + help="Resolution to resize the original images to for image conditioning.", ) parser.add_argument( "--cond_interpolation_type", diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index 29245f669e94..7c2e4b233e9e 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -1069,7 +1069,7 @@ def parse_args(): "--cond_resolution", type=int, default=518, - help="Resolution to resize the original images to for image conditioning." + help="Resolution to resize the original images to for image conditioning.", ) parser.add_argument( "--cond_interpolation_type", From d5b96194d1f528a81fc91464291db7e61bce000a Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 30 Dec 2023 18:59:20 -0800 Subject: [PATCH 32/53] In discriminator heads, concatenate conditioning embeddings and map and take product once rather than having separate maps and mapping and taking the product twice. --- examples/add/train_add_distill_sd_wds.py | 23 +++++++++++----------- examples/add/train_add_distill_sdxl_wds.py | 23 +++++++++++----------- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 2b65b74cb7de..6246043338d8 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -368,10 +368,11 @@ def __init__( # Project each token embedding from channels dimensions to cond_map_dim dimensions. self.cls = SpectralConv1d(channels, cond_map_dim, kernel_size=1, padding=0) - # Also project the text conditioning embeddings to dimension cond_map_dim. - self.c_text_map = torch.nn.Linear(self.c_text_embedding_dim, cond_map_dim) + # Also project the concatenated conditioning embeddings to dimension cond_map_dim. + c_map_input_dim = self.c_text_embedding_dim if self.c_img_embedding_dim is not None: - self.c_img_map = torch.nn.Linear(self.c_img_embedding_dim, cond_map_dim) + c_map_input_dim += self.c_img_embedding_dim + self.c_map = torch.nn.Linear(c_map_input_dim, cond_map_dim) def forward(self, x: torch.Tensor, c_text: torch.Tensor, c_img: Optional[torch.Tensor] = None) -> torch.Tensor: """ @@ -394,15 +395,15 @@ def forward(self, x: torch.Tensor, c_text: torch.Tensor, c_img: Optional[torch.T hidden_states = self.resblock(hidden_states) out = self.cls(hidden_states) - # Project text conditioning embedding to cond_map_dim and unsqueeze in the sequence length dimension. - c_text = self.c_text_map(c_text).unsqueeze(-1) - - # Combine image features with text conditioning embedding via a product. - out = (out * c_text).sum(1, keepdim=True) * (1 / np.sqrt(self.cond_map_dim)) - if self.c_img_embedding_dim is not None: - c_img = self.c_img_map(c_img).unsqueeze(-1) - out = (out * c_img).sum(1, keepdim=True) * (1 / np.sqrt(self.cond_map_dim)) + c = torch.cat([c_text, c_img], dim=1) + else: + c = c_text + # Project conditioning embedding to cond_map_dim and unsqueeze in the sequence length dimension. + c = self.c_map(c).unsqueeze(-1) + + # Combine image features with projected conditioning embedding via a product. + out = (out * c).sum(1, keepdim=True) * (1 / np.sqrt(self.cond_map_dim)) return out diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index 7c2e4b233e9e..2a5c52d75229 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -385,10 +385,11 @@ def __init__( # Project each token embedding from channels dimensions to cond_map_dim dimensions. self.cls = SpectralConv1d(channels, cond_map_dim, kernel_size=1, padding=0) - # Also project the text conditioning embeddings to dimension cond_map_dim. - self.c_text_map = torch.nn.Linear(self.c_text_embedding_dim, cond_map_dim) + # Also project the concatenated conditioning embeddings to dimension cond_map_dim. + c_map_input_dim = self.c_text_embedding_dim if self.c_img_embedding_dim is not None: - self.c_img_map = torch.nn.Linear(self.c_img_embedding_dim, cond_map_dim) + c_map_input_dim += self.c_img_embedding_dim + self.c_map = torch.nn.Linear(c_map_input_dim, cond_map_dim) def forward(self, x: torch.Tensor, c_text: torch.Tensor, c_img: Optional[torch.Tensor] = None) -> torch.Tensor: """ @@ -411,15 +412,15 @@ def forward(self, x: torch.Tensor, c_text: torch.Tensor, c_img: Optional[torch.T hidden_states = self.resblock(hidden_states) out = self.cls(hidden_states) - # Project text conditioning embedding to cond_map_dim and unsqueeze in the sequence length dimension. - c_text = self.c_text_map(c_text).unsqueeze(-1) - - # Combine image features with text conditioning embedding via a product. - out = (out * c_text).sum(1, keepdim=True) * (1 / np.sqrt(self.cond_map_dim)) - if self.c_img_embedding_dim is not None: - c_img = self.c_img_map(c_img).unsqueeze(-1) - out = (out * c_img).sum(1, keepdim=True) * (1 / np.sqrt(self.cond_map_dim)) + c = torch.cat([c_text, c_img], dim=1) + else: + c = c_text + # Project conditioning embedding to cond_map_dim and unsqueeze in the sequence length dimension. + c = self.c_map(c).unsqueeze(-1) + + # Combine image features with projected conditioning embedding via a product. + out = (out * c).sum(1, keepdim=True) * (1 / np.sqrt(self.cond_map_dim)) return out From b54c664818a150cb7e945ba6ae233a571582142e Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 30 Dec 2023 19:43:46 -0800 Subject: [PATCH 33/53] Fix bug when calculating image conditioning embeddings. --- examples/add/train_add_distill_sd_wds.py | 2 +- examples/add/train_add_distill_sdxl_wds.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 6246043338d8..49aaa12a9460 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -1758,8 +1758,8 @@ def compute_image_embeddings(image_batch, image_encoder): # Only supply image conditioning when student timestep is not last training timestep T. image_embedding = torch.where( student_timesteps.unsqueeze(1) < noise_scheduler.config.num_train_timesteps - 1, - torch.zeros_like(image_embedding), image_embedding, + torch.zeros_like(image_embedding), ) # 5. Get the student model predicted original sample `student_x_0`. diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index 2a5c52d75229..9f0b16b87abe 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -1838,8 +1838,8 @@ def compute_image_embeddings(image_batch, image_encoder): # Only supply image conditioning when student timestep is not last training timestep T. image_embedding = torch.where( student_timesteps.unsqueeze(1) < noise_scheduler.config.num_train_timesteps - 1, - torch.zeros_like(image_embedding), image_embedding, + torch.zeros_like(image_embedding), ) # 5. Get the student model predicted original sample `student_x_0`. From 275cc8a3ece8feae0f2cd228c499083acc35fe17 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 1 Jan 2024 11:15:24 -0800 Subject: [PATCH 34/53] Add batched VAE decoding and make VAE encoding/decoding batch size configurable. --- examples/add/train_add_distill_sd_wds.py | 33 +++++++++++--- examples/add/train_add_distill_sdxl_wds.py | 52 ++++++++++++++++------ 2 files changed, 67 insertions(+), 18 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 49aaa12a9460..c271788c9ef2 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -1119,6 +1119,16 @@ def parse_args(): " compared to the original paper." ), ) + parser.add_argument( + "--vae_encode_batch_size", + type=int, + default=32, + required=False, + help=( + "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE." + " Encoding or decoding the whole batch at once may run into OOM issues." + ), + ) # ----Exponential Moving Average (EMA)---- parser.add_argument( "--use_ema", action="store_true", help="Whether to also maintain an EMA version of the student U-Net weights." @@ -1726,10 +1736,10 @@ def compute_image_embeddings(image_batch, image_encoder): if vae.dtype != weight_dtype: vae.to(dtype=weight_dtype) - # encode pixel values with batch size of at most 32 + # encode pixel values with batch size of at most args.vae_encode_batch_size latents = [] - for i in range(0, pixel_values.shape[0], 32): - latents.append(vae.encode(pixel_values[i : i + 32]).latent_dist.sample()) + for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size): + latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()) latents = torch.cat(latents, dim=0) latents = latents * vae.config.scaling_factor @@ -1809,7 +1819,13 @@ def compute_image_embeddings(image_batch, image_encoder): # NOTE: the paper doesn't mention this explicitly AFAIK but I think this makes sense since the # pretrained feature network for the discriminator operates in pixel space rather than latent space. unscaled_student_x_0 = (1 / vae.config.scaling_factor) * student_x_0 - student_gen_image = vae.decode(unscaled_student_x_0.to(dtype=weight_dtype)).sample + # Perform batched decode with batch size of at most args.vae_encode_batch_size + student_gen_image = [] + for i in range(0, unscaled_student_x_0.shape[0], args.vae_encode_batch_size): + student_gen_image.append( + vae.decode(unscaled_student_x_0[i : i + args.vae_encode_batch_size].to(dtype=weight_dtype)).sample + ) + student_gen_image = torch.cat(student_gen_image, dim=0) # 2. Get discriminator real/fake outputs on the real and fake (generated) images respectively. disc_output_real = discriminator(pixel_values.float(), text_embedding, image_embedding) @@ -1865,7 +1881,14 @@ def compute_image_embeddings(image_batch, image_encoder): ############################ # Calculate distillation loss in pixel space rather than latent space (see section 3.1) unscaled_teacher_x_0 = (1 / vae.config.scaling_factor) * teacher_x_0 - teacher_gen_image = vae.decode(unscaled_teacher_x_0.to(dtype=weight_dtype)).sample + # Perform batched decode with batch size of at most args.vae_encode_batch_size + teacher_gen_image = [] + for i in range(0, unscaled_teacher_x_0.shape[0], args.vae_encode_batch_size): + teacher_gen_image.append( + vae.decode(unscaled_teacher_x_0[i : i + args.vae_encode_batch_size].to(dtype=weight_dtype)).sample + ) + teacher_gen_image = torch.cat(teacher_gen_image, dim=0) + per_instance_distillation_loss = F.mse_loss( student_gen_image.float(), teacher_gen_image.float(), reduction="none" ) diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index 9f0b16b87abe..c3270730b8df 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -1144,6 +1144,16 @@ def parse_args(): " compared to the original paper." ), ) + parser.add_argument( + "--vae_encode_batch_size", + type=int, + default=8, + required=False, + help=( + "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE." + " Encoding or decoding the whole batch at once may run into OOM issues." + ), + ) # ----Exponential Moving Average (EMA)---- parser.add_argument( "--use_ema", action="store_true", help="Whether to also maintain an EMA version of the student U-Net weights." @@ -1805,10 +1815,10 @@ def compute_image_embeddings(image_batch, image_encoder): else: pixel_values = image - # encode pixel values with batch size of at most 8 + # encode pixel values with batch size of at most args.vae_encode_batch_size latents = [] - for i in range(0, pixel_values.shape[0], 8): - latents.append(vae.encode(pixel_values[i : i + 8]).latent_dist.sample()) + for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size): + latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()) latents = torch.cat(latents, dim=0) latents = latents * vae.config.scaling_factor @@ -1896,11 +1906,19 @@ def compute_image_embeddings(image_batch, image_encoder): # NOTE: the paper doesn't mention this explicitly AFAIK but I think this makes sense since the # pretrained feature network for the discriminator operates in pixel space rather than latent space. unscaled_student_x_0 = (1 / vae.config.scaling_factor) * student_x_0 - if args.pretrained_vae_model_name_or_path is not None: - student_gen_image = vae.decode(unscaled_student_x_0.to(dtype=weight_dtype)).sample - else: - # VAE is in full precision due to possible NaN issues - student_gen_image = vae.decode(unscaled_student_x_0).sample + student_gen_image = [] + # Perform batched decode with batch size of at most args.vae_encode_batch_size + for i in range(0, unscaled_student_x_0.shape[0], args.vae_encode_batch_size): + if args.pretrained_vae_model_name_or_path: + student_gen_image.append( + vae.decode(unscaled_student_x_0[i : i + args.vae_encode_batch_size].to(dtype=weight_dtype)).sample + ) + else: + # VAE is in full precision due to possible NaN issues + student_gen_image.append( + vae.decode(unscaled_student_x_0[i : i + args.vae_encode_batch_size]).sample + ) + student_gen_image = torch.cat(student_gen_image, dim=0) # 2. Get discriminator real/fake outputs on the real and fake (generated) images respectively. disc_output_real = discriminator(pixel_values.float(), text_embedding, image_embedding) @@ -1956,11 +1974,19 @@ def compute_image_embeddings(image_batch, image_encoder): ############################ # Calculate distillation loss in pixel space rather than latent space (see section 3.1) unscaled_teacher_x_0 = (1 / vae.config.scaling_factor) * teacher_x_0 - if args.pretrained_vae_model_name_or_path is not None: - teacher_gen_image = vae.decode(unscaled_teacher_x_0.to(dtype=weight_dtype)).sample - else: - # VAE is in full precision due to possible NaN issues - teacher_gen_image = vae.decode(unscaled_teacher_x_0).sample + teacher_gen_image = [] + # Perform batched decode with batch size of at most args.vae_encode_batch_size + for i in range(0, unscaled_teacher_x_0.shape[0], args.vae_encode_batch_size): + if args.pretrained_vae_model_name_or_path: + teacher_gen_image.append( + vae.decode(unscaled_teacher_x_0[i : i + args.vae_encode_batch_size].to(dtype=weight_dtype)).sample + ) + else: + # VAE is in full precision due to possible NaN issues + teacher_gen_image.append( + vae.decode(unscaled_teacher_x_0[i : i + args.vae_encode_batch_size]).sample + ) + teacher_gen_image = torch.cat(teacher_gen_image, dim=0) per_instance_distillation_loss = F.mse_loss( student_gen_image.float(), teacher_gen_image.float(), reduction="none" From aa9cc411e83495abbe7616a42c54a710c707a3bd Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 1 Jan 2024 11:16:11 -0800 Subject: [PATCH 35/53] make style --- examples/add/train_add_distill_sd_wds.py | 8 ++++++-- examples/add/train_add_distill_sdxl_wds.py | 8 ++++++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index c271788c9ef2..cb6ada9d018f 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -1823,7 +1823,9 @@ def compute_image_embeddings(image_batch, image_encoder): student_gen_image = [] for i in range(0, unscaled_student_x_0.shape[0], args.vae_encode_batch_size): student_gen_image.append( - vae.decode(unscaled_student_x_0[i : i + args.vae_encode_batch_size].to(dtype=weight_dtype)).sample + vae.decode( + unscaled_student_x_0[i : i + args.vae_encode_batch_size].to(dtype=weight_dtype) + ).sample ) student_gen_image = torch.cat(student_gen_image, dim=0) @@ -1885,7 +1887,9 @@ def compute_image_embeddings(image_batch, image_encoder): teacher_gen_image = [] for i in range(0, unscaled_teacher_x_0.shape[0], args.vae_encode_batch_size): teacher_gen_image.append( - vae.decode(unscaled_teacher_x_0[i : i + args.vae_encode_batch_size].to(dtype=weight_dtype)).sample + vae.decode( + unscaled_teacher_x_0[i : i + args.vae_encode_batch_size].to(dtype=weight_dtype) + ).sample ) teacher_gen_image = torch.cat(teacher_gen_image, dim=0) diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index c3270730b8df..54394a76213d 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -1911,7 +1911,9 @@ def compute_image_embeddings(image_batch, image_encoder): for i in range(0, unscaled_student_x_0.shape[0], args.vae_encode_batch_size): if args.pretrained_vae_model_name_or_path: student_gen_image.append( - vae.decode(unscaled_student_x_0[i : i + args.vae_encode_batch_size].to(dtype=weight_dtype)).sample + vae.decode( + unscaled_student_x_0[i : i + args.vae_encode_batch_size].to(dtype=weight_dtype) + ).sample ) else: # VAE is in full precision due to possible NaN issues @@ -1979,7 +1981,9 @@ def compute_image_embeddings(image_batch, image_encoder): for i in range(0, unscaled_teacher_x_0.shape[0], args.vae_encode_batch_size): if args.pretrained_vae_model_name_or_path: teacher_gen_image.append( - vae.decode(unscaled_teacher_x_0[i : i + args.vae_encode_batch_size].to(dtype=weight_dtype)).sample + vae.decode( + unscaled_teacher_x_0[i : i + args.vae_encode_batch_size].to(dtype=weight_dtype) + ).sample ) else: # VAE is in full precision due to possible NaN issues From da6c5acd317b7078d76b2ff99c97612f37e08e0b Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 1 Jan 2024 11:33:44 -0800 Subject: [PATCH 36/53] Change validation prompts to example prompts from the ADD paper. --- examples/add/train_add_distill_sd_wds.py | 8 ++++---- examples/add/train_add_distill_sdxl_wds.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index cb6ada9d018f..47e6a4960072 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -676,10 +676,10 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="stude generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) validation_prompts = [ - "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography", - "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", - "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", - "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece", + "A cinematic shot of robot with colorful feathers.", + "Teddy bears working on new AI research on the moon in the 1980s.", + "A robot is playing the guitar at a rock concert in front of a large crowd.", + "A photo of an astronaut riding a horse in the forest. There is a river in front of them with water lilies.", ] image_logs = [] diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index 54394a76213d..e6e923b2cc50 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -693,10 +693,10 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="stude generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) validation_prompts = [ - "portrait photo of a girl, photograph, highly detailed face, depth of field, moody light, golden hour, style by Dan Winters, Russell James, Steve McCurry, centered, extremely detailed, Nikon D850, award winning photography", - "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", - "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", - "A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece", + "A cinematic shot of robot with colorful feathers.", + "Teddy bears working on new AI research on the moon in the 1980s.", + "A robot is playing the guitar at a rock concert in front of a large crowd.", + "A photo of an astronaut riding a horse in the forest. There is a river in front of them with water lilies.", ] image_logs = [] From db358dd6aa2d840220867f57a10459b061578a81 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 1 Jan 2024 12:35:18 -0800 Subject: [PATCH 37/53] Fix interpolation type bug. --- examples/add/train_add_distill_sd_wds.py | 1 + examples/add/train_add_distill_sdxl_wds.py | 1 + 2 files changed, 2 insertions(+) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 47e6a4960072..fc0123f2a6af 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -1588,6 +1588,7 @@ def compute_image_embeddings(image_batch, image_encoder): global_batch_size=args.train_batch_size * accelerator.num_processes, num_workers=args.dataloader_num_workers, resolution=args.resolution, + interpolation_type=args.interpolation_type, shuffle_buffer_size=1000, pin_memory=True, persistent_workers=True, diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index e6e923b2cc50..02a3f69ba36a 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -1660,6 +1660,7 @@ def compute_image_embeddings(image_batch, image_encoder): global_batch_size=args.train_batch_size * accelerator.num_processes, num_workers=args.dataloader_num_workers, resolution=args.resolution, + interpolation_type=args.interpolation_type, shuffle_buffer_size=1000, pin_memory=True, persistent_workers=True, From 9d462592303033e180c3b32fc5d3b067c970e140 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 6 Jan 2024 18:45:07 -0800 Subject: [PATCH 38/53] [Experimental] add scripts to use ADD to distill a LoRA rather than a full model --- examples/add/train_add_distill_lora_sd_wds.py | 2093 ++++++++++++++++ .../add/train_add_distill_lora_sdxl_wds.py | 2192 +++++++++++++++++ 2 files changed, 4285 insertions(+) create mode 100644 examples/add/train_add_distill_lora_sd_wds.py create mode 100644 examples/add/train_add_distill_lora_sdxl_wds.py diff --git a/examples/add/train_add_distill_lora_sd_wds.py b/examples/add/train_add_distill_lora_sd_wds.py new file mode 100644 index 000000000000..cadb9c2de64d --- /dev/null +++ b/examples/add/train_add_distill_lora_sd_wds.py @@ -0,0 +1,2093 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import copy +import functools +import gc +import itertools +import json +import logging +import math +import os +import random +import shutil +import types +from pathlib import Path +from typing import Callable, List, Optional, Union + +import accelerate +import numpy as np +import timm +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import torchvision.transforms.functional as TF +import transformers +import webdataset as wds +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from braceexpand import braceexpand +from huggingface_hub import create_repo +from packaging import version +from peft import LoraConfig, get_peft_model, get_peft_model_state_dict +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torch.utils.data import default_collate +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, CLIPTextModel, CLIPTextModelWithProjection +from webdataset.tariterators import ( + base_plus_ext, + tar_file_expander, + url_opener, + valid_sample, +) + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel +from diffusers.utils import BaseOutput, check_min_version, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available + + +MAX_SEQ_LENGTH = 77 + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.25.0.dev0") + +logger = get_logger(__name__) + + +def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter_name: str = "default"): + kohya_ss_state_dict = {} + for peft_key, weight in get_peft_model_state_dict(module, adapter_name=adapter_name).items(): + kohya_key = peft_key.replace("base_model.model", prefix) + kohya_key = kohya_key.replace("lora_A", "lora_down") + kohya_key = kohya_key.replace("lora_B", "lora_up") + kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2) + kohya_ss_state_dict[kohya_key] = weight.to(dtype) + + # Set alpha parameter + if "lora_down" in kohya_key: + alpha_key = f'{kohya_key.split(".")[0]}.alpha' + kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype) + + return kohya_ss_state_dict + + +def filter_keys(key_set): + def _f(dictionary): + return {k: v for k, v in dictionary.items() if k in key_set} + + return _f + + +def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): + """Return function over iterator that groups key, value pairs into samples. + + :param keys: function that splits the key into key and extension (base_plus_ext) :param lcase: convert suffixes to + lower case (Default value = True) + """ + current_sample = None + for filesample in data: + assert isinstance(filesample, dict) + fname, value = filesample["fname"], filesample["data"] + prefix, suffix = keys(fname) + if prefix is None: + continue + if lcase: + suffix = suffix.lower() + # FIXME webdataset version throws if suffix in current_sample, but we have a potential for + # this happening in the current LAION400m dataset if a tar ends with same prefix as the next + # begins, rare, but can happen since prefix aren't unique across tar files in that dataset + if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample: + if valid_sample(current_sample): + yield current_sample + current_sample = {"__key__": prefix, "__url__": filesample["__url__"]} + if suffixes is None or suffix in suffixes: + current_sample[suffix] = value + if valid_sample(current_sample): + yield current_sample + + +def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue): + # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw + streams = url_opener(src, handler=handler) + files = tar_file_expander(streams, handler=handler) + samples = group_by_keys_nothrow(files, handler=handler) + return samples + + +def resolve_interpolation_mode(interpolation_type): + if interpolation_type == "bilinear": + interpolation_mode = TF.InterpolationMode.BILINEAR + elif interpolation_type == "bicubic": + interpolation_mode = TF.InterpolationMode.BICUBIC + elif interpolation_type == "nearest": + interpolation_mode = TF.InterpolationMode.NEAREST + elif interpolation_type == "lanczos": + interpolation_mode = TF.InterpolationMode.LANCZOS + else: + raise ValueError( + f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" + f" modes are `bilinear`, `bicubic`, `lanczos`, and `nearest`." + ) + + return interpolation_mode + + +class WebdatasetFilter: + def __init__(self, min_size=1024, max_pwatermark=0.5): + self.min_size = min_size + self.max_pwatermark = max_pwatermark + + def __call__(self, x): + try: + if "json" in x: + x_json = json.loads(x["json"]) + filter_size = (x_json.get("original_width", 0.0) or 0.0) >= self.min_size and x_json.get( + "original_height", 0 + ) >= self.min_size + filter_watermark = (x_json.get("pwatermark", 1.0) or 1.0) <= self.max_pwatermark + return filter_size and filter_watermark + else: + return False + except Exception: + return False + + +class SDText2ImageDataset: + def __init__( + self, + train_shards_path_or_url: Union[str, List[str]], + num_train_examples: int, + per_gpu_batch_size: int, + global_batch_size: int, + num_workers: int, + resolution: int = 512, + interpolation_type: str = "bilinear", + shuffle_buffer_size: int = 1000, + pin_memory: bool = False, + persistent_workers: bool = False, + use_image_conditioning: bool = False, + cond_resolution: Optional[int] = None, + cond_interpolation_type: Optional[str] = None, + ): + if not isinstance(train_shards_path_or_url, str): + train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url] + # flatten list using itertools + train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url)) + + interpolation_mode = resolve_interpolation_mode(interpolation_type) + if use_image_conditioning: + cond_interpolation_mode = resolve_interpolation_mode(cond_interpolation_type) + + def transform(example): + # resize image + image = example["image"] + if use_image_conditioning: + cond_image = copy.deepcopy(image) + + image = TF.resize(image, resolution, interpolation=interpolation_mode) + + # get crop coordinates and crop image + c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) + image = TF.crop(image, c_top, c_left, resolution, resolution) + image = TF.to_tensor(image) + image = TF.normalize(image, [0.5], [0.5]) + + example["image"] = image + + if use_image_conditioning: + # Prepare a separate image for image conditioning since the preprocessing pipelines are different. + cond_image = TF.resize(cond_image, cond_resolution, interpolation=cond_interpolation_mode) + cond_image = TF.center_crop(cond_image, output_size=(cond_resolution, cond_resolution)) + cond_image = TF.normalize(cond_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + example["cond_image"] = cond_image + + return example + + processing_pipeline = [ + wds.decode("pil", handler=wds.ignore_and_continue), + wds.rename(image="jpg;png;jpeg;webp", text="text;txt;caption", handler=wds.warn_and_continue), + wds.map(filter_keys({"image", "text"})), + wds.map(transform), + ] + + if use_image_conditioning: + processing_pipeline.append(wds.to_tuple("image", "text", "cond_image")) + else: + processing_pipeline.append(wds.to_tuple("image", "text")) + + # Create train dataset and loader + pipeline = [ + wds.ResampledShards(train_shards_path_or_url), + tarfile_to_samples_nothrow, + wds.shuffle(shuffle_buffer_size), + *processing_pipeline, + wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate), + ] + + num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + + # each worker is iterating over this + self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches) + self._train_dataloader = wds.WebLoader( + self._train_dataset, + batch_size=None, + shuffle=False, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) + # add meta-data to dataloader instance for convenience + self._train_dataloader.num_batches = num_batches + self._train_dataloader.num_samples = num_samples + + @property + def train_dataset(self): + return self._train_dataset + + @property + def train_dataloader(self): + return self._train_dataloader + + +class Denoiser: + def __init__(self, alphas, sigmas, prediction_type="epsilon"): + self.alphas = alphas + self.sigmas = sigmas + self.prediction_type = prediction_type + + def to(self, device): + self.alphas = self.alphas.to(device) + self.sigmas = self.sigmas.to(device) + return self + + def __call__(self, model_output, timesteps, sample): + alphas = extract_into_tensor(self.alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(self.sigmas, timesteps, sample.shape) + if self.prediction_type == "epsilon": + pred_x_0 = (sample - sigmas * model_output) / alphas + elif self.prediction_type == "sample": + pred_x_0 = model_output + elif self.prediction_type == "v_prediction": + pred_x_0 = alphas * sample - sigmas * model_output + else: + raise ValueError( + f"Prediction type {self.prediction_type} is not supported; currently, `epsilon`, `sample`, and" + f" `v_prediction` are supported." + ) + + return pred_x_0 + + +# Based on SpectralConv1d from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L29 +class SpectralConv1d(torch.nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + torch.nn.utils.parametrizations.spectral_norm(self, name="weight", n_power_iterations=1, eps=1e-12, dim=0) + + +# Based on ResidualBlock from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/shared.py#L20 +class ResidualBlock(torch.nn.Module): + def __init__(self, fn: Callable): + super().__init__() + self.fn = fn + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return (self.fn(x) + x) / np.sqrt(2) + + +# Based on make_block from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L64 +class DiscHeadBlock(torch.nn.Module): + """ + StyleGAN-T block: SpectralConv1d => GroupNorm => LeakyReLU + """ + + def __init__( + self, + channels: int, + kernel_size: int, + num_groups: int = 8, + leaky_relu_neg_slope: float = 0.2, + ): + super().__init__() + self.channels = channels + + self.conv = SpectralConv1d( + channels, + channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + padding_mode="circular", + ) + self.norm = torch.nn.GroupNorm(num_groups, channels) + self.act_fn = torch.nn.LeakyReLU(leaky_relu_neg_slope, inplace=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) + x = self.norm(x) + x = self.act_fn(x) + return x + + +# Based on DiscHead in the official StyleGAN-T implementation +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L78 +class DiscriminatorHead(torch.nn.Module): + """ + Implements a StyleGAN-T-style discriminator head. The discriminator head takes in a (possibly intermediate) 1D + sequence of tokens from the feature network, processes it, and combines it with conditioning information to output + per-token logits. + """ + + def __init__( + self, + channels: int, + c_text_embedding_dim: int, + c_img_embedding_dim: Optional[int] = None, + cond_map_dim: int = 64, + ): + super().__init__() + self.channels = channels + self.c_text_embedding_dim = c_text_embedding_dim + self.c_img_embedding_dim = c_img_embedding_dim + self.cond_map_dim = cond_map_dim + + self.input_block = DiscHeadBlock(channels, kernel_size=1) + self.resblock = ResidualBlock(DiscHeadBlock(channels, kernel_size=9)) + # Project each token embedding from channels dimensions to cond_map_dim dimensions. + self.cls = SpectralConv1d(channels, cond_map_dim, kernel_size=1, padding=0) + + # Also project the concatenated conditioning embeddings to dimension cond_map_dim. + c_map_input_dim = self.c_text_embedding_dim + if self.c_img_embedding_dim is not None: + c_map_input_dim += self.c_img_embedding_dim + self.c_map = torch.nn.Linear(c_map_input_dim, cond_map_dim) + + def forward(self, x: torch.Tensor, c_text: torch.Tensor, c_img: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Maps a 1D sequence of tokens from a feature network (e.g. ViT trained with DINO) and a conditioning embedding + to per-token logits. + + Args: + x (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`): + A sequence of 1D tokens (possibly intermediate) from a ViT feature neetwork. Note that the channels dim + should be the same as the feature network's embedding dim. + c_text (`torch.Tensor` of shape `(batch_size, c_text_embedding_dim)`): + A conditioning embedding representing text conditioning information. + c_img (`torch.Tensor` of shape `(batch_size, c_img_embedding_dim)`): + A conditioning embedding representing image conditioning information. + + Returns: + `torch.Tensor` of shape `(batch_size, sequence_length)`: batched 1D sequence of per-token logits. + """ + hidden_states = self.input_block(x) + hidden_states = self.resblock(hidden_states) + out = self.cls(hidden_states) + + if self.c_img_embedding_dim is not None: + c = torch.cat([c_text, c_img], dim=1) + else: + c = c_text + # Project conditioning embedding to cond_map_dim and unsqueeze in the sequence length dimension. + c = self.c_map(c).unsqueeze(-1) + + # Combine image features with projected conditioning embedding via a product. + out = (out * c).sum(1, keepdim=True) * (1 / np.sqrt(self.cond_map_dim)) + + return out + + +activations = {} + + +# Based on get_activation from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L111 +def get_activation(name: str) -> Callable: + def hook(model, input, output): + activations[name] = output + + return hook + + +# Based on _resize_pos_embed from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L66 +def _resize_pos_embed(self, posemb: torch.Tensor, gs_h: int, gs_w: int) -> torch.Tensor: + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear", align_corners=False) + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +# Based on forward_flex from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L83 +def forward_flex(self, x: torch.Tensor) -> torch.Tensor: + # patch proj and dynamically resize + B, C, H, W = x.size() + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + pos_embed = self._resize_pos_embed(self.pos_embed, H // self.patch_size[1], W // self.patch_size[0]) + + # add cls token + cls_tokens = self.cls_token.expand(x.size(0), -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # forward pass + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + return x + + +# Based on forward_vit from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L60 +def forward_vit(pretrained: torch.nn.Module, x: torch.Tensor) -> dict: + _ = pretrained.model.forward_flex(x) + return {k: pretrained.rearrange(v) for k, v in activations.items()} + + +# Based on AddReadout from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L36 +class AddReadout(torch.nn.Module): + def __init__(self, start_index: int = 1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +# Based on Transpose from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L49 +class Transpose(torch.nn.Module): + def __init__(self, dim0: int, dim1: int): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.transpose(self.dim0, self.dim1) + return x.contiguous() + + +# Based on DINO from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L107 +class FeatureNetwork(torch.nn.Module): + """ + DINO ViT model to act as feature extractor for the discriminator. + """ + + def __init__( + self, + pretrained_feature_network: str = "vit_small_patch14_dinov2.lvd142m", + patch_size: List[int] = [14, 14], + hooks: List[int] = [2, 5, 8, 11], + start_index: int = 1, + ): + super().__init__() + self.num_hooks = len(hooks) + 1 + + pretrained_model = timm.create_model(pretrained_feature_network, pretrained=True) + + # Based on make_vit_backbone from the official StyleGAN-T code + # https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L117 + # which I believe is itself based on https://github.com/isl-org/DPT + model_with_hooks = torch.nn.Module() + model_with_hooks.model = pretrained_model + + # Add hooks + model_with_hooks.model.blocks[hooks[0]].register_forward_hook(get_activation("0")) + model_with_hooks.model.blocks[hooks[1]].register_forward_hook(get_activation("1")) + model_with_hooks.model.blocks[hooks[2]].register_forward_hook(get_activation("2")) + model_with_hooks.model.blocks[hooks[3]].register_forward_hook(get_activation("3")) + model_with_hooks.model.pos_drop.register_forward_hook(get_activation("4")) + + # Configure readout + model_with_hooks.rearrange = torch.nn.Sequential(AddReadout(start_index), Transpose(1, 2)) + model_with_hooks.model.start_index = start_index + model_with_hooks.model.patch_size = patch_size + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + model_with_hooks.model.forward_flex = types.MethodType(forward_flex, model_with_hooks.model) + model_with_hooks.model._resize_pos_embed = types.MethodType(_resize_pos_embed, model_with_hooks.model) + + self.model = model_with_hooks + # Freeze pretrained model with hooks + self.model = self.model.eval().requires_grad_(False) + + self.img_resolution = self.model.model.patch_embed.img_size[0] + self.embed_dim = self.model.model.embed_dim + self.norm = transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + + def forward(self, x: torch.Tensor): + """ + Forward pass consisting of interpolation, ImageNet normalization, and a forward pass of self.model. + + Args: + x (`torch.Tensor`): + Image with pixel values in [0, 1]. + + Returns: + `Dict[Any]`: dict of activations which are intermediate features from the feature network. The dict values + (feature embeddings) have shape `(batch_size, embed_dim, sequence_length)`. + """ + x = F.interpolate(x, self.img_resolution, mode="area") + x = self.norm(x) + + activation_dict = forward_vit(self.model, x) + return activation_dict + + +class DiscriminatorOutput(BaseOutput): + """ + Output class for the Discriminator module. + """ + + logits: torch.FloatTensor + + +# Based on ProjectedDiscriminator from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L130 +class Discriminator(torch.nn.Module): + """ + StyleGAN-T-style discriminator for adversarial diffusion distillation (ADD). + """ + + def __init__( + self, + pretrained_feature_network: str = "vit_small_patch14_dinov2.lvd142m", + c_text_embedding_dim: int = 768, + c_img_embedding_dim: Optional[int] = None, + cond_map_dim: int = 64, + patch_size: List[int] = [14, 14], + hooks: List[int] = [2, 5, 8, 11], + start_index: int = 1, + ): + super().__init__() + self.c_text_embedding_dim = c_text_embedding_dim + self.c_img_embedding_dim = c_img_embedding_dim + self.cond_map_dim = cond_map_dim + + # Frozen feature network, e.g. DINO + self.feature_network = FeatureNetwork( + pretrained_feature_network=pretrained_feature_network, + patch_size=patch_size, + hooks=hooks, + start_index=start_index, + ) + + # Trainable discriminator heads + heads = [] + for i in range(self.feature_network.num_hooks): + heads.append( + [ + str(i), + DiscriminatorHead( + self.feature_network.embed_dim, c_text_embedding_dim, c_img_embedding_dim, cond_map_dim + ), + ] + ) + self.heads = torch.nn.ModuleDict(heads) + + def train(self, mode: bool = True): + self.feature_network = self.feature_network.train(False) + self.heads = self.heads.train(mode) + return self + + def eval(self): + return self.train(False) + + def forward( + self, + x: torch.Tensor, + c_text: torch.Tensor, + c_img: Optional[torch.Tensor] = None, + transform_positive: bool = True, + return_dict: bool = True, + ): + # TODO: do we need the augmentations from the original StyleGAN-T code? + if transform_positive: + # Transform to [0, 1]. + x = x.add(1).div(2) + + # Forward pass through feature network. + features = self.feature_network(x) + + # Apply discriminator heads. + logits = [] + for k, head in self.heads.items(): + logits.append(head(features[k], c_text, c_img).view(x.size(0), -1)) + logits = torch.cat(logits, dim=1) + + if not return_dict: + return (logits,) + + return DiscriminatorOutput(logits=logits) + + +def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="student"): + logger.info("Running validation... ") + + unet = accelerator.unwrap_model(unet) + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_teacher_model, + vae=vae, + unet=unet, + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + lora_state_dict = get_module_kohya_state_dict(unet, "lora_unet", weight_dtype) + pipeline.load_lora_weights(lora_state_dict) + pipeline.fuse_lora() + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + validation_prompts = [ + "A cinematic shot of robot with colorful feathers.", + "Teddy bears working on new AI research on the moon in the 1980s.", + "A robot is playing the guitar at a rock concert in front of a large crowd.", + "A photo of an astronaut riding a horse in the forest. There is a river in front of them with water lilies.", + ] + + image_logs = [] + + for _, prompt in enumerate(validation_prompts): + images = [] + with torch.autocast("cuda"): + images = pipeline( + prompt=prompt, + num_inference_steps=1, + num_images_per_prompt=4, + generator=generator, + ).images + image_logs.append({"validation_prompt": prompt, "images": images}) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + formatted_images = [] + for image in images: + formatted_images.append(np.asarray(image)) + + formatted_images = np.stack(formatted_images) + + tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") + elif tracker.name == "wandb": + formatted_images = [] + + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + for image in images: + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + tracker.log({f"validation/{name}": formatted_images}) + else: + logger.warn(f"image logging not implemented for {tracker.name}") + + del pipeline + gc.collect() + torch.cuda.empty_cache() + + return image_logs + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + return x[(...,) + (None,) * dims_to_append] + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +@torch.no_grad() +def update_ema(target_params, source_params, rate=0.99): + """ + Update target parameters to be closer to those of source parameters using + an exponential moving average. + + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src, alpha=1 - rate) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + # ----------Model Checkpoint Loading Arguments---------- + parser.add_argument( + "--pretrained_teacher_model", + type=str, + default=None, + required=True, + help="Path to pretrained LDM teacher model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--teacher_revision", + type=str, + default=None, + required=False, + help="Revision of pretrained LDM teacher model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained LDM model identifier from huggingface.co/models.", + ) + # ----------Training Arguments---------- + # ----General Training Arguments---- + parser.add_argument( + "--output_dir", + type=str, + default="lcm-xl-distilled", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + # ----Logging---- + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + # ----Checkpointing---- + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + # ----Image Processing---- + parser.add_argument( + "--train_shards_path_or_url", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--interpolation_type", + type=str, + default="bilinear", + help=( + "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," + " `bicubic`, `lanczos`, and `nearest`." + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + # ----Dataloader---- + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + # ----Batch Size and Training Steps---- + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + # ----Learning Rate---- + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--discriminator_learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + # ----Optimizer (Adam)---- + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--discriminator_adam_beta1", type=float, default=0.0, help="The beta1 parameter for the Adam optimizer." + ) + parser.add_argument( + "--discriminator_adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer." + ) + parser.add_argument("--discriminator_adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument( + "--discriminator_adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer" + ) + # ----Diffusion Training Arguments---- + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + # ----Adversarial Diffusion Distillation (ADD) Specific Arguments---- + parser.add_argument( + "--pretrained_feature_network", + type=str, + default="vit_small_patch14_dinov2.lvd142m", + help=( + "The pretrained feature network used in the discriminator, typically a vision transformer (ViT) trained" + " the DINO objective. The given identifier should be compatible with `timm.create_model`." + ), + ) + parser.add_argument( + "--feature_network_patch_size", + type=int, + default=14, + help="The patch size of the `pretrained_feature_network`.", + ) + parser.add_argument( + "--cond_map_dim", + type=int, + default=64, + help=( + "The common dimension to which the discriminator feature network features and conditioning embeddings will" + " be projected to in the discriminator heads." + ), + ) + parser.add_argument( + "--use_pretrained_projection", + action="store_true", + help=( + "Whether to use a text encoder which projects the CLIP text embedding to a fixed length vector (that is, a" + " `CLIPTextModelWithProjection` rather than the `CLIPTextModel` usually used by SD 1.X/2.X.). If set, the" + " text encoder will be loaded from the model id or path given in `text_encoder_with_proj`." + ), + ) + parser.add_argument( + "--text_encoder_with_proj", + type=str, + default="openai/clip-vit-large-patch14", + help=( + "The text encoder with projection that will be used if `use_pretrained_projection` is set. Note that the" + " default value of `openai/clip-vit-large-patch14` is the CLIP model used by Stable Diffusion v1.5." + ), + ) + parser.add_argument( + "--use_image_conditioning", + action="store_true", + help=( + "Whether to also use an image encoder to calculate image conditioning embeddings for the discriminator. If" + " set, the model at the timm model id given in `image_encoder_with_proj` will be used." + ), + ) + parser.add_argument( + "--pretrained_image_encoder", + type=str, + default="vit_large_patch14_dinov2.lvd142m", + help=( + "An optional image encoder to add image conditioning information to the discriminator. Is used if" + " `use_image_conditioning` is set. The model id should be loadable by `timm.create_model`. Note that ADD" + " uses a DINOv2 ViT-L encoder (see section 4 of the paper)." + ), + ) + parser.add_argument( + "--cond_resolution", + type=int, + default=518, + help="Resolution to resize the original images to for image conditioning.", + ) + parser.add_argument( + "--cond_interpolation_type", + type=str, + default="bicubic", + help=( + "The interpolation function used when resizing the image for conditioning. Choose between `bilinear`," + " `bicubic`, `lanczos`, and `nearest`." + ), + ) + parser.add_argument( + "--weight_schedule", + type=str, + default="exponential", + help=( + "The time-dependent weighting function gamma used for scaling the distillation loss Choose between" + " `uniform`, `exponential`, `sds`, and `nfsd`." + ), + ) + parser.add_argument( + "--student_distillation_steps", + type=int, + default=4, + help="The number of student timesteps N used during distillation.", + ) + parser.add_argument( + "--student_timestep_schedule", + type=str, + default="uniform", + help="The method by which the student timestep schedule is determined. Currently, only `uniform` is implemented.", + ) + parser.add_argument( + "--student_custom_timesteps", + type=str, + default=None, + help=( + "A comma-separated list of timesteps which will override the timestep schedule set in" + " `student_timestep_schedule` if set." + ), + ) + parser.add_argument( + "--discriminator_r1_strength", + type=float, + default=1e-05, + help="The discriminator R1 gradient penalty strength gamma.", + ) + parser.add_argument( + "--distillation_weight_factor", + type=float, + default=2.5, + help="Multiplicative weight factor lambda for the distillation loss on the student generator U-Net.", + ) + parser.add_argument( + "--w_min", + type=float, + default=1.0, + required=False, + help=( + "The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG" + " formulation." + ), + ) + parser.add_argument( + "--w_max", + type=float, + default=15.0, + required=False, + help=( + "The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG" + " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as" + " compared to the original paper." + ), + ) + parser.add_argument( + "--vae_encode_batch_size", + type=int, + default=32, + required=False, + help=( + "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE." + " Encoding or decoding the whole batch at once may run into OOM issues." + ), + ) + # LoRA Arguments + parser.add_argument( + "--lora_rank", + type=int, + default=64, + help="The rank of the LoRA projection matrix.", + ) + parser.add_argument( + "--lora_alpha", + type=int, + default=64, + help=( + "The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight" + " update delta_W. No scaling will be performed if this value is equal to `lora_rank`." + ), + ) + parser.add_argument( + "--lora_dropout", + type=float, + default=0.0, + help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.", + ) + parser.add_argument( + "--lora_target_modules", + type=str, + default=None, + help=( + "A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will" + " be used. By default, LoRA will be applied to all conv and linear layers." + ), + ) + # ----Exponential Moving Average (EMA)---- + parser.add_argument( + "--use_ema", action="store_true", help="Whether to also maintain an EMA version of the student U-Net weights." + ) + parser.add_argument( + "--ema_decay", + type=float, + default=0.95, + required=False, + help="The exponential moving average (EMA) rate or decay factor.", + ) + parser.add_argument( + "--ema_min_decay", + type=float, + default=None, + help=( + "The minimum EMA decay rate, which the effective EMA decay rate (e.g. if warmup is used) will never go" + " below. If not set, the value set for `ema_decay` will be used, which results in a fixed EMA decay rate" + " equal to that value." + ), + ) + # ----Mixed Precision---- + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--cast_teacher_unet", + action="store_true", + help="Whether to cast the teacher U-Net to the precision specified by `--mixed_precision`.", + ) + # ----Training Optimizations---- + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + # ----Distributed Training---- + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + # ----------Validation Arguments---------- + parser.add_argument( + "--validation_steps", + type=int, + default=200, + help="Run validation every X steps.", + ) + # ----------Huggingface Hub Arguments----------- + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + # ----------Accelerate Arguments---------- + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") + + return args + + +# Adapted from pipelines.StableDiffusionPipeline.encode_prompt +def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, has_projection, is_train=True): + captions = [] + for caption in prompt_batch: + if random.random() < proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + + with torch.no_grad(): + text_inputs = tokenizer( + captions, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + text_model_output = text_encoder(text_input_ids.to(text_encoder.device)) + # Get last hidden states for use in conditioning the student and teacher U-Nets + prompt_embeds = text_model_output.last_hidden_state + # Get text embedding (if model has projection) or pooled output for use in conditioning the discriminator + if has_projection: + pooled_output = text_model_output.text_embeds + else: + pooled_output = text_model_output.pooler_output + + return prompt_embeds, pooled_output + + +def encode_images(image_batch, image_encoder): + # image_encoder pre-processing is done in SDText2ImageDataset + image_embeds = image_encoder(image_batch) + return image_embeds + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + token=args.hub_token, + private=True, + ).repo_id + + # 1. Create the noise scheduler and the desired noise schedule. + # Enforce zero terminal SNR (see section 3.1 of ADD paper) + teacher_scheduler = DDPMScheduler.from_pretrained( + args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision + ) + if not teacher_scheduler.config.rescale_betas_zero_snr: + teacher_scheduler.config["rescale_betas_zero_snr"] = True + noise_scheduler = DDPMScheduler(**teacher_scheduler.config) + + # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us + # Note that the ADD paper parameterizes alpha and sigma as x_t = alpha_t * x_0 + sigma_t * eps + alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) + sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) + # denoiser gets predicted original sample x_0 from prediction_type using alpha and sigma noise schedules + denoiser = Denoiser(alpha_schedule, sigma_schedule) + + # Create time-dependent weighting schedule c(t) for scaling the GAN generator reconstruction loss term. + if args.weight_schedule == "uniform": + train_weight_schedule = torch.ones_like(noise_scheduler.alphas_cumprod) + elif args.weight_schedule == "exponential": + # Set weight schedule equal to alpha_schedule. Higher timesteps have less weight. + train_weight_schedule = alpha_schedule + elif args.weight_schedule == "sds": + # Score distillation sampling weighting: alpha_t / (2 * sigma_t) * w(t) + # NOTE: choose w(t) = 1 + # Introduced in the DreamFusion paper: https://arxiv.org/pdf/2209.14988.pdf. + train_weight_schedule = alpha_schedule / (2 * sigma_schedule) + elif args.weight_schedule == "nfsd": + # Noise-free score distillation weighting + # Introduced in "Noise-Free Score Distillation": https://arxiv.org/pdf/2310.17590.pdf. + raise NotImplementedError("NFSD distillation weighting is not yet implemented.") + else: + raise ValueError( + f"Weight schedule {args.weight_schedule} is not currently supported. Supported schedules are `uniform`," + f" `exponential`, `sds`, and `nfsd`." + ) + + # Create student timestep schedule tau_1, ..., tau_N. + if args.student_custom_timesteps is not None: + student_timestep_schedule = np.asarray( + sorted([int(timestep.strip()) for timestep in args.student_custom_timesteps.split(",")]), dtype=np.int64 + ) + elif args.student_timestep_schedule == "uniform": + student_timestep_schedule = ( + np.linspace(0, noise_scheduler.config.num_train_timesteps - 1, args.student_distillation_steps) + .round() + .astype(np.int64) + ) + else: + raise ValueError( + f"Student timestep schedule {args.student_timestep_schedule} was not recognized and custom student" + f" timesteps have not been provided. Either use one of `uniform` for `student_timestep_schedule` or" + f" provide custom timesteps via `student_custom_timesteps`." + ) + student_distillation_steps = student_timestep_schedule.shape[0] + + # 2. Load tokenizers from SD 1.X/2.X checkpoint. + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False + ) + + # 3. Load text encoders from SD 1.X/2.X checkpoint. + if args.use_pretrained_projection: + text_encoder = CLIPTextModelWithProjection.from_pretrained(args.text_encoder_with_proj) + else: + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision + ) + + # Optionally load a image encoder model for image conditioning of the discriminator. + if args.use_image_conditioning: + # Set num_classes=0 so that we get image embeddings from image_encoder forward pass + image_encoder = timm.create_model(args.pretrained_image_encoder, pretrained=True, num_classes=0) + + # 4. Load VAE from SD 1.X/2.X checkpoint + vae = AutoencoderKL.from_pretrained( + args.pretrained_teacher_model, + subfolder="vae", + revision=args.teacher_revision, + ) + + # 5. Load teacher U-Net from SD 1.X/2.X checkpoint + teacher_unet = UNet2DConditionModel.from_pretrained( + args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision + ) + + # 6. Initialize GAN generator U-Net from SD 1.X/2.X checkpoint with the teacher U-Net's pretrained weights + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision + ) + + # Make exponential moving average (EMA) version of the student unet weights, if using. + if args.use_ema: + if args.ema_min_decay is None: + # Default to `args.ema_decay`, which results in a fixed EMA decay rate throughout distillation. + args.ema_min_decay = args.ema_decay + ema_unet = UNet2DConditionModel.from_pretrained( + args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision + ) + ema_unet = EMAModel( + ema_unet.parameters(), + decay=args.ema_decay, + min_decay=args.ema_min_decay, + model_cls=UNet2DConditionModel, + model_config=ema_unet.config, + ) + + # 7. Initialize GAN discriminator. + text_conditioning_dim = ( + text_encoder.config.projection_dim if args.use_pretrained_projection else text_encoder.config.hidden_size + ) + img_conditioning_dim = image_encoder.num_features if args.use_image_conditioning else None + discriminator = Discriminator( + pretrained_feature_network=args.pretrained_feature_network, + c_text_embedding_dim=text_conditioning_dim, + c_img_embedding_dim=img_conditioning_dim, + cond_map_dim=args.cond_map_dim, + patch_size=[args.feature_network_patch_size, args.feature_network_patch_size], + ) + + # 8. Freeze teacher vae, text_encoder, and teacher_unet + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + teacher_unet.requires_grad_(False) + if args.use_image_conditioning: + image_encoder.eval() + image_encoder.requires_grad_(False) + + unet.train() + + # 9. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer. + if args.lora_target_modules is not None: + lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")] + else: + lora_target_modules = [ + "to_q", + "to_k", + "to_v", + "to_out.0", + "proj_in", + "proj_out", + "ff.net.0.proj", + "ff.net.2", + "conv1", + "conv2", + "conv_shortcut", + "downsamplers.0.conv", + "upsamplers.0.conv", + "time_emb_proj", + ] + lora_config = LoraConfig( + r=args.lora_rank, + target_modules=lora_target_modules, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + ) + unet = get_peft_model(unet, lora_config) + + # Check that all trainable models are in full precision + low_precision_error_string = ( + " Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training, copy of the weights should still be float32." + ) + + if accelerator.unwrap_model(unet).dtype != torch.float32: + raise ValueError( + f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" + ) + + # 10. Handle mixed precision and device placement + # For mixed precision training we cast all non-trainable weigths to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move vae, text_encoder, and teacher_unet to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + vae.to(accelerator.device) + if args.pretrained_vae_model_name_or_path is not None: + vae.to(dtype=weight_dtype) + text_encoder.to(accelerator.device, dtype=weight_dtype) + teacher_unet.to(accelerator.device, dtype=weight_dtype) + if args.use_image_conditioning: + image_encoder.to(accelerator.device, dtype=weight_dtype) + + # Move target (EMA) unet to device but keep in full precision + if args.use_ema: + ema_unet.to(accelerator.device) + + # Also move the denoiser and schedules to accelerator.device + denoiser.to(accelerator.device) + train_weight_schedule = train_weight_schedule.to(accelerator.device) + student_timestep_schedule = torch.from_numpy(student_timestep_schedule).to(accelerator.device) + + # 11. Handle saving and loading of checkpoints + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + if args.use_ema: + ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) + + unet_ = accelerator.unwrap_model(unet) + lora_state_dict = get_peft_model_state_dict(unet_, adapter_name="default") + StableDiffusionPipeline.save_lora_weights(os.path.join(output_dir, "unet_lora"), lora_state_dict) + # save weights in peft format to be able to load them back + unet_.save_pretrained(output_dir) + + for i, model in enumerate(models): + # model.save_pretrained(os.path.join(output_dir, "unet")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) + ema_unet.load_state_dict(load_model.state_dict()) + ema_unet.to(accelerator.device) + del load_model + + # load the LoRA into the model + unet_ = accelerator.unwrap_model(unet) + unet_.load_adapter(input_dir, "default", is_trainable=True) + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + # load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + # model.register_to_config(**load_model.config) + + # model.load_state_dict(load_model.state_dict()) + # del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # 12. Enable optimizations + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + teacher_unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # 13. Optimizer creation for generator and discriminator + optimizer = optimizer_class( + unet.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + discriminator_optimizer = optimizer_class( + discriminator.parameters(), + lr=args.discriminator_learning_rate, + betas=(args.discriminator_adam_beta1, args.discriminator_adam_beta2), + weight_decay=args.discriminator_adam_weight_decay, + eps=args.discriminator_adam_epsilon, + ) + + # 14. Dataset creation and data processing + # Compute the text encoder last hidden states `prompt_embeds` for use in the teacher/student U-Nets and pooled + # output `text_embedding` for use in the discriminator. + def compute_embeddings( + prompt_batch, proportion_empty_prompts, text_encoder, tokenizer, has_projection, is_train=True + ): + prompt_embeds, text_embedding = encode_prompt( + prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, has_projection, is_train + ) + return {"prompt_embeds": prompt_embeds, "text_embedding": text_embedding} + + def compute_image_embeddings(image_batch, image_encoder): + image_embeds = encode_images(image_batch, image_encoder) + return {"image_embeds": image_embeds} + + dataset = SDText2ImageDataset( + train_shards_path_or_url=args.train_shards_path_or_url, + num_train_examples=args.max_train_samples, + per_gpu_batch_size=args.train_batch_size, + global_batch_size=args.train_batch_size * accelerator.num_processes, + num_workers=args.dataloader_num_workers, + resolution=args.resolution, + interpolation_type=args.interpolation_type, + shuffle_buffer_size=1000, + pin_memory=True, + persistent_workers=True, + use_image_conditioning=args.use_image_conditioning, + cond_resolution=args.cond_resolution, + cond_interpolation_type=args.cond_interpolation_type, + ) + train_dataloader = dataset.train_dataloader + + compute_embeddings_fn = functools.partial( + compute_embeddings, + proportion_empty_prompts=0, + text_encoder=text_encoder, + tokenizer=tokenizer, + has_projection=args.use_pretrained_projection, + ) + + if args.use_image_conditioning: + compute_image_embeddings_fn = functools.partial( + compute_image_embeddings, + image_encoder=image_encoder, + ) + + # 15. Create learning rate scheduler for generator and discriminator + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps, + ) + + discriminator_lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=discriminator_optimizer, + num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps, + ) + + # 16. Prepare for training + # Prepare everything with our `accelerator`. + ( + unet, + discriminator, + optimizer, + discriminator_optimizer, + lr_scheduler, + discriminator_lr_scheduler, + ) = accelerator.prepare( + unet, + discriminator, + optimizer, + discriminator_optimizer, + lr_scheduler, + discriminator_lr_scheduler, + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + accelerator.init_trackers(args.tracker_project_name, config=tracker_config) + + # Prepare unconditional text embedding for CFG. + uncond_input_ids = tokenizer( + [""] * args.train_batch_size, return_tensors="pt", padding="max_length", max_length=MAX_SEQ_LENGTH + ).input_ids.to(accelerator.device) + uncond_prompt_embeds = text_encoder(uncond_input_ids)[0] + + # 17. Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num batches each epoch = {train_dataloader.num_batches}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + # 1. Load and process the image and text conditioning + if args.use_image_conditioning: + image, text, cond_image = batch + else: + image, text = batch + + image = image.to(accelerator.device, non_blocking=True) + encoded_text = compute_embeddings_fn(text) + if args.use_image_conditioning: + encoded_image = compute_image_embeddings_fn(cond_image) + + pixel_values = image.to(dtype=weight_dtype) + if vae.dtype != weight_dtype: + vae.to(dtype=weight_dtype) + + # encode pixel values with batch size of at most args.vae_encode_batch_size + latents = [] + for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size): + latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()) + latents = torch.cat(latents, dim=0) + + latents = latents * vae.config.scaling_factor + latents = latents.to(weight_dtype) + bsz = latents.shape[0] + + # 2. Sample random student timesteps s uniformly in `student_timestep_schedule` and sample random + # teacher timesteps t uniformly in [0, ..., noise_scheduler.config.num_train_timesteps - 1]. + student_index = torch.randint(0, student_distillation_steps, (bsz,), device=latents.device).long() + student_timesteps = student_timestep_schedule[student_index] + teacher_timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ).long() + + # 3. Sample noise and add it to the latents according to the noise magnitude at each student timestep + # (that is, run the forward process on the student model) + student_noise = torch.randn_like(latents) + noisy_student_input = noise_scheduler.add_noise(latents, student_noise, student_timesteps) + + # 4. Prepare prompt embeds (for teacher/student U-Net) and text embedding (for discriminator). + prompt_embeds = encoded_text.pop("prompt_embeds") + text_embedding = encoded_text.pop("text_embedding") + image_embedding = None + if args.use_image_conditioning: + image_embedding = encoded_image.pop("image_embeds") + # Only supply image conditioning when student timestep is not last training timestep T. + image_embedding = torch.where( + student_timesteps.unsqueeze(1) < noise_scheduler.config.num_train_timesteps - 1, + image_embedding, + torch.zeros_like(image_embedding), + ) + + # 5. Get the student model predicted original sample `student_x_0`. + student_noise_pred = unet( + noisy_student_input, + student_timesteps, + encoder_hidden_states=prompt_embeds.float(), + ).sample + student_x_0 = denoiser(student_noise_pred, student_timesteps, noisy_student_input) + + # 6. Sample noise and add it to the student's predicted original sample according to the noise + # magnitude at each teacher timestep (that is, run the forward process on the teacher model, but + # using `student_x_0` instead of latents sampled from the prior). + teacher_noise = torch.randn_like(student_x_0) + noisy_teacher_input = noise_scheduler.add_noise(student_x_0, teacher_noise, teacher_timesteps) + + # 7. Sample random guidance scales w ~ U[w_min, w_max] for CFG. + w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min + w = w.reshape(bsz, 1, 1, 1) + # Move to U-Net device and dtype + w = w.to(device=latents.device, dtype=latents.dtype) + + # 8. Get teacher model predicted original sample `teacher_x_0`. + with torch.no_grad(), torch.autocast("cuda", dtype=weight_dtype): + teacher_cond_noise_pred = teacher_unet( + noisy_teacher_input.detach(), + teacher_timesteps, + encoder_hidden_states=prompt_embeds, + ).sample + + teacher_uncond_noise_pred = teacher_unet( + noisy_teacher_input.detach(), + teacher_timesteps, + encoder_hidden_states=uncond_prompt_embeds, + ).sample + + # Get the teacher's CFG estimate of x_0. + teacher_cfg_noise_pred = w * teacher_cond_noise_pred + (1 - w) * teacher_uncond_noise_pred + teacher_x_0 = denoiser(teacher_cfg_noise_pred, teacher_timesteps, noisy_teacher_input) + + ############################ + # 9. Discriminator Loss + ############################ + discriminator_optimizer.zero_grad(set_to_none=True) + + # 1. Decode real and fake (generated) latents back to pixel space. + # NOTE: the paper doesn't mention this explicitly AFAIK but I think this makes sense since the + # pretrained feature network for the discriminator operates in pixel space rather than latent space. + unscaled_student_x_0 = (1 / vae.config.scaling_factor) * student_x_0 + # Perform batched decode with batch size of at most args.vae_encode_batch_size + student_gen_image = [] + for i in range(0, unscaled_student_x_0.shape[0], args.vae_encode_batch_size): + student_gen_image.append( + vae.decode( + unscaled_student_x_0[i : i + args.vae_encode_batch_size].to(dtype=weight_dtype) + ).sample + ) + student_gen_image = torch.cat(student_gen_image, dim=0) + + # 2. Get discriminator real/fake outputs on the real and fake (generated) images respectively. + disc_output_real = discriminator(pixel_values.float(), text_embedding, image_embedding) + disc_output_fake = discriminator(student_gen_image.detach().float(), text_embedding, image_embedding) + + # 3. Calculate the discriminator real adversarial loss terms. + d_logits_real = disc_output_real.logits + # Use hinge loss (see section 3.2, Equation 3 of paper) + d_adv_loss_real = torch.mean(F.relu(torch.ones_like(d_logits_real) - d_logits_real)) + + # 4. Calculate the discriminator R1 gradient penalty term with respect to the gradients from the real + # data. + d_r1_regularizer = 0 + for k, head in discriminator.heads.items(): + head_grad_params = torch.autograd.grad( + outputs=d_adv_loss_real, inputs=head.parameters(), create_graph=True + ) + head_grad_norm = 0 + for grad in head_grad_params: + head_grad_norm += grad.abs().sum() + d_r1_regularizer += head_grad_norm + + d_loss_real = d_adv_loss_real + args.discriminator_r1_strength * d_r1_regularizer + accelerator.backward(d_loss_real, retain_graph=True) + + # 5. Calculate the discriminator fake adversarial loss terms. + d_logits_fake = disc_output_fake.logits + # Use hinge loss (see section 3.2, Equation 3 of paper) + d_adv_loss_fake = torch.mean(F.relu(torch.ones_like(d_logits_fake) + d_logits_fake)) + accelerator.backward(d_adv_loss_fake) + + d_total_loss = d_loss_real + d_adv_loss_fake + + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(discriminator.parameters(), args.max_grad_norm) + discriminator_optimizer.step() + discriminator_lr_scheduler.step() + + ############################ + # 10. Generator Loss + ############################ + optimizer.zero_grad(set_to_none=True) + + # 1. Rerun the disc on generated image, but this time allow gradients to flow through the generator + disc_output_fake = discriminator(student_gen_image, text_embedding, image_embedding) + + # 2. Calculate generator adversarial loss term + g_logits_fake = disc_output_fake.logits + g_adv_loss = torch.mean(-g_logits_fake) + + ############################ + # 11. Distillation Loss + ############################ + # Calculate distillation loss in pixel space rather than latent space (see section 3.1) + unscaled_teacher_x_0 = (1 / vae.config.scaling_factor) * teacher_x_0 + # Perform batched decode with batch size of at most args.vae_encode_batch_size + teacher_gen_image = [] + for i in range(0, unscaled_teacher_x_0.shape[0], args.vae_encode_batch_size): + teacher_gen_image.append( + vae.decode( + unscaled_teacher_x_0[i : i + args.vae_encode_batch_size].to(dtype=weight_dtype) + ).sample + ) + teacher_gen_image = torch.cat(teacher_gen_image, dim=0) + + per_instance_distillation_loss = F.mse_loss( + student_gen_image.float(), teacher_gen_image.float(), reduction="none" + ) + # Note that we use the teacher timesteps t when getting the loss weights. + c_t = extract_into_tensor( + train_weight_schedule, teacher_timesteps, per_instance_distillation_loss.shape + ) + g_distillation_loss = torch.mean(c_t * per_instance_distillation_loss) + + g_total_loss = g_adv_loss + args.distillation_weight_factor * g_distillation_loss + + # Backprop on the generator total loss + accelerator.backward(g_total_loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + # 12. Perform an EMA update on the EMA version of the student U-Net weights. + if args.use_ema: + ema_unet.step(unet.parameters()) + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if global_step % args.validation_steps == 0: + if args.use_ema: + # Store the student unet weights and load the EMA weights. + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + + log_validation(vae, unet, args, accelerator, weight_dtype, global_step, "ema_student") + + # Restore student unet weights + ema_unet.restore(unet.parameters()) + + log_validation(vae, unet, args, accelerator, weight_dtype, global_step, "student") + + logs = { + "d_total_loss": d_total_loss.detach().item(), + "g_total_loss": g_total_loss.detach().item(), + "g_adv_loss": g_adv_loss.detach().item(), + "g_distill_loss": g_distillation_loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + } + progress_bar.set_postfix(**logs) + + # Write out additional values for accelerator to report. + logs["d_adv_loss_fake"] = d_adv_loss_fake.detach().item() + logs["d_adv_loss_real"] = d_adv_loss_real.detach().item() + logs["d_r1_regularizer"] = d_r1_regularizer.detach().item() + logs["d_loss_real"] = d_loss_real.detach().item() + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) + unet.save_pretrained(os.path.join(args.output_dir, "unet")) + + lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default") + StableDiffusionPipeline.save_lora_weights(os.path.join(args.output_dir, "unet_lora"), lora_state_dict) + + # If using EMA, save EMA weights as well. + if args.use_ema: + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + + unet.save_pretrained(os.path.join(args.output_dir, "ema_unet")) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/add/train_add_distill_lora_sdxl_wds.py b/examples/add/train_add_distill_lora_sdxl_wds.py new file mode 100644 index 000000000000..672b3ee94bf8 --- /dev/null +++ b/examples/add/train_add_distill_lora_sdxl_wds.py @@ -0,0 +1,2192 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import copy +import functools +import gc +import itertools +import json +import logging +import math +import os +import random +import shutil +import types +from pathlib import Path +from typing import Callable, List, Optional, Union + +import accelerate +import numpy as np +import timm +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import torchvision.transforms.functional as TF +import transformers +import webdataset as wds +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from braceexpand import braceexpand +from huggingface_hub import create_repo +from packaging import version +from peft import LoraConfig, get_peft_model, get_peft_model_state_dict +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torch.utils.data import default_collate +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig +from webdataset.tariterators import ( + base_plus_ext, + tar_file_expander, + url_opener, + valid_sample, +) + +import diffusers +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + StableDiffusionXLPipeline, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel +from diffusers.utils import BaseOutput, check_min_version, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available + + +MAX_SEQ_LENGTH = 77 + +# Adjust for your dataset +WDS_JSON_WIDTH = "width" # original_width for LAION +WDS_JSON_HEIGHT = "height" # original_height for LAION +MIN_SIZE = 700 # ~960 for LAION, ideal: 1024 if the dataset contains large images + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.25.0.dev0") + +logger = get_logger(__name__) + + +def get_module_kohya_state_dict(module, prefix: str, dtype: torch.dtype, adapter_name: str = "default"): + kohya_ss_state_dict = {} + for peft_key, weight in get_peft_model_state_dict(module, adapter_name=adapter_name).items(): + kohya_key = peft_key.replace("base_model.model", prefix) + kohya_key = kohya_key.replace("lora_A", "lora_down") + kohya_key = kohya_key.replace("lora_B", "lora_up") + kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2) + kohya_ss_state_dict[kohya_key] = weight.to(dtype) + + # Set alpha parameter + if "lora_down" in kohya_key: + alpha_key = f'{kohya_key.split(".")[0]}.alpha' + kohya_ss_state_dict[alpha_key] = torch.tensor(module.peft_config[adapter_name].lora_alpha).to(dtype) + + return kohya_ss_state_dict + + +def filter_keys(key_set): + def _f(dictionary): + return {k: v for k, v in dictionary.items() if k in key_set} + + return _f + + +def group_by_keys_nothrow(data, keys=base_plus_ext, lcase=True, suffixes=None, handler=None): + """Return function over iterator that groups key, value pairs into samples. + + :param keys: function that splits the key into key and extension (base_plus_ext) :param lcase: convert suffixes to + lower case (Default value = True) + """ + current_sample = None + for filesample in data: + assert isinstance(filesample, dict) + fname, value = filesample["fname"], filesample["data"] + prefix, suffix = keys(fname) + if prefix is None: + continue + if lcase: + suffix = suffix.lower() + # FIXME webdataset version throws if suffix in current_sample, but we have a potential for + # this happening in the current LAION400m dataset if a tar ends with same prefix as the next + # begins, rare, but can happen since prefix aren't unique across tar files in that dataset + if current_sample is None or prefix != current_sample["__key__"] or suffix in current_sample: + if valid_sample(current_sample): + yield current_sample + current_sample = {"__key__": prefix, "__url__": filesample["__url__"]} + if suffixes is None or suffix in suffixes: + current_sample[suffix] = value + if valid_sample(current_sample): + yield current_sample + + +def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue): + # NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw + streams = url_opener(src, handler=handler) + files = tar_file_expander(streams, handler=handler) + samples = group_by_keys_nothrow(files, handler=handler) + return samples + + +def resolve_interpolation_mode(interpolation_type): + if interpolation_type == "bilinear": + interpolation_mode = TF.InterpolationMode.BILINEAR + elif interpolation_type == "bicubic": + interpolation_mode = TF.InterpolationMode.BICUBIC + elif interpolation_type == "nearest": + interpolation_mode = TF.InterpolationMode.NEAREST + elif interpolation_type == "lanczos": + interpolation_mode = TF.InterpolationMode.LANCZOS + else: + raise ValueError( + f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" + f" modes are `bilinear`, `bicubic`, `lanczos`, and `nearest`." + ) + + return interpolation_mode + + +class WebdatasetFilter: + def __init__(self, min_size=1024, max_pwatermark=0.5): + self.min_size = min_size + self.max_pwatermark = max_pwatermark + + def __call__(self, x): + try: + if "json" in x: + x_json = json.loads(x["json"]) + filter_size = (x_json.get(WDS_JSON_WIDTH, 0.0) or 0.0) >= self.min_size and x_json.get( + WDS_JSON_HEIGHT, 0 + ) >= self.min_size + filter_watermark = (x_json.get("pwatermark", 0.0) or 0.0) <= self.max_pwatermark + return filter_size and filter_watermark + else: + return False + except Exception: + return False + + +class SDXLText2ImageDataset: + def __init__( + self, + train_shards_path_or_url: Union[str, List[str]], + num_train_examples: int, + per_gpu_batch_size: int, + global_batch_size: int, + num_workers: int, + resolution: int = 1024, + interpolation_type: str = "bilinear", + shuffle_buffer_size: int = 1000, + pin_memory: bool = False, + persistent_workers: bool = False, + use_fix_crop_and_size: bool = False, + use_image_conditioning: bool = False, + cond_resolution: Optional[int] = None, + cond_interpolation_type: Optional[str] = None, + ): + if not isinstance(train_shards_path_or_url, str): + train_shards_path_or_url = [list(braceexpand(urls)) for urls in train_shards_path_or_url] + # flatten list using itertools + train_shards_path_or_url = list(itertools.chain.from_iterable(train_shards_path_or_url)) + + def get_orig_size(json): + if use_fix_crop_and_size: + return (resolution, resolution) + else: + return (int(json.get(WDS_JSON_WIDTH, 0.0)), int(json.get(WDS_JSON_HEIGHT, 0.0))) + + interpolation_mode = resolve_interpolation_mode(interpolation_type) + if use_image_conditioning: + cond_interpolation_mode = resolve_interpolation_mode(cond_interpolation_type) + + def transform(example): + # resize image + image = example["image"] + if use_image_conditioning: + cond_image = copy.deepcopy(image) + + image = TF.resize(image, resolution, interpolation=interpolation_mode) + + # get crop coordinates and crop image + c_top, c_left, _, _ = transforms.RandomCrop.get_params(image, output_size=(resolution, resolution)) + image = TF.crop(image, c_top, c_left, resolution, resolution) + image = TF.to_tensor(image) + image = TF.normalize(image, [0.5], [0.5]) + + example["image"] = image + example["crop_coords"] = (c_top, c_left) if not use_fix_crop_and_size else (0, 0) + + if use_image_conditioning: + # Prepare a separate image for image conditioning since the preprocessing pipelines are different. + cond_image = TF.resize(cond_image, cond_resolution, interpolation=cond_interpolation_mode) + cond_image = TF.center_crop(cond_image, output_size=(cond_resolution, cond_resolution)) + cond_image = TF.normalize(cond_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + example["cond_image"] = cond_image + + return example + + processing_pipeline = [ + wds.decode("pil", handler=wds.ignore_and_continue), + wds.rename( + image="jpg;png;jpeg;webp", text="text;txt;caption", orig_size="json", handler=wds.warn_and_continue + ), + wds.map(filter_keys({"image", "text", "orig_size"})), + wds.map_dict(orig_size=get_orig_size), + wds.map(transform), + ] + + if use_image_conditioning: + processing_pipeline.append(wds.to_tuple("image", "text", "orig_size", "crop_coords", "cond_image")) + else: + processing_pipeline.append(wds.to_tuple("image", "text", "orig_size", "crop_coords")) + + # Create train dataset and loader + pipeline = [ + wds.ResampledShards(train_shards_path_or_url), + tarfile_to_samples_nothrow, + wds.select(WebdatasetFilter(min_size=MIN_SIZE)), + wds.shuffle(shuffle_buffer_size), + *processing_pipeline, + wds.batched(per_gpu_batch_size, partial=False, collation_fn=default_collate), + ] + + num_worker_batches = math.ceil(num_train_examples / (global_batch_size * num_workers)) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + + # each worker is iterating over this + self._train_dataset = wds.DataPipeline(*pipeline).with_epoch(num_worker_batches) + self._train_dataloader = wds.WebLoader( + self._train_dataset, + batch_size=None, + shuffle=False, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=persistent_workers, + ) + # add meta-data to dataloader instance for convenience + self._train_dataloader.num_batches = num_batches + self._train_dataloader.num_samples = num_samples + + @property + def train_dataset(self): + return self._train_dataset + + @property + def train_dataloader(self): + return self._train_dataloader + + +class Denoiser: + def __init__(self, alphas, sigmas, prediction_type="epsilon"): + self.alphas = alphas + self.sigmas = sigmas + self.prediction_type = prediction_type + + def to(self, device): + self.alphas = self.alphas.to(device) + self.sigmas = self.sigmas.to(device) + return self + + def __call__(self, model_output, timesteps, sample): + alphas = extract_into_tensor(self.alphas, timesteps, sample.shape) + sigmas = extract_into_tensor(self.sigmas, timesteps, sample.shape) + if self.prediction_type == "epsilon": + pred_x_0 = (sample - sigmas * model_output) / alphas + elif self.prediction_type == "sample": + pred_x_0 = model_output + elif self.prediction_type == "v_prediction": + pred_x_0 = alphas * sample - sigmas * model_output + else: + raise ValueError( + f"Prediction type {self.prediction_type} is not supported; currently, `epsilon`, `sample`, and" + f" `v_prediction` are supported." + ) + + return pred_x_0 + + +# Based on SpectralConv1d from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L29 +class SpectralConv1d(torch.nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + torch.nn.utils.parametrizations.spectral_norm(self, name="weight", n_power_iterations=1, eps=1e-12, dim=0) + + +# Based on ResidualBlock from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/shared.py#L20 +class ResidualBlock(torch.nn.Module): + def __init__(self, fn: Callable): + super().__init__() + self.fn = fn + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return (self.fn(x) + x) / np.sqrt(2) + + +# Based on make_block from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L64 +class DiscHeadBlock(torch.nn.Module): + """ + StyleGAN-T block: SpectralConv1d => GroupNorm => LeakyReLU + """ + + def __init__( + self, + channels: int, + kernel_size: int, + num_groups: int = 8, + leaky_relu_neg_slope: float = 0.2, + ): + super().__init__() + self.channels = channels + + self.conv = SpectralConv1d( + channels, + channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + padding_mode="circular", + ) + self.norm = torch.nn.GroupNorm(num_groups, channels) + self.act_fn = torch.nn.LeakyReLU(leaky_relu_neg_slope, inplace=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.conv(x) + x = self.norm(x) + x = self.act_fn(x) + return x + + +# Based on DiscHead in the official StyleGAN-T implementation +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L78 +class DiscriminatorHead(torch.nn.Module): + """ + Implements a StyleGAN-T-style discriminator head. The discriminator head takes in a (possibly intermediate) 1D + sequence of tokens from the feature network, processes it, and combines it with conditioning information to output + per-token logits. + """ + + def __init__( + self, + channels: int, + c_text_embedding_dim: int, + c_img_embedding_dim: Optional[int] = None, + cond_map_dim: int = 64, + ): + super().__init__() + self.channels = channels + self.c_text_embedding_dim = c_text_embedding_dim + self.c_img_embedding_dim = c_img_embedding_dim + self.cond_map_dim = cond_map_dim + + self.input_block = DiscHeadBlock(channels, kernel_size=1) + self.resblock = ResidualBlock(DiscHeadBlock(channels, kernel_size=9)) + # Project each token embedding from channels dimensions to cond_map_dim dimensions. + self.cls = SpectralConv1d(channels, cond_map_dim, kernel_size=1, padding=0) + + # Also project the concatenated conditioning embeddings to dimension cond_map_dim. + c_map_input_dim = self.c_text_embedding_dim + if self.c_img_embedding_dim is not None: + c_map_input_dim += self.c_img_embedding_dim + self.c_map = torch.nn.Linear(c_map_input_dim, cond_map_dim) + + def forward(self, x: torch.Tensor, c_text: torch.Tensor, c_img: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Maps a 1D sequence of tokens from a feature network (e.g. ViT trained with DINO) and a conditioning embedding + to per-token logits. + + Args: + x (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`): + A sequence of 1D tokens (possibly intermediate) from a ViT feature neetwork. Note that the channels dim + should be the same as the feature network's embedding dim. + c_text (`torch.Tensor` of shape `(batch_size, c_text_embedding_dim)`): + A conditioning embedding representing text conditioning information. + c_img (`torch.Tensor` of shape `(batch_size, c_img_embedding_dim)`): + A conditioning embedding representing image conditioning information. + + Returns: + `torch.Tensor` of shape `(batch_size, sequence_length)`: batched 1D sequence of per-token logits. + """ + hidden_states = self.input_block(x) + hidden_states = self.resblock(hidden_states) + out = self.cls(hidden_states) + + if self.c_img_embedding_dim is not None: + c = torch.cat([c_text, c_img], dim=1) + else: + c = c_text + # Project conditioning embedding to cond_map_dim and unsqueeze in the sequence length dimension. + c = self.c_map(c).unsqueeze(-1) + + # Combine image features with projected conditioning embedding via a product. + out = (out * c).sum(1, keepdim=True) * (1 / np.sqrt(self.cond_map_dim)) + + return out + + +activations = {} + + +# Based on get_activation from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L111 +def get_activation(name: str) -> Callable: + def hook(model, input, output): + activations[name] = output + + return hook + + +# Based on _resize_pos_embed from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L66 +def _resize_pos_embed(self, posemb: torch.Tensor, gs_h: int, gs_w: int) -> torch.Tensor: + posemb_tok, posemb_grid = ( + posemb[:, : self.start_index], + posemb[0, self.start_index :], + ) + + gs_old = int(math.sqrt(len(posemb_grid))) + + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear", align_corners=False) + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1) + + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + + return posemb + + +# Based on forward_flex from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L83 +def forward_flex(self, x: torch.Tensor) -> torch.Tensor: + # patch proj and dynamically resize + B, C, H, W = x.size() + x = self.patch_embed.proj(x).flatten(2).transpose(1, 2) + pos_embed = self._resize_pos_embed(self.pos_embed, H // self.patch_size[1], W // self.patch_size[0]) + + # add cls token + cls_tokens = self.cls_token.expand(x.size(0), -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + + # forward pass + x = x + pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + return x + + +# Based on forward_vit from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L60 +def forward_vit(pretrained: torch.nn.Module, x: torch.Tensor) -> dict: + _ = pretrained.model.forward_flex(x) + return {k: pretrained.rearrange(v) for k, v in activations.items()} + + +# Based on AddReadout from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L36 +class AddReadout(torch.nn.Module): + def __init__(self, start_index: int = 1): + super(AddReadout, self).__init__() + self.start_index = start_index + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.start_index == 2: + readout = (x[:, 0] + x[:, 1]) / 2 + else: + readout = x[:, 0] + return x[:, self.start_index :] + readout.unsqueeze(1) + + +# Based on Transpose from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L49 +class Transpose(torch.nn.Module): + def __init__(self, dim0: int, dim1: int): + super(Transpose, self).__init__() + self.dim0 = dim0 + self.dim1 = dim1 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.transpose(self.dim0, self.dim1) + return x.contiguous() + + +# Based on DINO from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L107 +class FeatureNetwork(torch.nn.Module): + """ + DINO ViT model to act as feature extractor for the discriminator. + """ + + def __init__( + self, + pretrained_feature_network: str = "vit_small_patch14_dinov2.lvd142m", + patch_size: List[int] = [14, 14], + hooks: List[int] = [2, 5, 8, 11], + start_index: int = 1, + ): + super().__init__() + self.num_hooks = len(hooks) + 1 + + pretrained_model = timm.create_model(pretrained_feature_network, pretrained=True) + + # Based on make_vit_backbone from the official StyleGAN-T code + # https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/vit_utils.py#L117 + # which I believe is itself based on https://github.com/isl-org/DPT + model_with_hooks = torch.nn.Module() + model_with_hooks.model = pretrained_model + + # Add hooks + model_with_hooks.model.blocks[hooks[0]].register_forward_hook(get_activation("0")) + model_with_hooks.model.blocks[hooks[1]].register_forward_hook(get_activation("1")) + model_with_hooks.model.blocks[hooks[2]].register_forward_hook(get_activation("2")) + model_with_hooks.model.blocks[hooks[3]].register_forward_hook(get_activation("3")) + model_with_hooks.model.pos_drop.register_forward_hook(get_activation("4")) + + # Configure readout + model_with_hooks.rearrange = torch.nn.Sequential(AddReadout(start_index), Transpose(1, 2)) + model_with_hooks.model.start_index = start_index + model_with_hooks.model.patch_size = patch_size + + # We inject this function into the VisionTransformer instances so that + # we can use it with interpolated position embeddings without modifying the library source. + model_with_hooks.model.forward_flex = types.MethodType(forward_flex, model_with_hooks.model) + model_with_hooks.model._resize_pos_embed = types.MethodType(_resize_pos_embed, model_with_hooks.model) + + self.model = model_with_hooks + # Freeze pretrained model with hooks + self.model = self.model.eval().requires_grad_(False) + + self.img_resolution = self.model.model.patch_embed.img_size[0] + self.embed_dim = self.model.model.embed_dim + self.norm = transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) + + def forward(self, x: torch.Tensor): + """ + Forward pass consisting of interpolation, ImageNet normalization, and a forward pass of self.model. + + Args: + x (`torch.Tensor`): + Image with pixel values in [0, 1]. + + Returns: + `Dict[Any]`: dict of activations which are intermediate features from the feature network. The dict values + (feature embeddings) have shape `(batch_size, embed_dim, sequence_length)`. + """ + x = F.interpolate(x, self.img_resolution, mode="area") + x = self.norm(x) + + activation_dict = forward_vit(self.model, x) + return activation_dict + + +class DiscriminatorOutput(BaseOutput): + """ + Output class for the Discriminator module. + """ + + logits: torch.FloatTensor + + +# Based on ProjectedDiscriminator from the official StyleGAN-T code +# https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L130 +class Discriminator(torch.nn.Module): + """ + StyleGAN-T-style discriminator for adversarial diffusion distillation (ADD). + """ + + def __init__( + self, + pretrained_feature_network: str = "vit_small_patch14_dinov2.lvd142m", + c_text_embedding_dim: int = 768, + c_img_embedding_dim: Optional[int] = None, + cond_map_dim: int = 64, + patch_size: List[int] = [14, 14], + hooks: List[int] = [2, 5, 8, 11], + start_index: int = 1, + ): + super().__init__() + self.c_text_embedding_dim = c_text_embedding_dim + self.c_img_embedding_dim = c_img_embedding_dim + self.cond_map_dim = cond_map_dim + + # Frozen feature network, e.g. DINO + self.feature_network = FeatureNetwork( + pretrained_feature_network=pretrained_feature_network, + patch_size=patch_size, + hooks=hooks, + start_index=start_index, + ) + + # Trainable discriminator heads + heads = [] + for i in range(self.feature_network.num_hooks): + heads.append( + [ + str(i), + DiscriminatorHead( + self.feature_network.embed_dim, c_text_embedding_dim, c_img_embedding_dim, cond_map_dim + ), + ] + ) + self.heads = torch.nn.ModuleDict(heads) + + def train(self, mode: bool = True): + self.feature_network = self.feature_network.train(False) + self.heads = self.heads.train(mode) + return self + + def eval(self): + return self.train(False) + + def forward( + self, + x: torch.Tensor, + c_text: torch.Tensor, + c_img: Optional[torch.Tensor] = None, + transform_positive: bool = True, + return_dict: bool = True, + ): + # TODO: do we need the augmentations from the original StyleGAN-T code? + if transform_positive: + # Transform to [0, 1]. + x = x.add(1).div(2) + + # Forward pass through feature network. + features = self.feature_network(x) + + # Apply discriminator heads. + logits = [] + for k, head in self.heads.items(): + logits.append(head(features[k], c_text, c_img).view(x.size(0), -1)) + logits = torch.cat(logits, dim=1) + + if not return_dict: + return (logits,) + + return DiscriminatorOutput(logits=logits) + + +def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="student"): + logger.info("Running validation... ") + + unet = accelerator.unwrap_model(unet) + pipeline = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_teacher_model, + vae=vae, + unet=unet, + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + lora_state_dict = get_module_kohya_state_dict(unet, "lora_unet", weight_dtype) + pipeline.load_lora_weights(lora_state_dict) + pipeline.fuse_lora() + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + validation_prompts = [ + "A cinematic shot of robot with colorful feathers.", + "Teddy bears working on new AI research on the moon in the 1980s.", + "A robot is playing the guitar at a rock concert in front of a large crowd.", + "A photo of an astronaut riding a horse in the forest. There is a river in front of them with water lilies.", + ] + + image_logs = [] + + for _, prompt in enumerate(validation_prompts): + images = [] + with torch.autocast("cuda"): + images = pipeline( + prompt=prompt, + num_inference_steps=1, + num_images_per_prompt=4, + generator=generator, + ).images + image_logs.append({"validation_prompt": prompt, "images": images}) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + formatted_images = [] + for image in images: + formatted_images.append(np.asarray(image)) + + formatted_images = np.stack(formatted_images) + + tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") + elif tracker.name == "wandb": + formatted_images = [] + + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + for image in images: + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + tracker.log({f"validation/{name}": formatted_images}) + else: + logger.warn(f"image logging not implemented for {tracker.name}") + + del pipeline + gc.collect() + torch.cuda.empty_cache() + + return image_logs + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") + return x[(...,) + (None,) * dims_to_append] + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +@torch.no_grad() +def update_ema(target_params, source_params, rate=0.99): + """ + Update target parameters to be closer to those of source parameters using + an exponential moving average. + + :param target_params: the target parameter sequence. + :param source_params: the source parameter sequence. + :param rate: the EMA rate (closer to 1 means slower). + """ + for targ, src in zip(target_params, source_params): + targ.detach().mul_(rate).add_(src, alpha=1 - rate) + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + # ----------Model Checkpoint Loading Arguments---------- + parser.add_argument( + "--pretrained_teacher_model", + type=str, + default=None, + required=True, + help="Path to pretrained LDM teacher model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--teacher_revision", + type=str, + default=None, + required=False, + help="Revision of pretrained LDM teacher model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained LDM model identifier from huggingface.co/models.", + ) + # ----------Training Arguments---------- + # ----General Training Arguments---- + parser.add_argument( + "--output_dir", + type=str, + default="lcm-xl-distilled", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + # ----Logging---- + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + # ----Checkpointing---- + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + # ----Image Processing---- + parser.add_argument( + "--train_shards_path_or_url", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--resolution", + type=int, + default=1024, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--interpolation_type", + type=str, + default="bilinear", + help=( + "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," + " `bicubic`, `lanczos`, and `nearest`." + ), + ) + parser.add_argument( + "--use_fix_crop_and_size", + action="store_true", + help="Whether or not to use the fixed crop and size for the teacher model.", + default=False, + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + # ----Dataloader---- + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + # ----Batch Size and Training Steps---- + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=100) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + # ----Learning Rate---- + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--discriminator_learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + # ----Optimizer (Adam)---- + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--discriminator_adam_beta1", type=float, default=0.0, help="The beta1 parameter for the Adam optimizer." + ) + parser.add_argument( + "--discriminator_adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer." + ) + parser.add_argument("--discriminator_adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument( + "--discriminator_adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer" + ) + # ----Diffusion Training Arguments---- + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + # ----Adversarial Diffusion Distillation (ADD) Specific Arguments---- + parser.add_argument( + "--pretrained_feature_network", + type=str, + default="vit_small_patch14_dinov2.lvd142m", + help=( + "The pretrained feature network used in the discriminator, typically a vision transformer (ViT) trained" + " the DINO objective. The given identifier should be compatible with `timm.create_model`." + ), + ) + parser.add_argument( + "--feature_network_patch_size", + type=int, + default=14, + help="The patch size of the `pretrained_feature_network`.", + ) + parser.add_argument( + "--cond_map_dim", + type=int, + default=64, + help=( + "The common dimension to which the discriminator feature network features and conditioning embeddings will" + " be projected to in the discriminator heads." + ), + ) + parser.add_argument( + "--use_image_conditioning", + action="store_true", + help=( + "Whether to also use an image encoder to calculate image conditioning embeddings for the discriminator. If" + " set, the model at the timm model id given in `image_encoder_with_proj` will be used." + ), + ) + parser.add_argument( + "--pretrained_image_encoder", + type=str, + default="vit_large_patch14_dinov2.lvd142m", + help=( + "An optional image encoder to add image conditioning information to the discriminator. Is used if" + " `use_image_conditioning` is set. The model id should be loadable by `timm.create_model`. Note that ADD" + " uses a DINOv2 ViT-L encoder (see section 4 of the paper)." + ), + ) + parser.add_argument( + "--cond_resolution", + type=int, + default=518, + help="Resolution to resize the original images to for image conditioning.", + ) + parser.add_argument( + "--cond_interpolation_type", + type=str, + default="bicubic", + help=( + "The interpolation function used when resizing the image for conditioning. Choose between `bilinear`," + " `bicubic`, `lanczos`, and `nearest`." + ), + ) + parser.add_argument( + "--weight_schedule", + type=str, + default="exponential", + help=( + "The time-dependent weighting function gamma used for scaling the distillation loss Choose between" + " `uniform`, `exponential`, `sds`, and `nfsd`." + ), + ) + parser.add_argument( + "--student_distillation_steps", + type=int, + default=4, + help="The number of student timesteps N used during distillation.", + ) + parser.add_argument( + "--student_timestep_schedule", + type=str, + default="uniform", + help="The method by which the student timestep schedule is determined. Currently, only `uniform` is implemented.", + ) + parser.add_argument( + "--student_custom_timesteps", + type=str, + default=None, + help=( + "A comma-separated list of timesteps which will override the timestep schedule set in" + " `student_timestep_schedule` if set." + ), + ) + parser.add_argument( + "--discriminator_r1_strength", + type=float, + default=1e-05, + help="The discriminator R1 gradient penalty strength gamma.", + ) + parser.add_argument( + "--distillation_weight_factor", + type=float, + default=2.5, + help="Multiplicative weight factor lambda for the distillation loss on the student generator U-Net.", + ) + parser.add_argument( + "--w_min", + type=float, + default=1.0, + required=False, + help=( + "The minimum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG" + " formulation." + ), + ) + parser.add_argument( + "--w_max", + type=float, + default=8.0, + required=False, + help=( + "The maximum guidance scale value for guidance scale sampling. Note that we are using the Imagen CFG" + " formulation rather than the LCM formulation, which means all guidance scales have 1 added to them as" + " compared to the original paper." + ), + ) + parser.add_argument( + "--vae_encode_batch_size", + type=int, + default=8, + required=False, + help=( + "The batch size used when encoding (and decoding) images to latents (and vice versa) using the VAE." + " Encoding or decoding the whole batch at once may run into OOM issues." + ), + ) + # LoRA Arguments + parser.add_argument( + "--lora_rank", + type=int, + default=64, + help="The rank of the LoRA projection matrix.", + ) + parser.add_argument( + "--lora_alpha", + type=int, + default=64, + help=( + "The value of the LoRA alpha parameter, which controls the scaling factor in front of the LoRA weight" + " update delta_W. No scaling will be performed if this value is equal to `lora_rank`." + ), + ) + parser.add_argument( + "--lora_dropout", + type=float, + default=0.0, + help="The dropout probability for the dropout layer added before applying the LoRA to each layer input.", + ) + parser.add_argument( + "--lora_target_modules", + type=str, + default=None, + help=( + "A comma-separated string of target module keys to add LoRA to. If not set, a default list of modules will" + " be used. By default, LoRA will be applied to all conv and linear layers." + ), + ) + # ----Exponential Moving Average (EMA)---- + parser.add_argument( + "--use_ema", action="store_true", help="Whether to also maintain an EMA version of the student U-Net weights." + ) + parser.add_argument( + "--ema_decay", + type=float, + default=0.95, + required=False, + help="The exponential moving average (EMA) rate or decay factor.", + ) + parser.add_argument( + "--ema_min_decay", + type=float, + default=None, + help=( + "The minimum EMA decay rate, which the effective EMA decay rate (e.g. if warmup is used) will never go" + " below. If not set, the value set for `ema_decay` will be used, which results in a fixed EMA decay rate" + " equal to that value." + ), + ) + # ----Mixed Precision---- + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--cast_teacher_unet", + action="store_true", + help="Whether to cast the teacher U-Net to the precision specified by `--mixed_precision`.", + ) + # ----Training Optimizations---- + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + # ----Distributed Training---- + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + # ----------Validation Arguments---------- + parser.add_argument( + "--validation_steps", + type=int, + default=200, + help="Run validation every X steps.", + ) + # ----------Huggingface Hub Arguments----------- + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + # ----------Accelerate Arguments---------- + parser.add_argument( + "--tracker_project_name", + type=str, + default="text2image-fine-tune", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") + + return args + + +# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt +def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True): + prompt_embeds_list = [] + + captions = [] + for caption in prompt_batch: + if random.random() < proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + + with torch.no_grad(): + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + captions, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return prompt_embeds, pooled_prompt_embeds + + +def encode_images(image_batch, image_encoder): + # image_encoder pre-processing is done in SDText2ImageDataset + image_embeds = image_encoder(image_batch) + return image_embeds + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + split_batches=True, # It's important to set this to True when using webdataset to get the right number of steps for lr scheduling. If set to False, the number of steps will be devide by the number of processes assuming batches are multiplied by the number of processes + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + token=args.hub_token, + private=True, + ).repo_id + + # 1. Create the noise scheduler and the desired noise schedule. + # Enforce zero terminal SNR (see section 3.1 of ADD paper) + teacher_scheduler = DDPMScheduler.from_pretrained( + args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision + ) + if not teacher_scheduler.config.rescale_betas_zero_snr: + teacher_scheduler.config["rescale_betas_zero_snr"] = True + noise_scheduler = DDPMScheduler(**teacher_scheduler.config) + + # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us + # Note that the ADD paper parameterizes alpha and sigma as x_t = alpha_t * x_0 + sigma_t * eps + alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) + sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) + # denoiser gets predicted original sample x_0 from prediction_type using alpha and sigma noise schedules + denoiser = Denoiser(alpha_schedule, sigma_schedule) + + # Create time-dependent weighting schedule c(t) for scaling the GAN generator reconstruction loss term. + if args.weight_schedule == "uniform": + train_weight_schedule = torch.ones_like(noise_scheduler.alphas_cumprod) + elif args.weight_schedule == "exponential": + # Set weight schedule equal to alpha_schedule. Higher timesteps have less weight. + train_weight_schedule = alpha_schedule + elif args.weight_schedule == "sds": + # Score distillation sampling weighting: alpha_t / (2 * sigma_t) * w(t) + # NOTE: choose w(t) = 1 + # Introduced in the DreamFusion paper: https://arxiv.org/pdf/2209.14988.pdf. + train_weight_schedule = alpha_schedule / (2 * sigma_schedule) + elif args.weight_schedule == "nfsd": + # Noise-free score distillation weighting + # Introduced in "Noise-Free Score Distillation": https://arxiv.org/pdf/2310.17590.pdf. + raise NotImplementedError("NFSD distillation weighting is not yet implemented.") + else: + raise ValueError( + f"Weight schedule {args.weight_schedule} is not currently supported. Supported schedules are `uniform`," + f" `exponential`, `sds`, and `nfsd`." + ) + + # Create student timestep schedule tau_1, ..., tau_N. + if args.student_custom_timesteps is not None: + student_timestep_schedule = np.asarray( + sorted([int(timestep.strip()) for timestep in args.student_custom_timesteps.split(",")]), dtype=np.int64 + ) + elif args.student_timestep_schedule == "uniform": + student_timestep_schedule = ( + np.linspace(0, noise_scheduler.config.num_train_timesteps - 1, args.student_distillation_steps) + .round() + .astype(np.int64) + ) + else: + raise ValueError( + f"Student timestep schedule {args.student_timestep_schedule} was not recognized and custom student" + f" timesteps have not been provided. Either use one of `uniform` for `student_timestep_schedule` or" + f" provide custom timesteps via `student_custom_timesteps`." + ) + student_distillation_steps = student_timestep_schedule.shape[0] + + # 2. Load tokenizers from SD-XL checkpoint. + tokenizer_one = AutoTokenizer.from_pretrained( + args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False + ) + tokenizer_two = AutoTokenizer.from_pretrained( + args.pretrained_teacher_model, subfolder="tokenizer_2", revision=args.teacher_revision, use_fast=False + ) + + # 3. Load text encoders from SD-XL checkpoint. + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_teacher_model, args.teacher_revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_teacher_model, args.teacher_revision, subfolder="text_encoder_2" + ) + + text_encoder_one = text_encoder_cls_one.from_pretrained( + args.pretrained_teacher_model, subfolder="text_encoder", revision=args.teacher_revision + ) + text_encoder_two = text_encoder_cls_two.from_pretrained( + args.pretrained_teacher_model, subfolder="text_encoder_2", revision=args.teacher_revision + ) + + # Optionally load a image encoder model for image conditioning of the discriminator. + if args.use_image_conditioning: + # Set num_classes=0 so that we get image embeddings from image_encoder forward pass + image_encoder = timm.create_model(args.pretrained_image_encoder, pretrained=True, num_classes=0) + + # 4. Load VAE from SD-XL checkpoint (or more stable VAE) + vae_path = ( + args.pretrained_teacher_model + if args.pretrained_vae_model_name_or_path is None + else args.pretrained_vae_model_name_or_path + ) + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.teacher_revision, + ) + + # 5. Load teacher U-Net from SD-XL checkpoint + teacher_unet = UNet2DConditionModel.from_pretrained( + args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision + ) + + # 6. Initialize GAN generator U-Net from SD-XL checkpoint with the teacher U-Net's pretrained weights + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision + ) + + # Make exponential moving average (EMA) version of the student unet weights, if using. + if args.use_ema: + if args.ema_min_decay is None: + # Default to `args.ema_decay`, which results in a fixed EMA decay rate throughout distillation. + args.ema_min_decay = args.ema_decay + ema_unet = UNet2DConditionModel.from_pretrained( + args.pretrained_teacher_model, subfolder="unet", revision=args.teacher_revision + ) + ema_unet = EMAModel( + ema_unet.parameters(), + decay=args.ema_decay, + min_decay=args.ema_min_decay, + model_cls=UNet2DConditionModel, + model_config=ema_unet.config, + ) + + # 7. Initialize GAN discriminator. + # Use text_encoder_two here since it already projects the CLIP embedding to a fixed length vector (e.g. it's + # already a ClipTextModelWithProjection) + # TODO: what if there's no text_encoder_two? I think we already assume text_encoder_two exists in Step 3 above so + # it might be fine? + text_conditioning_dim = text_encoder_two.config.projection_dim + img_conditioning_dim = image_encoder.num_features if args.use_image_conditioning else None + discriminator = Discriminator( + pretrained_feature_network=args.pretrained_feature_network, + c_text_embedding_dim=text_conditioning_dim, + c_img_embedding_dim=img_conditioning_dim, + patch_size=[args.feature_network_patch_size, args.feature_network_patch_size], + ) + + # 8. Freeze teacher vae, text_encoders, and teacher_unet + vae.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + teacher_unet.requires_grad_(False) + if args.use_image_conditioning: + image_encoder.eval() + image_encoder.requires_grad_(False) + + unet.train() + + # Check that all trainable models are in full precision + low_precision_error_string = ( + " Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training, copy of the weights should still be float32." + ) + + if accelerator.unwrap_model(unet).dtype != torch.float32: + raise ValueError( + f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" + ) + + # 9. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer. + if args.lora_target_modules is not None: + lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")] + else: + lora_target_modules = [ + "to_q", + "to_k", + "to_v", + "to_out.0", + "proj_in", + "proj_out", + "ff.net.0.proj", + "ff.net.2", + "conv1", + "conv2", + "conv_shortcut", + "downsamplers.0.conv", + "upsamplers.0.conv", + "time_emb_proj", + ] + lora_config = LoraConfig( + r=args.lora_rank, + target_modules=lora_target_modules, + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + ) + unet = get_peft_model(unet, lora_config) + + # 10. Handle mixed precision and device placement + # For mixed precision training we cast all non-trainable weigths to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move vae, text_encoders, and teacher_unet to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + vae.to(accelerator.device) + if args.pretrained_vae_model_name_or_path is not None: + vae.to(dtype=weight_dtype) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + teacher_unet.to(accelerator.device, dtype=weight_dtype) + if args.use_image_conditioning: + image_encoder.to(accelerator.device, dtype=weight_dtype) + + # Move target (EMA) unet to device but keep in full precision + if args.use_ema: + ema_unet.to(accelerator.device) + + # Also move the denoiser and schedules to accelerator.device + denoiser.to(accelerator.device) + train_weight_schedule = train_weight_schedule.to(accelerator.device) + student_timestep_schedule = torch.from_numpy(student_timestep_schedule).to(accelerator.device) + + # 11. Handle saving and loading of checkpoints + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + if args.use_ema: + ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) + + unet_ = accelerator.unwrap_model(unet) + lora_state_dict = get_peft_model_state_dict(unet_, adapter_name="default") + StableDiffusionXLPipeline.save_lora_weights(os.path.join(output_dir, "unet_lora"), lora_state_dict) + # save weights in peft format to be able to load them back + unet_.save_pretrained(output_dir) + + for i, model in enumerate(models): + # model.save_pretrained(os.path.join(output_dir, "unet")) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + if args.use_ema: + load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) + ema_unet.load_state_dict(load_model.state_dict()) + ema_unet.to(accelerator.device) + del load_model + + # load the LoRA into the model + unet_ = accelerator.unwrap_model(unet) + unet_.load_adapter(input_dir, "default", is_trainable=True) + + for i in range(len(models)): + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + # load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + # model.register_to_config(**load_model.config) + + # model.load_state_dict(load_model.state_dict()) + # del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # 12. Enable optimizations + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + teacher_unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # 13. Optimizer creation for generator and discriminator + optimizer = optimizer_class( + unet.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + discriminator_optimizer = optimizer_class( + discriminator.parameters(), + lr=args.discriminator_learning_rate, + betas=(args.discriminator_adam_beta1, args.discriminator_adam_beta2), + weight_decay=args.discriminator_adam_weight_decay, + eps=args.discriminator_adam_epsilon, + ) + + # 14. Dataset creation and data processing + # Here, we compute not just the text embeddings but also the additional embeddings + # needed for the SD XL UNet to operate. + def compute_embeddings( + prompt_batch, original_sizes, crop_coords, proportion_empty_prompts, text_encoders, tokenizers, is_train=True + ): + target_size = (args.resolution, args.resolution) + original_sizes = list(map(list, zip(*original_sizes))) + crops_coords_top_left = list(map(list, zip(*crop_coords))) + + original_sizes = torch.tensor(original_sizes, dtype=torch.long) + crops_coords_top_left = torch.tensor(crops_coords_top_left, dtype=torch.long) + + prompt_embeds, pooled_prompt_embeds = encode_prompt( + prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train + ) + add_text_embeds = pooled_prompt_embeds + + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids + add_time_ids = list(target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.repeat(len(prompt_batch), 1) + add_time_ids = torch.cat([original_sizes, crops_coords_top_left, add_time_ids], dim=-1) + add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype) + + prompt_embeds = prompt_embeds.to(accelerator.device) + add_text_embeds = add_text_embeds.to(accelerator.device) + unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} + + def compute_image_embeddings(image_batch, image_encoder): + image_embeds = encode_images(image_batch, image_encoder) + return {"image_embeds": image_embeds} + + dataset = SDXLText2ImageDataset( + train_shards_path_or_url=args.train_shards_path_or_url, + num_train_examples=args.max_train_samples, + per_gpu_batch_size=args.train_batch_size, + global_batch_size=args.train_batch_size * accelerator.num_processes, + num_workers=args.dataloader_num_workers, + resolution=args.resolution, + interpolation_type=args.interpolation_type, + shuffle_buffer_size=1000, + pin_memory=True, + persistent_workers=True, + use_fix_crop_and_size=args.use_fix_crop_and_size, + use_image_conditioning=args.use_image_conditioning, + cond_resolution=args.cond_resolution, + cond_interpolation_type=args.cond_interpolation_type, + ) + train_dataloader = dataset.train_dataloader + + # Let's first compute all the embeddings so that we can free up the text encoders + # from memory. + text_encoders = [text_encoder_one, text_encoder_two] + tokenizers = [tokenizer_one, tokenizer_two] + + compute_embeddings_fn = functools.partial( + compute_embeddings, + proportion_empty_prompts=0, + text_encoders=text_encoders, + tokenizers=tokenizers, + ) + + if args.use_image_conditioning: + compute_image_embeddings_fn = functools.partial( + compute_image_embeddings, + image_encoder=image_encoder, + ) + + # 15. Create learning rate scheduler for generator and discriminator + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps, + ) + + discriminator_lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=discriminator_optimizer, + num_warmup_steps=args.lr_warmup_steps, + num_training_steps=args.max_train_steps, + ) + + # 16. Prepare for training + # Prepare everything with our `accelerator`. + ( + unet, + discriminator, + optimizer, + discriminator_optimizer, + lr_scheduler, + discriminator_lr_scheduler, + ) = accelerator.prepare( + unet, + discriminator, + optimizer, + discriminator_optimizer, + lr_scheduler, + discriminator_lr_scheduler, + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + accelerator.init_trackers(args.tracker_project_name, config=tracker_config) + + # 17. Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num batches each epoch = {train_dataloader.num_batches}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + # 1. Load and process the image, text, and micro-conditioning (original image size, crop coordinates) + if args.use_image_conditioning: + image, text, orig_size, crop_coords, cond_image = batch + else: + image, text, orig_size, crop_coords = batch + + image = image.to(accelerator.device, non_blocking=True) + encoded_text = compute_embeddings_fn(text, orig_size, crop_coords) + if args.use_image_conditioning: + encoded_image = compute_image_embeddings_fn(cond_image) + + if args.pretrained_vae_model_name_or_path is not None: + pixel_values = image.to(dtype=weight_dtype) + if vae.dtype != weight_dtype: + vae.to(dtype=weight_dtype) + else: + pixel_values = image + + # encode pixel values with batch size of at most args.vae_encode_batch_size + latents = [] + for i in range(0, pixel_values.shape[0], args.vae_encode_batch_size): + latents.append(vae.encode(pixel_values[i : i + args.vae_encode_batch_size]).latent_dist.sample()) + latents = torch.cat(latents, dim=0) + + latents = latents * vae.config.scaling_factor + if args.pretrained_vae_model_name_or_path is None: + latents = latents.to(weight_dtype) + bsz = latents.shape[0] + + # 2. Sample random student timesteps s uniformly in `student_timestep_schedule` and sample random + # teacher timesteps t uniformly in [0, ..., noise_scheduler.config.num_train_timesteps - 1]. + student_index = torch.randint(0, student_distillation_steps, (bsz,), device=latents.device).long() + student_timesteps = student_timestep_schedule[student_index] + teacher_timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ).long() + + # 3. Sample noise and add it to the latents according to the noise magnitude at each student timestep + # (that is, run the forward process on the student model) + student_noise = torch.randn_like(latents) + noisy_student_input = noise_scheduler.add_noise(latents, student_noise, student_timesteps) + + # 4. Prepare prompt embeds (for teacher/student U-Net) and text embedding (for discriminator). + prompt_embeds = encoded_text.pop("prompt_embeds") + text_embedding = encoded_text["text_embeds"] + image_embedding = None + if args.use_image_conditioning: + image_embedding = encoded_image.pop("image_embeds") + # Only supply image conditioning when student timestep is not last training timestep T. + image_embedding = torch.where( + student_timesteps.unsqueeze(1) < noise_scheduler.config.num_train_timesteps - 1, + image_embedding, + torch.zeros_like(image_embedding), + ) + + # 5. Get the student model predicted original sample `student_x_0`. + student_noise_pred = unet( + noisy_student_input, + student_timesteps, + encoder_hidden_states=prompt_embeds.float(), + added_cond_kwargs=encoded_text, + ).sample + student_x_0 = denoiser(student_noise_pred, student_timesteps, noisy_student_input) + + # 6. Sample noise and add it to the student's predicted original sample according to the noise + # magnitude at each teacher timestep (that is, run the forward process on the teacher model, but + # using `student_x_0` instead of latents sampled from the prior). + teacher_noise = torch.randn_like(student_x_0) + noisy_teacher_input = noise_scheduler.add_noise(student_x_0, teacher_noise, teacher_timesteps) + + # 7. Sample random guidance scales w ~ U[w_min, w_max] for CFG. + w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min + w = w.reshape(bsz, 1, 1, 1) + # Move to U-Net device and dtype + w = w.to(device=latents.device, dtype=latents.dtype) + + # 8. Get teacher model predicted original sample `teacher_x_0`. + with torch.no_grad(), torch.autocast("cuda", dtype=weight_dtype): + teacher_cond_noise_pred = teacher_unet( + noisy_teacher_input.detach(), + teacher_timesteps, + encoder_hidden_states=prompt_embeds, + added_cond_kwargs=encoded_text, + ).sample + + uncond_prompt_embeds = torch.zeros_like(prompt_embeds) + uncond_pooled_prompt_embeds = torch.zeros_like(encoded_text["text_embeds"]) + uncond_added_conditions = copy.deepcopy(encoded_text) + uncond_added_conditions["text_embeds"] = uncond_pooled_prompt_embeds + teacher_uncond_noise_pred = teacher_unet( + noisy_teacher_input.detach(), + teacher_timesteps, + encoder_hidden_states=uncond_prompt_embeds, + added_cond_kwargs=uncond_added_conditions, + ).sample + + # Get the teacher's CFG estimate of x_0. + teacher_cfg_noise_pred = w * teacher_cond_noise_pred + (1 - w) * teacher_uncond_noise_pred + teacher_x_0 = denoiser(teacher_cfg_noise_pred, teacher_timesteps, noisy_teacher_input) + + ############################ + # 9. Discriminator Loss + ############################ + discriminator_optimizer.zero_grad(set_to_none=True) + + # 1. Decode real and fake (generated) latents back to pixel space. + # NOTE: the paper doesn't mention this explicitly AFAIK but I think this makes sense since the + # pretrained feature network for the discriminator operates in pixel space rather than latent space. + unscaled_student_x_0 = (1 / vae.config.scaling_factor) * student_x_0 + student_gen_image = [] + # Perform batched decode with batch size of at most args.vae_encode_batch_size + for i in range(0, unscaled_student_x_0.shape[0], args.vae_encode_batch_size): + if args.pretrained_vae_model_name_or_path: + student_gen_image.append( + vae.decode( + unscaled_student_x_0[i : i + args.vae_encode_batch_size].to(dtype=weight_dtype) + ).sample + ) + else: + # VAE is in full precision due to possible NaN issues + student_gen_image.append( + vae.decode(unscaled_student_x_0[i : i + args.vae_encode_batch_size]).sample + ) + student_gen_image = torch.cat(student_gen_image, dim=0) + + # 2. Get discriminator real/fake outputs on the real and fake (generated) images respectively. + disc_output_real = discriminator(pixel_values.float(), text_embedding, image_embedding) + disc_output_fake = discriminator(student_gen_image.detach().float(), text_embedding, image_embedding) + + # 3. Calculate the discriminator real adversarial loss terms. + d_logits_real = disc_output_real.logits + # Use hinge loss (see section 3.2, Equation 3 of paper) + d_adv_loss_real = torch.mean(F.relu(torch.ones_like(d_logits_real) - d_logits_real)) + + # 4. Calculate the discriminator R1 gradient penalty term with respect to the gradients from the real + # data. + d_r1_regularizer = 0 + for k, head in discriminator.heads.items(): + head_grad_params = torch.autograd.grad( + outputs=d_adv_loss_real, inputs=head.parameters(), create_graph=True + ) + head_grad_norm = 0 + for grad in head_grad_params: + head_grad_norm += grad.abs().sum() + d_r1_regularizer += head_grad_norm + + d_loss_real = d_adv_loss_real + args.discriminator_r1_strength * d_r1_regularizer + accelerator.backward(d_loss_real, retain_graph=True) + + # 5. Calculate the discriminator fake adversarial loss terms. + d_logits_fake = disc_output_fake.logits + # Use hinge loss (see section 3.2, Equation 3 of paper) + d_adv_loss_fake = torch.mean(F.relu(torch.ones_like(d_logits_fake) + d_logits_fake)) + accelerator.backward(d_adv_loss_fake) + + d_total_loss = d_loss_real + d_adv_loss_fake + + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(discriminator.parameters(), args.max_grad_norm) + discriminator_optimizer.step() + discriminator_lr_scheduler.step() + + ############################ + # 10. Generator Loss + ############################ + optimizer.zero_grad(set_to_none=True) + + # 1. Rerun the disc on generated image, but this time allow gradients to flow through the generator + disc_output_fake = discriminator(student_gen_image, text_embedding, image_embedding) + + # 2. Calculate generator adversarial loss term + g_logits_fake = disc_output_fake.logits + g_adv_loss = torch.mean(-g_logits_fake) + + ############################ + # 11. Distillation Loss + ############################ + # Calculate distillation loss in pixel space rather than latent space (see section 3.1) + unscaled_teacher_x_0 = (1 / vae.config.scaling_factor) * teacher_x_0 + teacher_gen_image = [] + # Perform batched decode with batch size of at most args.vae_encode_batch_size + for i in range(0, unscaled_teacher_x_0.shape[0], args.vae_encode_batch_size): + if args.pretrained_vae_model_name_or_path: + teacher_gen_image.append( + vae.decode( + unscaled_teacher_x_0[i : i + args.vae_encode_batch_size].to(dtype=weight_dtype) + ).sample + ) + else: + # VAE is in full precision due to possible NaN issues + teacher_gen_image.append( + vae.decode(unscaled_teacher_x_0[i : i + args.vae_encode_batch_size]).sample + ) + teacher_gen_image = torch.cat(teacher_gen_image, dim=0) + + per_instance_distillation_loss = F.mse_loss( + student_gen_image.float(), teacher_gen_image.float(), reduction="none" + ) + # Note that we use the teacher timesteps t when getting the loss weights. + c_t = extract_into_tensor( + train_weight_schedule, teacher_timesteps, per_instance_distillation_loss.shape + ) + g_distillation_loss = torch.mean(c_t * per_instance_distillation_loss) + + g_total_loss = g_adv_loss + args.distillation_weight_factor * g_distillation_loss + + # Backprop on the generator total loss + accelerator.backward(g_total_loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + # 12. Perform an EMA update on the EMA version of the student U-Net weights. + if args.use_ema: + ema_unet.step(unet.parameters()) + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if global_step % args.validation_steps == 0: + if args.use_ema: + # Store the student unet weights and load the EMA weights. + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + + log_validation(vae, unet, args, accelerator, weight_dtype, global_step, "ema_student") + + # Restore student unet weights + ema_unet.restore(unet.parameters()) + + log_validation(vae, unet, args, accelerator, weight_dtype, global_step, "student") + + logs = { + "d_total_loss": d_total_loss.detach().item(), + "g_total_loss": g_total_loss.detach().item(), + "g_adv_loss": g_adv_loss.detach().item(), + "g_distill_loss": g_distillation_loss.detach().item(), + "lr": lr_scheduler.get_last_lr()[0], + } + progress_bar.set_postfix(**logs) + + # Write out additional values for accelerator to report. + logs["d_adv_loss_fake"] = d_adv_loss_fake.detach().item() + logs["d_adv_loss_real"] = d_adv_loss_real.detach().item() + logs["d_r1_regularizer"] = d_r1_regularizer.detach().item() + logs["d_loss_real"] = d_loss_real.detach().item() + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) + unet.save_pretrained(os.path.join(args.output_dir, "unet")) + + lora_state_dict = get_peft_model_state_dict(unet, adapter_name="default") + StableDiffusionXLPipeline.save_lora_weights(os.path.join(args.output_dir, "unet_lora"), lora_state_dict) + + # If using EMA, save EMA weights as well. + if args.use_ema: + ema_unet.store(unet.parameters()) + ema_unet.copy_to(unet.parameters()) + + unet.save_pretrained(os.path.join(args.output_dir, "ema_unet")) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) From c8c1e83e4e5235bfb8026dc176cb54da20f8ad5d Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 6 Jan 2024 18:47:24 -0800 Subject: [PATCH 39/53] make style --- examples/add/train_add_distill_lora_sd_wds.py | 7 ++++--- examples/add/train_add_distill_lora_sdxl_wds.py | 9 +++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/add/train_add_distill_lora_sd_wds.py b/examples/add/train_add_distill_lora_sd_wds.py index cadb9c2de64d..fbb6ffdc9d56 100644 --- a/examples/add/train_add_distill_lora_sd_wds.py +++ b/examples/add/train_add_distill_lora_sd_wds.py @@ -1565,7 +1565,7 @@ def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: if args.use_ema: ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) - + unet_ = accelerator.unwrap_model(unet) lora_state_dict = get_peft_model_state_dict(unet_, adapter_name="default") StableDiffusionPipeline.save_lora_weights(os.path.join(output_dir, "unet_lora"), lora_state_dict) @@ -1584,14 +1584,15 @@ def load_model_hook(models, input_dir): ema_unet.load_state_dict(load_model.state_dict()) ema_unet.to(accelerator.device) del load_model - + # load the LoRA into the model unet_ = accelerator.unwrap_model(unet) unet_.load_adapter(input_dir, "default", is_trainable=True) for i in range(len(models)): # pop models so that they are not loaded again - model = models.pop() + # model = models.pop() + models.pop() # load diffusers style into model # load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") diff --git a/examples/add/train_add_distill_lora_sdxl_wds.py b/examples/add/train_add_distill_lora_sdxl_wds.py index 672b3ee94bf8..8751c3bf9077 100644 --- a/examples/add/train_add_distill_lora_sdxl_wds.py +++ b/examples/add/train_add_distill_lora_sdxl_wds.py @@ -1551,7 +1551,7 @@ def main(args): raise ValueError( f"Controlnet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" ) - + # 9. Add LoRA to the student U-Net, only the LoRA projection matrix will be updated by the optimizer. if args.lora_target_modules is not None: lora_target_modules = [module_key.strip() for module_key in args.lora_target_modules.split(",")] @@ -1617,7 +1617,7 @@ def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: if args.use_ema: ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) - + unet_ = accelerator.unwrap_model(unet) lora_state_dict = get_peft_model_state_dict(unet_, adapter_name="default") StableDiffusionXLPipeline.save_lora_weights(os.path.join(output_dir, "unet_lora"), lora_state_dict) @@ -1636,14 +1636,15 @@ def load_model_hook(models, input_dir): ema_unet.load_state_dict(load_model.state_dict()) ema_unet.to(accelerator.device) del load_model - + # load the LoRA into the model unet_ = accelerator.unwrap_model(unet) unet_.load_adapter(input_dir, "default", is_trainable=True) for i in range(len(models)): # pop models so that they are not loaded again - model = models.pop() + # model = models.pop() + models.pop() # load diffusers style into model # load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") From b64b700e558375dafb6af81ac5fa8f51d73f5efe Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 6 Jan 2024 19:07:58 -0800 Subject: [PATCH 40/53] Use diffusers.training_utils.resolve_interpolation_mode after PR #6420 is merged. --- examples/add/train_add_distill_lora_sd_wds.py | 22 ++----------------- .../add/train_add_distill_lora_sdxl_wds.py | 22 ++----------------- examples/add/train_add_distill_sd_wds.py | 22 ++----------------- examples/add/train_add_distill_sdxl_wds.py | 22 ++----------------- 4 files changed, 8 insertions(+), 80 deletions(-) diff --git a/examples/add/train_add_distill_lora_sd_wds.py b/examples/add/train_add_distill_lora_sd_wds.py index fbb6ffdc9d56..ee5e981c48bd 100644 --- a/examples/add/train_add_distill_lora_sd_wds.py +++ b/examples/add/train_add_distill_lora_sd_wds.py @@ -64,7 +64,7 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler -from diffusers.training_utils import EMAModel +from diffusers.training_utils import EMAModel, resolve_interpolation_mode from diffusers.utils import BaseOutput, check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -140,24 +140,6 @@ def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue): return samples -def resolve_interpolation_mode(interpolation_type): - if interpolation_type == "bilinear": - interpolation_mode = TF.InterpolationMode.BILINEAR - elif interpolation_type == "bicubic": - interpolation_mode = TF.InterpolationMode.BICUBIC - elif interpolation_type == "nearest": - interpolation_mode = TF.InterpolationMode.NEAREST - elif interpolation_type == "lanczos": - interpolation_mode = TF.InterpolationMode.LANCZOS - else: - raise ValueError( - f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" - f" modes are `bilinear`, `bicubic`, `lanczos`, and `nearest`." - ) - - return interpolation_mode - - class WebdatasetFilter: def __init__(self, min_size=1024, max_pwatermark=0.5): self.min_size = min_size @@ -893,7 +875,7 @@ def parse_args(): default="bilinear", help=( "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," - " `bicubic`, `lanczos`, and `nearest`." + " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." ), ) parser.add_argument( diff --git a/examples/add/train_add_distill_lora_sdxl_wds.py b/examples/add/train_add_distill_lora_sdxl_wds.py index 8751c3bf9077..f25316a6664f 100644 --- a/examples/add/train_add_distill_lora_sdxl_wds.py +++ b/examples/add/train_add_distill_lora_sdxl_wds.py @@ -64,7 +64,7 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler -from diffusers.training_utils import EMAModel +from diffusers.training_utils import EMAModel, resolve_interpolation_mode from diffusers.utils import BaseOutput, check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -145,24 +145,6 @@ def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue): return samples -def resolve_interpolation_mode(interpolation_type): - if interpolation_type == "bilinear": - interpolation_mode = TF.InterpolationMode.BILINEAR - elif interpolation_type == "bicubic": - interpolation_mode = TF.InterpolationMode.BICUBIC - elif interpolation_type == "nearest": - interpolation_mode = TF.InterpolationMode.NEAREST - elif interpolation_type == "lanczos": - interpolation_mode = TF.InterpolationMode.LANCZOS - else: - raise ValueError( - f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" - f" modes are `bilinear`, `bicubic`, `lanczos`, and `nearest`." - ) - - return interpolation_mode - - class WebdatasetFilter: def __init__(self, min_size=1024, max_pwatermark=0.5): self.min_size = min_size @@ -930,7 +912,7 @@ def parse_args(): default="bilinear", help=( "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," - " `bicubic`, `lanczos`, and `nearest`." + " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." ), ) parser.add_argument( diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index fc0123f2a6af..60e7f4fc17c9 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -63,7 +63,7 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler -from diffusers.training_utils import EMAModel +from diffusers.training_utils import EMAModel, resolve_interpolation_mode from diffusers.utils import BaseOutput, check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -122,24 +122,6 @@ def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue): return samples -def resolve_interpolation_mode(interpolation_type): - if interpolation_type == "bilinear": - interpolation_mode = TF.InterpolationMode.BILINEAR - elif interpolation_type == "bicubic": - interpolation_mode = TF.InterpolationMode.BICUBIC - elif interpolation_type == "nearest": - interpolation_mode = TF.InterpolationMode.NEAREST - elif interpolation_type == "lanczos": - interpolation_mode = TF.InterpolationMode.LANCZOS - else: - raise ValueError( - f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" - f" modes are `bilinear`, `bicubic`, `lanczos`, and `nearest`." - ) - - return interpolation_mode - - class WebdatasetFilter: def __init__(self, min_size=1024, max_pwatermark=0.5): self.min_size = min_size @@ -871,7 +853,7 @@ def parse_args(): default="bilinear", help=( "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," - " `bicubic`, `lanczos`, and `nearest`." + " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." ), ) parser.add_argument( diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index 02a3f69ba36a..a3d51cc453b5 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -63,7 +63,7 @@ UNet2DConditionModel, ) from diffusers.optimization import get_scheduler -from diffusers.training_utils import EMAModel +from diffusers.training_utils import EMAModel, resolve_interpolation_mode from diffusers.utils import BaseOutput, check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -127,24 +127,6 @@ def tarfile_to_samples_nothrow(src, handler=wds.warn_and_continue): return samples -def resolve_interpolation_mode(interpolation_type): - if interpolation_type == "bilinear": - interpolation_mode = TF.InterpolationMode.BILINEAR - elif interpolation_type == "bicubic": - interpolation_mode = TF.InterpolationMode.BICUBIC - elif interpolation_type == "nearest": - interpolation_mode = TF.InterpolationMode.NEAREST - elif interpolation_type == "lanczos": - interpolation_mode = TF.InterpolationMode.LANCZOS - else: - raise ValueError( - f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation" - f" modes are `bilinear`, `bicubic`, `lanczos`, and `nearest`." - ) - - return interpolation_mode - - class WebdatasetFilter: def __init__(self, min_size=1024, max_pwatermark=0.5): self.min_size = min_size @@ -908,7 +890,7 @@ def parse_args(): default="bilinear", help=( "The interpolation function used when resizing images to the desired resolution. Choose between `bilinear`," - " `bicubic`, `lanczos`, and `nearest`." + " `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`." ), ) parser.add_argument( From bb0619fef98b071579b315ca74bb0301d9d58186 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sun, 7 Jan 2024 14:40:04 -0800 Subject: [PATCH 41/53] Add allow_nonzero_terminal_snr argument to disable enforcing zero terminal SNR. --- examples/add/train_add_distill_lora_sd_wds.py | 11 ++++++++++- examples/add/train_add_distill_lora_sdxl_wds.py | 11 ++++++++++- examples/add/train_add_distill_sd_wds.py | 11 ++++++++++- examples/add/train_add_distill_sdxl_wds.py | 11 ++++++++++- 4 files changed, 40 insertions(+), 4 deletions(-) diff --git a/examples/add/train_add_distill_lora_sd_wds.py b/examples/add/train_add_distill_lora_sd_wds.py index ee5e981c48bd..b68c100468a9 100644 --- a/examples/add/train_add_distill_lora_sd_wds.py +++ b/examples/add/train_add_distill_lora_sd_wds.py @@ -1133,6 +1133,15 @@ def parse_args(): " Encoding or decoding the whole batch at once may run into OOM issues." ), ) + parser.add_argument( + "--allow_nonzero_terminal_snr", + action="store_true", + help=( + "Option to turn off enforcing zero terminal SNR. The ADD paper states that they enforce zero terminal SNR" + " during training, but this may lead to numerical instability issues during training at the last training" + " timestep T (`noise_scheduler.config.num_train_timesteps - 1`)." + ), + ) # LoRA Arguments parser.add_argument( "--lora_rank", @@ -1347,7 +1356,7 @@ def main(args): teacher_scheduler = DDPMScheduler.from_pretrained( args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) - if not teacher_scheduler.config.rescale_betas_zero_snr: + if not teacher_scheduler.config.rescale_betas_zero_snr and not args.allow_nonzero_terminal_snr: teacher_scheduler.config["rescale_betas_zero_snr"] = True noise_scheduler = DDPMScheduler(**teacher_scheduler.config) diff --git a/examples/add/train_add_distill_lora_sdxl_wds.py b/examples/add/train_add_distill_lora_sdxl_wds.py index f25316a6664f..bb6f82b6e2f7 100644 --- a/examples/add/train_add_distill_lora_sdxl_wds.py +++ b/examples/add/train_add_distill_lora_sdxl_wds.py @@ -1158,6 +1158,15 @@ def parse_args(): " Encoding or decoding the whole batch at once may run into OOM issues." ), ) + parser.add_argument( + "--allow_nonzero_terminal_snr", + action="store_true", + help=( + "Option to turn off enforcing zero terminal SNR. The ADD paper states that they enforce zero terminal SNR" + " during training, but this may lead to numerical instability issues during training at the last training" + " timestep T (`noise_scheduler.config.num_train_timesteps - 1`)." + ), + ) # LoRA Arguments parser.add_argument( "--lora_rank", @@ -1380,7 +1389,7 @@ def main(args): teacher_scheduler = DDPMScheduler.from_pretrained( args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) - if not teacher_scheduler.config.rescale_betas_zero_snr: + if not teacher_scheduler.config.rescale_betas_zero_snr and not args.allow_nonzero_terminal_snr: teacher_scheduler.config["rescale_betas_zero_snr"] = True noise_scheduler = DDPMScheduler(**teacher_scheduler.config) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 60e7f4fc17c9..b397b41bc51e 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -1111,6 +1111,15 @@ def parse_args(): " Encoding or decoding the whole batch at once may run into OOM issues." ), ) + parser.add_argument( + "--allow_nonzero_terminal_snr", + action="store_true", + help=( + "Option to turn off enforcing zero terminal SNR. The ADD paper states that they enforce zero terminal SNR" + " during training, but this may lead to numerical instability issues during training at the last training" + " timestep T (`noise_scheduler.config.num_train_timesteps - 1`)." + ), + ) # ----Exponential Moving Average (EMA)---- parser.add_argument( "--use_ema", action="store_true", help="Whether to also maintain an EMA version of the student U-Net weights." @@ -1294,7 +1303,7 @@ def main(args): teacher_scheduler = DDPMScheduler.from_pretrained( args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) - if not teacher_scheduler.config.rescale_betas_zero_snr: + if not teacher_scheduler.config.rescale_betas_zero_snr and not args.allow_nonzero_terminal_snr: teacher_scheduler.config["rescale_betas_zero_snr"] = True noise_scheduler = DDPMScheduler(**teacher_scheduler.config) diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index a3d51cc453b5..251fcc106deb 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -1136,6 +1136,15 @@ def parse_args(): " Encoding or decoding the whole batch at once may run into OOM issues." ), ) + parser.add_argument( + "--allow_nonzero_terminal_snr", + action="store_true", + help=( + "Option to turn off enforcing zero terminal SNR. The ADD paper states that they enforce zero terminal SNR" + " during training, but this may lead to numerical instability issues during training at the last training" + " timestep T (`noise_scheduler.config.num_train_timesteps - 1`)." + ), + ) # ----Exponential Moving Average (EMA)---- parser.add_argument( "--use_ema", action="store_true", help="Whether to also maintain an EMA version of the student U-Net weights." @@ -1327,7 +1336,7 @@ def main(args): teacher_scheduler = DDPMScheduler.from_pretrained( args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) - if not teacher_scheduler.config.rescale_betas_zero_snr: + if not teacher_scheduler.config.rescale_betas_zero_snr and not args.allow_nonzero_terminal_snr: teacher_scheduler.config["rescale_betas_zero_snr"] = True noise_scheduler = DDPMScheduler(**teacher_scheduler.config) From 5cfd13720d84194b415242c0bc219a5c0e7b659d Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sun, 7 Jan 2024 16:20:12 -0800 Subject: [PATCH 42/53] Use ConfigMixin.from_config with rescale_betas_zero_snr kwarg when creating noise scheduler. --- examples/add/train_add_distill_lora_sd_wds.py | 5 ++--- examples/add/train_add_distill_lora_sdxl_wds.py | 5 ++--- examples/add/train_add_distill_sd_wds.py | 5 ++--- examples/add/train_add_distill_sdxl_wds.py | 5 ++--- 4 files changed, 8 insertions(+), 12 deletions(-) diff --git a/examples/add/train_add_distill_lora_sd_wds.py b/examples/add/train_add_distill_lora_sd_wds.py index b68c100468a9..1ff40ac595a6 100644 --- a/examples/add/train_add_distill_lora_sd_wds.py +++ b/examples/add/train_add_distill_lora_sd_wds.py @@ -1356,9 +1356,8 @@ def main(args): teacher_scheduler = DDPMScheduler.from_pretrained( args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) - if not teacher_scheduler.config.rescale_betas_zero_snr and not args.allow_nonzero_terminal_snr: - teacher_scheduler.config["rescale_betas_zero_snr"] = True - noise_scheduler = DDPMScheduler(**teacher_scheduler.config) + enforce_zero_snr = teacher_scheduler.config.rescale_betas_zero_snr if args.allow_nonzero_terminal_snr else True + noise_scheduler = DDPMScheduler.from_config(teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr) # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us # Note that the ADD paper parameterizes alpha and sigma as x_t = alpha_t * x_0 + sigma_t * eps diff --git a/examples/add/train_add_distill_lora_sdxl_wds.py b/examples/add/train_add_distill_lora_sdxl_wds.py index bb6f82b6e2f7..d8d528badfab 100644 --- a/examples/add/train_add_distill_lora_sdxl_wds.py +++ b/examples/add/train_add_distill_lora_sdxl_wds.py @@ -1389,9 +1389,8 @@ def main(args): teacher_scheduler = DDPMScheduler.from_pretrained( args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) - if not teacher_scheduler.config.rescale_betas_zero_snr and not args.allow_nonzero_terminal_snr: - teacher_scheduler.config["rescale_betas_zero_snr"] = True - noise_scheduler = DDPMScheduler(**teacher_scheduler.config) + enforce_zero_snr = teacher_scheduler.config.rescale_betas_zero_snr if args.allow_nonzero_terminal_snr else True + noise_scheduler = DDPMScheduler.from_config(teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr) # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us # Note that the ADD paper parameterizes alpha and sigma as x_t = alpha_t * x_0 + sigma_t * eps diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index b397b41bc51e..32bb432f43f5 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -1303,9 +1303,8 @@ def main(args): teacher_scheduler = DDPMScheduler.from_pretrained( args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) - if not teacher_scheduler.config.rescale_betas_zero_snr and not args.allow_nonzero_terminal_snr: - teacher_scheduler.config["rescale_betas_zero_snr"] = True - noise_scheduler = DDPMScheduler(**teacher_scheduler.config) + enforce_zero_snr = teacher_scheduler.config.rescale_betas_zero_snr if args.allow_nonzero_terminal_snr else True + noise_scheduler = DDPMScheduler.from_config(teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr) # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us # Note that the ADD paper parameterizes alpha and sigma as x_t = alpha_t * x_0 + sigma_t * eps diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index 251fcc106deb..97805eb3534d 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -1336,9 +1336,8 @@ def main(args): teacher_scheduler = DDPMScheduler.from_pretrained( args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) - if not teacher_scheduler.config.rescale_betas_zero_snr and not args.allow_nonzero_terminal_snr: - teacher_scheduler.config["rescale_betas_zero_snr"] = True - noise_scheduler = DDPMScheduler(**teacher_scheduler.config) + enforce_zero_snr = teacher_scheduler.config.rescale_betas_zero_snr if args.allow_nonzero_terminal_snr else True + noise_scheduler = DDPMScheduler.from_config(teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr) # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us # Note that the ADD paper parameterizes alpha and sigma as x_t = alpha_t * x_0 + sigma_t * eps From acf5175c6f745910166b265976540dc3ec3542af Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sun, 7 Jan 2024 16:33:40 -0800 Subject: [PATCH 43/53] Allow noise_scheduler to be configured between ddpm and euler. --- examples/add/train_add_distill_lora_sd_wds.py | 25 +++++++++++++++++-- .../add/train_add_distill_lora_sdxl_wds.py | 25 +++++++++++++++++-- examples/add/train_add_distill_sd_wds.py | 25 +++++++++++++++++-- examples/add/train_add_distill_sdxl_wds.py | 25 +++++++++++++++++-- 4 files changed, 92 insertions(+), 8 deletions(-) diff --git a/examples/add/train_add_distill_lora_sd_wds.py b/examples/add/train_add_distill_lora_sd_wds.py index 1ff40ac595a6..188fb42ba192 100644 --- a/examples/add/train_add_distill_lora_sd_wds.py +++ b/examples/add/train_add_distill_lora_sd_wds.py @@ -60,6 +60,7 @@ from diffusers import ( AutoencoderKL, DDPMScheduler, + EulerDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel, ) @@ -1142,6 +1143,15 @@ def parse_args(): " timestep T (`noise_scheduler.config.num_train_timesteps - 1`)." ), ) + parser.add_argument( + "--noise_scheduler_type", + type=str, + default="ddpm", + help=( + "The scheduler class to use for the noise scheduler during training. This affects how noise is added to" + " the latents (the forward process). Choose between `ddpm` and `euler`." + ), + ) # LoRA Arguments parser.add_argument( "--lora_rank", @@ -1353,11 +1363,22 @@ def main(args): # 1. Create the noise scheduler and the desired noise schedule. # Enforce zero terminal SNR (see section 3.1 of ADD paper) - teacher_scheduler = DDPMScheduler.from_pretrained( + if args.noise_scheduler_type == "ddpm": + noise_scheduler_cls = DDPMScheduler + elif args.noise_scheduler_type == "euler": + noise_scheduler_cls = EulerDiscreteScheduler + else: + raise ValueError( + f"Noise scheduler type {args.noise_scheduler_type} is not supported. Supported scheduler types are `ddpm`" + f" and `euler`." + ) + teacher_scheduler = noise_scheduler_cls.from_pretrained( args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) enforce_zero_snr = teacher_scheduler.config.rescale_betas_zero_snr if args.allow_nonzero_terminal_snr else True - noise_scheduler = DDPMScheduler.from_config(teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr) + noise_scheduler = noise_scheduler_cls.from_config( + teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr + ) # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us # Note that the ADD paper parameterizes alpha and sigma as x_t = alpha_t * x_0 + sigma_t * eps diff --git a/examples/add/train_add_distill_lora_sdxl_wds.py b/examples/add/train_add_distill_lora_sdxl_wds.py index d8d528badfab..17310c31c95e 100644 --- a/examples/add/train_add_distill_lora_sdxl_wds.py +++ b/examples/add/train_add_distill_lora_sdxl_wds.py @@ -60,6 +60,7 @@ from diffusers import ( AutoencoderKL, DDPMScheduler, + EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel, ) @@ -1167,6 +1168,15 @@ def parse_args(): " timestep T (`noise_scheduler.config.num_train_timesteps - 1`)." ), ) + parser.add_argument( + "--noise_scheduler_type", + type=str, + default="ddpm", + help=( + "The scheduler class to use for the noise scheduler during training. This affects how noise is added to" + " the latents (the forward process). Choose between `ddpm` and `euler`." + ), + ) # LoRA Arguments parser.add_argument( "--lora_rank", @@ -1386,11 +1396,22 @@ def main(args): # 1. Create the noise scheduler and the desired noise schedule. # Enforce zero terminal SNR (see section 3.1 of ADD paper) - teacher_scheduler = DDPMScheduler.from_pretrained( + if args.noise_scheduler_type == "ddpm": + noise_scheduler_cls = DDPMScheduler + elif args.noise_scheduler_type == "euler": + noise_scheduler_cls = EulerDiscreteScheduler + else: + raise ValueError( + f"Noise scheduler type {args.noise_scheduler_type} is not supported. Supported scheduler types are `ddpm`" + f" and `euler`." + ) + teacher_scheduler = noise_scheduler_cls.from_pretrained( args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) enforce_zero_snr = teacher_scheduler.config.rescale_betas_zero_snr if args.allow_nonzero_terminal_snr else True - noise_scheduler = DDPMScheduler.from_config(teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr) + noise_scheduler = noise_scheduler_cls.from_config( + teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr + ) # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us # Note that the ADD paper parameterizes alpha and sigma as x_t = alpha_t * x_0 + sigma_t * eps diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 32bb432f43f5..2bfb833e9ddf 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -59,6 +59,7 @@ from diffusers import ( AutoencoderKL, DDPMScheduler, + EulerDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel, ) @@ -1120,6 +1121,15 @@ def parse_args(): " timestep T (`noise_scheduler.config.num_train_timesteps - 1`)." ), ) + parser.add_argument( + "--noise_scheduler_type", + type=str, + default="ddpm", + help=( + "The scheduler class to use for the noise scheduler during training. This affects how noise is added to" + " the latents (the forward process). Choose between `ddpm` and `euler`." + ), + ) # ----Exponential Moving Average (EMA)---- parser.add_argument( "--use_ema", action="store_true", help="Whether to also maintain an EMA version of the student U-Net weights." @@ -1300,11 +1310,22 @@ def main(args): # 1. Create the noise scheduler and the desired noise schedule. # Enforce zero terminal SNR (see section 3.1 of ADD paper) - teacher_scheduler = DDPMScheduler.from_pretrained( + if args.noise_scheduler_type == "ddpm": + noise_scheduler_cls = DDPMScheduler + elif args.noise_scheduler_type == "euler": + noise_scheduler_cls = EulerDiscreteScheduler + else: + raise ValueError( + f"Noise scheduler type {args.noise_scheduler_type} is not supported. Supported scheduler types are `ddpm`" + f" and `euler`." + ) + teacher_scheduler = noise_scheduler_cls.from_pretrained( args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) enforce_zero_snr = teacher_scheduler.config.rescale_betas_zero_snr if args.allow_nonzero_terminal_snr else True - noise_scheduler = DDPMScheduler.from_config(teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr) + noise_scheduler = noise_scheduler_cls.from_config( + teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr + ) # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us # Note that the ADD paper parameterizes alpha and sigma as x_t = alpha_t * x_0 + sigma_t * eps diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index 97805eb3534d..95f7adf403fd 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -59,6 +59,7 @@ from diffusers import ( AutoencoderKL, DDPMScheduler, + EulerDiscreteScheduler, StableDiffusionXLPipeline, UNet2DConditionModel, ) @@ -1145,6 +1146,15 @@ def parse_args(): " timestep T (`noise_scheduler.config.num_train_timesteps - 1`)." ), ) + parser.add_argument( + "--noise_scheduler_type", + type=str, + default="ddpm", + help=( + "The scheduler class to use for the noise scheduler during training. This affects how noise is added to" + " the latents (the forward process). Choose between `ddpm` and `euler`." + ), + ) # ----Exponential Moving Average (EMA)---- parser.add_argument( "--use_ema", action="store_true", help="Whether to also maintain an EMA version of the student U-Net weights." @@ -1333,11 +1343,22 @@ def main(args): # 1. Create the noise scheduler and the desired noise schedule. # Enforce zero terminal SNR (see section 3.1 of ADD paper) - teacher_scheduler = DDPMScheduler.from_pretrained( + if args.noise_scheduler_type == "ddpm": + noise_scheduler_cls = DDPMScheduler + elif args.noise_scheduler_type == "euler": + noise_scheduler_cls = EulerDiscreteScheduler + else: + raise ValueError( + f"Noise scheduler type {args.noise_scheduler_type} is not supported. Supported scheduler types are `ddpm`" + f" and `euler`." + ) + teacher_scheduler = noise_scheduler_cls.from_pretrained( args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision ) enforce_zero_snr = teacher_scheduler.config.rescale_betas_zero_snr if args.allow_nonzero_terminal_snr else True - noise_scheduler = DDPMScheduler.from_config(teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr) + noise_scheduler = noise_scheduler_cls.from_config( + teacher_scheduler.config, rescale_betas_zero_snr=enforce_zero_snr + ) # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us # Note that the ADD paper parameterizes alpha and sigma as x_t = alpha_t * x_0 + sigma_t * eps From cd825657b27a26547a10d94102fb2ec487276f63 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sun, 14 Jan 2024 03:49:03 -0800 Subject: [PATCH 44/53] Return features from Discriminator and use L2 gradient penalty (haven't fixed autograd call yet). --- examples/add/train_add_distill_lora_sd_wds.py | 10 ++++++---- examples/add/train_add_distill_lora_sdxl_wds.py | 10 ++++++---- examples/add/train_add_distill_sd_wds.py | 10 ++++++---- examples/add/train_add_distill_sdxl_wds.py | 10 ++++++---- 4 files changed, 24 insertions(+), 16 deletions(-) diff --git a/examples/add/train_add_distill_lora_sd_wds.py b/examples/add/train_add_distill_lora_sd_wds.py index 188fb42ba192..e9675933553b 100644 --- a/examples/add/train_add_distill_lora_sd_wds.py +++ b/examples/add/train_add_distill_lora_sd_wds.py @@ -26,7 +26,7 @@ import shutil import types from pathlib import Path -from typing import Callable, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import accelerate import numpy as np @@ -573,6 +573,7 @@ class DiscriminatorOutput(BaseOutput): """ logits: torch.FloatTensor + features: Optional[Dict[str, torch.FloatTensor]] = None # Based on ProjectedDiscriminator from the official StyleGAN-T code @@ -651,7 +652,7 @@ def forward( if not return_dict: return (logits,) - return DiscriminatorOutput(logits=logits) + return DiscriminatorOutput(logits=logits, features=features) def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="student"): @@ -1951,7 +1952,8 @@ def compute_image_embeddings(image_batch, image_encoder): ) head_grad_norm = 0 for grad in head_grad_params: - head_grad_norm += grad.abs().sum() + head_grad_norm += grad.pow(2).sum() + head_grad_norm = head_grad_norm.sqrt() d_r1_regularizer += head_grad_norm d_loss_real = d_adv_loss_real + args.discriminator_r1_strength * d_r1_regularizer @@ -2074,7 +2076,7 @@ def compute_image_embeddings(image_batch, image_encoder): # Write out additional values for accelerator to report. logs["d_adv_loss_fake"] = d_adv_loss_fake.detach().item() logs["d_adv_loss_real"] = d_adv_loss_real.detach().item() - logs["d_r1_regularizer"] = d_r1_regularizer.detach().item() + logs["d_r1_penalty_scaled"] = args.discriminator_r1_strength * d_r1_regularizer.detach().item() logs["d_loss_real"] = d_loss_real.detach().item() accelerator.log(logs, step=global_step) diff --git a/examples/add/train_add_distill_lora_sdxl_wds.py b/examples/add/train_add_distill_lora_sdxl_wds.py index 17310c31c95e..16a1b7d6dcb7 100644 --- a/examples/add/train_add_distill_lora_sdxl_wds.py +++ b/examples/add/train_add_distill_lora_sdxl_wds.py @@ -26,7 +26,7 @@ import shutil import types from pathlib import Path -from typing import Callable, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import accelerate import numpy as np @@ -590,6 +590,7 @@ class DiscriminatorOutput(BaseOutput): """ logits: torch.FloatTensor + features: Optional[Dict[str, torch.FloatTensor]] = None # Based on ProjectedDiscriminator from the official StyleGAN-T code @@ -668,7 +669,7 @@ def forward( if not return_dict: return (logits,) - return DiscriminatorOutput(logits=logits) + return DiscriminatorOutput(logits=logits, features=features) def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="student"): @@ -2044,7 +2045,8 @@ def compute_image_embeddings(image_batch, image_encoder): ) head_grad_norm = 0 for grad in head_grad_params: - head_grad_norm += grad.abs().sum() + head_grad_norm += grad.pow(2).sum() + head_grad_norm = head_grad_norm.sqrt() d_r1_regularizer += head_grad_norm d_loss_real = d_adv_loss_real + args.discriminator_r1_strength * d_r1_regularizer @@ -2173,7 +2175,7 @@ def compute_image_embeddings(image_batch, image_encoder): # Write out additional values for accelerator to report. logs["d_adv_loss_fake"] = d_adv_loss_fake.detach().item() logs["d_adv_loss_real"] = d_adv_loss_real.detach().item() - logs["d_r1_regularizer"] = d_r1_regularizer.detach().item() + logs["d_r1_penalty_scaled"] = args.discriminator_r1_strength * d_r1_regularizer.detach().item() logs["d_loss_real"] = d_loss_real.detach().item() accelerator.log(logs, step=global_step) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 2bfb833e9ddf..6cb39e1dc19e 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -26,7 +26,7 @@ import shutil import types from pathlib import Path -from typing import Callable, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import accelerate import numpy as np @@ -555,6 +555,7 @@ class DiscriminatorOutput(BaseOutput): """ logits: torch.FloatTensor + features: Optional[Dict[str, torch.FloatTensor]] = None # Based on ProjectedDiscriminator from the official StyleGAN-T code @@ -633,7 +634,7 @@ def forward( if not return_dict: return (logits,) - return DiscriminatorOutput(logits=logits) + return DiscriminatorOutput(logits=logits, features=features) def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="student"): @@ -1859,7 +1860,8 @@ def compute_image_embeddings(image_batch, image_encoder): ) head_grad_norm = 0 for grad in head_grad_params: - head_grad_norm += grad.abs().sum() + head_grad_norm += grad.pow(2).sum() + head_grad_norm = head_grad_norm.sqrt() d_r1_regularizer += head_grad_norm d_loss_real = d_adv_loss_real + args.discriminator_r1_strength * d_r1_regularizer @@ -1982,7 +1984,7 @@ def compute_image_embeddings(image_batch, image_encoder): # Write out additional values for accelerator to report. logs["d_adv_loss_fake"] = d_adv_loss_fake.detach().item() logs["d_adv_loss_real"] = d_adv_loss_real.detach().item() - logs["d_r1_regularizer"] = d_r1_regularizer.detach().item() + logs["d_r1_penalty_scaled"] = args.discriminator_r1_strength * d_r1_regularizer.detach().item() logs["d_loss_real"] = d_loss_real.detach().item() accelerator.log(logs, step=global_step) diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index 95f7adf403fd..d6029618b640 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -26,7 +26,7 @@ import shutil import types from pathlib import Path -from typing import Callable, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import accelerate import numpy as np @@ -572,6 +572,7 @@ class DiscriminatorOutput(BaseOutput): """ logits: torch.FloatTensor + features: Optional[Dict[str, torch.FloatTensor]] = None # Based on ProjectedDiscriminator from the official StyleGAN-T code @@ -650,7 +651,7 @@ def forward( if not return_dict: return (logits,) - return DiscriminatorOutput(logits=logits) + return DiscriminatorOutput(logits=logits, features=features) def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="student"): @@ -1952,7 +1953,8 @@ def compute_image_embeddings(image_batch, image_encoder): ) head_grad_norm = 0 for grad in head_grad_params: - head_grad_norm += grad.abs().sum() + head_grad_norm += grad.pow(2).sum() + head_grad_norm = head_grad_norm.sqrt() d_r1_regularizer += head_grad_norm d_loss_real = d_adv_loss_real + args.discriminator_r1_strength * d_r1_regularizer @@ -2081,7 +2083,7 @@ def compute_image_embeddings(image_batch, image_encoder): # Write out additional values for accelerator to report. logs["d_adv_loss_fake"] = d_adv_loss_fake.detach().item() logs["d_adv_loss_real"] = d_adv_loss_real.detach().item() - logs["d_r1_regularizer"] = d_r1_regularizer.detach().item() + logs["d_r1_penalty_scaled"] = args.discriminator_r1_strength * d_r1_regularizer.detach().item() logs["d_loss_real"] = d_loss_real.detach().item() accelerator.log(logs, step=global_step) From ab46142de967b21e031e2f7cdac82898e8a7e5f9 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sun, 14 Jan 2024 17:09:48 -0800 Subject: [PATCH 45/53] [not sure if correct] Add tentative fixed implementation of discriminator R1 gradient penalty. --- examples/add/train_add_distill_sd_wds.py | 77 ++++++++++++++---------- 1 file changed, 46 insertions(+), 31 deletions(-) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 6cb39e1dc19e..c5b2245be777 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -609,26 +609,27 @@ def train(self, mode: bool = True): def eval(self): return self.train(False) - def forward( - self, - x: torch.Tensor, - c_text: torch.Tensor, - c_img: Optional[torch.Tensor] = None, - transform_positive: bool = True, - return_dict: bool = True, - ): - # TODO: do we need the augmentations from the original StyleGAN-T code? + def get_features(self, image: torch.Tensor, transform_positive: bool = True) -> Dict[str, torch.Tensor]: if transform_positive: # Transform to [0, 1]. - x = x.add(1).div(2) + image = image.add(1).div(2) # Forward pass through feature network. - features = self.feature_network(x) + features = self.feature_network(image) + return features + def forward_features( + self, + features: Dict[str, torch.Tensor], + c_text: torch.Tensor, + c_img: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + batch_size = features["0"].size(0) # Apply discriminator heads. logits = [] for k, head in self.heads.items(): - logits.append(head(features[k], c_text, c_img).view(x.size(0), -1)) + logits.append(head(features[k], c_text, c_img).view(batch_size, -1)) logits = torch.cat(logits, dim=1) if not return_dict: @@ -636,6 +637,18 @@ def forward( return DiscriminatorOutput(logits=logits, features=features) + def forward( + self, + image: torch.Tensor, + c_text: torch.Tensor, + c_img: Optional[torch.Tensor] = None, + transform_positive: bool = True, + return_dict: bool = True, + ): + features = self.get_features(image, transform_positive=transform_positive) + d_output = self.forward_features(features, c_text, c_img=c_img, return_dict=return_dict) + return d_output + def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="student"): logger.info("Running validation... ") @@ -1777,7 +1790,7 @@ def compute_image_embeddings(image_batch, image_encoder): text_embedding = encoded_text.pop("text_embedding") image_embedding = None if args.use_image_conditioning: - image_embedding = encoded_image.pop("image_embeds") + image_embedding = encoded_image.pop("image_embeds").float() # Only supply image conditioning when student timestep is not last training timestep T. image_embedding = torch.where( student_timesteps.unsqueeze(1) < noise_scheduler.config.num_train_timesteps - 1, @@ -1842,32 +1855,34 @@ def compute_image_embeddings(image_batch, image_encoder): ) student_gen_image = torch.cat(student_gen_image, dim=0) - # 2. Get discriminator real/fake outputs on the real and fake (generated) images respectively. - disc_output_real = discriminator(pixel_values.float(), text_embedding, image_embedding) - disc_output_fake = discriminator(student_gen_image.detach().float(), text_embedding, image_embedding) - - # 3. Calculate the discriminator real adversarial loss terms. - d_logits_real = disc_output_real.logits + # 2. Calculate the discriminator real adversarial loss terms. + features_real = discriminator.get_features(pixel_values.float()) + for k, feature in features_real.items(): + # Required so that the torch.autograd.grad call below works properly? + feature.requires_grad_(True) + d_logits_real = discriminator.forward_features( + features_real, text_embedding.float(), image_embedding, return_dict=False + )[0] # Use hinge loss (see section 3.2, Equation 3 of paper) d_adv_loss_real = torch.mean(F.relu(torch.ones_like(d_logits_real) - d_logits_real)) - # 4. Calculate the discriminator R1 gradient penalty term with respect to the gradients from the real - # data. + # 3. Calculate the discriminator R1 gradient penalty term on the gradients with respect to the + # discriminator head input features from the real data. d_r1_regularizer = 0 - for k, head in discriminator.heads.items(): - head_grad_params = torch.autograd.grad( - outputs=d_adv_loss_real, inputs=head.parameters(), create_graph=True - ) - head_grad_norm = 0 - for grad in head_grad_params: - head_grad_norm += grad.pow(2).sum() - head_grad_norm = head_grad_norm.sqrt() - d_r1_regularizer += head_grad_norm + grad_params = torch.autograd.grad( + outputs=d_adv_loss_real, + inputs=features_real.values(), + create_graph=True, + ) + for grad in grad_params: + d_r1_regularizer += grad.pow(2).sum() + d_r1_regularizer = d_r1_regularizer.sqrt() d_loss_real = d_adv_loss_real + args.discriminator_r1_strength * d_r1_regularizer accelerator.backward(d_loss_real, retain_graph=True) - # 5. Calculate the discriminator fake adversarial loss terms. + # 4. Calculate the discriminator fake adversarial loss terms. + disc_output_fake = discriminator(student_gen_image.detach().float(), text_embedding, image_embedding) d_logits_fake = disc_output_fake.logits # Use hinge loss (see section 3.2, Equation 3 of paper) d_adv_loss_fake = torch.mean(F.relu(torch.ones_like(d_logits_fake) + d_logits_fake)) From 51213422d31aa8c9ce987620cd1fbe884b4fa81b Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sun, 14 Jan 2024 18:17:25 -0800 Subject: [PATCH 46/53] Fix checkpointing bug where discriminator was not being properly saved/loaded. --- examples/add/train_add_distill_lora_sd_wds.py | 20 ++++++++++++------- .../add/train_add_distill_lora_sdxl_wds.py | 20 ++++++++++++------- examples/add/train_add_distill_sd_wds.py | 15 +++++++++++--- examples/add/train_add_distill_sdxl_wds.py | 15 +++++++++++--- 4 files changed, 50 insertions(+), 20 deletions(-) diff --git a/examples/add/train_add_distill_lora_sd_wds.py b/examples/add/train_add_distill_lora_sd_wds.py index e9675933553b..bbb951d71f29 100644 --- a/examples/add/train_add_distill_lora_sd_wds.py +++ b/examples/add/train_add_distill_lora_sd_wds.py @@ -44,6 +44,7 @@ from huggingface_hub import create_repo from packaging import version from peft import LoraConfig, get_peft_model, get_peft_model_state_dict +from peft.peft_model import PeftModel from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from torch.utils.data import default_collate from torchvision import transforms @@ -64,6 +65,8 @@ StableDiffusionPipeline, UNet2DConditionModel, ) +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel, resolve_interpolation_mode from diffusers.utils import BaseOutput, check_min_version, is_wandb_available @@ -578,11 +581,12 @@ class DiscriminatorOutput(BaseOutput): # Based on ProjectedDiscriminator from the official StyleGAN-T code # https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L130 -class Discriminator(torch.nn.Module): +class Discriminator(ModelMixin, ConfigMixin): """ StyleGAN-T-style discriminator for adversarial diffusion distillation (ADD). """ + @register_to_config def __init__( self, pretrained_feature_network: str = "vit_small_patch14_dinov2.lvd142m", @@ -1586,6 +1590,8 @@ def save_model_hook(models, weights, output_dir): for i, model in enumerate(models): # model.save_pretrained(os.path.join(output_dir, "unet")) + if not isinstance(model, PeftModel): + model.save_pretrained(os.path.join(output_dir, "discriminator")) # make sure to pop weight so that corresponding model is not saved again weights.pop() @@ -1603,15 +1609,15 @@ def load_model_hook(models, input_dir): for i in range(len(models)): # pop models so that they are not loaded again - # model = models.pop() - models.pop() + model = models.pop() # load diffusers style into model - # load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") - # model.register_to_config(**load_model.config) + if not isinstance(model, PeftModel): + load_model = Discriminator.from_pretrained(input_dir, subfolder="discriminator") + model.register_to_config(**load_model.config) - # model.load_state_dict(load_model.state_dict()) - # del load_model + model.load_state_dict(load_model.state_dict()) + del load_model accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) diff --git a/examples/add/train_add_distill_lora_sdxl_wds.py b/examples/add/train_add_distill_lora_sdxl_wds.py index 16a1b7d6dcb7..b14a813f51e0 100644 --- a/examples/add/train_add_distill_lora_sdxl_wds.py +++ b/examples/add/train_add_distill_lora_sdxl_wds.py @@ -44,6 +44,7 @@ from huggingface_hub import create_repo from packaging import version from peft import LoraConfig, get_peft_model, get_peft_model_state_dict +from peft.peft_model import PeftModel from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from torch.utils.data import default_collate from torchvision import transforms @@ -64,6 +65,8 @@ StableDiffusionXLPipeline, UNet2DConditionModel, ) +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel, resolve_interpolation_mode from diffusers.utils import BaseOutput, check_min_version, is_wandb_available @@ -595,11 +598,12 @@ class DiscriminatorOutput(BaseOutput): # Based on ProjectedDiscriminator from the official StyleGAN-T code # https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L130 -class Discriminator(torch.nn.Module): +class Discriminator(ModelMixin, ConfigMixin): """ StyleGAN-T-style discriminator for adversarial diffusion distillation (ADD). """ + @register_to_config def __init__( self, pretrained_feature_network: str = "vit_small_patch14_dinov2.lvd142m", @@ -1638,6 +1642,8 @@ def save_model_hook(models, weights, output_dir): for i, model in enumerate(models): # model.save_pretrained(os.path.join(output_dir, "unet")) + if not isinstance(model, PeftModel): + model.save_pretrained(os.path.join(output_dir, "discriminator")) # make sure to pop weight so that corresponding model is not saved again weights.pop() @@ -1655,15 +1661,15 @@ def load_model_hook(models, input_dir): for i in range(len(models)): # pop models so that they are not loaded again - # model = models.pop() - models.pop() + model = models.pop() # load diffusers style into model - # load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") - # model.register_to_config(**load_model.config) + if not isinstance(model, PeftModel): + load_model = Discriminator.from_pretrained(input_dir, subfolder="discriminator") + model.register_to_config(**load_model.config) - # model.load_state_dict(load_model.state_dict()) - # del load_model + model.load_state_dict(load_model.state_dict()) + del load_model accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index c5b2245be777..039e7e97673d 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -63,6 +63,8 @@ StableDiffusionPipeline, UNet2DConditionModel, ) +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel, resolve_interpolation_mode from diffusers.utils import BaseOutput, check_min_version, is_wandb_available @@ -560,11 +562,12 @@ class DiscriminatorOutput(BaseOutput): # Based on ProjectedDiscriminator from the official StyleGAN-T code # https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L130 -class Discriminator(torch.nn.Module): +class Discriminator(ModelMixin, ConfigMixin): """ StyleGAN-T-style discriminator for adversarial diffusion distillation (ADD). """ + @register_to_config def __init__( self, pretrained_feature_network: str = "vit_small_patch14_dinov2.lvd142m", @@ -1511,7 +1514,10 @@ def save_model_hook(models, weights, output_dir): ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) for i, model in enumerate(models): - model.save_pretrained(os.path.join(output_dir, "unet")) + if isinstance(model, UNet2DConditionModel): + model.save_pretrained(os.path.join(output_dir, "unet")) + else: + model.save_pretrained(os.path.join(output_dir, "discriminator")) # make sure to pop weight so that corresponding model is not saved again weights.pop() @@ -1528,7 +1534,10 @@ def load_model_hook(models, input_dir): model = models.pop() # load diffusers style into model - load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + if isinstance(model, UNet2DConditionModel): + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + else: + load_model = Discriminator.from_pretrained(input_dir, subfolder="discriminator") model.register_to_config(**load_model.config) model.load_state_dict(load_model.state_dict()) diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index d6029618b640..bb9affe0949f 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -63,6 +63,8 @@ StableDiffusionXLPipeline, UNet2DConditionModel, ) +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel, resolve_interpolation_mode from diffusers.utils import BaseOutput, check_min_version, is_wandb_available @@ -577,11 +579,12 @@ class DiscriminatorOutput(BaseOutput): # Based on ProjectedDiscriminator from the official StyleGAN-T code # https://github.com/autonomousvision/stylegan-t/blob/36ab80ce76237fefe03e65e9b3161c040ae888e3/networks/discriminator.py#L130 -class Discriminator(torch.nn.Module): +class Discriminator(ModelMixin, ConfigMixin): """ StyleGAN-T-style discriminator for adversarial diffusion distillation (ADD). """ + @register_to_config def __init__( self, pretrained_feature_network: str = "vit_small_patch14_dinov2.lvd142m", @@ -1550,7 +1553,10 @@ def save_model_hook(models, weights, output_dir): ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) for i, model in enumerate(models): - model.save_pretrained(os.path.join(output_dir, "unet")) + if isinstance(model, UNet2DConditionModel): + model.save_pretrained(os.path.join(output_dir, "unet")) + else: + model.save_pretrained(os.path.join(output_dir, "discriminator")) # make sure to pop weight so that corresponding model is not saved again weights.pop() @@ -1567,7 +1573,10 @@ def load_model_hook(models, input_dir): model = models.pop() # load diffusers style into model - load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + if isinstance(model, UNet2DConditionModel): + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + else: + load_model = Discriminator.from_pretrained(input_dir, subfolder="discriminator") model.register_to_config(**load_model.config) model.load_state_dict(load_model.state_dict()) From 286b8c5b9735464a5ebb74deedeaf15b2b9b81a7 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sun, 14 Jan 2024 18:32:01 -0800 Subject: [PATCH 47/53] Propagate discriminator R1 penalty fix to other scripts. --- examples/add/train_add_distill_lora_sd_wds.py | 75 +++++++++++-------- .../add/train_add_distill_lora_sdxl_wds.py | 75 +++++++++++-------- examples/add/train_add_distill_sdxl_wds.py | 75 +++++++++++-------- 3 files changed, 135 insertions(+), 90 deletions(-) diff --git a/examples/add/train_add_distill_lora_sd_wds.py b/examples/add/train_add_distill_lora_sd_wds.py index bbb951d71f29..8a6aa22502f2 100644 --- a/examples/add/train_add_distill_lora_sd_wds.py +++ b/examples/add/train_add_distill_lora_sd_wds.py @@ -631,26 +631,27 @@ def train(self, mode: bool = True): def eval(self): return self.train(False) - def forward( - self, - x: torch.Tensor, - c_text: torch.Tensor, - c_img: Optional[torch.Tensor] = None, - transform_positive: bool = True, - return_dict: bool = True, - ): - # TODO: do we need the augmentations from the original StyleGAN-T code? + def get_features(self, image: torch.Tensor, transform_positive: bool = True) -> Dict[str, torch.Tensor]: if transform_positive: # Transform to [0, 1]. - x = x.add(1).div(2) + image = image.add(1).div(2) # Forward pass through feature network. - features = self.feature_network(x) + features = self.feature_network(image) + return features + def forward_features( + self, + features: Dict[str, torch.Tensor], + c_text: torch.Tensor, + c_img: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + batch_size = features["0"].size(0) # Apply discriminator heads. logits = [] for k, head in self.heads.items(): - logits.append(head(features[k], c_text, c_img).view(x.size(0), -1)) + logits.append(head(features[k], c_text, c_img).view(batch_size, -1)) logits = torch.cat(logits, dim=1) if not return_dict: @@ -658,6 +659,18 @@ def forward( return DiscriminatorOutput(logits=logits, features=features) + def forward( + self, + image: torch.Tensor, + c_text: torch.Tensor, + c_img: Optional[torch.Tensor] = None, + transform_positive: bool = True, + return_dict: bool = True, + ): + features = self.get_features(image, transform_positive=transform_positive) + d_output = self.forward_features(features, c_text, c_img=c_img, return_dict=return_dict) + return d_output + def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="student"): logger.info("Running validation... ") @@ -1940,32 +1953,34 @@ def compute_image_embeddings(image_batch, image_encoder): ) student_gen_image = torch.cat(student_gen_image, dim=0) - # 2. Get discriminator real/fake outputs on the real and fake (generated) images respectively. - disc_output_real = discriminator(pixel_values.float(), text_embedding, image_embedding) - disc_output_fake = discriminator(student_gen_image.detach().float(), text_embedding, image_embedding) - - # 3. Calculate the discriminator real adversarial loss terms. - d_logits_real = disc_output_real.logits + # 2. Calculate the discriminator real adversarial loss terms. + features_real = discriminator.get_features(pixel_values.float()) + for k, feature in features_real.items(): + # Required so that the torch.autograd.grad call below works properly? + feature.requires_grad_(True) + d_logits_real = discriminator.forward_features( + features_real, text_embedding.float(), image_embedding, return_dict=False + )[0] # Use hinge loss (see section 3.2, Equation 3 of paper) d_adv_loss_real = torch.mean(F.relu(torch.ones_like(d_logits_real) - d_logits_real)) - # 4. Calculate the discriminator R1 gradient penalty term with respect to the gradients from the real - # data. + # 3. Calculate the discriminator R1 gradient penalty term on the gradients with respect to the + # discriminator head input features from the real data. d_r1_regularizer = 0 - for k, head in discriminator.heads.items(): - head_grad_params = torch.autograd.grad( - outputs=d_adv_loss_real, inputs=head.parameters(), create_graph=True - ) - head_grad_norm = 0 - for grad in head_grad_params: - head_grad_norm += grad.pow(2).sum() - head_grad_norm = head_grad_norm.sqrt() - d_r1_regularizer += head_grad_norm + grad_params = torch.autograd.grad( + outputs=d_adv_loss_real, + inputs=features_real.values(), + create_graph=True, + ) + for grad in grad_params: + d_r1_regularizer += grad.pow(2).sum() + d_r1_regularizer = d_r1_regularizer.sqrt() d_loss_real = d_adv_loss_real + args.discriminator_r1_strength * d_r1_regularizer accelerator.backward(d_loss_real, retain_graph=True) - # 5. Calculate the discriminator fake adversarial loss terms. + # 4. Calculate the discriminator fake adversarial loss terms. + disc_output_fake = discriminator(student_gen_image.detach().float(), text_embedding, image_embedding) d_logits_fake = disc_output_fake.logits # Use hinge loss (see section 3.2, Equation 3 of paper) d_adv_loss_fake = torch.mean(F.relu(torch.ones_like(d_logits_fake) + d_logits_fake)) diff --git a/examples/add/train_add_distill_lora_sdxl_wds.py b/examples/add/train_add_distill_lora_sdxl_wds.py index b14a813f51e0..4f673c3beccb 100644 --- a/examples/add/train_add_distill_lora_sdxl_wds.py +++ b/examples/add/train_add_distill_lora_sdxl_wds.py @@ -648,26 +648,27 @@ def train(self, mode: bool = True): def eval(self): return self.train(False) - def forward( - self, - x: torch.Tensor, - c_text: torch.Tensor, - c_img: Optional[torch.Tensor] = None, - transform_positive: bool = True, - return_dict: bool = True, - ): - # TODO: do we need the augmentations from the original StyleGAN-T code? + def get_features(self, image: torch.Tensor, transform_positive: bool = True) -> Dict[str, torch.Tensor]: if transform_positive: # Transform to [0, 1]. - x = x.add(1).div(2) + image = image.add(1).div(2) # Forward pass through feature network. - features = self.feature_network(x) + features = self.feature_network(image) + return features + def forward_features( + self, + features: Dict[str, torch.Tensor], + c_text: torch.Tensor, + c_img: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + batch_size = features["0"].size(0) # Apply discriminator heads. logits = [] for k, head in self.heads.items(): - logits.append(head(features[k], c_text, c_img).view(x.size(0), -1)) + logits.append(head(features[k], c_text, c_img).view(batch_size, -1)) logits = torch.cat(logits, dim=1) if not return_dict: @@ -675,6 +676,18 @@ def forward( return DiscriminatorOutput(logits=logits, features=features) + def forward( + self, + image: torch.Tensor, + c_text: torch.Tensor, + c_img: Optional[torch.Tensor] = None, + transform_positive: bool = True, + return_dict: bool = True, + ): + features = self.get_features(image, transform_positive=transform_positive) + d_output = self.forward_features(features, c_text, c_img=c_img, return_dict=return_dict) + return d_output + def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="student"): logger.info("Running validation... ") @@ -2033,32 +2046,34 @@ def compute_image_embeddings(image_batch, image_encoder): ) student_gen_image = torch.cat(student_gen_image, dim=0) - # 2. Get discriminator real/fake outputs on the real and fake (generated) images respectively. - disc_output_real = discriminator(pixel_values.float(), text_embedding, image_embedding) - disc_output_fake = discriminator(student_gen_image.detach().float(), text_embedding, image_embedding) - - # 3. Calculate the discriminator real adversarial loss terms. - d_logits_real = disc_output_real.logits + # 2. Calculate the discriminator real adversarial loss terms. + features_real = discriminator.get_features(pixel_values.float()) + for k, feature in features_real.items(): + # Required so that the torch.autograd.grad call below works properly? + feature.requires_grad_(True) + d_logits_real = discriminator.forward_features( + features_real, text_embedding.float(), image_embedding, return_dict=False + )[0] # Use hinge loss (see section 3.2, Equation 3 of paper) d_adv_loss_real = torch.mean(F.relu(torch.ones_like(d_logits_real) - d_logits_real)) - # 4. Calculate the discriminator R1 gradient penalty term with respect to the gradients from the real - # data. + # 3. Calculate the discriminator R1 gradient penalty term on the gradients with respect to the + # discriminator head input features from the real data. d_r1_regularizer = 0 - for k, head in discriminator.heads.items(): - head_grad_params = torch.autograd.grad( - outputs=d_adv_loss_real, inputs=head.parameters(), create_graph=True - ) - head_grad_norm = 0 - for grad in head_grad_params: - head_grad_norm += grad.pow(2).sum() - head_grad_norm = head_grad_norm.sqrt() - d_r1_regularizer += head_grad_norm + grad_params = torch.autograd.grad( + outputs=d_adv_loss_real, + inputs=features_real.values(), + create_graph=True, + ) + for grad in grad_params: + d_r1_regularizer += grad.pow(2).sum() + d_r1_regularizer = d_r1_regularizer.sqrt() d_loss_real = d_adv_loss_real + args.discriminator_r1_strength * d_r1_regularizer accelerator.backward(d_loss_real, retain_graph=True) - # 5. Calculate the discriminator fake adversarial loss terms. + # 4. Calculate the discriminator fake adversarial loss terms. + disc_output_fake = discriminator(student_gen_image.detach().float(), text_embedding, image_embedding) d_logits_fake = disc_output_fake.logits # Use hinge loss (see section 3.2, Equation 3 of paper) d_adv_loss_fake = torch.mean(F.relu(torch.ones_like(d_logits_fake) + d_logits_fake)) diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index bb9affe0949f..e5dde8a1e2aa 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -629,26 +629,27 @@ def train(self, mode: bool = True): def eval(self): return self.train(False) - def forward( - self, - x: torch.Tensor, - c_text: torch.Tensor, - c_img: Optional[torch.Tensor] = None, - transform_positive: bool = True, - return_dict: bool = True, - ): - # TODO: do we need the augmentations from the original StyleGAN-T code? + def get_features(self, image: torch.Tensor, transform_positive: bool = True) -> Dict[str, torch.Tensor]: if transform_positive: # Transform to [0, 1]. - x = x.add(1).div(2) + image = image.add(1).div(2) # Forward pass through feature network. - features = self.feature_network(x) + features = self.feature_network(image) + return features + def forward_features( + self, + features: Dict[str, torch.Tensor], + c_text: torch.Tensor, + c_img: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + batch_size = features["0"].size(0) # Apply discriminator heads. logits = [] for k, head in self.heads.items(): - logits.append(head(features[k], c_text, c_img).view(x.size(0), -1)) + logits.append(head(features[k], c_text, c_img).view(batch_size, -1)) logits = torch.cat(logits, dim=1) if not return_dict: @@ -656,6 +657,18 @@ def forward( return DiscriminatorOutput(logits=logits, features=features) + def forward( + self, + image: torch.Tensor, + c_text: torch.Tensor, + c_img: Optional[torch.Tensor] = None, + transform_positive: bool = True, + return_dict: bool = True, + ): + features = self.get_features(image, transform_positive=transform_positive) + d_output = self.forward_features(features, c_text, c_img=c_img, return_dict=return_dict) + return d_output + def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="student"): logger.info("Running validation... ") @@ -1944,32 +1957,34 @@ def compute_image_embeddings(image_batch, image_encoder): ) student_gen_image = torch.cat(student_gen_image, dim=0) - # 2. Get discriminator real/fake outputs on the real and fake (generated) images respectively. - disc_output_real = discriminator(pixel_values.float(), text_embedding, image_embedding) - disc_output_fake = discriminator(student_gen_image.detach().float(), text_embedding, image_embedding) - - # 3. Calculate the discriminator real adversarial loss terms. - d_logits_real = disc_output_real.logits + # 2. Calculate the discriminator real adversarial loss terms. + features_real = discriminator.get_features(pixel_values.float()) + for k, feature in features_real.items(): + # Required so that the torch.autograd.grad call below works properly? + feature.requires_grad_(True) + d_logits_real = discriminator.forward_features( + features_real, text_embedding.float(), image_embedding, return_dict=False + )[0] # Use hinge loss (see section 3.2, Equation 3 of paper) d_adv_loss_real = torch.mean(F.relu(torch.ones_like(d_logits_real) - d_logits_real)) - # 4. Calculate the discriminator R1 gradient penalty term with respect to the gradients from the real - # data. + # 3. Calculate the discriminator R1 gradient penalty term on the gradients with respect to the + # discriminator head input features from the real data. d_r1_regularizer = 0 - for k, head in discriminator.heads.items(): - head_grad_params = torch.autograd.grad( - outputs=d_adv_loss_real, inputs=head.parameters(), create_graph=True - ) - head_grad_norm = 0 - for grad in head_grad_params: - head_grad_norm += grad.pow(2).sum() - head_grad_norm = head_grad_norm.sqrt() - d_r1_regularizer += head_grad_norm + grad_params = torch.autograd.grad( + outputs=d_adv_loss_real, + inputs=features_real.values(), + create_graph=True, + ) + for grad in grad_params: + d_r1_regularizer += grad.pow(2).sum() + d_r1_regularizer = d_r1_regularizer.sqrt() d_loss_real = d_adv_loss_real + args.discriminator_r1_strength * d_r1_regularizer accelerator.backward(d_loss_real, retain_graph=True) - # 5. Calculate the discriminator fake adversarial loss terms. + # 4. Calculate the discriminator fake adversarial loss terms. + disc_output_fake = discriminator(student_gen_image.detach().float(), text_embedding, image_embedding) d_logits_fake = disc_output_fake.logits # Use hinge loss (see section 3.2, Equation 3 of paper) d_adv_loss_fake = torch.mean(F.relu(torch.ones_like(d_logits_fake) + d_logits_fake)) From 686af03eb415953ab286a45a51844696a29dc55d Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sun, 14 Jan 2024 19:02:42 -0800 Subject: [PATCH 48/53] Fix R1 grad penalty L2 norm implementation. --- examples/add/train_add_distill_lora_sd_wds.py | 7 +++---- examples/add/train_add_distill_lora_sdxl_wds.py | 7 +++---- examples/add/train_add_distill_sd_wds.py | 7 +++---- examples/add/train_add_distill_sdxl_wds.py | 7 +++---- 4 files changed, 12 insertions(+), 16 deletions(-) diff --git a/examples/add/train_add_distill_lora_sd_wds.py b/examples/add/train_add_distill_lora_sd_wds.py index 8a6aa22502f2..47a8d9f5350e 100644 --- a/examples/add/train_add_distill_lora_sd_wds.py +++ b/examples/add/train_add_distill_lora_sd_wds.py @@ -1967,14 +1967,13 @@ def compute_image_embeddings(image_batch, image_encoder): # 3. Calculate the discriminator R1 gradient penalty term on the gradients with respect to the # discriminator head input features from the real data. d_r1_regularizer = 0 - grad_params = torch.autograd.grad( + feature_grads = torch.autograd.grad( outputs=d_adv_loss_real, inputs=features_real.values(), create_graph=True, ) - for grad in grad_params: - d_r1_regularizer += grad.pow(2).sum() - d_r1_regularizer = d_r1_regularizer.sqrt() + for grad in feature_grads: + d_r1_regularizer += torch.linalg.vector_norm(grad.view(grad.size(0), -1), dim=1).pow(2).mean() d_loss_real = d_adv_loss_real + args.discriminator_r1_strength * d_r1_regularizer accelerator.backward(d_loss_real, retain_graph=True) diff --git a/examples/add/train_add_distill_lora_sdxl_wds.py b/examples/add/train_add_distill_lora_sdxl_wds.py index 4f673c3beccb..d4269577161d 100644 --- a/examples/add/train_add_distill_lora_sdxl_wds.py +++ b/examples/add/train_add_distill_lora_sdxl_wds.py @@ -2060,14 +2060,13 @@ def compute_image_embeddings(image_batch, image_encoder): # 3. Calculate the discriminator R1 gradient penalty term on the gradients with respect to the # discriminator head input features from the real data. d_r1_regularizer = 0 - grad_params = torch.autograd.grad( + feature_grads = torch.autograd.grad( outputs=d_adv_loss_real, inputs=features_real.values(), create_graph=True, ) - for grad in grad_params: - d_r1_regularizer += grad.pow(2).sum() - d_r1_regularizer = d_r1_regularizer.sqrt() + for grad in feature_grads: + d_r1_regularizer += torch.linalg.vector_norm(grad.view(grad.size(0), -1), dim=1).pow(2).mean() d_loss_real = d_adv_loss_real + args.discriminator_r1_strength * d_r1_regularizer accelerator.backward(d_loss_real, retain_graph=True) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 039e7e97673d..1b4fa62ba3d4 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -1878,14 +1878,13 @@ def compute_image_embeddings(image_batch, image_encoder): # 3. Calculate the discriminator R1 gradient penalty term on the gradients with respect to the # discriminator head input features from the real data. d_r1_regularizer = 0 - grad_params = torch.autograd.grad( + feature_grads = torch.autograd.grad( outputs=d_adv_loss_real, inputs=features_real.values(), create_graph=True, ) - for grad in grad_params: - d_r1_regularizer += grad.pow(2).sum() - d_r1_regularizer = d_r1_regularizer.sqrt() + for grad in feature_grads: + d_r1_regularizer += torch.linalg.vector_norm(grad.view(grad.size(0), -1), dim=1).pow(2).mean() d_loss_real = d_adv_loss_real + args.discriminator_r1_strength * d_r1_regularizer accelerator.backward(d_loss_real, retain_graph=True) diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index e5dde8a1e2aa..29538b8ae2c4 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -1971,14 +1971,13 @@ def compute_image_embeddings(image_batch, image_encoder): # 3. Calculate the discriminator R1 gradient penalty term on the gradients with respect to the # discriminator head input features from the real data. d_r1_regularizer = 0 - grad_params = torch.autograd.grad( + feature_grads = torch.autograd.grad( outputs=d_adv_loss_real, inputs=features_real.values(), create_graph=True, ) - for grad in grad_params: - d_r1_regularizer += grad.pow(2).sum() - d_r1_regularizer = d_r1_regularizer.sqrt() + for grad in feature_grads: + d_r1_regularizer += torch.linalg.vector_norm(grad.view(grad.size(0), -1), dim=1).pow(2).mean() d_loss_real = d_adv_loss_real + args.discriminator_r1_strength * d_r1_regularizer accelerator.backward(d_loss_real, retain_graph=True) From 807150c04b85732f440c2775f9f7ec5d5eb264e1 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 20 Jan 2024 17:32:13 -0800 Subject: [PATCH 49/53] Fix bug when performing validation in the LoRA scripts. --- examples/add/train_add_distill_lora_sd_wds.py | 4 +--- examples/add/train_add_distill_lora_sdxl_wds.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/examples/add/train_add_distill_lora_sd_wds.py b/examples/add/train_add_distill_lora_sd_wds.py index 47a8d9f5350e..e52bd5c50b30 100644 --- a/examples/add/train_add_distill_lora_sd_wds.py +++ b/examples/add/train_add_distill_lora_sd_wds.py @@ -679,11 +679,9 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="stude pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_teacher_model, vae=vae, - unet=unet, revision=args.revision, torch_dtype=weight_dtype, - ) - pipeline = pipeline.to(accelerator.device) + ).to(accelerator.device) pipeline.set_progress_bar_config(disable=True) lora_state_dict = get_module_kohya_state_dict(unet, "lora_unet", weight_dtype) diff --git a/examples/add/train_add_distill_lora_sdxl_wds.py b/examples/add/train_add_distill_lora_sdxl_wds.py index d4269577161d..c64ec2c447fa 100644 --- a/examples/add/train_add_distill_lora_sdxl_wds.py +++ b/examples/add/train_add_distill_lora_sdxl_wds.py @@ -696,11 +696,9 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="stude pipeline = StableDiffusionXLPipeline.from_pretrained( args.pretrained_teacher_model, vae=vae, - unet=unet, revision=args.revision, torch_dtype=weight_dtype, - ) - pipeline = pipeline.to(accelerator.device) + ).to(accelerator.device) pipeline.set_progress_bar_config(disable=True) lora_state_dict = get_module_kohya_state_dict(unet, "lora_unet", weight_dtype) From 8ff02b3d1f21295c93a926499bf6b3b4d7e23c77 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Sat, 20 Jan 2024 18:06:39 -0800 Subject: [PATCH 50/53] Fix bugs when using image conditioning for the discriminator. --- examples/add/train_add_distill_lora_sd_wds.py | 10 +++++++--- examples/add/train_add_distill_lora_sdxl_wds.py | 10 +++++++--- examples/add/train_add_distill_sd_wds.py | 10 +++++++--- examples/add/train_add_distill_sdxl_wds.py | 10 +++++++--- 4 files changed, 28 insertions(+), 12 deletions(-) diff --git a/examples/add/train_add_distill_lora_sd_wds.py b/examples/add/train_add_distill_lora_sd_wds.py index e52bd5c50b30..af38ce4ebdcf 100644 --- a/examples/add/train_add_distill_lora_sd_wds.py +++ b/examples/add/train_add_distill_lora_sd_wds.py @@ -210,6 +210,7 @@ def transform(example): # Prepare a separate image for image conditioning since the preprocessing pipelines are different. cond_image = TF.resize(cond_image, cond_resolution, interpolation=cond_interpolation_mode) cond_image = TF.center_crop(cond_image, output_size=(cond_resolution, cond_resolution)) + cond_image = TF.to_tensor(cond_image) cond_image = TF.normalize(cond_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) example["cond_image"] = cond_image @@ -1327,8 +1328,9 @@ def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompt return prompt_embeds, pooled_output -def encode_images(image_batch, image_encoder): +def encode_images(image_batch, image_encoder, device, weight_dtype): # image_encoder pre-processing is done in SDText2ImageDataset + image_batch = image_batch.to(device=device, dtype=weight_dtype) image_embeds = image_encoder(image_batch) return image_embeds @@ -1698,8 +1700,8 @@ def compute_embeddings( ) return {"prompt_embeds": prompt_embeds, "text_embedding": text_embedding} - def compute_image_embeddings(image_batch, image_encoder): - image_embeds = encode_images(image_batch, image_encoder) + def compute_image_embeddings(image_batch, image_encoder, device, weight_dtype): + image_embeds = encode_images(image_batch, image_encoder, device, weight_dtype) return {"image_embeds": image_embeds} dataset = SDText2ImageDataset( @@ -1731,6 +1733,8 @@ def compute_image_embeddings(image_batch, image_encoder): compute_image_embeddings_fn = functools.partial( compute_image_embeddings, image_encoder=image_encoder, + device=accelerator.device, + weight_dtype=weight_dtype, ) # 15. Create learning rate scheduler for generator and discriminator diff --git a/examples/add/train_add_distill_lora_sdxl_wds.py b/examples/add/train_add_distill_lora_sdxl_wds.py index c64ec2c447fa..4db450a7b5b1 100644 --- a/examples/add/train_add_distill_lora_sdxl_wds.py +++ b/examples/add/train_add_distill_lora_sdxl_wds.py @@ -223,6 +223,7 @@ def transform(example): # Prepare a separate image for image conditioning since the preprocessing pipelines are different. cond_image = TF.resize(cond_image, cond_resolution, interpolation=cond_interpolation_mode) cond_image = TF.center_crop(cond_image, output_size=(cond_resolution, cond_resolution)) + cond_image = TF.to_tensor(cond_image) cond_image = TF.normalize(cond_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) example["cond_image"] = cond_image @@ -1360,8 +1361,9 @@ def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prom return prompt_embeds, pooled_prompt_embeds -def encode_images(image_batch, image_encoder): +def encode_images(image_batch, image_encoder, device, weight_dtype): # image_encoder pre-processing is done in SDText2ImageDataset + image_batch = image_batch.to(device=device, dtype=weight_dtype) image_embeds = image_encoder(image_batch) return image_embeds @@ -1770,8 +1772,8 @@ def compute_embeddings( return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} - def compute_image_embeddings(image_batch, image_encoder): - image_embeds = encode_images(image_batch, image_encoder) + def compute_image_embeddings(image_batch, image_encoder, device, weight_dtype): + image_embeds = encode_images(image_batch, image_encoder, device, weight_dtype) return {"image_embeds": image_embeds} dataset = SDXLText2ImageDataset( @@ -1808,6 +1810,8 @@ def compute_image_embeddings(image_batch, image_encoder): compute_image_embeddings_fn = functools.partial( compute_image_embeddings, image_encoder=image_encoder, + device=accelerator.device, + weight_dtype=weight_dtype, ) # 15. Create learning rate scheduler for generator and discriminator diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 1b4fa62ba3d4..7d85300b4473 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -191,6 +191,7 @@ def transform(example): # Prepare a separate image for image conditioning since the preprocessing pipelines are different. cond_image = TF.resize(cond_image, cond_resolution, interpolation=cond_interpolation_mode) cond_image = TF.center_crop(cond_image, output_size=(cond_resolution, cond_resolution)) + cond_image = TF.to_tensor(cond_image) cond_image = TF.normalize(cond_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) example["cond_image"] = cond_image @@ -1275,8 +1276,9 @@ def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompt return prompt_embeds, pooled_output -def encode_images(image_batch, image_encoder): +def encode_images(image_batch, image_encoder, device, weight_dtype): # image_encoder pre-processing is done in SDText2ImageDataset + image_batch = image_batch.to(device=device, dtype=weight_dtype) image_embeds = image_encoder(image_batch) return image_embeds @@ -1611,8 +1613,8 @@ def compute_embeddings( ) return {"prompt_embeds": prompt_embeds, "text_embedding": text_embedding} - def compute_image_embeddings(image_batch, image_encoder): - image_embeds = encode_images(image_batch, image_encoder) + def compute_image_embeddings(image_batch, image_encoder, device, weight_dtype): + image_embeds = encode_images(image_batch, image_encoder, device, weight_dtype) return {"image_embeds": image_embeds} dataset = SDText2ImageDataset( @@ -1644,6 +1646,8 @@ def compute_image_embeddings(image_batch, image_encoder): compute_image_embeddings_fn = functools.partial( compute_image_embeddings, image_encoder=image_encoder, + device=accelerator.device, + weight_dtype=weight_dtype, ) # 14. Create learning rate scheduler for generator and discriminator diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index 29538b8ae2c4..e5d53def0cff 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -204,6 +204,7 @@ def transform(example): # Prepare a separate image for image conditioning since the preprocessing pipelines are different. cond_image = TF.resize(cond_image, cond_resolution, interpolation=cond_interpolation_mode) cond_image = TF.center_crop(cond_image, output_size=(cond_resolution, cond_resolution)) + cond_image = TF.to_tensor(cond_image) cond_image = TF.normalize(cond_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD) example["cond_image"] = cond_image @@ -1308,8 +1309,9 @@ def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prom return prompt_embeds, pooled_prompt_embeds -def encode_images(image_batch, image_encoder): +def encode_images(image_batch, image_encoder, device, weight_dtype): # image_encoder pre-processing is done in SDText2ImageDataset + image_batch = image_batch.to(device=device, dtype=weight_dtype) image_embeds = image_encoder(image_batch) return image_embeds @@ -1683,8 +1685,8 @@ def compute_embeddings( return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} - def compute_image_embeddings(image_batch, image_encoder): - image_embeds = encode_images(image_batch, image_encoder) + def compute_image_embeddings(image_batch, image_encoder, device, weight_dtype): + image_embeds = encode_images(image_batch, image_encoder, device, weight_dtype) return {"image_embeds": image_embeds} dataset = SDXLText2ImageDataset( @@ -1721,6 +1723,8 @@ def compute_image_embeddings(image_batch, image_encoder): compute_image_embeddings_fn = functools.partial( compute_image_embeddings, image_encoder=image_encoder, + device=accelerator.device, + weight_dtype=weight_dtype, ) # 14. Create learning rate scheduler for generator and discriminator From 15c292d3460bcf77117e6f7ab757acebf596eb00 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Wed, 7 Feb 2024 23:38:23 -0800 Subject: [PATCH 51/53] Disable SD 1.X/2.X safety checker when doing validation. --- examples/add/train_add_distill_lora_sd_wds.py | 1 + examples/add/train_add_distill_sd_wds.py | 1 + 2 files changed, 2 insertions(+) diff --git a/examples/add/train_add_distill_lora_sd_wds.py b/examples/add/train_add_distill_lora_sd_wds.py index af38ce4ebdcf..90cc00691ef0 100644 --- a/examples/add/train_add_distill_lora_sd_wds.py +++ b/examples/add/train_add_distill_lora_sd_wds.py @@ -682,6 +682,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="stude vae=vae, revision=args.revision, torch_dtype=weight_dtype, + safety_checker=None, ).to(accelerator.device) pipeline.set_progress_bar_config(disable=True) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index 7d85300b4473..c702fc549b1b 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -664,6 +664,7 @@ def log_validation(vae, unet, args, accelerator, weight_dtype, step, name="stude unet=unet, revision=args.revision, torch_dtype=weight_dtype, + safety_checker=None, ) pipeline = pipeline.to(accelerator.device) pipeline.set_progress_bar_config(disable=True) From d9a037990ac08c6628525defda558ec74a4088e8 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Fri, 9 Feb 2024 19:27:52 -0800 Subject: [PATCH 52/53] When rescaling to zero terminal SNR replace last alpha with small positive value instead of zero following EulerDiscreteScheduler. --- examples/add/train_add_distill_lora_sd_wds.py | 9 +++++++-- examples/add/train_add_distill_lora_sdxl_wds.py | 9 +++++++-- examples/add/train_add_distill_sd_wds.py | 9 +++++++-- examples/add/train_add_distill_sdxl_wds.py | 9 +++++++-- 4 files changed, 28 insertions(+), 8 deletions(-) diff --git a/examples/add/train_add_distill_lora_sd_wds.py b/examples/add/train_add_distill_lora_sd_wds.py index 90cc00691ef0..707492bc7478 100644 --- a/examples/add/train_add_distill_lora_sd_wds.py +++ b/examples/add/train_add_distill_lora_sd_wds.py @@ -1401,8 +1401,13 @@ def main(args): # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us # Note that the ADD paper parameterizes alpha and sigma as x_t = alpha_t * x_0 + sigma_t * eps - alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) - sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) + scheduler_alphas = noise_scheduler.alphas_cumprod + if noise_scheduler.config.rescale_betas_zero_snr: + # When rescaling betas to zero terminal SNR, follow EulerDiscreteScheduler in setting the last alpha_cumprod + # (corresponding to the last training timestep) to a small positive value rather than 0 + scheduler_alphas[-1] = 2**-24 + alpha_schedule = torch.sqrt(scheduler_alphas) + sigma_schedule = torch.sqrt(1 - scheduler_alphas) # denoiser gets predicted original sample x_0 from prediction_type using alpha and sigma noise schedules denoiser = Denoiser(alpha_schedule, sigma_schedule) diff --git a/examples/add/train_add_distill_lora_sdxl_wds.py b/examples/add/train_add_distill_lora_sdxl_wds.py index 4db450a7b5b1..e93cc33c42fe 100644 --- a/examples/add/train_add_distill_lora_sdxl_wds.py +++ b/examples/add/train_add_distill_lora_sdxl_wds.py @@ -1433,8 +1433,13 @@ def main(args): # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us # Note that the ADD paper parameterizes alpha and sigma as x_t = alpha_t * x_0 + sigma_t * eps - alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) - sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) + scheduler_alphas = noise_scheduler.alphas_cumprod + if noise_scheduler.config.rescale_betas_zero_snr: + # When rescaling betas to zero terminal SNR, follow EulerDiscreteScheduler in setting the last alpha_cumprod + # (corresponding to the last training timestep) to a small positive value rather than 0 + scheduler_alphas[-1] = 2**-24 + alpha_schedule = torch.sqrt(scheduler_alphas) + sigma_schedule = torch.sqrt(1 - scheduler_alphas) # denoiser gets predicted original sample x_0 from prediction_type using alpha and sigma noise schedules denoiser = Denoiser(alpha_schedule, sigma_schedule) diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index c702fc549b1b..e9fc6a472075 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -1349,8 +1349,13 @@ def main(args): # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us # Note that the ADD paper parameterizes alpha and sigma as x_t = alpha_t * x_0 + sigma_t * eps - alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) - sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) + scheduler_alphas = noise_scheduler.alphas_cumprod + if noise_scheduler.config.rescale_betas_zero_snr: + # When rescaling betas to zero terminal SNR, follow EulerDiscreteScheduler in setting the last alpha_cumprod + # (corresponding to the last training timestep) to a small positive value rather than 0 + scheduler_alphas[-1] = 2**-24 + alpha_schedule = torch.sqrt(scheduler_alphas) + sigma_schedule = torch.sqrt(1 - scheduler_alphas) # denoiser gets predicted original sample x_0 from prediction_type using alpha and sigma noise schedules denoiser = Denoiser(alpha_schedule, sigma_schedule) diff --git a/examples/add/train_add_distill_sdxl_wds.py b/examples/add/train_add_distill_sdxl_wds.py index e5d53def0cff..cdd538768a1f 100644 --- a/examples/add/train_add_distill_sdxl_wds.py +++ b/examples/add/train_add_distill_sdxl_wds.py @@ -1381,8 +1381,13 @@ def main(args): # DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us # Note that the ADD paper parameterizes alpha and sigma as x_t = alpha_t * x_0 + sigma_t * eps - alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod) - sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod) + scheduler_alphas = noise_scheduler.alphas_cumprod + if noise_scheduler.config.rescale_betas_zero_snr: + # When rescaling betas to zero terminal SNR, follow EulerDiscreteScheduler in setting the last alpha_cumprod + # (corresponding to the last training timestep) to a small positive value rather than 0 + scheduler_alphas[-1] = 2**-24 + alpha_schedule = torch.sqrt(scheduler_alphas) + sigma_schedule = torch.sqrt(1 - scheduler_alphas) # denoiser gets predicted original sample x_0 from prediction_type using alpha and sigma noise schedules denoiser = Denoiser(alpha_schedule, sigma_schedule) From bc4b516b701b65ad9b35f75503a7fb0fd01ff513 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Fri, 9 Feb 2024 20:23:41 -0800 Subject: [PATCH 53/53] For SD 1.X/2.X get last_hidden_state of uncond_prompt_embeds correctly whether we use a CLIPTextModel or CLIPTextModelWithProjection (e.g. with --use_pretrained_projection). --- examples/add/train_add_distill_lora_sd_wds.py | 11 +++++++---- examples/add/train_add_distill_sd_wds.py | 11 +++++++---- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/examples/add/train_add_distill_lora_sd_wds.py b/examples/add/train_add_distill_lora_sd_wds.py index 707492bc7478..2d62954dd8f9 100644 --- a/examples/add/train_add_distill_lora_sd_wds.py +++ b/examples/add/train_add_distill_lora_sd_wds.py @@ -1452,9 +1452,12 @@ def main(args): student_distillation_steps = student_timestep_schedule.shape[0] # 2. Load tokenizers from SD 1.X/2.X checkpoint. - tokenizer = AutoTokenizer.from_pretrained( - args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False - ) + if args.use_pretrained_projection: + tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_with_proj, use_fast=False) + else: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False + ) # 3. Load text encoders from SD 1.X/2.X checkpoint. if args.use_pretrained_projection: @@ -1800,7 +1803,7 @@ def compute_image_embeddings(image_batch, image_encoder, device, weight_dtype): uncond_input_ids = tokenizer( [""] * args.train_batch_size, return_tensors="pt", padding="max_length", max_length=MAX_SEQ_LENGTH ).input_ids.to(accelerator.device) - uncond_prompt_embeds = text_encoder(uncond_input_ids)[0] + uncond_prompt_embeds = text_encoder(uncond_input_ids).last_hidden_state # 17. Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps diff --git a/examples/add/train_add_distill_sd_wds.py b/examples/add/train_add_distill_sd_wds.py index e9fc6a472075..fdcfa0e8c719 100644 --- a/examples/add/train_add_distill_sd_wds.py +++ b/examples/add/train_add_distill_sd_wds.py @@ -1400,9 +1400,12 @@ def main(args): student_distillation_steps = student_timestep_schedule.shape[0] # 2. Load tokenizers from SD 1.X/2.X checkpoint. - tokenizer = AutoTokenizer.from_pretrained( - args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False - ) + if args.use_pretrained_projection: + tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_with_proj, use_fast=False) + else: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_teacher_model, subfolder="tokenizer", revision=args.teacher_revision, use_fast=False + ) # 3. Load text encoders from SD 1.X/2.X checkpoint. if args.use_pretrained_projection: @@ -1713,7 +1716,7 @@ def compute_image_embeddings(image_batch, image_encoder, device, weight_dtype): uncond_input_ids = tokenizer( [""] * args.train_batch_size, return_tensors="pt", padding="max_length", max_length=MAX_SEQ_LENGTH ).input_ids.to(accelerator.device) - uncond_prompt_embeds = text_encoder(uncond_input_ids)[0] + uncond_prompt_embeds = text_encoder(uncond_input_ids).last_hidden_state # 16. Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps