Skip to content

Commit

Permalink
Merge pull request #846 from kohya-ss/dev
Browse files Browse the repository at this point in the history
Fix to work training U-Net only LoRA for SD1/2
  • Loading branch information
kohya-ss authored Oct 1, 2023
2 parents 6bd6cd9 + 4cc9196 commit 9315524
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@

from tqdm import tqdm
import torch

try:
import intel_extension_for_pytorch as ipex

if torch.xpu.is_available():
from library.ipex import ipex_init

ipex_init()
except Exception:
pass
Expand Down Expand Up @@ -426,7 +429,12 @@ def train(self, args):
t_enc.train()

# set top parameter requires_grad = True for gradient checkpointing works
t_enc.text_model.embeddings.requires_grad_(True)
if train_text_encoder:
t_enc.text_model.embeddings.requires_grad_(True)

# set top parameter requires_grad = True for gradient checkpointing works
if not train_text_encoder: # train U-Net only
unet.parameters().__next__().requires_grad_(True)
else:
unet.eval()
for t_enc in text_encoders:
Expand Down

0 comments on commit 9315524

Please sign in to comment.