Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support new scheduler free optimizer #1315

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 30 additions & 12 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,23 +250,32 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
unet.to(weight_dtype)
text_encoder.to(weight_dtype)

use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
if args.deepspeed:
if args.train_text_encoder:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
else:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
ds_model, optimizer, train_dataloader = accelerator.prepare(ds_model, optimizer, train_dataloader)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
training_models = [ds_model]
else:
# acceleratorがなんかよろしくやってくれるらしい
if args.train_text_encoder:
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
)
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(unet, text_encoder, optimizer, train_dataloader)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)

# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None

# 実験的機能:勾配も含めたfp16学習を行う PyTorchにパッチを当ててfp16でのgrad scaleを有効にする
if args.full_fp16:
Expand Down Expand Up @@ -324,6 +333,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
m.train()

for step, batch in enumerate(train_dataloader):
optimizer_train_if_needed()
current_step.value = global_step
with accelerator.accumulate(*training_models):
with torch.no_grad():
Expand Down Expand Up @@ -354,7 +364,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)

# Predict the noise residual
with accelerator.autocast():
Expand All @@ -368,7 +380,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
loss = loss.mean([1, 2, 3])

if args.min_snr_gamma:
Expand All @@ -380,7 +394,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

loss = loss.mean() # mean over batch dimension
else:
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
)

accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
Expand All @@ -390,9 +406,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
lr_scheduler.step() # if schedule-free optimizer is used, this is a no-op
optimizer.zero_grad(set_to_none=True)

optimizer_eval_if_needed()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down Expand Up @@ -471,7 +489,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

accelerator.end_training()

if is_main_process and (args.save_state or args.save_state_on_train_end):
if is_main_process and (args.save_state or args.save_state_on_train_end):
train_util.save_state_on_train_end(args, accelerator)

del accelerator # この後メモリを使うのでこれは消す
Expand Down
27 changes: 25 additions & 2 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3087,7 +3087,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
)
parser.add_argument("--seed", type=int, default=None, help="random seed for training / 学習時の乱数のseed")
parser.add_argument(
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / grandient checkpointingを有効にする"
"--gradient_checkpointing", action="store_true", help="enable gradient checkpointing / gradient checkpointingを有効にする"
)
parser.add_argument(
"--gradient_accumulation_steps",
Expand Down Expand Up @@ -4088,6 +4088,21 @@ def get_optimizer(args, trainable_params):
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

elif optimizer_type.endswith("schedulefree".lower()):
try:
import schedulefree as sf
except ImportError:
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")
if optimizer_type == "AdamWScheduleFree".lower():
optimizer_class = sf.AdamWScheduleFree
logger.info(f"use AdamWScheduleFree optimizer | {optimizer_kwargs}")
elif optimizer_type == "SGDScheduleFree".lower():
optimizer_class = sf.SGDScheduleFree
logger.info(f"use SGDScheduleFree optimizer | {optimizer_kwargs}")
else:
raise ValueError(f"Unknown optimizer type: {optimizer_type}")
optimizer = optimizer_class(trainable_params, lr=lr, **optimizer_kwargs)

if optimizer is None:
# 任意のoptimizerを使う
optimizer_type = args.optimizer_type # lowerでないやつ(微妙)
Expand Down Expand Up @@ -4116,6 +4131,14 @@ def get_scheduler_fix(args, optimizer: Optimizer, num_processes: int):
"""
Unified API to get any scheduler from its name.
"""
# supports schedule free optimizer
if args.optimizer_type.lower().endswith("schedulefree"):
# return dummy scheduler: it has 'step' method but does nothing
logger.info("use dummy scheduler for schedule free optimizer / schedule free optimizer用のダミースケジューラを使用します")
lr_scheduler = TYPE_TO_SCHEDULER_FUNCTION[SchedulerType.CONSTANT](optimizer)
lr_scheduler.step = lambda: None
return lr_scheduler

name = args.lr_scheduler
num_warmup_steps: Optional[int] = args.lr_warmup_steps
num_training_steps = args.max_train_steps * num_processes # * args.gradient_accumulation_steps
Expand Down Expand Up @@ -4250,7 +4273,7 @@ def load_tokenizer(args: argparse.Namespace):
return tokenizer


def prepare_accelerator(args: argparse.Namespace):
def prepare_accelerator(args: argparse.Namespace) -> Accelerator:
"""
this function also prepares deepspeed plugin
"""
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
accelerate==0.25.0
accelerate==0.30.0
transformers==4.36.2
diffusers[torch]==0.25.0
ftfy==6.1.1
Expand All @@ -9,6 +9,7 @@ pytorch-lightning==1.9.0
bitsandbytes==0.43.0
prodigyopt==1.0
lion-pytorch==0.0.6
schedulefree==1.2.5
tensorboard
safetensors==0.4.2
# gradio==3.16.2
Expand Down
38 changes: 29 additions & 9 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder1.text_model.encoder.layers[-1].requires_grad_(False)
text_encoder1.text_model.final_layer_norm.requires_grad_(False)

use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
if args.deepspeed:
ds_model = deepspeed_utils.prepare_deepspeed_model(
args,
Expand All @@ -415,9 +416,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder2=text_encoder2 if train_text_encoder2 else None,
)
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
ds_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
ds_model, optimizer, train_dataloader, lr_scheduler
)
ds_model, optimizer, train_dataloader = accelerator.prepare(ds_model, optimizer, train_dataloader)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
training_models = [ds_model]

