Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train text encoder #84

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
224 changes: 193 additions & 31 deletions train_scripts/train_pixart_lora_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,19 @@

logger = get_logger(__name__, log_level="INFO")


def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
if not isinstance(model, list):
model = [model]
for m in model:
for param in m.parameters():
# only upcast trainable parameters into fp32
if param.requires_grad:
param.data = param.to(dtype)

def get_trainable_parameters(optimizer: torch.optim.Optimizer) -> int:
return sum(sum(p.numel() for p in param_group['params'] if p.requires_grad) for param_group in optimizer.param_groups)

def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None, train_text_encoder=False):
img_str = ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
Expand All @@ -77,6 +88,7 @@ def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str,
model_card = f"""
# LoRA text2image fine-tuning - {repo_id}
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
LoRA for text encoder was enabled: {train_text_encoder} \n
{img_str}
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
Expand Down Expand Up @@ -272,6 +284,28 @@ def parse_args():
help="Whether or not to use RS Lora. For more information, see"
" https://huggingface.co/docs/peft/package_reference/lora#peft.LoraConfig.use_rslora"
)
parser.add_argument(
"--train_text_encoder",
action="store_true",
help="Whether or not to also train the text encoder")
parser.add_argument(
"--text_encoder_lora_rank",
type=int,
default=4,
help=("The dimension of the LoRA update matrices for the text encoder training."),
)
parser.add_argument(
"--text_encoder_learning_rate",
type=float,
default=None,
help="learning rate for text encoder trainer, default is same as learning_rate",
)
parser.add_argument(
"--text_encoder_stop_at_percentage_steps",
type=float,
default=1.0,
help="the percentage of the total training steps at which the training of the text encoder should be halted. 1.0 means train for all steps.",
)
parser.add_argument(
"--allow_tf32",
action="store_true",
Expand Down Expand Up @@ -502,6 +536,9 @@ def main():
for param in transformer.parameters():
param.requires_grad_(False)

# Move transformer, vae and text_encoder to device and cast to weight_dtype
transformer.to(accelerator.device)

lora_config = LoraConfig(
r=args.rank,
init_lora_weights="gaussian",
Expand All @@ -523,25 +560,49 @@ def main():
use_rslora=args.use_rslora
)

# Move transformer, vae and text_encoder to device and cast to weight_dtype
transformer.to(accelerator.device)

def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
if not isinstance(model, list):
model = [model]
for m in model:
for param in m.parameters():
# only upcast trainable parameters into fp32
if param.requires_grad:
param.data = param.to(dtype)

transformer = get_peft_model(transformer, lora_config)
if args.mixed_precision == "fp16":
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(transformer, dtype=torch.float32)

accelerator.print("Transformer:")
transformer.print_trainable_parameters()

if args.train_text_encoder:
if not 0 < args.text_encoder_stop_at_percentage_steps <= 1:
args.text_encoder_stop_at_percentage_steps = 1

if args.gradient_checkpointing:
# this needs to be done before adding the LoRA layers
# otherwise, enabling gradient checkpointing for the text encoder will generate the warning:
# "UserWarning: None of the inputs have requires_grad=True. Gradients will be None"
# more info:
# https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L2235
text_encoder.gradient_checkpointing_enable()

# prepare the text_encoder for LoRA
lora_config_for_text_encoder = LoraConfig(
init_lora_weights="gaussian",
r=args.text_encoder_lora_rank,
# lora_alpha=args. ...,
# the dropout probability of the LoRA layers
lora_dropout=0.01,
target_modules=["k","q","v","o"]
)

text_encoder=get_peft_model(text_encoder, lora_config_for_text_encoder)
if args.mixed_precision == "fp16":
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(text_encoder, dtype=torch.float32)

accelerator.print("\033[91m")
accelerator.print("IMPORTANT !! Training the Text Encoder in fp16 might lead to NaNs in step_loss, if it does please use fp32 or bf16 for training the Text Encoder.")
accelerator.print(" more info: \n https://github.com/huggingface/transformers/issues/4586#issuecomment-639704855 \n https://github.com/huggingface/transformers/issues/17978#issuecomment-1173761651")
accelerator.print("\033[0m")

accelerator.print("Text Encoder:")
text_encoder.print_trainable_parameters()

# 10. Handle saving and loading of checkpoints
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
Expand All @@ -550,9 +611,19 @@ def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
transformer_ = accelerator.unwrap_model(transformer)
lora_state_dict = get_peft_model_state_dict(transformer_, adapter_name="default")
StableDiffusionPipeline.save_lora_weights(os.path.join(output_dir, "transformer_lora"), lora_state_dict)

text_encoder_to_save = None
if args.train_text_encoder:
text_encoder_ = accelerator.unwrap_model(text_encoder)
text_encoder_to_save = get_peft_model_state_dict(text_encoder_)

StableDiffusionPipeline.save_lora_weights(os.path.join(output_dir, "transformer_lora_weights"), lora_state_dict,
text_encoder_lora_layers=text_encoder_to_save)

# save weights in peft format to be able to load them back
transformer_.save_pretrained(output_dir)
transformer_.save_pretrained(os.path.join(output_dir, "transformer"))
if args.train_text_encoder:
text_encoder_.save_pretrained(os.path.join(output_dir, "text_encoder"))

for _, model in enumerate(models):
# make sure to pop weight so that corresponding model is not saved again
Expand All @@ -563,6 +634,14 @@ def load_model_hook(models, input_dir):
transformer_ = accelerator.unwrap_model(transformer)
transformer_.load_adapter(input_dir, "default", is_trainable=True)

# raulc0399: is seems that this hook is not implemented!!
# check https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora_sdxl.py#L701
# todo: load and use transformer and text_encoder + cast training params
# if args.mixed_precision == "fp16":
# # only upcast trainable parameters (LoRA) into fp32
# cast_training_params(transformer_, dtype=torch.float32)
accelerator.print("load_model_hook NOT IMPLEMENTED!!!")

for _ in range(len(models)):
# pop models so that they are not loaded again
models.pop()
Expand All @@ -585,7 +664,24 @@ def load_model_hook(models, input_dir):
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")

lora_layers = filter(lambda p: p.requires_grad, transformer.parameters())
# transformer params to optimize
params_to_optimize = list(filter(lambda p: p.requires_grad, transformer.parameters()))
params_to_clip = params_to_optimize

if args.train_text_encoder:
text_encoder_params_to_optimize = list(filter(lambda p: p.requires_grad, text_encoder.parameters()))
params_to_clip = params_to_optimize + text_encoder_params_to_optimize

# transformer and text encoder have the same learning rate
if args.text_encoder_learning_rate is None:
params_to_optimize = (
params_to_optimize + text_encoder_params_to_optimize
)
else:
params_to_optimize = [
{"params": params_to_optimize, "lr": args.learning_rate},
{"params": text_encoder_params_to_optimize, "lr": args.text_encoder_learning_rate},
]

# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
Expand All @@ -612,13 +708,15 @@ def load_model_hook(models, input_dir):
optimizer_cls = torch.optim.AdamW

optimizer = optimizer_cls(
lora_layers,
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)

accelerator.print(f"Total of trainable parameters: {get_trainable_parameters(optimizer):,}")

# Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).

Expand Down Expand Up @@ -744,6 +842,9 @@ def collate_fn(examples):
transformer, optimizer, train_dataloader, lr_scheduler = \
accelerator.prepare(transformer, optimizer, train_dataloader, lr_scheduler)

if args.train_text_encoder:
text_encoder = accelerator.prepare(text_encoder)

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
Expand All @@ -766,6 +867,13 @@ def collate_fn(examples):
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")

if args.train_text_encoder:
training_rate_text_encoder = args.text_encoder_learning_rate if args.text_encoder_learning_rate is not None else f"{args.learning_rate} same as transformer"
logger.info(f" Training text encoder with rank {args.text_encoder_lora_rank}, learing rate {training_rate_text_encoder}")
if args.text_encoder_stop_at_percentage_steps < 1:
logger.info(f" Stop training text encoder at {args.text_encoder_stop_at_percentage_steps * 100}% of total training steps")

global_step = 0
first_epoch = 0

Expand Down Expand Up @@ -804,11 +912,16 @@ def collate_fn(examples):
disable=not accelerator.is_local_main_process,
)

models_for_accumulate = [transformer, text_encoder] if args.train_text_encoder else transformer

for epoch in range(first_epoch, args.num_train_epochs):
transformer.train()
if args.train_text_encoder:
text_encoder.train()

train_loss = 0.0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(transformer):
with accelerator.accumulate(models_for_accumulate):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
Expand Down Expand Up @@ -885,7 +998,6 @@ def collate_fn(examples):
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = lora_layers
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
Expand Down Expand Up @@ -923,14 +1035,21 @@ def collate_fn(examples):
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)

unwrapped_transformer = accelerator.unwrap_model(transformer, keep_fp32_wrapper=False)
transformer_lora_state_dict = get_peft_model_state_dict(unwrapped_transformer)
# raulc0399: not needed since they are already saved in save_state hook
# unwrapped_transformer = accelerator.unwrap_model(transformer, keep_fp32_wrapper=False)
# transformer_lora_state_dict = get_peft_model_state_dict(unwrapped_transformer)

StableDiffusionPipeline.save_lora_weights(
save_directory=save_path,
unet_lora_layers=transformer_lora_state_dict,
safe_serialization=True,
)
# text_encoder_to_save = None
# if args.train_text_encoder:
# text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)
# text_encoder_to_save = get_peft_model_state_dict(text_encoder_)

# StableDiffusionPipeline.save_lora_weights(
# save_directory=os.path.join(save_path, "transformer_lora_weights_checkpoint"),
# unet_lora_layers=transformer_lora_state_dict,
# text_encoder_lora_layers=text_encoder_to_save,
# safe_serialization=True,
# )

logger.info(f"Saved state to {save_path}")

Expand All @@ -940,17 +1059,39 @@ def collate_fn(examples):
if global_step >= args.max_train_steps:
break

if args.train_text_encoder and args.text_encoder_stop_at_percentage_steps < 1 and global_step >= args.max_train_steps * args.text_encoder_stop_at_percentage_steps:
accelerator.print("\033[91mFreezing text encoder...")

accelerator.print(f"Number of trainable parameters before freeze: {get_trainable_parameters(optimizer):,}")

text_encoder.zero_grad()
text_encoder.requires_grad_(False)
params_to_clip = list(filter(lambda p: p.requires_grad, transformer.parameters()))
models_for_accumulate = transformer

args.train_text_encoder = False

text_encoder = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)
text_encoder.save_pretrained(os.path.join(args.output_dir, "text_encoder"))

accelerator.print(f"Number of trainable parameters after freeze: {get_trainable_parameters(optimizer):,}")
accelerator.print("\033[0m")

if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)

text_encoder_for_generation = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) if args.train_text_encoder else text_encoder

# create pipeline
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
transformer=accelerator.unwrap_model(transformer, keep_fp32_wrapper=False),
text_encoder=text_encoder, vae=vae,
text_encoder=text_encoder_for_generation,
vae=vae,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
Expand Down Expand Up @@ -986,9 +1127,18 @@ def collate_fn(examples):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
transformer = accelerator.unwrap_model(transformer, keep_fp32_wrapper=False)
transformer.save_pretrained(args.output_dir)
transformer.save_pretrained(os.path.join(args.output_dir, "transformer"))
lora_state_dict = get_peft_model_state_dict(transformer)
StableDiffusionPipeline.save_lora_weights(os.path.join(args.output_dir, "transformer_lora"), lora_state_dict)

if args.train_text_encoder:
text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)
text_encoder_.save_pretrained(os.path.join(args.output_dir, "text_encoder"))
text_encoder_to_save = get_peft_model_state_dict(text_encoder_)
else:
text_encoder_to_save = None

StableDiffusionPipeline.save_lora_weights(os.path.join(args.output_dir, "transformer_lora_weights"), lora_state_dict,
text_encoder_lora_layers=text_encoder_to_save)

if args.push_to_hub:
save_model_card(
Expand All @@ -997,6 +1147,7 @@ def collate_fn(examples):
base_model=args.pretrained_model_name_or_path,
dataset_name=args.dataset_name,
repo_folder=args.output_dir,
train_text_encoder=args.train_text_encoder,
)
upload_folder(
repo_id=repo_id,
Expand All @@ -1011,14 +1162,25 @@ def collate_fn(examples):
args.pretrained_model_name_or_path, subfolder='transformer', torch_dtype=weight_dtype
)
# load lora weight
transformer = PeftModel.from_pretrained(transformer, args.output_dir)
transformer = PeftModel.from_pretrained(transformer, os.path.join(args.output_dir, "transformer"))

if args.train_text_encoder:
# Load previous text_encoder
text_encoder = T5EncoderModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder='text_encoder', torch_dtype=weight_dtype
)
text_encoder = PeftModel.from_pretrained(text_encoder, os.path.join(args.output_dir, "text_encoder"))

# Load previous pipeline
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, transformer=transformer, text_encoder=text_encoder, vae=vae,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)

if args.train_text_encoder:
del text_encoder

del transformer
torch.cuda.empty_cache()

Expand Down