Skip to content

Commit

Permalink
Remove redundant move unet to device
Browse files Browse the repository at this point in the history
  • Loading branch information
MohamadZeina committed Dec 5, 2023
1 parent 2f97b67 commit ede94eb
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion examples/text_to_image/train_text_to_image_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,6 @@ def main():
weight_dtype = torch.bfloat16

# Move unet, vae and text_encoder to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)

Expand Down Expand Up @@ -536,6 +535,7 @@ def main():
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())

# Move unet and lora to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype)
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
Expand Down

0 comments on commit ede94eb

Please sign in to comment.