Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Training examples] Follow up of #6306 #6346

Merged
merged 9 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions examples/consistency_distillation/train_lcm_distill_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available


Expand Down Expand Up @@ -113,7 +113,7 @@ def log_validation(vae, args, accelerator, weight_dtype, step, unet=None, is_fin
if unet is None:
raise ValueError("Must provide a `unet` when doing intermediate validation.")
unet = accelerator.unwrap_model(unet)
state_dict = get_peft_model_state_dict(unet)
state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
to_load = state_dict
else:
to_load = args.output_dir
Expand Down Expand Up @@ -819,7 +819,7 @@ def save_model_hook(models, weights, output_dir):
unet_ = accelerator.unwrap_model(unet)
# also save the checkpoints in native `diffusers` format so that it can be easily
# be independently loaded via `load_lora_weights()`.
state_dict = get_peft_model_state_dict(unet_)
state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet_))
StableDiffusionXLPipeline.save_lora_weights(output_dir, unet_lora_layers=state_dict)

for _, model in enumerate(models):
Expand Down Expand Up @@ -1184,7 +1184,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
# solver timestep.

# With the adapters disabled, the `unet` is the regular teacher model.
unet.disable_adapters()
accelerator.unwrap_model(unet).disable_adapters()
with torch.no_grad():
# 1. Get teacher model prediction on noisy_model_input z_{t_{n + k}} and conditional embedding c
cond_teacher_output = unet(
Expand Down Expand Up @@ -1248,7 +1248,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
x_prev = solver.ddim_step(pred_x0, pred_noise, index).to(unet.dtype)

# re-enable unet adapters to turn the `unet` into a student unet.
unet.enable_adapters()
accelerator.unwrap_model(unet).enable_adapters()

# 9. Get target LCM prediction on x_prev, w, c, t_n (timesteps)
# Note that we do not use a separate target network for LCM-LoRA distillation.
Expand Down Expand Up @@ -1332,7 +1332,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = get_peft_model_state_dict(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))
StableDiffusionXLPipeline.save_lora_weights(args.output_dir, unet_lora_layers=unet_lora_state_dict)

if args.push_to_hub:
Expand Down
12 changes: 7 additions & 5 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available


Expand Down Expand Up @@ -853,9 +853,11 @@ def save_model_hook(models, weights, output_dir):

for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = get_peft_model_state_dict(model)
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
else:
raise ValueError(f"unexpected save model: {model.__class__}")

Expand Down Expand Up @@ -1285,11 +1287,11 @@ def compute_text_embeddings(prompt):
unet = accelerator.unwrap_model(unet)
unet = unet.to(torch.float32)

unet_lora_state_dict = get_peft_model_state_dict(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))

if args.train_text_encoder:
text_encoder = accelerator.unwrap_model(text_encoder)
text_encoder_state_dict = get_peft_model_state_dict(text_encoder)
text_encoder_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder))
else:
text_encoder_state_dict = None

Expand Down
8 changes: 5 additions & 3 deletions examples/text_to_image/train_text_to_image_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available


Expand Down Expand Up @@ -809,7 +809,9 @@ def collate_fn(examples):
accelerator.save_state(save_path)

unwrapped_unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(
get_peft_model_state_dict(unwrapped_unet)
)

StableDiffusionPipeline.save_lora_weights(
save_directory=save_path,
Expand Down Expand Up @@ -876,7 +878,7 @@ def collate_fn(examples):
unet = unet.to(torch.float32)

unwrapped_unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = get_peft_model_state_dict(unwrapped_unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))
StableDiffusionPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
Expand Down
18 changes: 11 additions & 7 deletions examples/text_to_image/train_text_to_image_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from diffusers.loaders import LoraLoaderMixin
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available


Expand Down Expand Up @@ -651,11 +651,15 @@ def save_model_hook(models, weights, output_dir):

for model in models:
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = get_peft_model_state_dict(model)
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model)
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
get_peft_model_state_dict(model)
)
else:
raise ValueError(f"unexpected save model: {model.__class__}")

Expand Down Expand Up @@ -1160,14 +1164,14 @@ def compute_time_ids(original_size, crops_coords_top_left):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
unet_lora_state_dict = get_peft_model_state_dict(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet))

if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
text_encoder_two = accelerator.unwrap_model(text_encoder_two)

text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one)
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two)
text_encoder_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_one))
text_encoder_2_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(text_encoder_two))
else:
text_encoder_lora_layers = None
text_encoder_2_lora_layers = None
Expand Down
Loading