Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[advanced dreambooth lora sdxl training script] load pipeline for inference only if validation prompt is used #6171

Merged
merged 4 commits into from
Dec 14, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def save_model_card(
repo_folder=None,
vae_path=None,
):
img_str = "widget:\n" if images else ""
img_str = "widget:\n"
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"""
Expand All @@ -121,6 +121,10 @@ def save_model_card(
url:
"image_{i}.png"
"""
if not images:
img_str += f"""
- text: '{instance_prompt}'
"""

trigger_str = f"You should use {instance_prompt} to trigger the image generation."
diffusers_imports_pivotal = ""
Expand Down Expand Up @@ -157,8 +161,6 @@ def save_model_card(
base_model: {base_model}
instance_prompt: {instance_prompt}
license: openrail++
widget:
- text: '{validation_prompt if validation_prompt else instance_prompt}'
---
"""

Expand Down Expand Up @@ -2010,43 +2012,42 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
text_encoder_lora_layers=text_encoder_lora_layers,
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
)
images = []
if args.validation_prompt and args.num_validation_images > 0:
# Final inference
# Load previous pipeline
vae = AutoencoderKL.from_pretrained(
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)

# Final inference
# Load previous pipeline
vae = AutoencoderKL.from_pretrained(
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)

# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}

if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type

if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"

scheduler_args["variance_type"] = variance_type
scheduler_args["variance_type"] = variance_type

pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)

# load attention processors
pipeline.load_lora_weights(args.output_dir)
# load attention processors
pipeline.load_lora_weights(args.output_dir)

# run inference
images = []
if args.validation_prompt and args.num_validation_images > 0:
# run inference
pipeline = pipeline.to(accelerator.device)
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
images = [
Expand Down
Loading