Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
gesen2egee committed Apr 18, 2024
1 parent 631af43 commit 8dae925
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
6 changes: 3 additions & 3 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder2=text_encoder2 if train_text_encoder2 else None,
)
# most of ZeRO stage uses optimizer partitioning, so we have to prepare optimizer and ds_model at the same time. # pull/1139#issuecomment-1986790007
if args.optimizer_type.lower().endswith("scheduleFree"):
if args.optimizer_type.lower().endswith("schedulefree"):
ds_model, optimizer, train_dataloader = accelerator.prepare(
ds_model, optimizer, train_dataloader
)
Expand All @@ -433,7 +433,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
text_encoder1 = accelerator.prepare(text_encoder1)
if train_text_encoder2:
text_encoder2 = accelerator.prepare(text_encoder2)
if args.optimizer_type.lower().endswith("scheduleFree"):
if args.optimizer_type.lower().endswith("schedulefree"):
optimizer, train_dataloader = accelerator.prepare(optimizer, train_dataloader)
else:
optimizer, train_dataloader, lr_scheduler = accelerator.prepare(optimizer, train_dataloader, lr_scheduler)
Expand Down Expand Up @@ -636,7 +636,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
if not args.optimizer_type.lower().endswith("scheduleFree"):
if not args.optimizer_type.lower().endswith("schedulefree"):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

Expand Down
8 changes: 4 additions & 4 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def train(args):
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet, text_encoder=text_encoder)
else:
ds_model = deepspeed_utils.prepare_deepspeed_model(args, unet=unet)
if args.optimizer_type.lower().endswith("scheduleFree"):
if args.optimizer_type.lower().endswith("schedulefree"):
ds_model, optimizer, train_dataloader = accelerator.prepare(
ds_model, optimizer, train_dataloader
)
Expand All @@ -242,7 +242,7 @@ def train(args):

else:
if train_text_encoder:
if args.optimizer_type.lower().endswith("scheduleFree"):
if args.optimizer_type.lower().endswith("schedulefree"):
unet, text_encoder, optimizer, train_dataloader = accelerator.prepare(
unet, text_encoder, optimizer, train_dataloader
)
Expand All @@ -252,7 +252,7 @@ def train(args):
)
training_models = [unet, text_encoder]
else:
if args.optimizer_type.lower().endswith("scheduleFree"):
if args.optimizer_type.lower().endswith("schedulefree"):
unet, optimizer, train_dataloader = accelerator.prepare(unet, optimizer, train_dataloader)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
Expand Down Expand Up @@ -402,7 +402,7 @@ def train(args):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
if not args.optimizer_type.lower().endswith("scheduleFree"):
if not args.optimizer_type.lower().endswith("schedulefree"):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

Expand Down
6 changes: 3 additions & 3 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def train(self, args):
text_encoder2=text_encoders[1] if train_text_encoder and len(text_encoders) > 1 else None,
network=network,
)
if args.optimizer_type.lower().endswith("scheduleFree"):
if args.optimizer_type.lower().endswith("schedulefree"):
ds_model, optimizer, train_dataloader = accelerator.prepare(
ds_model, optimizer, train_dataloader
)
Expand All @@ -617,7 +617,7 @@ def train(self, args):
else:
pass # if text_encoder is not trained, no need to prepare. and device and dtype are already set

if args.optimizer_type.lower().endswith("scheduleFree"):
if args.optimizer_type.lower().endswith("schedulefree"):
network, optimizer, train_dataloader = accelerator.prepare(
network, optimizer, train_dataloader
)
Expand Down Expand Up @@ -1162,7 +1162,7 @@ def remove_model(old_ckpt_name):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)

optimizer.step()
if not args.optimizer_type.lower().endswith("scheduleFree"):
if not args.optimizer_type.lower().endswith("schedulefree"):
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)

Expand Down

0 comments on commit 8dae925

Please sign in to comment.