diff --git a/examples/consistency_distillation/train_lcm_distill_sd_wds.py b/examples/consistency_distillation/train_lcm_distill_sd_wds.py index 54d05bb5ea26..9e8df77aacbd 100644 --- a/examples/consistency_distillation/train_lcm_distill_sd_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sd_wds.py @@ -889,6 +889,7 @@ def main(args): # Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None if teacher_unet.config.time_cond_proj_dim is None: teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim + time_cond_proj_dim = teacher_unet.config.time_cond_proj_dim unet = UNet2DConditionModel(**teacher_unet.config) # load teacher_unet weights into unet unet.load_state_dict(teacher_unet.state_dict(), strict=False) @@ -1175,7 +1176,7 @@ def compute_embeddings(prompt_batch, proportion_empty_prompts, text_encoder, tok # 5. Sample a random guidance scale w from U[w_min, w_max] and embed it w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min - w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim) + w_embedding = guidance_scale_embedding(w, embedding_dim=time_cond_proj_dim) w = w.reshape(bsz, 1, 1, 1) # Move to U-Net device and dtype w = w.to(device=latents.device, dtype=latents.dtype) diff --git a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py index e58db46c9811..da49649c918c 100644 --- a/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py +++ b/examples/consistency_distillation/train_lcm_distill_sdxl_wds.py @@ -948,6 +948,7 @@ def main(args): # Add `time_cond_proj_dim` to the student U-Net if `teacher_unet.config.time_cond_proj_dim` is None if teacher_unet.config.time_cond_proj_dim is None: teacher_unet.config["time_cond_proj_dim"] = args.unet_time_cond_proj_dim + time_cond_proj_dim = teacher_unet.config.time_cond_proj_dim unet = UNet2DConditionModel(**teacher_unet.config) # load teacher_unet weights into unet unet.load_state_dict(teacher_unet.state_dict(), strict=False) @@ -1273,7 +1274,7 @@ def compute_embeddings( # 5. Sample a random guidance scale w from U[w_min, w_max] and embed it w = (args.w_max - args.w_min) * torch.rand((bsz,)) + args.w_min - w_embedding = guidance_scale_embedding(w, embedding_dim=unet.config.time_cond_proj_dim) + w_embedding = guidance_scale_embedding(w, embedding_dim=time_cond_proj_dim) w = w.reshape(bsz, 1, 1, 1) # Move to U-Net device and dtype w = w.to(device=latents.device, dtype=latents.dtype)