diff --git a/train_network.py b/train_network.py index c99d37247..3a5255160 100644 --- a/train_network.py +++ b/train_network.py @@ -474,7 +474,8 @@ def train(self, args): # before resuming make hook for saving/loading to save/load the network weights only def save_model_hook(models, weights, output_dir): # pop weights of other models than network to save only network weights - if accelerator.is_main_process: + # only main process or deepspeed https://github.com/huggingface/diffusers/issues/2606 + if accelerator.is_main_process or args.deepspeed: remove_indices = [] for i, model in enumerate(models): if not isinstance(model, type(accelerator.unwrap_model(network))):