diff --git a/train_network.py b/train_network.py index 0e2e0fa9f..6d98037ca 100644 --- a/train_network.py +++ b/train_network.py @@ -426,7 +426,10 @@ def train(self, args): t_enc.train() # set top parameter requires_grad = True for gradient checkpointing works - t_enc.text_model.embeddings.requires_grad_(True) + if train_text_encoder: + t_enc.text_model.embeddings.requires_grad_(True) + else: + unet.parameters().__next__().requires_grad_(True) else: unet.eval() for t_enc in text_encoders: