From d56825e4b45e9c1002e5b39c2bc585fa231b918d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 2 Oct 2023 18:29:52 +0200 Subject: [PATCH] fix: how print training resume logs. (#5117) * fix: how print training resume logs. * propagate changes to text-to-image scripts. * propagate changes to instructpix2pix. * propagate changes to dreambooth * propagate changes to custom diffusion and instructpix2pix * propagate changes to kandinsky * propagate changes to textual inv. * debug * fix: checkpointing. * debug * debug * debug * back to the square * debug * debug * change condition order. * debug * debug * debug * debug * revert to original * clean --------- Co-authored-by: Patrick von Platen --- .../train_custom_diffusion.py | 24 ++++++++-------- examples/dreambooth/train_dreambooth.py | 24 ++++++++-------- examples/dreambooth/train_dreambooth_lora.py | 24 ++++++++-------- .../dreambooth/train_dreambooth_lora_sdxl.py | 23 +++++++-------- .../train_instruct_pix2pix_sdxl.py | 28 ++++++++++--------- .../train_text_to_image_decoder.py | 26 +++++++++-------- .../train_text_to_image_lora_decoder.py | 22 +++++++-------- .../train_text_to_image_lora_prior.py | 25 +++++++++-------- .../train_text_to_image_prior.py | 26 +++++++++-------- examples/text_to_image/train_text_to_image.py | 27 +++++++++--------- .../text_to_image/train_text_to_image_lora.py | 23 ++++++++------- .../train_text_to_image_lora_sdxl.py | 23 +++++++-------- .../text_to_image/train_text_to_image_sdxl.py | 26 +++++++++-------- .../textual_inversion/textual_inversion.py | 23 +++++++-------- 14 files changed, 180 insertions(+), 164 deletions(-) diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py index 60d8d6723dcf..3288fe3258ac 100644 --- a/examples/custom_diffusion/train_custom_diffusion.py +++ b/examples/custom_diffusion/train_custom_diffusion.py @@ -1075,30 +1075,30 @@ def main(args): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) - - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) for epoch in range(first_epoch, args.num_train_epochs): unet.train() if args.modifier_token is not None: text_encoder.train() for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(unet), accelerator.accumulate(text_encoder): # Convert images to latent space latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 71510b18c8a3..606cc5c6cfdd 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -1178,30 +1178,30 @@ def compute_text_embeddings(prompt): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) - - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) for epoch in range(first_epoch, args.num_train_epochs): unet.train() if args.train_text_encoder: text_encoder.train() for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(unet): pixel_values = batch["pixel_values"].to(dtype=weight_dtype) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index dc90d10f2b26..5bb6c78b7b74 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -1108,30 +1108,30 @@ def compute_text_embeddings(prompt): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) - - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) for epoch in range(first_epoch, args.num_train_epochs): unet.train() if args.train_text_encoder: text_encoder.train() for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(unet): pixel_values = batch["pixel_values"].to(dtype=weight_dtype) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 24dbf4313662..ac59bba6c847 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1048,18 +1048,25 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) for epoch in range(first_epoch, args.num_train_epochs): unet.train() @@ -1067,12 +1074,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): text_encoder_one.train() text_encoder_two.train() for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(unet): pixel_values = batch["pixel_values"].to(dtype=vae.dtype) diff --git a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py index 4d0b9bef55f1..e2d9b2105160 100644 --- a/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py +++ b/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py @@ -726,6 +726,9 @@ def preprocess_images(examples): text_encoder_1.requires_grad_(False) text_encoder_2.requires_grad_(False) + # Set UNet to trainable. + unet.train() + # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt def encode_prompt(text_encoders, tokenizers, prompt): prompt_embeds_list = [] @@ -933,29 +936,28 @@ def collate_fn(examples): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) - - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) for epoch in range(first_epoch, args.num_train_epochs): - unet.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(unet): # We want to learn the denoising process w.r.t the edited images which # are conditioned on the original image (which was edited) and the edit instruction. diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py index dd79c88f8a76..0938d38ba487 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py @@ -512,6 +512,9 @@ def deepspeed_zero_init_disabled_context_manager(): vae.requires_grad_(False) image_encoder.requires_grad_(False) + # Set unet to trainable. + unet.train() + # Create EMA for the unet. if args.use_ema: ema_unet = UNet2DConditionModel.from_pretrained(args.pretrained_decoder_model_name_or_path, subfolder="unet") @@ -727,27 +730,28 @@ def collate_fn(examples): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") for epoch in range(first_epoch, args.num_train_epochs): - unet.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(unet): # Convert images to latent space images = batch["pixel_values"].to(weight_dtype) diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py index 5e5f4b9cbf5d..91091f4d80fb 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py @@ -579,29 +579,29 @@ def collate_fn(examples): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + else: + initial_global_step = 0 - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) for epoch in range(first_epoch, args.num_train_epochs): unet.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(unet): # Convert images to latent space images = batch["pixel_values"].to(weight_dtype) diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py index d2aabb948969..3099c613df73 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py @@ -595,30 +595,33 @@ def collate_fn(examples): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + clip_mean = clip_mean.to(weight_dtype).to(accelerator.device) clip_std = clip_std.to(weight_dtype).to(accelerator.device) + for epoch in range(first_epoch, args.num_train_epochs): prior.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(prior): # Convert images to latent space text_input_ids, text_mask, clip_images = ( diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py index b86df4de600c..74fa504345fe 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py @@ -517,6 +517,9 @@ def deepspeed_zero_init_disabled_context_manager(): text_encoder.requires_grad_(False) image_encoder.requires_grad_(False) + # Set prior to trainable. + prior.train() + # Create EMA for the prior. if args.use_ema: ema_prior = PriorTransformer.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior") @@ -741,32 +744,31 @@ def collate_fn(examples): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + else: + initial_global_step = 0 - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) clip_mean = clip_mean.to(weight_dtype).to(accelerator.device) clip_std = clip_std.to(weight_dtype).to(accelerator.device) for epoch in range(first_epoch, args.num_train_epochs): - prior.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(prior): # Convert images to latent space text_input_ids, text_mask, clip_images = ( diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 535942629314..03efc3fa13c5 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -577,9 +577,10 @@ def deepspeed_zero_init_disabled_context_manager(): args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision ) - # Freeze vae and text_encoder + # Freeze vae and text_encoder and set unet to trainable vae.requires_grad_(False) text_encoder.requires_grad_(False) + unet.train() # Create EMA for the unet. if args.use_ema: @@ -854,29 +855,29 @@ def collate_fn(examples): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) for epoch in range(first_epoch, args.num_train_epochs): - unet.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(unet): # Convert images to latent space latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample() diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index d4d13a144f38..6ae157bce9b3 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -429,7 +429,6 @@ def main(): # freeze parameters of models to save more memory unet.requires_grad_(False) vae.requires_grad_(False) - text_encoder.requires_grad_(False) # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision @@ -690,29 +689,29 @@ def collate_fn(examples): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) + else: + initial_global_step = 0 - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) for epoch in range(first_epoch, args.num_train_epochs): unet.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(unet): # Convert images to latent space latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 991f8a84a243..d685d468db4d 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -947,18 +947,25 @@ def collate_fn(examples): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) for epoch in range(first_epoch, args.num_train_epochs): unet.train() @@ -967,12 +974,6 @@ def collate_fn(examples): text_encoder_two.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(unet): # Convert images to latent space if args.pretrained_vae_model_name_or_path is not None: diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 649b82ed3baa..4cf966af77cd 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -657,6 +657,8 @@ def main(args): vae.requires_grad_(False) text_encoder_one.requires_grad_(False) text_encoder_two.requires_grad_(False) + # Set unet as trainable. + unet.train() # For mixed precision training we cast all non-trainable weigths to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. @@ -967,29 +969,29 @@ def collate_fn(examples): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) for epoch in range(first_epoch, args.num_train_epochs): - unet.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(unet): # Sample noise that we'll add to the latents model_input = batch["model_input"].to(accelerator.device) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 2e6f9a7d9522..01830751ffe2 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -809,18 +809,25 @@ def main(): f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." ) args.resume_from_checkpoint = None + initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) - resume_global_step = global_step * args.gradient_accumulation_steps + initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch - resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) # keep original embeddings as reference orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone() @@ -828,12 +835,6 @@ def main(): for epoch in range(first_epoch, args.num_train_epochs): text_encoder.train() for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % args.gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - with accelerator.accumulate(text_encoder): # Convert images to latent space latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()