Skip to content

Commit

Permalink
Use DDPMScheduler instead of DDIMScheduler for noise_scheduler since PR
Browse files Browse the repository at this point in the history
huggingface#6305 has been merged.
  • Loading branch information
dg845 committed Dec 26, 2023
1 parent 100ef46 commit 60a9ea7
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions examples/add/train_add_distill_sd_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
import diffusers
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DDPMScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
Expand Down Expand Up @@ -1119,12 +1119,12 @@ def main(args):

# 1. Create the noise scheduler and the desired noise schedule.
# Enforce zero terminal SNR (see section 3.1 of ADD paper)
teacher_scheduler = DDIMScheduler.from_pretrained(
teacher_scheduler = DDPMScheduler.from_pretrained(
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
)
if not teacher_scheduler.config.rescale_betas_zero_snr:
teacher_scheduler.config["rescale_betas_zero_snr"] = True
noise_scheduler = DDIMScheduler(**teacher_scheduler.config)
noise_scheduler = DDPMScheduler(**teacher_scheduler.config)

# DDIMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
Expand Down
6 changes: 3 additions & 3 deletions examples/add/train_add_distill_sdxl_wds.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
import diffusers
from diffusers import (
AutoencoderKL,
DDIMScheduler,
DDPMScheduler,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
Expand Down Expand Up @@ -1177,12 +1177,12 @@ def main(args):

# 1. Create the noise scheduler and the desired noise schedule.
# Enforce zero terminal SNR (see section 3.1 of ADD paper)
teacher_scheduler = DDIMScheduler.from_pretrained(
teacher_scheduler = DDPMScheduler.from_pretrained(
args.pretrained_teacher_model, subfolder="scheduler", revision=args.teacher_revision
)
if not teacher_scheduler.config.rescale_betas_zero_snr:
teacher_scheduler.config["rescale_betas_zero_snr"] = True
noise_scheduler = DDIMScheduler(**teacher_scheduler.config)
noise_scheduler = DDPMScheduler(**teacher_scheduler.config)

# DDPMScheduler calculates the alpha and sigma noise schedules (based on the alpha bars) for us
alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
Expand Down

0 comments on commit 60a9ea7

Please sign in to comment.