diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 369823a56a03..c71e7c29b023 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -479,7 +479,7 @@ def main(): elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - # Move unet, vae and text_encoder to device and cast to weight_dtype + # Move vae and text_encoder to device and cast to weight_dtype vae.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype)