-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add PEFT to advanced training script #6294
Merged
Merged
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
6338ad5
Fix ProdigyOPT in SDXL Dreambooth script
apolinario 565416c
style
apolinario 3338ce0
style
apolinario d62076a
Add PEFT to Advanced Training Script
apolinario 965b40a
Merge branch 'main' into add-peft-to-advanced-training-script
9b910bd
style
apolinario 566aaab
style
apolinario a837033
✨ style ✨
apolinario 2bfdcab
change order for logic operation
apolinario b03aa10
add lora alpha
apolinario 38aece9
style
apolinario daa7566
Align PEFT to new format
apolinario f6844d3
Merge branch 'main' into add-peft-to-advanced-training-script
0f9427f
Update train_dreambooth_lora_sdxl_advanced.py
apolinario File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,6 +37,8 @@ | |
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed | ||
from huggingface_hub import create_repo, upload_folder | ||
from packaging import version | ||
from peft import LoraConfig | ||
from peft.utils import get_peft_model_state_dict | ||
from PIL import Image | ||
from PIL.ImageOps import exif_transpose | ||
from safetensors.torch import save_file | ||
|
@@ -54,10 +56,9 @@ | |
UNet2DConditionModel, | ||
) | ||
from diffusers.loaders import LoraLoaderMixin | ||
from diffusers.models.lora import LoRALinearLayer | ||
from diffusers.optimization import get_scheduler | ||
from diffusers.training_utils import compute_snr, unet_lora_state_dict | ||
from diffusers.utils import check_min_version, is_wandb_available | ||
from diffusers.training_utils import compute_snr | ||
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available | ||
from diffusers.utils.import_utils import is_xformers_available | ||
|
||
|
||
|
@@ -67,39 +68,6 @@ | |
logger = get_logger(__name__) | ||
|
||
|
||
# TODO: This function should be removed once training scripts are rewritten in PEFT | ||
def text_encoder_lora_state_dict(text_encoder): | ||
state_dict = {} | ||
|
||
def text_encoder_attn_modules(text_encoder): | ||
from transformers import CLIPTextModel, CLIPTextModelWithProjection | ||
|
||
attn_modules = [] | ||
|
||
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): | ||
for i, layer in enumerate(text_encoder.text_model.encoder.layers): | ||
name = f"text_model.encoder.layers.{i}.self_attn" | ||
mod = layer.self_attn | ||
attn_modules.append((name, mod)) | ||
|
||
return attn_modules | ||
|
||
for name, module in text_encoder_attn_modules(text_encoder): | ||
for k, v in module.q_proj.lora_linear_layer.state_dict().items(): | ||
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v | ||
|
||
for k, v in module.k_proj.lora_linear_layer.state_dict().items(): | ||
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v | ||
|
||
for k, v in module.v_proj.lora_linear_layer.state_dict().items(): | ||
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v | ||
|
||
for k, v in module.out_proj.lora_linear_layer.state_dict().items(): | ||
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v | ||
|
||
return state_dict | ||
|
||
|
||
def save_model_card( | ||
repo_id: str, | ||
images=None, | ||
|
@@ -161,8 +129,6 @@ def save_model_card( | |
base_model: {base_model} | ||
instance_prompt: {instance_prompt} | ||
license: openrail++ | ||
widget: | ||
- text: '{validation_prompt if validation_prompt else instance_prompt}' | ||
--- | ||
""" | ||
|
||
|
@@ -1264,54 +1230,25 @@ def main(args): | |
text_encoder_two.gradient_checkpointing_enable() | ||
|
||
# now we will add new LoRA weights to the attention layers | ||
# Set correct lora layers | ||
unet_lora_parameters = [] | ||
for attn_processor_name, attn_processor in unet.attn_processors.items(): | ||
# Parse the attention module. | ||
attn_module = unet | ||
for n in attn_processor_name.split(".")[:-1]: | ||
attn_module = getattr(attn_module, n) | ||
|
||
# Set the `lora_layer` attribute of the attention-related matrices. | ||
attn_module.to_q.set_lora_layer( | ||
LoRALinearLayer( | ||
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank | ||
) | ||
) | ||
attn_module.to_k.set_lora_layer( | ||
LoRALinearLayer( | ||
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank | ||
) | ||
) | ||
attn_module.to_v.set_lora_layer( | ||
LoRALinearLayer( | ||
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank | ||
) | ||
) | ||
attn_module.to_out[0].set_lora_layer( | ||
LoRALinearLayer( | ||
in_features=attn_module.to_out[0].in_features, | ||
out_features=attn_module.to_out[0].out_features, | ||
rank=args.rank, | ||
) | ||
) | ||
|
||
# Accumulate the LoRA params to optimize. | ||
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters()) | ||
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters()) | ||
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) | ||
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) | ||
unet_lora_config = LoraConfig( | ||
r=args.rank, | ||
lora_alpha=args.rank, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very important! |
||
init_lora_weights="gaussian", | ||
target_modules=["to_k", "to_q", "to_v", "to_out.0"], | ||
) | ||
unet.add_adapter(unet_lora_config) | ||
|
||
# The text encoder comes from 🤗 transformers, so we cannot directly modify it. | ||
# So, instead, we monkey-patch the forward calls of its attention-blocks. | ||
if args.train_text_encoder: | ||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 | ||
text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder( | ||
text_encoder_one, dtype=torch.float32, rank=args.rank | ||
) | ||
text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder( | ||
text_encoder_two, dtype=torch.float32, rank=args.rank | ||
text_lora_config = LoraConfig( | ||
r=args.rank, | ||
lora_alpha=args.rank, | ||
init_lora_weights="gaussian", | ||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], | ||
) | ||
text_encoder_one.add_adapter(text_lora_config) | ||
text_encoder_two.add_adapter(text_lora_config) | ||
|
||
# if we use textual inversion, we freeze all parameters except for the token embeddings | ||
# in text encoder | ||
|
@@ -1335,6 +1272,17 @@ def main(args): | |
else: | ||
param.requires_grad = False | ||
|
||
# Make sure the trainable params are in float32. | ||
if args.mixed_precision == "fp16": | ||
models = [unet] | ||
if args.train_text_encoder: | ||
models.extend([text_encoder_one, text_encoder_two]) | ||
for model in models: | ||
for param in model.parameters(): | ||
# only upcast trainable parameters (LoRA) into fp32 | ||
if param.requires_grad: | ||
param.data = param.to(torch.float32) | ||
Comment on lines
+1275
to
+1284
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another important one! |
||
|
||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format | ||
def save_model_hook(models, weights, output_dir): | ||
if accelerator.is_main_process: | ||
|
@@ -1346,11 +1294,15 @@ def save_model_hook(models, weights, output_dir): | |
|
||
for model in models: | ||
if isinstance(model, type(accelerator.unwrap_model(unet))): | ||
unet_lora_layers_to_save = unet_lora_state_dict(model) | ||
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) | ||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): | ||
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model) | ||
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( | ||
get_peft_model_state_dict(model) | ||
) | ||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): | ||
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model) | ||
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( | ||
get_peft_model_state_dict(model) | ||
) | ||
else: | ||
raise ValueError(f"unexpected save model: {model.__class__}") | ||
|
||
|
@@ -1407,6 +1359,12 @@ def load_model_hook(models, input_dir): | |
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes | ||
) | ||
|
||
unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters())) | ||
|
||
if args.train_text_encoder: | ||
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) | ||
text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters())) | ||
|
||
# If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training | ||
freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti) | ||
|
||
|
@@ -1997,13 +1955,17 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): | |
if accelerator.is_main_process: | ||
unet = accelerator.unwrap_model(unet) | ||
unet = unet.to(torch.float32) | ||
unet_lora_layers = unet_lora_state_dict(unet) | ||
unet_lora_layers = get_peft_model_state_dict(unet) | ||
|
||
if args.train_text_encoder: | ||
text_encoder_one = accelerator.unwrap_model(text_encoder_one) | ||
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32)) | ||
text_encoder_lora_layers = convert_state_dict_to_diffusers( | ||
get_peft_model_state_dict(text_encoder_one.to(torch.float32)) | ||
) | ||
text_encoder_two = accelerator.unwrap_model(text_encoder_two) | ||
text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32)) | ||
text_encoder_2_lora_layers = convert_state_dict_to_diffusers( | ||
get_peft_model_state_dict(text_encoder_two.to(torch.float32)) | ||
) | ||
else: | ||
text_encoder_lora_layers = None | ||
text_encoder_2_lora_layers = None | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will need to
peft
as a dependency in therequirements.txt
.