From ede94ebdfe48d1799cffa51661624c37fdb55fdd Mon Sep 17 00:00:00 2001 From: Mohamad Zeina Date: Tue, 5 Dec 2023 13:34:36 +0000 Subject: [PATCH] Remove redundant move unet to device --- examples/text_to_image/train_text_to_image_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 275f4aac0294..369823a56a03 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -480,7 +480,6 @@ def main(): weight_dtype = torch.bfloat16 # Move unet, vae and text_encoder to device and cast to weight_dtype - unet.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) @@ -536,6 +535,7 @@ def main(): unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters()) unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) + # Move unet and lora to device and cast to weight_dtype unet.to(accelerator.device, dtype=weight_dtype) if args.enable_xformers_memory_efficient_attention: if is_xformers_available():