Skip to content

Commit

Permalink
fix: how print training resume logs. (#5117)
Browse files Browse the repository at this point in the history
* fix: how print training resume logs.

* propagate changes to text-to-image scripts.

* propagate changes to instructpix2pix.

* propagate changes to dreambooth

* propagate changes to custom diffusion and instructpix2pix

* propagate changes to kandinsky

* propagate changes to textual inv.

* debug

* fix: checkpointing.

* debug

* debug

* debug

* back to the square

* debug

* debug

* change condition order.

* debug

* debug

* debug

* debug

* revert to original

* clean

---------

Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
sayakpaul and patrickvonplaten authored Oct 2, 2023
1 parent cd1b8d7 commit d56825e
Show file tree
Hide file tree
Showing 14 changed files with 180 additions and 164 deletions.
24 changes: 12 additions & 12 deletions examples/custom_diffusion/train_custom_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,30 +1075,30 @@ def main(args):
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])

resume_global_step = global_step * args.gradient_accumulation_steps
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)

# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
else:
initial_global_step = 0

progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)

for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
if args.modifier_token is not None:
text_encoder.train()
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue

with accelerator.accumulate(unet), accelerator.accumulate(text_encoder):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
Expand Down
24 changes: 12 additions & 12 deletions examples/dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,30 +1178,30 @@ def compute_text_embeddings(prompt):
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])

resume_global_step = global_step * args.gradient_accumulation_steps
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)

# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
else:
initial_global_step = 0

progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)

for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
if args.train_text_encoder:
text_encoder.train()
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue

with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)

Expand Down
24 changes: 12 additions & 12 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,30 +1108,30 @@ def compute_text_embeddings(prompt):
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])

resume_global_step = global_step * args.gradient_accumulation_steps
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)

# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
else:
initial_global_step = 0

progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)

for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
if args.train_text_encoder:
text_encoder.train()
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue

with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)

Expand Down
23 changes: 12 additions & 11 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,31 +1048,32 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])

resume_global_step = global_step * args.gradient_accumulation_steps
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)

# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
else:
initial_global_step = 0

progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)

for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
if args.train_text_encoder:
text_encoder_one.train()
text_encoder_two.train()
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue

with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)

Expand Down
28 changes: 15 additions & 13 deletions examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,9 @@ def preprocess_images(examples):
text_encoder_1.requires_grad_(False)
text_encoder_2.requires_grad_(False)

# Set UNet to trainable.
unet.train()

# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(text_encoders, tokenizers, prompt):
prompt_embeds_list = []
Expand Down Expand Up @@ -933,29 +936,28 @@ def collate_fn(examples):
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])

resume_global_step = global_step * args.gradient_accumulation_steps
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)

# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
else:
initial_global_step = 0

progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)

for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue

with accelerator.accumulate(unet):
# We want to learn the denoising process w.r.t the edited images which
# are conditioned on the original image (which was edited) and the edit instruction.
Expand Down
26 changes: 15 additions & 11 deletions examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,9 @@ def deepspeed_zero_init_disabled_context_manager():
vae.requires_grad_(False)
image_encoder.requires_grad_(False)

# Set unet to trainable.
unet.train()

# Create EMA for the unet.
if args.use_ema:
ema_unet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="unet")
Expand Down Expand Up @@ -727,27 +730,28 @@ def collate_fn(examples):
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])

resume_global_step = global_step * args.gradient_accumulation_steps
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
else:
initial_global_step = 0

progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)

progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue

with accelerator.accumulate(unet):
# Convert images to latent space
images = batch["pixel_values"].to(weight_dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -579,29 +579,29 @@ def collate_fn(examples):
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])

resume_global_step = global_step * args.gradient_accumulation_steps
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
else:
initial_global_step = 0

# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)

for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue

with accelerator.accumulate(unet):
# Convert images to latent space
images = batch["pixel_values"].to(weight_dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -595,30 +595,33 @@ def collate_fn(examples):
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])

resume_global_step = global_step * args.gradient_accumulation_steps
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)

# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
else:
initial_global_step = 0

progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)

clip_mean = clip_mean.to(weight_dtype).to(accelerator.device)
clip_std = clip_std.to(weight_dtype).to(accelerator.device)

for epoch in range(first_epoch, args.num_train_epochs):
prior.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue

with accelerator.accumulate(prior):
# Convert images to latent space
text_input_ids, text_mask, clip_images = (
Expand Down
Loading

0 comments on commit d56825e

Please sign in to comment.