diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index a02f8772e28b4..71600054c149b 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -37,6 +37,8 @@ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from huggingface_hub import create_repo, upload_folder from packaging import version +from peft import LoraConfig +from peft.utils import get_peft_model_state_dict from PIL import Image from PIL.ImageOps import exif_transpose from safetensors.torch import save_file @@ -54,10 +56,9 @@ UNet2DConditionModel, ) from diffusers.loaders import LoraLoaderMixin -from diffusers.models.lora import LoRALinearLayer from diffusers.optimization import get_scheduler -from diffusers.training_utils import compute_snr, unet_lora_state_dict -from diffusers.utils import check_min_version, is_wandb_available +from diffusers.training_utils import compute_snr +from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -67,39 +68,6 @@ logger = get_logger(__name__) -# TODO: This function should be removed once training scripts are rewritten in PEFT -def text_encoder_lora_state_dict(text_encoder): - state_dict = {} - - def text_encoder_attn_modules(text_encoder): - from transformers import CLIPTextModel, CLIPTextModelWithProjection - - attn_modules = [] - - if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): - for i, layer in enumerate(text_encoder.text_model.encoder.layers): - name = f"text_model.encoder.layers.{i}.self_attn" - mod = layer.self_attn - attn_modules.append((name, mod)) - - return attn_modules - - for name, module in text_encoder_attn_modules(text_encoder): - for k, v in module.q_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v - - for k, v in module.k_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v - - for k, v in module.v_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v - - for k, v in module.out_proj.lora_linear_layer.state_dict().items(): - state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v - - return state_dict - - def save_model_card( repo_id: str, images=None, @@ -161,8 +129,6 @@ def save_model_card( base_model: {base_model} instance_prompt: {instance_prompt} license: openrail++ -widget: - - text: '{validation_prompt if validation_prompt else instance_prompt}' --- """ @@ -1264,54 +1230,25 @@ def main(args): text_encoder_two.gradient_checkpointing_enable() # now we will add new LoRA weights to the attention layers - # Set correct lora layers - unet_lora_parameters = [] - for attn_processor_name, attn_processor in unet.attn_processors.items(): - # Parse the attention module. - attn_module = unet - for n in attn_processor_name.split(".")[:-1]: - attn_module = getattr(attn_module, n) - - # Set the `lora_layer` attribute of the attention-related matrices. - attn_module.to_q.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank - ) - ) - attn_module.to_k.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank - ) - ) - attn_module.to_v.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank - ) - ) - attn_module.to_out[0].set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_out[0].in_features, - out_features=attn_module.to_out[0].out_features, - rank=args.rank, - ) - ) - - # Accumulate the LoRA params to optimize. - unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) - unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) + unet_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + unet.add_adapter(unet_lora_config) # The text encoder comes from 🤗 transformers, so we cannot directly modify it. # So, instead, we monkey-patch the forward calls of its attention-blocks. if args.train_text_encoder: - # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 - text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder( - text_encoder_one, dtype=torch.float32, rank=args.rank - ) - text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder( - text_encoder_two, dtype=torch.float32, rank=args.rank + text_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], ) + text_encoder_one.add_adapter(text_lora_config) + text_encoder_two.add_adapter(text_lora_config) # if we use textual inversion, we freeze all parameters except for the token embeddings # in text encoder @@ -1335,6 +1272,17 @@ def main(args): else: param.requires_grad = False + # Make sure the trainable params are in float32. + if args.mixed_precision == "fp16": + models = [unet] + if args.train_text_encoder: + models.extend([text_encoder_one, text_encoder_two]) + for model in models: + for param in model.parameters(): + # only upcast trainable parameters (LoRA) into fp32 + if param.requires_grad: + param.data = param.to(torch.float32) + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: @@ -1346,11 +1294,15 @@ def save_model_hook(models, weights, output_dir): for model in models: if isinstance(model, type(accelerator.unwrap_model(unet))): - unet_lora_layers_to_save = unet_lora_state_dict(model) + unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): - text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model) + text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( + get_peft_model_state_dict(model) + ) elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): - text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model) + text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( + get_peft_model_state_dict(model) + ) else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1407,6 +1359,12 @@ def load_model_hook(models, input_dir): args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) + unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters())) + + if args.train_text_encoder: + text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) + text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters())) + # If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti) @@ -1997,13 +1955,17 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): if accelerator.is_main_process: unet = accelerator.unwrap_model(unet) unet = unet.to(torch.float32) - unet_lora_layers = unet_lora_state_dict(unet) + unet_lora_layers = get_peft_model_state_dict(unet) if args.train_text_encoder: text_encoder_one = accelerator.unwrap_model(text_encoder_one) - text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32)) + text_encoder_lora_layers = convert_state_dict_to_diffusers( + get_peft_model_state_dict(text_encoder_one.to(torch.float32)) + ) text_encoder_two = accelerator.unwrap_model(text_encoder_two) - text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32)) + text_encoder_2_lora_layers = convert_state_dict_to_diffusers( + get_peft_model_state_dict(text_encoder_two.to(torch.float32)) + ) else: text_encoder_lora_layers = None text_encoder_2_lora_layers = None