else:
Expand All @@ -428,7 +429,17 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)
optimizer, train_dataloader = accelerator.prepare(optimizer, train_dataloader)

# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None

# TextEncoderの出力をキャッシュするときにはCPUへ移動する
if args.cache_text_encoder_outputs:
Expand Down Expand Up @@ -503,6 +514,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
m.train()

for step, batch in enumerate(train_dataloader):
optimizer_train_if_needed()
current_step.value = global_step
with accelerator.accumulate(*training_models):
if "latents" in batch and batch["latents"] is not None:
Expand Down Expand Up @@ -582,7 +594,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)

noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype

Expand All @@ -600,7 +614,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
or args.masked_loss
):
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
if args.masked_loss:
loss = apply_masked_loss(loss, batch)
loss = loss.mean([1, 2, 3])
Expand All @@ -616,7 +632,9 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

loss = loss.mean() # mean over batch dimension
else:
loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="mean", loss_type=args.loss_type, huber_c=huber_c
)

accelerator.backward(loss)
if accelerator.sync_gradients and args.max_grad_norm != 0.0:
Expand All @@ -626,9 +644,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
lr_scheduler.step()
lr_scheduler.step() # if schedule-free optimizer is used, this is a no-op
optimizer.zero_grad(set_to_none=True)

optimizer_eval_if_needed()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down Expand Up @@ -736,7 +756,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

accelerator.end_training()

if args.save_state or args.save_state_on_train_end:
if args.save_state or args.save_state_on_train_end:
train_util.save_state_on_train_end(args, accelerator)

del accelerator # この後メモリを使うのでこれは消す
Expand Down
25 changes: 22 additions & 3 deletions sdxl_train_control_net_lllite.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import torch
from library.device_utils import init_ipex, clean_memory_on_device

init_ipex()

from torch.nn.parallel import DistributedDataParallel as DDP
Expand Down Expand Up @@ -286,7 +287,18 @@ def train(args):
unet.to(weight_dtype)

# acceleratorがなんかよろしくやってくれるらしい
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
use_schedule_free_optimizer = args.optimizer_type.lower().endswith("schedulefree")
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
if not use_schedule_free_optimizer:
lr_scheduler = accelerator.prepare(lr_scheduler)

# make lambda function for calling optimizer.train() and optimizer.eval() if schedule-free optimizer is used
if use_schedule_free_optimizer:
optimizer_train_if_needed = lambda: optimizer.train()
optimizer_eval_if_needed = lambda: optimizer.eval()
else:
optimizer_train_if_needed = lambda: None
optimizer_eval_if_needed = lambda: None

if args.gradient_checkpointing:
unet.train() # according to TI example in Diffusers, train is required -> これオリジナルのU-Netしたので本当は外せる
Expand Down Expand Up @@ -390,6 +402,7 @@ def remove_model(old_ckpt_name):
current_epoch.value = epoch + 1

for step, batch in enumerate(train_dataloader):
optimizer_train_if_needed()
current_step.value = global_step
with accelerator.accumulate(unet):
with torch.no_grad():
Expand Down Expand Up @@ -439,7 +452,9 @@ def remove_model(old_ckpt_name):

# Sample noise, sample a random timestep for each image, and add noise to the latents,
# with noise offset and/or multires noise if specified
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents)
noise, noisy_latents, timesteps, huber_c = train_util.get_noise_noisy_latents_and_timesteps(
args, noise_scheduler, latents
)

noisy_latents = noisy_latents.to(weight_dtype) # TODO check why noisy_latents is not weight_dtype

Expand All @@ -458,7 +473,9 @@ def remove_model(old_ckpt_name):
else:
target = noise

loss = train_util.conditional_loss(noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c)
loss = train_util.conditional_loss(
noise_pred.float(), target.float(), reduction="none", loss_type=args.loss_type, huber_c=huber_c
)
loss = loss.mean([1, 2, 3])

loss_weights = batch["loss_weights"] # 各sampleごとのweight
Expand All @@ -484,6 +501,8 @@ def remove_model(old_ckpt_name):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

optimizer_eval_if_needed()

# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
Expand Down
Loading