Skip to content

Commit

Permalink
[advanced dreambooth lora sdxl training script] load pipeline for inf…
Browse files Browse the repository at this point in the history
…erence only if validation prompt is used (huggingface#6171)

* load pipeline for inference only if validation prompt is used

* move things outside

* load pipeline for inference only if validation prompt is used

* fix readme when validation prompt is used

---------

Co-authored-by: linoytsaban <[email protected]>
Co-authored-by: apolinário <[email protected]>
  • Loading branch information
3 people authored and Jimmy committed Apr 26, 2024
1 parent 2600ec8 commit a4e9719
Showing 1 changed file with 34 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def save_model_card(
repo_folder=None,
vae_path=None,
):
img_str = "widget:\n" if images else ""
img_str = "widget:\n"
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"""
Expand All @@ -121,6 +121,10 @@ def save_model_card(
url:
"image_{i}.png"
"""
if not images:
img_str += f"""
- text: '{instance_prompt}'
"""

trigger_str = f"You should use {instance_prompt} to trigger the image generation."
diffusers_imports_pivotal = ""
Expand Down Expand Up @@ -157,8 +161,6 @@ def save_model_card(
base_model: {base_model}
instance_prompt: {instance_prompt}
license: openrail++
widget:
- text: '{validation_prompt if validation_prompt else instance_prompt}'
---
"""

Expand Down Expand Up @@ -2010,43 +2012,42 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
text_encoder_lora_layers=text_encoder_lora_layers,
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
)
images = []
if args.validation_prompt and args.num_validation_images > 0:
# Final inference
# Load previous pipeline
vae = AutoencoderKL.from_pretrained(
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)

# Final inference
# Load previous pipeline
vae = AutoencoderKL.from_pretrained(
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)

# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}

if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type

if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"

scheduler_args["variance_type"] = variance_type
scheduler_args["variance_type"] = variance_type

pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)

# load attention processors
pipeline.load_lora_weights(args.output_dir)
# load attention processors
pipeline.load_lora_weights(args.output_dir)

# run inference
images = []
if args.validation_prompt and args.num_validation_images > 0:
# run inference
pipeline = pipeline.to(accelerator.device)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
images = [
Expand Down

0 comments on commit a4e9719

Please sign in to comment.