From 81419f7f320930a0e3deb330f374fca6dffc671f Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Oct 2023 16:37:23 +0900 Subject: [PATCH 1/2] Fix to work training U-Net only LoRA for SD1/2 --- train_network.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/train_network.py b/train_network.py index 0e2e0fa9f..6d98037ca 100644 --- a/train_network.py +++ b/train_network.py @@ -426,7 +426,10 @@ def train(self, args): t_enc.train() # set top parameter requires_grad = True for gradient checkpointing works - t_enc.text_model.embeddings.requires_grad_(True) + if train_text_encoder: + t_enc.text_model.embeddings.requires_grad_(True) + else: + unet.parameters().__next__().requires_grad_(True) else: unet.eval() for t_enc in text_encoders: From 4cc919607a55efa038cc1f4866ee09dcca39e6b5 Mon Sep 17 00:00:00 2001 From: Kohya S Date: Sun, 1 Oct 2023 16:41:48 +0900 Subject: [PATCH 2/2] fix placing of requires_grad_ of U-Net --- train_network.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/train_network.py b/train_network.py index 6d98037ca..1a1713259 100644 --- a/train_network.py +++ b/train_network.py @@ -12,10 +12,13 @@ from tqdm import tqdm import torch + try: import intel_extension_for_pytorch as ipex + if torch.xpu.is_available(): from library.ipex import ipex_init + ipex_init() except Exception: pass @@ -428,8 +431,10 @@ def train(self, args): # set top parameter requires_grad = True for gradient checkpointing works if train_text_encoder: t_enc.text_model.embeddings.requires_grad_(True) - else: - unet.parameters().__next__().requires_grad_(True) + + # set top parameter requires_grad = True for gradient checkpointing works + if not train_text_encoder: # train U-Net only + unet.parameters().__next__().requires_grad_(True) else: unet.eval() for t_enc in text_encoders: