From bb7ffde2d3ea1f3f09b1aaf9006a6a5e9debe234 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Thu, 21 Dec 2023 15:21:08 -0800 Subject: [PATCH] Fix bug when creating the guidance embeddings using multiple GPUs. --- examples/consistency_distillation/train_lcm_distill_sd_wds.py | 3 ++- .../consistency_distillation/train_lcm_distill_sdxl_wds.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index 54d05bb5ea26..9e8df77aacbd 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -889,6 +889,7 @@ def main(args): # Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None if teacher_unet.config.time_cond_proj_dim is None: teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim + time_cond_proj_dim = teacher_unet.config.time_cond_proj_dim unet = UNet2DConditionModel(**teacher_unet.config) # load teacher_unet weights into unet unet.load_state_dict(teacher_unet.state_dict(), strict=False) @@ -1175,7 +1176,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok # 5. Sample a random guidance scale w from U[w_min, w_max] and embed it w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min - w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim) + w_embedding = guidance_scale_embedding(w, embedding_dim=time_cond_proj_dim) w = w.reshape(bsz, 1, 1, 1) # Move to U-Net device and dtype w = w.to(device=latents.device, dtype=latents.dtype) diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index e58db46c9811..da49649c918c 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -948,6 +948,7 @@ def main(args): # Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None if teacher_unet.config.time_cond_proj_dim is None: teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim + time_cond_proj_dim = teacher_unet.config.time_cond_proj_dim unet = UNet2DConditionModel(**teacher_unet.config) # load teacher_unet weights into unet unet.load_state_dict(teacher_unet.state_dict(), strict=False) @@ -1273,7 +1274,7 @@ def compute_embeddings( # 5. Sample a random guidance scale w from U[w_min, w_max] and embed it w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min - w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim) + w_embedding = guidance_scale_embedding(w, embedding_dim=time_cond_proj_dim) w = w.reshape(bsz, 1, 1, 1) # Move to U-Net device and dtype w = w.to(device=latents.device, dtype=latents.dtype)