From d29d97b616debea86e058c1addfa7ca1c85066fd Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Fri, 1 Dec 2023 15:18:43 +0200 Subject: [PATCH 01/14] [examples/advanced_diffusion_training] bug fixes and improvements for LoRA Dreambooth SDXL advanced training script (#5935) * imports and readme bug fixes * bug fix - ensures text_encoder params are dtype==float32 (when using pivotal tuning) even if the rest of the model is loaded in fp16 * added pivotal tuning to readme * mapping token identifier to new inserted token in validation prompt (if used) * correct default value of --train_text_encoder_frac * change default value of --adam_weight_decay_text_encoder * validation prompt generations when using pivotal tuning bug fix * style fix * textual inversion embeddings name change * style fix * bug fix - stopping text encoder optimization halfway * readme - will include token abstraction and new inserted tokens when using pivotal tuning - added type to --num_new_tokens_per_abstraction * style fix --------- Co-authored-by: Linoy Tsaban --- .../train_dreambooth_lora_sdxl_advanced.py | 148 +++++++++++------- 1 file changed, 92 insertions(+), 56 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index f032634a11f0..3fccd1786be5 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -54,7 +54,7 @@ UNet2DConditionModel, ) from diffusers.loaders import LoraLoaderMixin -from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict +from diffusers.models.lora import LoRALinearLayer from diffusers.optimization import get_scheduler from diffusers.training_utils import compute_snr, unet_lora_state_dict from diffusers.utils import check_min_version, is_wandb_available @@ -67,11 +67,46 @@ logger = get_logger(__name__) +# TODO: This function should be removed once training scripts are rewritten in PEFT +def text_encoder_lora_state_dict(text_encoder): + state_dict = {} + + def text_encoder_attn_modules(text_encoder): + from transformers import CLIPTextModel, CLIPTextModelWithProjection + + attn_modules = [] + + if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): + for i, layer in enumerate(text_encoder.text_model.encoder.layers): + name = f"text_model.encoder.layers.{i}.self_attn" + mod = layer.self_attn + attn_modules.append((name, mod)) + + return attn_modules + + for name, module in text_encoder_attn_modules(text_encoder): + for k, v in module.q_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v + + for k, v in module.k_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v + + for k, v in module.v_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v + + for k, v in module.out_proj.lora_linear_layer.state_dict().items(): + state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v + + return state_dict + + def save_model_card( repo_id: str, images=None, base_model=str, train_text_encoder=False, + train_text_encoder_ti=False, + token_abstraction_dict=None, instance_prompt=str, validation_prompt=str, repo_folder=None, @@ -83,10 +118,23 @@ def save_model_card( img_str += f""" - text: '{validation_prompt if validation_prompt else ' ' }' output: - url: >- + url: "image_{i}.png" """ + trigger_str = f"You should use {instance_prompt} to trigger the image generation." + if train_text_encoder_ti: + trigger_str = ( + "To trigger image generation of trained concept(or concepts) replace each concept identifier " + "in you prompt with the new inserted tokens:\n" + ) + if token_abstraction_dict: + for key, value in token_abstraction_dict.items(): + tokens = "".join(value) + trigger_str += f""" + to trigger concept {key}-> use {tokens} in your prompt \n + """ + yaml = f""" --- tags: @@ -96,9 +144,7 @@ def save_model_card( - diffusers - lora - template:sd-lora -widget: {img_str} ---- base_model: {base_model} instance_prompt: {instance_prompt} license: openrail++ @@ -112,14 +158,19 @@ def save_model_card( ## Model description -These are {repo_id} LoRA adaption weights for {base_model}. +### These are {repo_id} LoRA adaption weights for {base_model}. + The weights were trained using [DreamBooth](https://dreambooth.github.io/). + LoRA for the text encoder was enabled: {train_text_encoder}. + +Pivotal tuning was enabled: {train_text_encoder_ti}. + Special VAE used for training: {vae_path}. ## Trigger words -You should use {instance_prompt} to trigger the image generation. +{trigger_str} ## Download model @@ -244,6 +295,7 @@ def parse_args(input_args=None): parser.add_argument( "--num_new_tokens_per_abstraction", + type=int, default=2, help="number of new tokens inserted to the tokenizers per token_abstraction value when " "--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new " @@ -455,7 +507,7 @@ def parse_args(input_args=None): parser.add_argument( "--train_text_encoder_frac", type=float, - default=0.5, + default=1.0, help=("The percentage of epochs to perform text encoder tuning"), ) @@ -488,7 +540,7 @@ def parse_args(input_args=None): parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") parser.add_argument( - "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + "--adam_weight_decay_text_encoder", type=float, default=None, help="Weight decay to use for text_encoder" ) parser.add_argument( @@ -679,12 +731,19 @@ def initialize_new_tokens(self, inserting_toks: List[str]): def save_embeddings(self, file_path: str): assert self.train_ids is not None, "Initialize new tokens before saving embeddings." tensors = {} + # text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14 + idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"} for idx, text_encoder in enumerate(self.text_encoders): assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len( self.tokenizers[0] ), "Tokenizers should be the same." new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids] - tensors[f"text_encoders_{idx}"] = new_token_embeddings + + # New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for + # text_encoder 1) to keep compatible with the ecosystem. + # Note: When loading with diffusers, any name can work - simply specify in inference + tensors[idx_to_text_encoder_name[idx]] = new_token_embeddings + # tensors[f"text_encoders_{idx}"] = new_token_embeddings save_file(tensors, file_path) @@ -696,19 +755,6 @@ def dtype(self): def device(self): return self.text_encoders[0].device - # def _load_embeddings(self, loaded_embeddings, tokenizer, text_encoder): - # # Assuming new tokens are of the format - # self.inserting_toks = [f"" for i in range(loaded_embeddings.shape[0])] - # special_tokens_dict = {"additional_special_tokens": self.inserting_toks} - # tokenizer.add_special_tokens(special_tokens_dict) - # text_encoder.resize_token_embeddings(len(tokenizer)) - # - # self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks) - # assert self.train_ids is not None, "New tokens could not be converted to IDs." - # text_encoder.text_model.embeddings.token_embedding.weight.data[ - # self.train_ids - # ] = loaded_embeddings.to(device=self.device).to(dtype=self.dtype) - @torch.no_grad() def retract_embeddings(self): for idx, text_encoder in enumerate(self.text_encoders): @@ -730,15 +776,6 @@ def retract_embeddings(self): new_embeddings = new_embeddings * (off_ratio**0.1) text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings - # def load_embeddings(self, file_path: str): - # with safe_open(file_path, framework="pt", device=self.device.type) as f: - # for idx in range(len(self.text_encoders)): - # text_encoder = self.text_encoders[idx] - # tokenizer = self.tokenizers[idx] - # - # loaded_embeddings = f.get_tensor(f"text_encoders_{idx}") - # self._load_embeddings(loaded_embeddings, tokenizer, text_encoder) - class DreamBoothDataset(Dataset): """ @@ -1216,6 +1253,8 @@ def main(args): text_lora_parameters_one = [] for name, param in text_encoder_one.named_parameters(): if "token_embedding" in name: + # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 + param = param.to(dtype=torch.float32) param.requires_grad = True text_lora_parameters_one.append(param) else: @@ -1223,6 +1262,8 @@ def main(args): text_lora_parameters_two = [] for name, param in text_encoder_two.named_parameters(): if "token_embedding" in name: + # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 + param = param.to(dtype=torch.float32) param.requires_grad = True text_lora_parameters_two.append(param) else: @@ -1309,12 +1350,16 @@ def load_model_hook(models, input_dir): # different learning rate for text encoder and unet text_lora_parameters_one_with_lr = { "params": text_lora_parameters_one, - "weight_decay": args.adam_weight_decay_text_encoder, + "weight_decay": args.adam_weight_decay_text_encoder + if args.adam_weight_decay_text_encoder + else args.adam_weight_decay, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, } text_lora_parameters_two_with_lr = { "params": text_lora_parameters_two, - "weight_decay": args.adam_weight_decay_text_encoder, + "weight_decay": args.adam_weight_decay_text_encoder + if args.adam_weight_decay_text_encoder + else args.adam_weight_decay, "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, } params_to_optimize = [ @@ -1494,6 +1539,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) + if args.train_text_encoder_ti and args.validation_prompt: + # replace instances of --token_abstraction in validation prompt with the new tokens: "" etc. + for token_abs, token_replacement in train_dataset.token_abstraction_dict.items(): + args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement)) + print("validation prompt:", args.validation_prompt) + # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -1593,27 +1644,10 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): if epoch == num_train_epochs_text_encoder: print("PIVOT HALFWAY", epoch) # stopping optimization of text_encoder params - params_to_optimize = params_to_optimize[:1] - # reinitializing the optimizer to optimize only on unet params - if args.optimizer.lower() == "prodigy": - optimizer = optimizer_class( - params_to_optimize, - lr=args.learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - beta3=args.prodigy_beta3, - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - decouple=args.prodigy_decouple, - use_bias_correction=args.prodigy_use_bias_correction, - safeguard_warmup=args.prodigy_safeguard_warmup, - ) - else: # AdamW or 8-bit-AdamW - optimizer = optimizer_class( - params_to_optimize, - betas=(args.adam_beta1, args.adam_beta2), - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - ) + # re setting the optimizer to optimize only on unet params + optimizer.param_groups[1]["lr"] = 0.0 + optimizer.param_groups[2]["lr"] = 0.0 + else: # still optimizng the text encoder text_encoder_one.train() @@ -1628,7 +1662,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): with accelerator.accumulate(unet): pixel_values = batch["pixel_values"].to(dtype=vae.dtype) prompts = batch["prompts"] - print(prompts) + # print(prompts) # encode batch prompts when custom prompts are provided for each image - if train_dataset.custom_instance_prompts: if freeze_text_encoder: @@ -1801,7 +1835,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): f" {args.validation_prompt}." ) # create pipeline - if not args.train_text_encoder: + if freeze_text_encoder: text_encoder_one = text_encoder_cls_one.from_pretrained( args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) @@ -1948,6 +1982,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): images=images, base_model=args.pretrained_model_name_or_path, train_text_encoder=args.train_text_encoder, + train_text_encoder_ti=args.train_text_encoder_ti, + token_abstraction_dict=train_dataset.token_abstraction_dict, instance_prompt=args.instance_prompt, validation_prompt=args.validation_prompt, repo_folder=args.output_dir, From c1e45295418af5de11b23c2995fe828e33cce5c9 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Fri, 1 Dec 2023 16:14:57 +0200 Subject: [PATCH 02/14] [advanced_dreambooth_lora_sdxl_tranining_script] readme fix (#6019) readme --- .../train_dreambooth_lora_sdxl_advanced.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 3fccd1786be5..821cc8b0d2cc 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -132,8 +132,8 @@ def save_model_card( for key, value in token_abstraction_dict.items(): tokens = "".join(value) trigger_str += f""" - to trigger concept {key}-> use {tokens} in your prompt \n - """ +to trigger concept `{key}->` use `{tokens}` in your prompt \n +""" yaml = f""" --- From 6ba4c5395fb2f1720176efda79d2d5053fc8f62f Mon Sep 17 00:00:00 2001 From: Steven Liu <59462357+stevhliu@users.noreply.github.com> Date: Fri, 1 Dec 2023 07:07:47 -0800 Subject: [PATCH 03/14] [docs] Fix SVD video (#6004) Update svd.md --- docs/source/en/using-diffusers/svd.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/en/using-diffusers/svd.md b/docs/source/en/using-diffusers/svd.md index 4fdb2608aa76..7fd29284cbd0 100644 --- a/docs/source/en/using-diffusers/svd.md +++ b/docs/source/en/using-diffusers/svd.md @@ -53,8 +53,9 @@ frames = pipe(image, decode_chunk_size=8, generator=generator).frames[0] export_to_video(frames, "generated.mp4", fps=7) ``` -