diff --git a/train_network.py b/train_network.py index 7ba073855..a5ef36462 100644 --- a/train_network.py +++ b/train_network.py @@ -1028,6 +1028,8 @@ def remove_model(old_ckpt_name): self.sample_images(accelerator, args, None, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + clean_memory_on_device(accelerator.device) + # 指定ステップごとにモデルを保存 if args.save_every_n_steps is not None and global_step % args.save_every_n_steps == 0: accelerator.wait_for_everyone() @@ -1084,6 +1086,8 @@ def remove_model(old_ckpt_name): self.sample_images(accelerator, args, epoch + 1, global_step, accelerator.device, vae, tokenizer, text_encoder, unet) + clean_memory_on_device(accelerator.device) + # end of epoch # metadata["ss_epoch"] = str(num_train_epochs)