Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[examples/advanced_diffusion_training] bug fixes and improvements for LoRA Dreambooth SDXL advanced training script #5935

Merged
merged 14 commits into from
Dec 1, 2023
Merged
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
UNet2DConditionModel,
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr, unet_lora_state_dict
from diffusers.utils import check_min_version, is_wandb_available
Expand All @@ -67,11 +67,45 @@
logger = get_logger(__name__)


# TODO: This function should be removed once training scripts are rewritten in PEFT
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}

def text_encoder_attn_modules(text_encoder):
from transformers import CLIPTextModel, CLIPTextModelWithProjection

attn_modules = []

if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))

return attn_modules

for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v

for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v

for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v

for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v

return state_dict


def save_model_card(
repo_id: str,
images=None,
base_model=str,
train_text_encoder=False,
train_text_encoder_ti=False,
instance_prompt=str,
validation_prompt=str,
repo_folder=None,
Expand All @@ -83,7 +117,7 @@ def save_model_card(
img_str += f"""
- text: '{validation_prompt if validation_prompt else ' ' }'
output:
url: >-
url:
"image_{i}.png"
"""

Expand All @@ -96,9 +130,7 @@ def save_model_card(
- diffusers
- lora
- template:sd-lora
widget:
{img_str}
---
base_model: {base_model}
instance_prompt: {instance_prompt}
license: openrail++
Expand All @@ -112,9 +144,14 @@ def save_model_card(

## Model description

These are {repo_id} LoRA adaption weights for {base_model}.
### These are {repo_id} LoRA adaption weights for {base_model}.

The weights were trained using [DreamBooth](https://dreambooth.github.io/).

LoRA for the text encoder was enabled: {train_text_encoder}.

Pivotal tuning was enabled: {train_text_encoder_ti}.

Special VAE used for training: {vae_path}.

## Trigger words
Expand Down Expand Up @@ -455,7 +492,7 @@ def parse_args(input_args=None):
parser.add_argument(
"--train_text_encoder_frac",
type=float,
default=0.5,
default=1.0,
help=("The percentage of epochs to perform text encoder tuning"),
)

Expand Down Expand Up @@ -488,7 +525,7 @@ def parse_args(input_args=None):
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
parser.add_argument(
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
"--adam_weight_decay_text_encoder", type=float, default=None, help="Weight decay to use for text_encoder"
)

parser.add_argument(
Expand Down Expand Up @@ -679,12 +716,19 @@ def initialize_new_tokens(self, inserting_toks: List[str]):
def save_embeddings(self, file_path: str):
assert self.train_ids is not None, "Initialize new tokens before saving embeddings."
tensors = {}
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14
idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
for idx, text_encoder in enumerate(self.text_encoders):
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
self.tokenizers[0]
), "Tokenizers should be the same."
new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids]
tensors[f"text_encoders_{idx}"] = new_token_embeddings

# New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for
# text_encoder 1) to keep compatible with the ecosystem.
# Note: When loading with diffusers, any name can work - simply specify in inference
tensors[idx_to_text_encoder_name[idx]] = new_token_embeddings
# tensors[f"text_encoders_{idx}"] = new_token_embeddings

save_file(tensors, file_path)

Expand All @@ -696,19 +740,6 @@ def dtype(self):
def device(self):
return self.text_encoders[0].device

# def _load_embeddings(self, loaded_embeddings, tokenizer, text_encoder):
# # Assuming new tokens are of the format <s_i>
# self.inserting_toks = [f"<s{i}>" for i in range(loaded_embeddings.shape[0])]
# special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
# tokenizer.add_special_tokens(special_tokens_dict)
# text_encoder.resize_token_embeddings(len(tokenizer))
#
# self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
# assert self.train_ids is not None, "New tokens could not be converted to IDs."
# text_encoder.text_model.embeddings.token_embedding.weight.data[
# self.train_ids
# ] = loaded_embeddings.to(device=self.device).to(dtype=self.dtype)

@torch.no_grad()
def retract_embeddings(self):
for idx, text_encoder in enumerate(self.text_encoders):
Expand All @@ -730,15 +761,6 @@ def retract_embeddings(self):
new_embeddings = new_embeddings * (off_ratio**0.1)
text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings

# def load_embeddings(self, file_path: str):
# with safe_open(file_path, framework="pt", device=self.device.type) as f:
# for idx in range(len(self.text_encoders)):
# text_encoder = self.text_encoders[idx]
# tokenizer = self.tokenizers[idx]
#
# loaded_embeddings = f.get_tensor(f"text_encoders_{idx}")
# self._load_embeddings(loaded_embeddings, tokenizer, text_encoder)


class DreamBoothDataset(Dataset):
"""
Expand Down Expand Up @@ -1216,13 +1238,17 @@ def main(args):
text_lora_parameters_one = []
for name, param in text_encoder_one.named_parameters():
if "token_embedding" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param = param.to(dtype=torch.float32)
param.requires_grad = True
text_lora_parameters_one.append(param)
else:
param.requires_grad = False
text_lora_parameters_two = []
for name, param in text_encoder_two.named_parameters():
if "token_embedding" in name:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
param = param.to(dtype=torch.float32)
param.requires_grad = True
text_lora_parameters_two.append(param)
else:
Expand Down Expand Up @@ -1309,12 +1335,16 @@ def load_model_hook(models, input_dir):
# different learning rate for text encoder and unet
text_lora_parameters_one_with_lr = {
"params": text_lora_parameters_one,
"weight_decay": args.adam_weight_decay_text_encoder,
"weight_decay": args.adam_weight_decay_text_encoder
if args.adam_weight_decay_text_encoder
else args.adam_weight_decay,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
}
text_lora_parameters_two_with_lr = {
"params": text_lora_parameters_two,
"weight_decay": args.adam_weight_decay_text_encoder,
"weight_decay": args.adam_weight_decay_text_encoder
if args.adam_weight_decay_text_encoder
else args.adam_weight_decay,
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
}
params_to_optimize = [
Expand Down Expand Up @@ -1494,6 +1524,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)

if args.train_text_encoder_ti and args.validation_prompt:
# replace instances of --token_abstraction in validation prompt with the new tokens: "<si><si+1>" etc.
for token_abs, token_replacement in train_dataset.token_abstraction_dict.items():
args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement))
print("validation prompt:", args.validation_prompt)

# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
Expand Down Expand Up @@ -1593,27 +1629,10 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
if epoch == num_train_epochs_text_encoder:
print("PIVOT HALFWAY", epoch)
# stopping optimization of text_encoder params
params_to_optimize = params_to_optimize[:1]
# reinitializing the optimizer to optimize only on unet params
if args.optimizer.lower() == "prodigy":
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
decouple=args.prodigy_decouple,
use_bias_correction=args.prodigy_use_bias_correction,
safeguard_warmup=args.prodigy_safeguard_warmup,
)
else: # AdamW or 8-bit-AdamW
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
# re setting the optimizer to optimize only on unet params
optimizer.param_groups[1]["lr"] = 0.0
optimizer.param_groups[2]["lr"] = 0.0

else:
# still optimizng the text encoder
text_encoder_one.train()
Expand All @@ -1628,7 +1647,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
prompts = batch["prompts"]
print(prompts)
# print(prompts)
# encode batch prompts when custom prompts are provided for each image -
if train_dataset.custom_instance_prompts:
if freeze_text_encoder:
Expand Down Expand Up @@ -1801,7 +1820,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
f" {args.validation_prompt}."
)
# create pipeline
if not args.train_text_encoder:
if freeze_text_encoder:
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
Expand Down Expand Up @@ -1948,6 +1967,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
images=images,
base_model=args.pretrained_model_name_or_path,
train_text_encoder=args.train_text_encoder,
train_text_encoder_ti=args.train_text_encoder_ti,
instance_prompt=args.instance_prompt,
validation_prompt=args.validation_prompt,
repo_folder=args.output_dir,
Expand Down
Loading