Skip to content

Commit

Permalink
fix placing of requires_grad_ of U-Net
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Oct 1, 2023
1 parent 81419f7 commit 4cc9196
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@

from tqdm import tqdm
import torch

try:
import intel_extension_for_pytorch as ipex

if torch.xpu.is_available():
from library.ipex import ipex_init

ipex_init()
except Exception:
pass
Expand Down Expand Up @@ -428,8 +431,10 @@ def train(self, args):
# set top parameter requires_grad = True for gradient checkpointing works
if train_text_encoder:
t_enc.text_model.embeddings.requires_grad_(True)
else:
unet.parameters().__next__().requires_grad_(True)

# set top parameter requires_grad = True for gradient checkpointing works
if not train_text_encoder: # train U-Net only
unet.parameters().__next__().requires_grad_(True)
else:
unet.eval()
for t_enc in text_encoders:
Expand Down

0 comments on commit 4cc9196

Please sign in to comment.