From ba93a84a83520e8983fab32846bee9656783a7ad Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 14 Dec 2023 10:41:39 -0600 Subject: [PATCH 1/4] load pipeline for inference only if validation prompt is used --- .../train_dreambooth_lora_sdxl_advanced.py | 59 +++++++++---------- 1 file changed, 29 insertions(+), 30 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index a46a1afcc145..f1564a116537 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -2010,43 +2010,42 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): text_encoder_lora_layers=text_encoder_lora_layers, text_encoder_2_lora_layers=text_encoder_2_lora_layers, ) + images = [] + if args.validation_prompt and args.num_validation_images > 0: + # Final inference + # Load previous pipeline + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) - # Final inference - # Load previous pipeline - vae = AutoencoderKL.from_pretrained( - vae_path, - subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, - revision=args.revision, - variant=args.variant, - torch_dtype=weight_dtype, - ) - pipeline = StableDiffusionXLPipeline.from_pretrained( - args.pretrained_model_name_or_path, - vae=vae, - revision=args.revision, - variant=args.variant, - torch_dtype=weight_dtype, - ) - - # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it - scheduler_args = {} + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} - if "variance_type" in pipeline.scheduler.config: - variance_type = pipeline.scheduler.config.variance_type + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type - if variance_type in ["learned", "learned_range"]: - variance_type = "fixed_small" + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" - scheduler_args["variance_type"] = variance_type + scheduler_args["variance_type"] = variance_type - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) - # load attention processors - pipeline.load_lora_weights(args.output_dir) + # load attention processors + pipeline.load_lora_weights(args.output_dir) - # run inference - images = [] - if args.validation_prompt and args.num_validation_images > 0: + # run inference pipeline = pipeline.to(accelerator.device) generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None images = [ From b0f3d0c056b9c4d0adb09e7de159a68bdd43b9f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?apolin=C3=A1rio?= Date: Thu, 14 Dec 2023 10:52:32 -0600 Subject: [PATCH 2/4] move things outside --- .../train_dreambooth_lora_sdxl_advanced.py | 161 +++++++++--------- 1 file changed, 80 insertions(+), 81 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index f1564a116537..31a3c3b5c638 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -2010,88 +2010,87 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): text_encoder_lora_layers=text_encoder_lora_layers, text_encoder_2_lora_layers=text_encoder_2_lora_layers, ) - images = [] - if args.validation_prompt and args.num_validation_images > 0: - # Final inference - # Load previous pipeline - vae = AutoencoderKL.from_pretrained( - vae_path, - subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, - revision=args.revision, - variant=args.variant, - torch_dtype=weight_dtype, - ) - pipeline = StableDiffusionXLPipeline.from_pretrained( - args.pretrained_model_name_or_path, - vae=vae, - revision=args.revision, - variant=args.variant, - torch_dtype=weight_dtype, - ) - - # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it - scheduler_args = {} - - if "variance_type" in pipeline.scheduler.config: - variance_type = pipeline.scheduler.config.variance_type - - if variance_type in ["learned", "learned_range"]: - variance_type = "fixed_small" - - scheduler_args["variance_type"] = variance_type - - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) - - # load attention processors - pipeline.load_lora_weights(args.output_dir) - - # run inference - pipeline = pipeline.to(accelerator.device) - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None - images = [ - pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] - for _ in range(args.num_validation_images) - ] - - for tracker in accelerator.trackers: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "test": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) - - if args.train_text_encoder_ti: - embedding_handler.save_embeddings( - f"{args.output_dir}/embeddings.safetensors", - ) - save_model_card( - model_id if not args.push_to_hub else repo_id, - images=images, - base_model=args.pretrained_model_name_or_path, - train_text_encoder=args.train_text_encoder, - train_text_encoder_ti=args.train_text_encoder_ti, - token_abstraction_dict=train_dataset.token_abstraction_dict, - instance_prompt=args.instance_prompt, - validation_prompt=args.validation_prompt, - repo_folder=args.output_dir, - vae_path=args.pretrained_vae_model_name_or_path, - ) - if args.push_to_hub: - upload_folder( - repo_id=repo_id, - folder_path=args.output_dir, - commit_message="End of training", - ignore_patterns=["step_*", "epoch_*"], - ) - accelerator.end_training() + images = [] + if args.validation_prompt and args.num_validation_images > 0: + # Final inference + # Load previous pipeline + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) + + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference + pipeline = pipeline.to(accelerator.device) + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + images = [ + pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + for _ in range(args.num_validation_images) + ] + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + if args.train_text_encoder_ti: + embedding_handler.save_embeddings( + f"{args.output_dir}/embeddings.safetensors", + ) + save_model_card( + model_id if not args.push_to_hub else repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + train_text_encoder=args.train_text_encoder, + train_text_encoder_ti=args.train_text_encoder_ti, + token_abstraction_dict=train_dataset.token_abstraction_dict, + instance_prompt=args.instance_prompt, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + vae_path=args.pretrained_vae_model_name_or_path, + ) + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) if __name__ == "__main__": From 9d8c2d1e5e1615e79e0595a502c2dbff78c77439 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 14 Dec 2023 11:09:01 -0600 Subject: [PATCH 3/4] load pipeline for inference only if validation prompt is used --- .../train_dreambooth_lora_sdxl_advanced.py | 161 +++++++++--------- 1 file changed, 81 insertions(+), 80 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 31a3c3b5c638..f1564a116537 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -2010,87 +2010,88 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): text_encoder_lora_layers=text_encoder_lora_layers, text_encoder_2_lora_layers=text_encoder_2_lora_layers, ) - accelerator.end_training() - images = [] - if args.validation_prompt and args.num_validation_images > 0: - # Final inference - # Load previous pipeline - vae = AutoencoderKL.from_pretrained( - vae_path, - subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, - revision=args.revision, - variant=args.variant, - torch_dtype=weight_dtype, - ) - pipeline = StableDiffusionXLPipeline.from_pretrained( - args.pretrained_model_name_or_path, - vae=vae, - revision=args.revision, - variant=args.variant, - torch_dtype=weight_dtype, - ) - - # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it - scheduler_args = {} - - if "variance_type" in pipeline.scheduler.config: - variance_type = pipeline.scheduler.config.variance_type - - if variance_type in ["learned", "learned_range"]: - variance_type = "fixed_small" - - scheduler_args["variance_type"] = variance_type - - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) - - # load attention processors - pipeline.load_lora_weights(args.output_dir) - - # run inference - pipeline = pipeline.to(accelerator.device) - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None - images = [ - pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] - for _ in range(args.num_validation_images) - ] - - for tracker in accelerator.trackers: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "test": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) - - if args.train_text_encoder_ti: - embedding_handler.save_embeddings( - f"{args.output_dir}/embeddings.safetensors", - ) - save_model_card( - model_id if not args.push_to_hub else repo_id, - images=images, - base_model=args.pretrained_model_name_or_path, - train_text_encoder=args.train_text_encoder, - train_text_encoder_ti=args.train_text_encoder_ti, - token_abstraction_dict=train_dataset.token_abstraction_dict, - instance_prompt=args.instance_prompt, - validation_prompt=args.validation_prompt, - repo_folder=args.output_dir, - vae_path=args.pretrained_vae_model_name_or_path, - ) - if args.push_to_hub: - upload_folder( - repo_id=repo_id, - folder_path=args.output_dir, - commit_message="End of training", - ignore_patterns=["step_*", "epoch_*"], + images = [] + if args.validation_prompt and args.num_validation_images > 0: + # Final inference + # Load previous pipeline + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline = StableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) + + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference + pipeline = pipeline.to(accelerator.device) + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + images = [ + pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + for _ in range(args.num_validation_images) + ] + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + if args.train_text_encoder_ti: + embedding_handler.save_embeddings( + f"{args.output_dir}/embeddings.safetensors", + ) + save_model_card( + model_id if not args.push_to_hub else repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + train_text_encoder=args.train_text_encoder, + train_text_encoder_ti=args.train_text_encoder_ti, + token_abstraction_dict=train_dataset.token_abstraction_dict, + instance_prompt=args.instance_prompt, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + vae_path=args.pretrained_vae_model_name_or_path, ) + if args.push_to_hub: + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() if __name__ == "__main__": From f729e903487de1e97603b219176070af10a38e97 Mon Sep 17 00:00:00 2001 From: linoytsaban Date: Thu, 14 Dec 2023 11:21:19 -0600 Subject: [PATCH 4/4] fix readme when validation prompt is used --- .../train_dreambooth_lora_sdxl_advanced.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index f1564a116537..ad37363b7d30 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -112,7 +112,7 @@ def save_model_card( repo_folder=None, vae_path=None, ): - img_str = "widget:\n" if images else "" + img_str = "widget:\n" for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) img_str += f""" @@ -121,6 +121,10 @@ def save_model_card( url: "image_{i}.png" """ + if not images: + img_str += f""" + - text: '{instance_prompt}' + """ trigger_str = f"You should use {instance_prompt} to trigger the image generation." diffusers_imports_pivotal = "" @@ -157,8 +161,6 @@ def save_model_card( base_model: {base_model} instance_prompt: {instance_prompt} license: openrail++ -widget: - - text: '{validation_prompt if validation_prompt else instance_prompt}' --- """