-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add PEFT to advanced training script #6294
Changes from 11 commits
6338ad5
565416c
3338ce0
d62076a
965b40a
9b910bd
566aaab
a837033
2bfdcab
b03aa10
38aece9
daa7566
f6844d3
0f9427f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,6 +37,8 @@ | |
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed | ||
from huggingface_hub import create_repo, upload_folder | ||
from packaging import version | ||
from peft import LoraConfig | ||
from peft.utils import get_peft_model_state_dict | ||
from PIL import Image | ||
from PIL.ImageOps import exif_transpose | ||
from safetensors.torch import save_file | ||
|
@@ -54,9 +56,8 @@ | |
UNet2DConditionModel, | ||
) | ||
from diffusers.loaders import LoraLoaderMixin | ||
from diffusers.models.lora import LoRALinearLayer | ||
from diffusers.optimization import get_scheduler | ||
from diffusers.training_utils import compute_snr, unet_lora_state_dict | ||
from diffusers.training_utils import compute_snr | ||
from diffusers.utils import check_min_version, is_wandb_available | ||
from diffusers.utils.import_utils import is_xformers_available | ||
|
||
|
@@ -67,39 +68,6 @@ | |
logger = get_logger(__name__) | ||
|
||
|
||
# TODO: This function should be removed once training scripts are rewritten in PEFT | ||
def text_encoder_lora_state_dict(text_encoder): | ||
state_dict = {} | ||
|
||
def text_encoder_attn_modules(text_encoder): | ||
from transformers import CLIPTextModel, CLIPTextModelWithProjection | ||
|
||
attn_modules = [] | ||
|
||
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)): | ||
for i, layer in enumerate(text_encoder.text_model.encoder.layers): | ||
name = f"text_model.encoder.layers.{i}.self_attn" | ||
mod = layer.self_attn | ||
attn_modules.append((name, mod)) | ||
|
||
return attn_modules | ||
|
||
for name, module in text_encoder_attn_modules(text_encoder): | ||
for k, v in module.q_proj.lora_linear_layer.state_dict().items(): | ||
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v | ||
|
||
for k, v in module.k_proj.lora_linear_layer.state_dict().items(): | ||
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v | ||
|
||
for k, v in module.v_proj.lora_linear_layer.state_dict().items(): | ||
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v | ||
|
||
for k, v in module.out_proj.lora_linear_layer.state_dict().items(): | ||
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v | ||
|
||
return state_dict | ||
|
||
|
||
def save_model_card( | ||
repo_id: str, | ||
images=None, | ||
|
@@ -1262,54 +1230,25 @@ def main(args): | |
text_encoder_two.gradient_checkpointing_enable() | ||
|
||
# now we will add new LoRA weights to the attention layers | ||
# Set correct lora layers | ||
unet_lora_parameters = [] | ||
for attn_processor_name, attn_processor in unet.attn_processors.items(): | ||
# Parse the attention module. | ||
attn_module = unet | ||
for n in attn_processor_name.split(".")[:-1]: | ||
attn_module = getattr(attn_module, n) | ||
|
||
# Set the `lora_layer` attribute of the attention-related matrices. | ||
attn_module.to_q.set_lora_layer( | ||
LoRALinearLayer( | ||
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank | ||
) | ||
) | ||
attn_module.to_k.set_lora_layer( | ||
LoRALinearLayer( | ||
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank | ||
) | ||
) | ||
attn_module.to_v.set_lora_layer( | ||
LoRALinearLayer( | ||
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank | ||
) | ||
) | ||
attn_module.to_out[0].set_lora_layer( | ||
LoRALinearLayer( | ||
in_features=attn_module.to_out[0].in_features, | ||
out_features=attn_module.to_out[0].out_features, | ||
rank=args.rank, | ||
) | ||
) | ||
|
||
# Accumulate the LoRA params to optimize. | ||
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters()) | ||
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters()) | ||
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) | ||
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) | ||
unet_lora_config = LoraConfig( | ||
r=args.rank, | ||
lora_alpha=args.rank, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very important! |
||
init_lora_weights="gaussian", | ||
target_modules=["to_k", "to_q", "to_v", "to_out.0"], | ||
) | ||
unet.add_adapter(unet_lora_config) | ||
|
||
# The text encoder comes from 🤗 transformers, so we cannot directly modify it. | ||
# So, instead, we monkey-patch the forward calls of its attention-blocks. | ||
if args.train_text_encoder: | ||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 | ||
text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder( | ||
text_encoder_one, dtype=torch.float32, rank=args.rank | ||
) | ||
text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder( | ||
text_encoder_two, dtype=torch.float32, rank=args.rank | ||
text_lora_config = LoraConfig( | ||
r=args.rank, | ||
lora_alpha=args.rank, | ||
init_lora_weights="gaussian", | ||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], | ||
) | ||
text_encoder_one.add_adapter(text_lora_config) | ||
text_encoder_two.add_adapter(text_lora_config) | ||
|
||
# if we use textual inversion, we freeze all parameters except for the token embeddings | ||
# in text encoder | ||
|
@@ -1333,6 +1272,17 @@ def main(args): | |
else: | ||
param.requires_grad = False | ||
|
||
# Make sure the trainable params are in float32. | ||
if args.mixed_precision == "fp16": | ||
models = [unet] | ||
if args.train_text_encoder: | ||
models.extend([text_encoder_one, text_encoder_two]) | ||
for model in models: | ||
for param in model.parameters(): | ||
# only upcast trainable parameters (LoRA) into fp32 | ||
if param.requires_grad: | ||
param.data = param.to(torch.float32) | ||
Comment on lines
+1275
to
+1284
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another important one! |
||
|
||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format | ||
def save_model_hook(models, weights, output_dir): | ||
if accelerator.is_main_process: | ||
|
@@ -1344,11 +1294,11 @@ def save_model_hook(models, weights, output_dir): | |
|
||
for model in models: | ||
if isinstance(model, type(accelerator.unwrap_model(unet))): | ||
unet_lora_layers_to_save = unet_lora_state_dict(model) | ||
unet_lora_layers_to_save = get_peft_model_state_dict(model) | ||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))): | ||
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model) | ||
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) | ||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))): | ||
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model) | ||
text_encoder_two_lora_layers_to_save = get_peft_model_state_dict(model) | ||
else: | ||
raise ValueError(f"unexpected save model: {model.__class__}") | ||
|
||
|
@@ -1405,6 +1355,12 @@ def load_model_hook(models, input_dir): | |
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes | ||
) | ||
|
||
unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters())) | ||
|
||
if args.train_text_encoder: | ||
text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) | ||
text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters())) | ||
|
||
# If neither --train_text_encoder nor --train_text_encoder_ti, text_encoders remain frozen during training | ||
freeze_text_encoder = not (args.train_text_encoder or args.train_text_encoder_ti) | ||
|
||
|
@@ -1995,13 +1951,13 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): | |
if accelerator.is_main_process: | ||
unet = accelerator.unwrap_model(unet) | ||
unet = unet.to(torch.float32) | ||
unet_lora_layers = unet_lora_state_dict(unet) | ||
unet_lora_layers = get_peft_model_state_dict(unet) | ||
|
||
if args.train_text_encoder: | ||
text_encoder_one = accelerator.unwrap_model(text_encoder_one) | ||
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32)) | ||
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you train extra parameters that you keep unfrozen for text encoder, you need to add them in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got it! But I'm not training extra parameters with that operation. With However, with This is where is taking place - and was working prior to adding PEFT elsewhere: https://github.com/huggingface/diffusers/pull/6294/files#diff-24abe8b0339a563b68e03c979ee9e498ab7c49f3fd749ffb784156f4e2d54d90R1249 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, only thing that is outside of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, which makes sense because it is not training an adapter per se There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @apolinario for explaining ! makes sense |
||
text_encoder_two = accelerator.unwrap_model(text_encoder_two) | ||
text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32)) | ||
text_encoder_2_lora_layers = get_peft_model_state_dict(text_encoder_two.to(torch.float32)) | ||
else: | ||
text_encoder_lora_layers = None | ||
text_encoder_2_lora_layers = None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will need to
peft
as a dependency in therequirements.txt
.