Skip to content

Commit

Permalink
fix the bug of stop training around 4 hours.
Browse files Browse the repository at this point in the history
Signed-off-by: lawrence-cj <[email protected]>
  • Loading branch information
lawrence-cj committed Dec 4, 2024
1 parent 1b2901b commit 747918b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
1 change: 1 addition & 0 deletions diffusion/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ class TrainingConfig(BaseConfig):
load_mask_index: bool = False
snr_loss: bool = False
real_prompt_ratio: float = 1.0
training_hours: float = 10000.0
save_image_epochs: int = 1
save_model_epochs: int = 1
save_model_steps: int = 1000000
Expand Down
7 changes: 5 additions & 2 deletions train_scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,10 @@ def train(config, args, accelerator, model, optimizer, lr_scheduler, train_datal

if loss_nan_timer > 20:
raise ValueError("Loss is NaN too much times. Break here.")
if global_step % config.train.save_model_steps == 0 or (time.time() - training_start_time) / 3600 > 3.8:
if (
global_step % config.train.save_model_steps == 0
or (time.time() - training_start_time) / 3600 > config.train.training_hours
):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
os.umask(0o000)
Expand All @@ -469,7 +472,7 @@ def train(config, args, accelerator, model, optimizer, lr_scheduler, train_datal
f.write(osp.join(config.work_dir, "config.py") + "\n")
f.write(ckpt_saved_path)

if (time.time() - training_start_time) / 3600 > 3.8:
if (time.time() - training_start_time) / 3600 > config.train.training_hours:
logger.info(f"Stopping training at epoch {epoch}, step {global_step} due to time limit.")
return
if config.train.visualize and (global_step % config.train.eval_sampling_steps == 0 or (step + 1) == 1):
Expand Down

0 comments on commit 747918b

Please sign in to comment.