From 2be64b3655e9211010cce610ce0d18a61cc147f1 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 21 Apr 2024 17:41:32 +0900 Subject: [PATCH] disable main process check for deepspeed #1247 --- train_network.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index c99d37247..3a5255160 100644 --- a/train_network.py +++ b/train_network.py @@ -474,7 +474,8 @@ def train(self, args): # before resuming make hook for saving/loading to save/load the network weights only def save_model_hook(models, weights, output_dir): # pop weights of other models than network to save only network weights - if accelerator.is_main_process: + # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606 + if accelerator.is_main_process or args.deepspeed: remove_indices = [] for i, model in enumerate(models): if not isinstance(model, type(accelerator.unwrap_model(network))):